#pragma once

#include "stdafx.h"
#include "SocketX.h"

//////////////////////////////////////////////////////////////////////////

IMPLEMENT_DYNAMIC(CServerSocket, IBaseSocket)

CServerSocket::CServerSocket() :
	m_fAccept(TRUE),
	m_fInternet(TRUE)
{
}

CServerSocket::~CServerSocket()
{
	CloseAllSocketConnection();
}

void CServerSocket::OnSend(SOCKET hSocket, int nErrorCode)
{
	if (nErrorCode != 0)
	{
		TRACE1("***** ERROR: OnSend(%d) *****\n", nErrorCode);
		return;
	}

	m_SockMsg.pTgtObj->OnSocketSendMessage(m_SockMsg.dwSocketID);
}

void CServerSocket::OnReceive(SOCKET hSocket, int nErrorCode)
{
	CByteArray data, buf;
	SOCKADDR   SockAddr;
	NETADDR	   NetAddr;
	int  nRet, nIndex;

	if (nErrorCode != 0)
	{
		TRACE1("***** ERROR: OnReceive(%d) *****\n", nErrorCode);
		return;
	}

	nIndex = SearchSocketConnection(hSocket);

	if (nIndex == -1)
	{
		nRet = 0;
		buf.SetSize(SX_TCP_MINBUFSIZE);

		while (nRet != SOCKET_ERROR)
		{
			nRet = recv(hSocket, reinterpret_cast<LPSTR>(&buf[0]), buf.GetSize(), 0);
		}

		closesocket(hSocket);
		TRACE0("***** ERROR: Receive(Socket not found) *****\n");
		return;
	}

	if (!doGetPeerName(hSocket, &SockAddr) || !doReceive(hSocket, data))
	{
		return;
	}

	ConvertSockAddrToNetAddr(&SockAddr, NetAddr);
	m_SockMsg.pTgtObj->OnSocketReceiveMessage(m_SockMsg.dwSocketID, NetAddr, data);
}

void CServerSocket::OnAccept(SOCKET hSocket, int nErrorCode)
{
	SOCKADDR SockAddr;
	NETADDR  NetAddr;
	int nSockLen = sizeof(SOCKADDR);

	if (nErrorCode != 0)
	{
		TRACE1("***** ERROR: OnAccept(%d) *****\n", nErrorCode);
		return;
	}

	if (!Accept(&SockAddr, &nSockLen))
	{
		return;
	}

	if (!doSetKeepAlive(hSocket, SX_KA_TIME, SX_KA_INTERVAL))
	{
		doCloseSocket(hSocket);
		return;
	}

	ConvertSockAddrToNetAddr(&SockAddr, NetAddr);
	m_SockMsg.pTgtObj->OnSocketAcceptMessage(m_SockMsg.dwSocketID, NetAddr);
}

void CServerSocket::OnClose(SOCKET hSocket, int nErrorCode)
{
	NETADDR  NetAddr;

	// WinSock Error(10053) Keep-Alive timeout
	if (nErrorCode != 10053 && nErrorCode != 0)
	{
		TRACE1("***** ERROR: OnClose(%d) *****\n", nErrorCode);
		return;
	}

	if (!CloseSocketConnection(hSocket, &NetAddr))
	{
		TRACE("***** ERROR: OnClose(Socket not found) *****\n");
		return;
	}

	m_SockMsg.pTgtObj->OnSocketCloseMessage(m_SockMsg.dwSocketID, NetAddr);
}

