package net.osdn.util.ssdp.client;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.ConnectException;
import java.net.DatagramPacket;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InterfaceAddress;
import java.net.MulticastSocket;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.net.URLConnection;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import net.osdn.util.ssdp.Device;

public class SsdpClient {
	
	private static final Charset UTF_8 = Charset.forName("UTF-8");
	private static final int RECV_TIMEOUT = 5000;
	private static final int SEND_INTERVAL = 1200;
	
	private int timeout;
	private DiscoveryCallback callback;
	private MulticastSocket socket;
	private Thread sendThread;
	private Thread recvThread;
	private Set<Device> foundDevices = new HashSet<Device>();
	private ExecutorService executor;
	private List<Future<Device>> futures = new ArrayList<Future<Device>>();
	private boolean isSendRunning;
	private boolean isRecvRunning;
	private boolean isCancelRequested;
	private List<Device> list;
	
	public SsdpClient() {
	}
	

	/** 探索(discover)をキャンセルします。
	 * 
	 */
	public void cancel() {
		if(!isCancelRequested && (isSendRunning || isRecvRunning)) {
			isCancelRequested = true;

			if(socket != null) {
				try {
					socket.close();
				} catch(Exception e) {
					e.printStackTrace();
				} finally {
					socket = null;
				}
			}

			sendThread.interrupt();
			recvThread.interrupt();
		}
	}
	
	/** 探索(discover)が完了するまで待機して、見つかったデバイスのリストを返します。
	 * 
	 * @return 見つかったデバイスのリスト
	 */
	public List<Device> get() {
		
		if(list != null) {
			return list;
		}
		
		if(isSendRunning) {
			try {
				sendThread.join();
			} catch (InterruptedException e) {
				e.printStackTrace();
			} finally {
				sendThread = null;
			}
		}
		
		if(isRecvRunning) {
			try {
				recvThread.join();
			} catch (InterruptedException e) {
				e.printStackTrace();
			} finally {
				recvThread = null;
			}
			
		}
			
		if(socket != null) {
			try {
				socket.close();
			} catch(Exception e) {
				e.printStackTrace();
			} finally {
				socket = null;
			}
		}
			
		if(futures != null) {
			list = new ArrayList<Device>();
			for(Future<Device> future : futures) {
				Device device = null;
				try {
					device = future.get();
				} catch (InterruptedException e) {
					e.printStackTrace();
				} catch (ExecutionException e) {
					Throwable cause = e.getCause();
					if(cause instanceof ConnectException) {
						// System.err.println(cause.getClass().getName() + ": " + cause.getMessage());
					} else {
						e.printStackTrace();
					}
				}
				
				if(device != null) {
					list.add(device);
				}
			}
		}
		
		return list;
	}
	
	public void discover() throws IOException {
		discover(RECV_TIMEOUT, null);
	}
	
	public void discover(int timeout) throws IOException {
		discover(timeout, null);
	}
	
	public void discover(DiscoveryCallback callback) throws IOException {
		discover(RECV_TIMEOUT, callback);
	}
	
	public void discover(int timeout, DiscoveryCallback callback) throws IOException {
		if(isSendRunning || isRecvRunning) {
			return;
		}
		isCancelRequested = false;
		
		this.timeout = (timeout > RECV_TIMEOUT) ? timeout : RECV_TIMEOUT;
		this.callback = callback;
		this.list = null;
		
		socket = new  MulticastSocket(0);
			
		sendThread = new Thread() {
			@Override
			public void run() {
				try {
					send();
				} catch(Exception e) {
					if(!isCancelRequested) {
						throw new RuntimeException(e);
					}
				}
				isSendRunning = false;
			}
		};
		sendThread.start();
		isSendRunning = true;
			
		recvThread = new Thread() {
			@Override
			public void run() {
				try {
					recv();
				} catch(Exception e) {
					if(!isCancelRequested) {
						throw new RuntimeException(e);
					}
				}
				isRecvRunning = false;
			};
		};
		recvThread.start();
		isRecvRunning = true;
	}
	
