package net.osdn.util.ssdp.server;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.DatagramPacket;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.InterfaceAddress;
import java.net.MulticastSocket;
import java.net.NetworkInterface;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import net.osdn.util.ssdp.Device;

public class SsdpServer extends NanoHTTPD {
	private static final int PORT = 1980;
	private static final Charset UTF_8 = Charset.forName("UTF-8");
	
	private Device device;
	private String server;
	private volatile boolean isRunning;
	private Thread ssdpThread;
	private MulticastSocket ssdpSocket;

	public SsdpServer(Device device) {
		this(device, null);
	}
	
	public SsdpServer(Device device, String server) {
		super(PORT);
		
		this.device = device;
		
		if(server != null) {
			this.server = server;
		} else {
			this.server = "UPnP/1.0 UPnP-Device-Host/1.0";
		}
	}

	@Override
	public void start(int timeout, boolean daemon) throws IOException {
		//Start NanoHTTPD
		super.start(timeout, daemon);
		
		//Start SSDP
		ssdpThread = new Thread(new Runnable() {
			@Override
			public void run() {
				isRunning = true;
				while(isRunning) {
					try {
						ssdpRun();
					} catch (IOException e) {
						e.printStackTrace();
					}
				}
			}
		});
		ssdpThread.setDaemon(true);
		ssdpThread.setName("SSDP Listener");
		ssdpThread.start();
	}

	@Override
	public void stop() {
		//Stop SSDP
		if(isRunning) {
			isRunning = false;
			if(ssdpSocket != null) {
				try { ssdpSocket.close(); } catch(Exception e) {}
				ssdpSocket = null;
			}
			if(ssdpThread != null) {
				ssdpThread.interrupt();
				try { ssdpThread.join(); } catch(InterruptedException e) {}
				ssdpThread = null;
			}
		}
		
		// Stop NanoHTTPD
		super.stop();
	}

	@Override
	public Response serve(IHTTPSession session) {
		String xml =
				"<?xml version=\"1.0\"?>\r\n" + 
				"<root xmlns=\"urn:schemas-upnp-org:device-1-0\">\r\n" +
				"	<specVersion>\r\n" +
				"		<major>1</major>\r\n" +
				"		<minor>0</minor>\r\n" +
				"	</specVersion>\r\n" +
				"	<device>\r\n" +
				"		<UDN>${UDN}</UDN>\r\n" + 
				"		<friendlyName>${friendlyName}</friendlyName>\r\n" +
				"		<deviceType>${deviceType}</deviceType>\r\n" +
				"		<manufacturer>${manufacturer}</manufacturer>\r\n" +
				"		<modelName>${modelName}</modelName>\r\n" +
				"		<modelNumber>${modelNumber}</modelNumber>\r\n" + 
				"		<serialNumber>${serialNumber}</serialNumber>\r\n" + 
				"		<iconList>\r\n" +
				"		</iconList>\r\n" +
				"		<serviceList>\r\n" + 
				"		</serviceList>\r\n" +
				"	</device>\r\n" +
				"</root>\r\n";
			
		String udn = device.getUdn();
		String friendlyName = device.getFriendlyName();
		String deviceType = device.getDeviceType();
		String manufacturer = device.getManufacturer();
		String modelName = device.getModelName();
		String modelNumber = device.getModelNumber();
		String serialNumber = device.getSerialNumber();
			
		if(udn == null) {
			udn = "Unknown";
		}
		if(friendlyName == null || friendlyName.length() == 0) {
			friendlyName = "Unknown";
		}
		if(deviceType == null || deviceType.length() == 0) {
			deviceType = "Unknown";
		}
		if(manufacturer == null || manufacturer.length() == 0) {
			manufacturer = "Unknown";
		}
		if(modelName == null || modelName.length() == 0) {
			modelName = "Unknown";
		}
		if(modelNumber == null || modelNumber.length() == 0) {
			modelNumber = "Unknown";
		}
		if(serialNumber == null || serialNumber.length() == 0) {
			serialNumber = "Unknown";
		}
			
		xml = xml
			.replace("${UDN}", udn)
			.replace("${friendlyName}", friendlyName)
			.replace("${deviceType}", deviceType)
			.replace("${manufacturer}", manufacturer)
			.replace("${modelName}", modelName)
			.replace("${modelNumber}", modelNumber)
			.replace("${serialNumber}", serialNumber);
			
		return newFixedLengthResponse(Response.Status.OK, "text/xml", xml);
	}
	