BOOL CServerSocket::CreateSocket(SOCKMSG SockMsg)
{
	if (!IBaseSocket::CreateSocketWindow())
	{
		return FALSE;
	}

	SetTargetWnd(SockMsg);

	m_hSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);

	if (m_hSocket == INVALID_SOCKET)
	{
		TRACE1("***** ERROR: socket(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	if (!AsyncSelect(FD_SERVER))
	{
		return FALSE;
	}

	SetSendTimeout(m_dwSendTimeout);
	SetRecieveTimeout(m_dwReceiveTimeout);

	return TRUE;
}

BOOL CServerSocket::Initialize(SOCKMSG SockMsg, WORD wPort, DWORD dwBindAddress, int nBacklog)
{
	BOOL flag = TRUE;
	BOOL fSend, fReceive;

	if (!CreateSocket(SockMsg))
	{
		return FALSE;
	}

	if (!(Bind(wPort, dwBindAddress) && SetReuseAddr(TRUE) && Listen(nBacklog)))
	{
		CloseSocket();
		return FALSE;
	}

	// TCP Option(no delay)
	if (!SetSockOpt(TCP_NODELAY, (LPSTR)&flag, sizeof(flag)))
	{
		CloseSocket();
		return FALSE;
	}

	// buffer size
	fSend = fReceive = FALSE;

	for (UINT nBufSize = SX_TCP_MAXBUFSIZE; nBufSize > 0; nBufSize /= 2)
	{
		if (!fSend && SetSendBufferSize(nBufSize))
		{
			fSend = TRUE;
		}

		if (!fReceive && SetReceiveBufferSize(nBufSize))
		{
			fReceive = TRUE;
		}

		if (fSend && fReceive && TRUE)
		{
			break;
		}
	}

	m_fInit = TRUE;

	return TRUE;
}

BOOL CServerSocket::Listen(int nBacklog)
{
	if (listen(m_hSocket, nBacklog) == SOCKET_ERROR)
	{
		TRACE1("***** ERROR: listen(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	return TRUE;
}

BOOL CServerSocket::Accept(LPSOCKADDR lpSockAddr, int *lpSockAddrLen)
{
	SOCKET hSocket;
	BOOL   flag = TRUE;
	BOOL   fSend, fReceive;

	hSocket = accept(m_hSocket, lpSockAddr, lpSockAddrLen);

	if (hSocket == INVALID_SOCKET)
	{
		TRACE1("***** ERROR: accept(%d) *****\n", GetLastError());
		m_dwError = GetLastError();
		return FALSE;
	}

	if (!m_fInternet && !IsLANConnection(lpSockAddr))
	{
		doCloseSocket(hSocket);
		return FALSE;
	}

	// TCP Option(no delay) 
	if (!doSetSockOpt(hSocket, SOL_SOCKET, TCP_NODELAY, (LPSTR)&flag, sizeof(flag)))
	{
		m_dwError = GetLastError();
		doCloseSocket(hSocket);
		return FALSE;
	}

	// buffer size
	fSend = fReceive = FALSE;

	for (UINT nBufSize = SX_TCP_MAXBUFSIZE; nBufSize > 0; nBufSize /= 2)
	{
		if (!fSend && doSetSendBufferSize(hSocket, nBufSize))
		{
			fSend = TRUE;
		}

		if (!fReceive && doSetReceiveBufferSize(hSocket, nBufSize))
		{
			fReceive = TRUE;
		}

		if (fSend && fReceive && TRUE)
		{
			break;
		}
	}

	if (!doAsyncSelect(hSocket, SM_EVENT, FD_CLIENT))
	{
		m_dwError = GetLastError();
		doCloseSocket(hSocket);
		return FALSE;
	}

	AddConnection(hSocket, lpSockAddr, *lpSockAddrLen);

	return TRUE;
}

void CServerSocket::AddConnection(SOCKET hSocket, const LPSOCKADDR lpSockAddr, int nSockAddrLen)
{
	SOCKETDATA  sd;
	ZeroMemory(&sd, sizeof(sd));
	sd.hSocket	 = hSocket;
	ConvertSockAddrToNetAddr(lpSockAddr, sd.NetAddr);
	CSingleLock cl(&m_sdList.cs, TRUE);
	m_sdList.list.Add(sd);
}

BOOL CServerSocket::CloseSocketConnection(SOCKET hSocket, LPNETADDR lpNetAddr)
{
	int nIndex = SearchSocketConnection(hSocket);

	if (nIndex != -1)
	{
		CSingleLock sl(&m_sdList.cs, TRUE);
		doCloseSocket(m_sdList.list[nIndex].hSocket);
		*lpNetAddr = m_sdList.list[nIndex].NetAddr;
		m_sdList.list.RemoveAt(nIndex);
		return TRUE;
	}

	TRACE0("***** ERROR: CloseSocketConnection(Socket not found) *****\n");

	return FALSE;
}

void CServerSocket::CloseAllSocketConnection()
{
	UINT nSize;

	CSingleLock sl(&m_sdList.cs, TRUE);
	nSize = m_sdList.list.GetSize();

	for (UINT i = 0; i < nSize; i++)
	{
		doShutdown(m_sdList.list[i].hSocket, SD_BOTH);
		doCloseSocket(m_sdList.list[i].hSocket);
	}

	m_sdList.list.RemoveAll();
}

int CServerSocket::SearchSocketConnection(SOCKET hSocket)
{
	CSingleLock sl(&m_sdList.cs, TRUE);
	UINT nSize = m_sdList.list.GetSize();
	int nRet = -1;

	for (UINT i = 0; i < nSize; i++)
	{
		if (m_sdList.list[i].hSocket == hSocket)
		{
			nRet = i; break;
		}
	}

	return nRet;
}

BOOL CServerSocket::Broadcast(const CByteArray &data)
{
	CSingleLock sl(&m_sdList.cs, TRUE);

	BOOL fRet;
	BOOL fResult   = TRUE;
	UINT nDataSize = data.GetSize();
	UINT nListSize = static_cast<UINT>(m_sdList.list.GetSize());

	if (nListSize < 1)
	{
		return FALSE;
	}

	for (UINT i = 0; i < nListSize; i++)
	{
		fRet = doSend(m_sdList.list[i].hSocket, data);

		if (!fRet)
		{
			fResult = FALSE;
			TRACE1("***** ERROR: Broadcast(Dst: %s) *****\n", 
				   DwToIPAddress(m_sdList.list[i].NetAddr.dwAddress));
		}
	}

	return fResult;
}

SOCKET CServerSocket::GetSocketHandle(DWORD dwAddress, WORD wPort)
{
	CSingleLock sl(&m_sdList.cs, TRUE);
	SOCKET hSocket = INVALID_SOCKET;
	UINT nListSize = static_cast<UINT>(m_sdList.list.GetSize());

	for (UINT i = 0; i < nListSize; i++)
	{
		if ((m_sdList.list[i].NetAddr.dwAddress == dwAddress) &&
			(m_sdList.list[i].NetAddr.wPort == wPort))
		{
			hSocket = m_sdList.list[i].hSocket;
			break;
		}
	}

	return hSocket;
}

BOOL CServerSocket::SendToClient(const CByteArray &data, DWORD dwAddress, WORD wPort)
{
	SOCKET hSocket = GetSocketHandle(dwAddress, wPort);

	if (hSocket == INVALID_SOCKET)
	{
		TRACE0("***** ERROR: SendToClient(Socket not found) *****\n");
		return FALSE;
	}

	return doSend(hSocket, data);
}

void CServerSocket::SetAccept(BOOL fAccept)
{
	if (m_fAccept != fAccept)
	{
		m_fAccept = fAccept;
		m_lEvent  = m_fAccept ? (m_lEvent | FD_ACCEPT) : (m_lEvent ^ FD_ACCEPT);
		AsyncSelect(m_lEvent);
	}
}

//////////////////////////////////////////////////////////////////////////
