#include "sdusb_impl.h"
#include "sdusb_known.h"
#include <fstream>
#include <algorithm>
#include <functional>
#include <iterator>
#include <stdexcept>
#include <cstring>
#include <cstdio>
#include <sys/mount.h>
#include <ostream>
#include <iomanip>

namespace SimpleDetect {


// コンストラクタ
sdUSBIMPL::sdUSBIMPL(const std::string& ids_path, const std::string map_path) :
	m_ids(ids_path),
	m_map(map_path)
{}

// デストラクタ
sdUSBIMPL::~sdUSBIMPL()
{}


// public:
void sdUSBIMPL::initialize(void)
{
	m_ids.load();
	m_map.load();
}

void sdUSBIMPL::setBaseClassBasedLoadConfig  (const std::string& class_substr, load_config_t config)
{
	m_base_class_config.push_back( pattern_config_t(class_substr, config) );
}

void sdUSBIMPL::setSubClassBasedLoadConfig   (const std::string& class_substr, load_config_t config)
{
	m_base_class_config.push_back( pattern_config_t(class_substr, config) );
}

void sdUSBIMPL::setKnownClassBasedLoadConfig (known_class_t known_class, load_config_t config)
{
	m_known_class_config[ known_class ] = config;
}


void sdUSBIMPL::detect(ModuleLoader& loader, bool devices_loaded)
{
	using namespace USBKnown;
	loadBusDrivers(loader);

	if( ! devices_loaded ) {
		loadDevices();
	}
	devices_t::size_type before( m_devices.size() );

	for( devices_t::const_iterator dev_pair(m_devices.begin()), end(m_devices.end());
			dev_pair != end;
			++dev_pair ) {
		loadModules(loader, dev_pair->second);
	}

	loadDevices();
	devices_t::size_type after( m_devices.size() );

	if( before < after ) {
		// USB-USBブリッジのドライバがロードされ、USBデバイスが増えた
		// 再度detect
		detect(loader, true);	// loadDevices()は更新済みなので、再度実行しない
	}
}


// private:

namespace {
class USBFS {
public:
	USBFS(const std::string& path) :
		m_path(path),
		m_mounted(false)
	{}
	int mount(void) {		// マウント
		int retval;
		unsigned short i=1;
		while( (retval = ::mount("usbfs", m_path.c_str(), "usbfs", 0, NULL)) ) {
			// 3回までリトライ
			if( i > 3 ) return retval;
			::usleep(i<<19);
			i++;
		}
		m_mounted = true;
		return retval;
	}
	~USBFS()			// アンマウント
	{
		if( m_mounted ) ::umount(m_path.c_str());
	}
private:
	std::string m_path;
	bool m_mounted;
};
}  // noname namespace


namespace {
void loadTLine(const char* line, uint64_t* ret_usb_id)
{
	char trash[10];
	uint8_t bus  = 0;
	uint8_t lev  = 0;
	uint8_t prnt = 0;
	uint8_t port = 0;
	uint8_t cnt  = 0;

	//		T: Bus=00    Lev=00    Prnt=00   Port=00   Cnt=00
	::sscanf(line, "%s %4s%02hhx %4s%02hhx %5s%02hhx %5s%02hhx %4s%02hhx",
			trash,
			trash, &bus,
			trash, &lev,
			trash, &prnt,
			trash, &port,
			trash, &cnt
		);
	*ret_usb_id =
		static_cast<uint64_t>(bus)  << 40 |
		static_cast<uint64_t>(lev)  << 32 |
		static_cast<uint64_t>(prnt) << 16 |
		static_cast<uint64_t>(port) << 8  |
		static_cast<uint64_t>(cnt);
}

void loadPLine(const char* line, unsigned short* ret_vendor_id, unsigned short* ret_device_id)
{
	char trash[10];
	//		P: Vendor=0000 ProdID=0000 Rev= 2.06
	::sscanf(line, "%s %7s%04hx %7s%04hx",
			trash,
			trash, ret_vendor_id,
			trash, ret_device_id
		);
}

void loadILine(const char* line, uint8_t* ret_baseclass_id, uint8_t* ret_subclass_id)
{
	char trash[10];
	//		I:  If#= 0 Alt= 0 #EPs= 1 Cls=09(hub  ) Sub=00 Prot=00 Driver=hub
	::sscanf(line, "%s %s %s %s %s %s %s %4s%02hhx%s %4s%02hhx",
			trash, trash, trash, trash, trash, trash, trash,
			trash, ret_baseclass_id, trash,
			trash, ret_subclass_id
		);
}

void loadSLine(const char* line, std::string* ret_manufacturer, std::string* ret_product_name)
{
	// S:  Manufacturer=Linux 2.6.17-5mdv ohci_hcd
	// S:  Product=OHCI Host Controller
	// S:  SerialNumber=0000:00:13.0
	std::string line_string(line);
	std::string::size_type pos_eq( line_string.find('=') );
	if( pos_eq == std::string::npos || pos_eq <= 4 ) return;

	std::string before_eq( line_string.substr(4, pos_eq-4) );
	if( before_eq == "Manufacturer" ) {
		*ret_manufacturer = line_string.substr(pos_eq+1);
	} else if( before_eq == "Product" ) {
		*ret_product_name = line_string.substr(pos_eq+1);
	}
}

}  // noname namespace

void sdUSBIMPL::loadDevices(void)
{
	USBFS usbfs("/proc/bus/usb");

	std::ifstream proc_usb("/proc/bus/usb/devices");
	if( ! proc_usb.is_open() ) {
		// usbfsがマウントされていないので、マウントする
		// class usbfsのデストラクタでアンマウントされる
		usbfs.mount();
		proc_usb.open("/proc/bus/usb/devices");
		if( ! proc_usb.is_open() ) {
			// デバイス無し
			return;
		}
	}

	char buf[512];

	while( ! proc_usb.eof() ) {
		proc_usb.getline(buf, sizeof(buf));

		// Busが始まるまで待機
		if( buf[0] != 'T' ) continue;

		// 新しいBusを発見
		// T:  Bus=00 Lev=00 Prnt=00 Port=00 Cnt=00 ....
		uint64_t usb_id;
		unsigned short vendor_id = 0;
		unsigned short device_id = 0;

		uint8_t baseclass_id = 0;
		uint8_t subclass_id = 0;
		// unsigned short subvendor_id = 0;
		// unsigned short subdevice_id = 0;

		std::string manufacturer;
		std::string product_name;

		loadTLine(buf, &usb_id);

		bool bus_end = false;
		while( (!proc_usb.eof()) && !bus_end ) {
			proc_usb.getline(buf, sizeof(buf));
			switch( buf[0] ) {
			case 'P':
				// P:  Vendor=0000 ProdID=0000 Rev= 2.06
				loadPLine(buf, &vendor_id, &device_id);
				break;
			case 'I':
				// I:  If#= 0 Alt= 0 #EPs= 1 Cls=09(hub  ) Sub=00 Prot=00 Driver=hub
				loadILine(buf, &baseclass_id, &subclass_id);
				break;
			case 'S':
				// いずれか
				// S:  Manufacturer=Linux 2.6.17-5mdv ohci_hcd
				// S:  Product=OHCI Host Controller
					// S:  SerialNumber=0000:00:13.0
				loadSLine(buf, &manufacturer, &product_name);
				break;
			case '\0':
				// Bus終了
				bus_end = true;
				break;
			case 'T':
				// 次のBusが空行なしで現れた
				// ファイルポインタを1行戻してBus終了
				bus_end = true;
				proc_usb.seekg( -1 + proc_usb.tellg() - strlen(buf) );
				break;
			default:
				break;
			}
		}

		if( vendor_id == 0 ) continue;	// ホストコントローラはスキップ

		// Bus登録
		USBDevice device;
		device.vendor_id = vendor_id;
		device.vendor_id = vendor_id;
		device.device_id = device_id;
		// device.subvendor_id = subvendor_id;
		// device.subdevice_id = subdevice_id;
		device.baseclass_id = baseclass_id;
		device.subclass_id  = subclass_id;

		device.manufacturer = manufacturer;
		device.product_name = product_name;

		device.modules = m_map.getModuleName(vendor_id, device_id);

		device.loaded = false;
		m_devices[ usb_id ] = device;
	}
}


sdUSBIMPL::load_config_t sdUSBIMPL::detLoadConfig(
		const configured_list_t& configured_list,
		const std::string* p_pattern
		) const
{
	if( p_pattern != NULL ) {
		configured_list_t::const_iterator configured(
				std::find_if(
					configured_list.begin(),
					configured_list.end(),
					isMatchPattern(*p_pattern)
					)
				);
		if( configured != configured_list.end() ) {
			return configured->second;
		}
	}
	return sdUSB::AUTO;
}


void sdUSBIMPL::loadModules(ModuleLoader& loader, const USBDevice& device) const
{
	using namespace USBKnown;

	if( device.loaded ) return;

	known_class_t known_class;
	if( isKnown(device.baseclass_id, device.subclass_id, &known_class) ) {
		known_class_config_map_t::const_iterator configured_class(
				m_known_class_config.find(known_class)
				);
		if( configured_class != m_known_class_config.end() ) {
			// 設定値あり
			switch(configured_class->second) {
			case sdUSB::NEVER:
				return;		// ロードしない
			case sdUSB::FORCE:
				loadKnownDriversForce(loader, known_class);	// force load後、auto load
				// pass through
			case sdUSB::AUTO:
				loadKnownDriversAuto(loader, known_class);
				break;
			}
		} else {
			// 設定値無し
			loadKnownDriversAuto(loader, known_class);
		}
	}

	// Base Class"名"ベースのコンフィグレーション
	const std::string* baseclass_name( m_ids.getBaseClassName(device.baseclass_id) );
	switch( detLoadConfig(m_base_class_config, baseclass_name) ) {
	case sdUSB::NEVER:
		return;		// ロードしない
	case sdUSB::FORCE:
		// pass through
	case sdUSB::AUTO:
		break;
	}

	// Sub Class"名"ベースのコンフィグレーション
	const std::string* subclass_name( m_ids.getSubClassName(device.baseclass_id, device.subclass_id) );
	switch( detLoadConfig(m_base_class_config, subclass_name) ) {
	case sdUSB::NEVER:
		return;		// ロードしない
	case sdUSB::FORCE:
		// pass through
	case sdUSB::AUTO:
		break;
	}

	{  // Specific Deviceのコンフィグレーション
		known_class_config_map_t::const_iterator specific_config(
				m_known_class_config.find(sdUSB::SPECIFIC)
				);
		if( specific_config != m_known_class_config.end()  ||
				specific_config->second == sdUSB::AUTO ||
				specific_config->second == sdUSB::FORCE ) {
			// sdUSB::SPECIFICに設定がされていないか、
			// 設定されいてもAUTOかFORCE
			loadSpecificDrivers(loader, device);
		}
	}

	// FIXME: ModuleLoaderが不完全型なので、for_eachは使えない？
	for( std::list<std::string>::const_iterator mod(device.modules.begin()), end(device.modules.end());
			mod != end;
			++mod ) {
		loader.operator()(*mod);
	}
}


void sdUSBIMPL::showDetectedDevices(std::ostream& stream) const
{
	for( devices_t::const_iterator dev_pair(m_devices.begin()), end(m_devices.end());
			dev_pair != end;
			++dev_pair ) {
		const USBDevice& dev( dev_pair->second );
		// XXX: ストリームのフラグを保存して、出力が終わったらリストアする
		const std::string* vendor_name( m_ids.getVendorName(dev.vendor_id) );
		const std::string* device_name( m_ids.getDeviceName(dev.vendor_id, dev.device_id) );
		const std::string* baseclass_name( m_ids.getBaseClassName(dev.baseclass_id) );
		const std::string* subclass_name ( m_ids.getSubClassName (dev.baseclass_id, dev.subclass_id) );
		stream	<< "----"
			<< std::endl << std::setfill('0')
			<< "Vendor      0x" << std::hex << std::setw(4) << dev.vendor_id
			<< " : " << (vendor_name != NULL ? *vendor_name : "(unknown)")
			<< std::endl
			<< "Device      0x" << std::hex << std::setw(4) << dev.device_id
			<< " : " << (device_name != NULL ? *device_name : "(unknown)")
			<< std::endl
			<< "BaseClass   0x" << std::hex << std::setw(2) << dev.baseclass_id
			<< "   : " << (baseclass_name != NULL ? *baseclass_name : "(unknown)")
			<< std::endl
			<< "SubClass    0x" << std::hex << std::setw(2) << dev.subclass_id
			<< "   : " << (subclass_name  != NULL ? *subclass_name  : "(unknown)")
			<< std::endl
			<< "Manufacturer       : " << (dev.manufacturer != "" ? dev.manufacturer : "(unknown)")
			<< std::endl
			<< "Productname        : " << (dev.product_name != "" ? dev.product_name : "(unknown)")
			<< std::endl
			<< "Modules            : ";
		ModuleCollector collector;
		loadModules(collector, dev_pair->second);
		std::list<std::string> modules( collector.getList() );
		std::copy(modules.begin(), modules.end(), std::ostream_iterator<std::string>(stream, " "));
		stream << std::endl;
	}
}


}  // namespace SimpleDetect
