#include "nbd_processor.h"
#include "interface/interface.h"
#include "stream/autosock.h"
#include "log.h"
#include "exception.h"

namespace VFIELD {
namespace NBD{


namespace {
////
// NBD settings
//
static const uint64_t cliserv_magic = 0x00420281861253LL;
static const char* INIT_PASSWD = "NBDMAGIC";


/////
// Exception
//
struct NegotiationException : public SystemCallException {
	NegotiationException(int errno_, const std::string& message) :
		SystemCallException(errno_, message) {}
};
struct BrokenRequestException : public std::runtime_error {
	BrokenRequestException(const std::string& message) :
		std::runtime_error(message) {}
};


////
// AutoSock function
//
struct DoNothing : public AutoSockFinalizer {
	void operator() (int sock) throw() {}
};


////
// Helper functions
//
int send_reply_error(int sock, const nbd_request& req) throw()
{
	nbd_reply reply;
	reply.magic = htonl(NBD_REPLY_MAGIC);
	reply.error = htonl(0xffffffff);
	memcpy(&reply.handle, req.handle, sizeof(reply.handle));
	return ::write(sock, &reply, sizeof(reply));
}
}  // noname namespadce


int Negotiate(int sock, boost::mutex* sock_mutex, uint64_t image_size) throw()
{
	char zeros[128];
	uint64_t net_magic = htonll(cliserv_magic);
	uint64_t net_size  = htonll(image_size);

	boost::mutex::scoped_lock lk(*sock_mutex);
	if( ::write(sock, INIT_PASSWD,	8			)	< 0 ) return -1;
	if( ::write(sock, &net_magic,	sizeof(net_magic)	)	< 0 ) return -2;
	if( ::write(sock, &net_size,	sizeof(net_size)	)	< 0 ) return -3;
	if( ::write(sock, zeros,	128			)	< 0 ) return -4;

	return 0;
}


void WriteProcessor(int sock, boost::mutex* sock_mutex, const struct nbd_request& req) throw()
{
	LogWarn("Receive NBD write request, shared image is un-writable");

	uint32_t len  = ntohl(req.len);

	char buf[len];
	size_t rl = 0;

	boost::mutex::scoped_lock lk(*sock_mutex);
	do {
		ssize_t rl_tmp = ::read(sock, buf, len);	// 捨てる
		if( rl_tmp < 0 ) {
			LogWarn("NBD read error");
			break;
		}
		rl += rl_tmp;
	} while(rl < len);
	send_reply_error(sock, req);	// エラー無視
}

void DisconnectProcessor(int sock, boost::mutex* sock_mutex, const struct nbd_request& req) throw()
{
	LogInfo("Closing NBD connection");
}

void ReadProcessor(int sock, boost::mutex* sock_mutex, const struct nbd_request& req, Interface* vfif)
{
	try {
		if( unlikely(ntohl(req.magic) != NBD_REQUEST_MAGIC) ) {
			throw BrokenRequestException("Receive broken request");
		}

		AutoSock asock = makeAutoSock(sock, DoNothing());
		uint64_t from = ntohll(req.from);
		uint32_t len  = ntohl(req.len);

		LogDebug0( Log::format("Receive NBD read reaquest, from = %1% size = %2%") % from % len );

		boost::mutex::scoped_lock lk(*sock_mutex);

		if( unlikely(from + len > vfif->getImageSize()) ) {
			// リクエストがイメージサイズを超過
			LogWarn("NBD read request range exceeds image size");
			send_reply_error(sock, req);	// エラー無視
			return;
		}

		// 返答を作成
		uint32_t rep_len = len + sizeof(struct nbd_reply);
		char buf[ rep_len ];
		((struct nbd_reply*)buf)->magic = htonl( NBD_REPLY_MAGIC );
		((struct nbd_reply*)buf)->error = 0;
		::memcpy( &((struct nbd_reply*)buf)->handle, &req.handle, sizeof(((struct nbd_reply*)buf)->handle) );

		// データ取得
		vfif->getData(&buf[sizeof(struct nbd_reply)], from, len);

		uint32_t rl = 0;
		do {
			rl += asock.write(buf + rl, rep_len - rl);
		} while(rl < rep_len);

	} catch (const std::runtime_error& ex) {
		LogWarn( Log::format("Failed to process a NBD request: %1%") % ex.what());
		send_reply_error(sock, req);
	} catch (...) {
		LogWarn("Failed to process a NBD request");
		send_reply_error(sock, req);
	}
}

}  // namespace NBD
}  // namespace VFIELD