	protected void send() throws InterruptedException {
		String strIpv4mcastaddr = "239.255.255.250";
		String strIpv6mcastaddr = "FF02::C";
		int port = 1900;
		String query = "M-SEARCH * HTTP/1.1\r\n"
						   + "Host:${HOST}:${PORT}\r\n"
						   + "ST:upnp:rootdevice\r\n"
						   + "Man:\"ssdp:discover\"\r\n"
						   + "MX:2\r\n"
						   + "\r\n";
		
		int count = (timeout - 3000) / SEND_INTERVAL;
		if(count < 1) {
			count = 1;
		}
		
		InetAddress ipv4mcastaddr = null;
		DatagramPacket ipv4packet = null;
		try {
			byte[] buf = query.replace("${HOST}", strIpv4mcastaddr).replace("${PORT}", Integer.toString(port)).getBytes();
			ipv4mcastaddr = InetAddress.getByName(strIpv4mcastaddr);
			ipv4packet = new DatagramPacket(buf, 0, buf.length, ipv4mcastaddr, port);
		} catch(Exception e) {
			e.printStackTrace();
		}
		
		InetAddress ipv6mcastaddr = null;
		DatagramPacket ipv6packet = null;
		try {
			byte[] buf = query.replace("${HOST}", "[" + strIpv6mcastaddr + "]").replace("${PORT}", Integer.toString(port)).getBytes();
			ipv6mcastaddr = InetAddress.getByName(strIpv6mcastaddr);
			ipv6packet = new DatagramPacket(buf, 0, buf.length, ipv6mcastaddr, port);
		} catch(Exception e) {
			e.printStackTrace();
		}

		for(int i = 0; i < count; i++) {
			try {
				Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
				while(interfaces.hasMoreElements()) {
					NetworkInterface netIf = interfaces.nextElement();
					if(netIf.isLoopback() || netIf.isVirtual() || !netIf.isUp()) {
						continue;
					}
					for(InterfaceAddress ifAddr : netIf.getInterfaceAddresses()) {
						InetAddress addr = ifAddr != null ? ifAddr.getAddress() : null;
						if(ipv4mcastaddr != null && addr instanceof Inet4Address) {
							System.out.println("# SSDP: send: IPv4: netIf=" + netIf + ", ifAddr=" + ifAddr + ", addr=" + addr + ", isLoopback=" + netIf.isLoopback() + ", isVirtual=" + netIf.isVirtual() + ", netIf.isUp=" + netIf.isUp());
							socket.setInterface(addr);
							socket.send(ipv4packet);
						}
						if(addr instanceof Inet6Address) {
							System.out.println("# SSDP: send: IPv6: netIf=" + netIf + ", ifAddr=" + ifAddr + ", addr=" + addr + ", isLoopback=" + netIf.isLoopback() + ", isVirtual=" + netIf.isVirtual() + ", netIf.isUp=" + netIf.isUp());
							socket.setInterface(addr);
							socket.send(ipv6packet);
						}
					}
				}
				if(i + 1 < count) {
					Thread.sleep(SEND_INTERVAL);
				}
			} catch(IOException e) {
				e.printStackTrace();
			}
		}
	}
	
	protected void recv() throws InterruptedException, IOException {
		if(executor == null || executor.isShutdown()) {
			executor = Executors.newCachedThreadPool();
		}
		futures.clear();
		
		List<InetAddress> localAddresses = null;
		try {
			localAddresses = getLocalAddresses();
		} catch(SocketException e) {
			e.printStackTrace();
		}
		
		long startTime = System.currentTimeMillis();
		byte[] buf = new byte[1024];
		DatagramPacket packet = new DatagramPacket(buf, buf.length);
		for(;;) {
			try {
				int d = (int)(System.currentTimeMillis() - startTime);
				if(timeout - d <= 0) {
					break;
				}
				socket.setSoTimeout(timeout - d);
				socket.receive(packet);
			} catch(SocketTimeoutException e) {
				break;
			}
			if(packet.getLength() > 0 && packet.getData()[0] == 'H') {
				SearchResponse response = new SearchResponse(packet);
				if(response.getStatus() == 200) {
					final InetAddress deviceAddr = packet.getAddress();
					final String location = response.getHeader("Location");
					final DiscoveryCallback dc = callback;
					final boolean isSelf = (localAddresses != null && localAddresses.contains(deviceAddr));
					Future<Device> future = executor.submit(new Callable<Device>() {
						@Override
						public Device call() throws Exception {
							URL url = new URL(location);
							URLConnection connection = null;
							InputStream in = null;
							BufferedReader r = null;
							try {
								connection = url.openConnection();
								in = connection.getInputStream();
								r = new BufferedReader(new InputStreamReader(in, UTF_8));
								Device device = Device.parse(r);
								device.setAddress(deviceAddr);
								device.setSelf(isSelf);
								if(foundDevices.contains(device)) {
									return null;
								}
								foundDevices.add(device);
								if(dc != null) {
									dc.discovered(device);
								}
								return device;
							} finally {
								if(r != null) {
									try { r.close(); } catch(Exception e) {}
								}
								if(in != null) {
									try { in.close(); } catch(Exception e) {}
								}
							}
						}
					});
					futures.add(future);
				}
			}
		}
	}
	
	protected void shutdown() {
		
	}
	
	
	/** ローカルアドレスのリストを取得します。
	 * 
	 * @return
	 * @throws SocketException
	 */
	private static List<InetAddress> getLocalAddresses() throws SocketException {
		List<InetAddress> addresses = new ArrayList<InetAddress>();
		
		Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
		while(interfaces.hasMoreElements()) {
			NetworkInterface network = interfaces.nextElement();
			for(InterfaceAddress ifAddr : network.getInterfaceAddresses()) {
				InetAddress addr = ifAddr.getAddress();
				if(addr != null) {
					addresses.add(addr);
				}
			}
		}
		
		return addresses;
	}
}
