
#include <sched.h>

#include <vector>
#include <iostream>

#include "gmock/gmock.h"

#include "cm_thread.h"
#include "mt_auto_ptr.h"
#include "cm_socket_if.h"
#include "cm_socket_server.h"
#include "cm_socket_client.h"
#include "cm_event.h"
#include "mt_range.h"
#include "cm_mutex.h"

namespace {

struct FindSocket
{
    FindSocket(cm::SocketIf& socket_if)
        : socket_if_(socket_if)
    {}

    template <typename InputIterator>
    InputIterator operator()(InputIterator begin, const InputIterator& end)
    {
        for (; begin != end; ++begin) {
            if (*begin == &socket_if_) {
                return begin;
            }
        }
        return end;
    }

    const cm::SocketIf& socket_if_;
};

static const char* read_data_table[] = {
    "hello world!",
    "good bye world!",
    "invalid"
};

class SocketIfMock : public cm::SocketIf
{
public:
    SocketIfMock(cm::SocketIf* socket_if)
        : socket_if_(socket_if)
    {
        EXPECT_CALL(*this, read(::testing::_, ::testing::_, ::testing::_))
            .WillOnce(::testing::Invoke(this, &SocketIfMock::doHandleRead))   // data read
            .WillOnce(::testing::Invoke(this, &SocketIfMock::doHandleDisconnection));
    }

    bool doHandleRead(size_t& bytes_read, void* buf, size_t size_to_read)
    {
        bool ret = socket_if_->read(bytes_read, buf, size_to_read);
        EXPECT_EQ(ret, true);
        EXPECT_STREQ(reinterpret_cast<char*>(buf), read_data_table[mock_creation_count_++]);
        return ret;
    }

    bool doHandleDisconnection(size_t& bytes_read, void* buf, size_t size_to_read)
    {
        bool ret = socket_if_->read(bytes_read, buf, size_to_read);
        EXPECT_EQ(ret, false);
        return ret;
    }

    MOCK_METHOD3(read, bool(size_t& bytes_read, void* buf, size_t size_to_read));
    MOCK_METHOD3(write, bool(size_t& bytes_read, const void* buf, size_t size_to_write));
    MOCK_METHOD0(release, int());

    virtual int getFD() const
    {
        return socket_if_->getFD();
    }

    ~SocketIfMock()
    {
        delete socket_if_;
    }

private:
    SocketIf* socket_if_;
    static unsigned int mock_creation_count_;

    MOCK_CONST_METHOD1(doClone, SocketIf*(int fd));
};

unsigned int SocketIfMock::mock_creation_count_ = 0;

class CmSocketServerThread
{
public:
    CmSocketServerThread(cm::Mutex* mutex)
        : server_(cm::SOCKET_TYPE_INET_STREAM, "0.0.0.0", 8888),
          event_(cm::Event::tsdInstance()),
          accepted_sockets_()
    {
        cm::Mutex::Lock lock(*mutex);
        event_.addHandlerRead(*this, &CmSocketServerThread::accept, server_);
    }

    bool accept(cm::SocketServer& server)
    {
        std::cout << __FUNCTION__ << " entry" << std::endl;
        mt::AutoPtr<cm::SocketIf> sock = server.accept();
        cm::SocketIf* socket_ptr = sock.get();
        SocketIfMock* socket_if_mock_ptr = new SocketIfMock(socket_ptr);
        event_.addHandlerRead(*this, &CmSocketServerThread::readSocket, *socket_if_mock_ptr);
        accepted_sockets_.push_back(socket_if_mock_ptr);
        sock.release();
        return false;
    }

    ~CmSocketServerThread()
    {
        EXPECT_EQ(accepted_sockets_.size(), 0u);
    }

    bool readSocket(cm::SocketIf& socket)
    {
        std::cout << __FUNCTION__ << " entry" << std::endl;

        char buffer[20];
        size_t bytes_read = 0u;
        if (socket.read(bytes_read, buffer, sizeof(buffer))) {
            std::cout << "received " << buffer << std::endl;
        }
        else {
            std::cout << "closing socket " << socket.getFD() << std::endl;

            event_.delHandlerRead(socket);
            FindSocket socket_finder(socket);
            std::vector<cm::SocketIf*>::iterator it = socket_finder(mt::begin(accepted_sockets_),
                                                                    mt::end(accepted_sockets_));
            delete *it;
            accepted_sockets_.erase(it);
        }
        return false;
    }

    void run()
    {
        bool is_started = false;
        while (!is_started || accepted_sockets_.size() > 0) {
            event_.pend();
            is_started = true;
        }
    }

private:
    cm::SocketServer server_;
    cm::Event& event_;
    std::vector<cm::SocketIf*> accepted_sockets_;
};

TEST(CmSocketServerTest, inet_anyaddr)
{
    cm::Mutex mutex;
    cm::Thread<CmSocketServerThread, cm::Mutex> server_thread("server_thread");
    server_thread.create(&mutex);

    {
        sched_yield();
        cm::Mutex::Lock lock(mutex);
        cm::SocketClient client(cm::SOCKET_TYPE_INET_STREAM);

        mt::AutoPtr<cm::SocketIf> sock = client.connect("127.0.0.1", 8888);
        EXPECT_NE(sock.get(), static_cast<cm::SocketIf*>(0));

        char buffer[20];
        memset(buffer, 0, sizeof(buffer));
        strcpy(buffer, "hello world!");
        size_t bytes_written = 0u;
        sock->write(bytes_written, buffer, sizeof(buffer));
        sock.reset();

        memset(buffer, 0, sizeof(buffer));
        strcpy(buffer, "good bye world!");

        sock = client.connect("127.0.0.1", 8888);
        EXPECT_NE(sock.get(), static_cast<cm::SocketIf*>(0));

        sock->write(bytes_written, buffer, sizeof(buffer));
        sock.reset();
    }

    server_thread.join();
}

} // namespace
