#pragma once

#include "stdafx.h"
#include "SocketX.h"

//////////////////////////////////////////////////////////////////////////

IMPLEMENT_DYNAMIC(CMcastSocket, CPeerSocket)

CMcastSocket::CMcastSocket() :
	m_fMcast(FALSE)
{
	ZeroMemory(&m_McastAddr, sizeof(m_McastAddr));
}

CMcastSocket::~CMcastSocket()
{
}

BOOL CMcastSocket::SetReceiveMulticast(DWORD dwMcastAddress)
{
	SOCKADDR_IN siAddr;
	IP_MREQ		im;

	CopyMemory(&siAddr, &m_LocalAddr, sizeof(siAddr));
	ZeroMemory(&im, sizeof(im));
	im.imr_interface.s_addr = siAddr.sin_addr.s_addr;
	im.imr_multiaddr.s_addr = dwMcastAddress;

	if (!doSetSockOpt(m_hSocket, IPPROTO_IP, IP_ADD_MEMBERSHIP, (LPSTR)&im, sizeof(im)))
	{
		TRACE1("***** ERROR: Multicast Socket(%d) *****\n", GetLastError());
		return FALSE;
	}

	return TRUE;
}

BOOL CMcastSocket::JoinMulticastGroup(DWORD dwMcastAddress, WORD wPort)
{
	NETADDR NetAddr;

	ConvertSockAddrToNetAddr(&m_LocalAddr, NetAddr);

	if (!doSetSockOpt(m_hSocket, IPPROTO_IP, IP_MULTICAST_IF, (LPSTR)&NetAddr.dwAddress, sizeof(NetAddr.dwAddress)))
	{
		TRACE1("***** ERROR: Multicast JoinGroup(%d) *****\n", GetLastError());
		return FALSE;
	}

	ZeroMemory(&m_McastAddr, sizeof(m_McastAddr));
	MakeSockAddrIN((LPSOCKADDR_IN)&m_McastAddr, dwMcastAddress, wPort);
	m_fMcast = TRUE;

	return TRUE;
}

BOOL CMcastSocket::LeaveMulticastGroup()
{
	NETADDR NetAddr;

	ConvertSockAddrToNetAddr(&m_LocalAddr, NetAddr);

	if (!doSetSockOpt(m_hSocket, IPPROTO_IP, IP_DROP_MEMBERSHIP, (LPSTR)&NetAddr.dwAddress, sizeof(NetAddr.dwAddress)))
	{
		TRACE1("***** ERROR: Multicast LeaveGroup(%d) *****\n", GetLastError());
		return FALSE;
	}

	ZeroMemory(&m_McastAddr, sizeof(m_McastAddr));
	m_fMcast = FALSE;

	return TRUE;
}

BOOL CMcastSocket::SetTTL(DWORD dwTTL)
{
	if (!doSetSockOpt(m_hSocket, IPPROTO_IP, IP_MULTICAST_TTL, (LPSTR)&dwTTL, sizeof(dwTTL)))
	{
		TRACE1("***** ERROR: Multicast TTL(%d) *****\n", GetLastError());
		return FALSE;
	}

	return TRUE;
}

BOOL CMcastSocket::Multicast(const CByteArray &data)
{
	if (!m_fMcast)
	{
		TRACE("***** ERROR: Multicast(Socket isn't Multicast socket) *****\n");
		return FALSE;
	}

	if (!SendTo(data, &m_McastAddr, sizeof(m_McastAddr)))
	{
		TRACE1("***** ERROR: Multicast(%d) *****\n", GetLastError());
		return FALSE;
	}

	return TRUE;
}

//////////////////////////////////////////////////////////////////////////