#include "AsyncTCPSocket.h"

#if !defined(WIN32)
#include <sys/errno.h>
#else
#define EINPROGRESS WSAEINPROGRESS
#endif // !WIN32

AsyncTCPSocket::AsyncTCPSocket()
    : TCPSocket(), mSelector(NULL),
    mWaitRead(false), mWaitWrite(false)
{
}

AsyncTCPSocket::AsyncTCPSocket(int inDescriptor, struct sockaddr_in* inRemoteAddress)
    : TCPSocket(inDescriptor, inRemoteAddress), mSelector(NULL),
    mWaitRead(true), mWaitWrite(false) 
{
}

AsyncTCPSocket::~AsyncTCPSocket()
{
    Shutdown();
    Close();
}

void AsyncTCPSocket::Close()
{
    if (mDescriptor != SOCKET_INVALID_DESCRIPTOR)
    {
        mWaitRead = false;
        mWaitWrite = false;
    
        TCPSocket::Close();
        OnClose();
    }
}

void AsyncTCPSocket::Bind(u_int16_t inPort, u_int32_t inAddress)
{
    TCPSocket::Bind(inPort, inAddress);
    SetNonBlock(true);
}

void AsyncTCPSocket::Listen(u_int32_t inMaxPendingConnections)
{
    TCPSocket::Listen(inMaxPendingConnections);
    mWaitRead = true;
}

int AsyncTCPSocket::Send(const void *inBytes, u_int32_t inCount) throw (socket_error)
{
	int sent = 0;
    if (!mWaitWrite)
    {
        // Send the data
        sent = TCPSocket::Send(inBytes, inCount);
        if (sent == -1)
        {
            if (errno == EAGAIN)
            {
                sent = 0;
                DEBUG_CALL(printf("send would block\n"));
                mWaitWrite = true;
            }
            else
            {
                throw socket_error(errno);
            }
        }
        else if ((u_int32_t)sent != inCount)
        {
            DEBUG_CALL(printf("assuming send would block\n"));
            mWaitWrite = true;
        }
    }
    return sent;
}

void AsyncTCPSocket::Shutdown()
{
    TCPSocket::Shutdown();
    mWaitRead = false;
    mWaitWrite = false;
    
    if (mSelector)
    {
        mSelector->RemoveSocket(*this);
        mSelector = NULL;
    }
}

bool AsyncTCPSocket::SetNonBlock(bool nonBlock)
{
    if (nonBlock == false && mSelector)
    {
        mSelector->RemoveSocket(*this);
        mSelector = NULL;
    }
    return TCPSocket::SetNonBlock(nonBlock);
}

void AsyncTCPSocket::Connect(const string &inAddress, u_int16_t inPort)
{
	AsyncDNS dns;
	mSavedPort = inPort;
	dns.HostByName(inAddress, this);
	// this resumes in HostLookup callback
}

void AsyncTCPSocket::Connect(const u_int32_t inAddress, const u_int16_t inPort)
{
	SetNonBlock(true);

	struct sockaddr_in addr;
	addr.sin_family = AF_INET;
	addr.sin_addr.s_addr = htonl(inAddress);
	addr.sin_port = htons(inPort);
	if (connect(mDescriptor, (struct sockaddr *)&addr, sizeof(addr)) < 0)
	{
#ifdef WIN32
		if (WSAGetLastError() != WSAEWOULDBLOCK)
			throw socket_error(WSAGetLastError());
#else
		if (errno != EINPROGRESS)
			throw socket_error(errno);
#endif
	}
	mConnected = false;
	mRemoteHost = inAddress;
	mRemotePort = inPort;
	//mWaitRead = true;
	mWaitWrite = true;
}

void AsyncTCPSocket::HostLookup(const string &inName, const struct in_addr *inAddr)
{
	(void)inName;
	try
	{
		if (inAddr)
			AsyncTCPSocket::Connect(ntohl(inAddr->s_addr), mSavedPort);
		else
		{
			DEBUG_CALL(printf("host lookup failed: %s\n", inName.c_str()));
		}
	}
	catch (socket_error &err)
	{
		DEBUG_CALL(printf("connect error: %s\n", err.what()));
	}
}

#if 0
u_int32_t AsyncTCPSocket::LocalHost()
{
    struct sockaddr_in myaddr;
	int len = sizeof(myaddr);
	getsockname(mDescriptor, (sockaddr *)&myaddr, &len);	
	return myaddr.sin_addr.s_addr;
}

string AsyncTCPSocket::LocalHostName()
{
	struct sockaddr_in myaddr;
	int len = sizeof(myaddr);
	getsockname(mDescriptor, (sockaddr *)&myaddr, &len);
    // inet_ntoa isn't thread safe!!!
    return string(inet_ntoa(myaddr.sin_addr));
}
#endif