	protected void ssdpRun() throws IOException {
		String strIpv4mcastaddr = "239.255.255.250";
		String strIpv6mcastaddr = "FF02::C";
		int port = 1900;
		String response = "HTTP/1.1 200 OK\r\n"
						+ "ST:upnp:rootdevice\r\n"
						+ "USN:${USN}\r\n"
						+ "Location:${Location}\r\n"
						+ "OPT:\"http://schemas.upnp.org/upnp/1/0/\"; ns=01\r\n"
						+ "Server:${Server}\r\n"
						+ "\r\n";
		
		SocketAddress ipv4mcastaddr = null;
		try {
			ipv4mcastaddr = new InetSocketAddress(InetAddress.getByName(strIpv4mcastaddr), port);
		} catch(Exception e) {
			e.printStackTrace();
		}
		
		SocketAddress ipv6mcastaddr = null;
		try {
			ipv6mcastaddr = new InetSocketAddress(InetAddress.getByName(strIpv6mcastaddr), port);
		} catch(Exception e) {
			e.printStackTrace();
		}
		
		ssdpSocket = new MulticastSocket(port);

		Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
		while(interfaces.hasMoreElements()) {
			NetworkInterface netIf = interfaces.nextElement();
			for(InterfaceAddress ifAddr : netIf.getInterfaceAddresses()) {
				InetAddress addr = ifAddr != null ? ifAddr.getAddress() : null;
				if(ipv4mcastaddr != null && addr instanceof Inet4Address) {
					ssdpSocket.joinGroup(ipv4mcastaddr, netIf);
				}
				if(addr instanceof Inet6Address) {
					ssdpSocket.joinGroup(ipv6mcastaddr, netIf);
				}
			}
		}
		
		byte[] buf = new byte[1024];
		DatagramPacket request = new DatagramPacket(buf, buf.length);
		for(;;) {
			ssdpSocket.receive(request);
			Map<String, String> searchHeaders = parseSearchHeaders(request);
			if(searchHeaders == null) {
				continue;
			}
			String man = searchHeaders.get("MAN");
			if(!man.equalsIgnoreCase("\"ssdp:discover\"")) {
				continue;
			}
			String st = searchHeaders.get("ST");
			if(st == null || (!st.equalsIgnoreCase("upnp:rootdevice") && !st.equalsIgnoreCase("ssdp:all"))) {
				continue;
			}
			String host = searchHeaders.get("HOST");
			if(host == null) {
				continue;
			}
			String hostAddr;
			int i = host.lastIndexOf(':');
			if(i < 0) {
				hostAddr = host;
			} else {
				hostAddr = host.substring(0, i);
			}
			if(hostAddr.startsWith("[") && hostAddr.endsWith("]")) {
				hostAddr = hostAddr.substring(1, hostAddr.length() - 1);
			}
			
			String usn = device.getUdn() + "::upnp:rootdevice";
			String location = null;
			InetAddress localAddr = getLocalAddress(request.getAddress());
			if(localAddr instanceof Inet4Address) {
				location = "http://" + localAddr.getHostAddress() + ":" + PORT + "/";
			} else if(localAddr instanceof Inet6Address) {
				String a = localAddr.getHostAddress();
				int j = a.lastIndexOf('%');
				if(j > 0) {
					a = a.substring(0, j);
				}
				location = "http://[" + a + "]:" + PORT + "/";
			}
			
			if(location != null) {
				byte[] data = response
						.replace("${USN}", usn)
						.replace("${Location}", location)
						.replace("${Server}", server)
						.getBytes();
				DatagramPacket packet = new DatagramPacket(data, data.length, request.getAddress(), request.getPort());
				ssdpSocket.send(packet);
			}
		}
	}
	
	protected Map<String, String> parseSearchHeaders(DatagramPacket packet) throws IOException {
		BufferedReader r = null;
		try {
			r = new BufferedReader(new InputStreamReader(new ByteArrayInputStream(packet.getData(), packet.getOffset(), packet.getLength()), UTF_8));
			String header = r.readLine();
			if(header.toUpperCase().startsWith("M-SEARCH")) {
				Map<String, String> headers = new HashMap<String, String>();
				String line;
				while((line = r.readLine()) != null) {
					int i = line.indexOf(':');
					if(i > 0) {
						String key = line.substring(0,  i);
						String value = line.substring(i + 1).trim();
						headers.put(key.toUpperCase(), value);
					}
				}
				return headers;
			} else {
				return null;
			}
		} finally {
			if(r != null) {
				r.close();
			}
		}
	}
	
	private InetAddress getLocalAddress(InetAddress remoteAddr) throws SocketException {
		if(remoteAddr instanceof Inet4Address) {
			for(InterfaceAddress ifAddr : getInterfaceAddresses()) {
				InetAddress localAddr = ifAddr.getAddress();
				int networkPrefixLength = ifAddr.getNetworkPrefixLength();
				if(ifAddr.getAddress() instanceof Inet4Address) {
					byte[] b1 = remoteAddr.getAddress();
					long a1 = (long)(b1[0] & 0xFF) << 24 | (b1[1] & 0xFF) << 16 | (b1[2] & 0xFF) << 8 | (b1[3] & 0xFF);

					byte[] b2 = localAddr.getAddress();
					long a2 = (long)(b2[0] & 0xFF) << 24 | (b2[1] & 0xFF) << 16 | (b2[2] & 0xFF) << 8 | (b2[3] & 0xFF);
					
					int mask = 0xFFFFFFFF << (32 - networkPrefixLength);
					
					if((a1 & mask) == (a2 & mask)) {
						return localAddr;
					};
				}
			}
		} else if(remoteAddr instanceof Inet6Address) {
			for(InterfaceAddress ifAddr : getInterfaceAddresses()) {
				InetAddress localAddr = ifAddr.getAddress();
				int networkPrefixLength = ifAddr.getNetworkPrefixLength();
				if(ifAddr.getAddress() instanceof Inet6Address) {
					byte[] b1 = remoteAddr.getAddress();
					byte[] b2 = localAddr.getAddress();
					int bits = 0;
					int i;
					for(i = 0; i < 16; i++) {
						if(b1[i] == b2[i]) {
							bits += 8;
						} else {
							break;
						}
					}
					if(i < 15) {
						if((b1[i] & 0xFE) == (b2[i] & 0xFE)) {
							bits += 7;
						} else if((b1[i] & 0xFC) == (b2[i] & 0xFC)) {
							bits += 6;
						} else if((b1[i] & 0xF8) == (b2[i] & 0xF8)) {
							bits += 5;
						} else if((b1[i] & 0xF0) == (b2[i] & 0xF0)) {
							bits += 4;
						} else if((b1[i] & 0xE0) == (b2[i] & 0xE0)) {
							bits += 3;
						} else if((b1[i] & 0xC0) == (b2[i] & 0xC0)) {
							bits += 2;
						} else if((b1[i] & 0x80) == (b2[i] & 0x80)) {
							bits += 1;
						}
					}
					if(bits >= networkPrefixLength) {
						return localAddr;
					}
				}
			}
		}
		
		return null;
	}
	
	public static List<InterfaceAddress> getInterfaceAddresses() throws SocketException {
		List<InterfaceAddress> list = new ArrayList<InterfaceAddress>();

		Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
		while(interfaces.hasMoreElements()) {
			NetworkInterface network = interfaces.nextElement();
			for(InterfaceAddress ifAddr : network.getInterfaceAddresses()) {
				list.add(ifAddr);
			}
		}

		return list;
	}
}
