#pragma once

#include <sspi.h>
#include <wincrypt.h>
#include <schnlsp.h>
#include <winsock2.h>
#include <wincrypt.h>
#include <security.h>
#include <assert.h>

// Status of SSLEngine
enum SSLEngineStatus
{
  BUFFER_OVERFLOW,
  BUFFER_UNDERFLOW,
  CLOSED,
  OK,
};

/**
 * Handshake status of SSLEngine.
 */
enum SSLEngineHandshakeStatus
{
  FINISHED,
  NEED_TASK,
  NEED_UNWRAP,
  NEED_WRAP,
  NOT_HANDSHAKING,
};

/**
 * Task result of SSLEngine.
 */
enum SSLEngineTaskStatus
{
  TASK_OK,
  TASK_CERTIFICATE_VERIFY_ERROR,
  TASK_ILLEGAL_STATE_ERROR
};

/**
 * Handshake trigger
 */
enum SSLEngineHandshakeTrigger
{
  HELLO_REQUEST,
  CLOSE_NOTIFY_RECEIVED,
  OUTBOUND_CLOSED
};

/**
 * SSLEnginesʂ\NX
 */
class SSLEngineResult
{
private:
  SSLEngineStatus status;
  SSLEngineHandshakeStatus handshakeStatus;

  /**
   * Bytes consumed.
   */
  int consumed;

  /**
   * Bytes produced.
   */
  int produced;

  /**
   * Result of SSPI
   */
  HRESULT result;

public:
  SSLEngineResult()
    : status(OK),
      handshakeStatus(NOT_HANDSHAKING),
      consumed(0),
      produced(0),
      result(S_OK)
  {
  }

  inline void setValues(SSLEngineStatus status, SSLEngineHandshakeStatus handshakeStatus, int bytesConsumed, int bytesProduced)
  {
    this->status = status;
    this->handshakeStatus = handshakeStatus;
    this->consumed = bytesConsumed;
    this->produced = bytesProduced;
    this->result = S_OK;
  }

  inline void setError(HRESULT errorResult)
  {
    assert(FAILED(errorResult));
    this->result = errorResult;

    this->status = OK;
    this->handshakeStatus = NOT_HANDSHAKING;
    this->consumed = 0;
    this->produced = 0;

    assert(false);
  }

  inline HRESULT getResult() const
  {
    return this->result;
  }

  inline SSLEngineStatus getStatus() const
  {
    return this->status;
  }

  inline SSLEngineHandshakeStatus getHandshakeStatus() const
  {
    return this->handshakeStatus;
  }

  inline int bytesConsumed() const
  {
    return this->consumed;
  }

  inline int bytesProduced() const
  {
    return this->produced;
  }
};

/**
 * SSLEngine
 */
class SSLEngine
{
private:
  CRITICAL_SECTION          criticalSection;

  const _TCHAR*             peerHost;
  CredHandle*               phCredential;
  CtxtHandle*               phContext;

  static PSecurityFunctionTable    pFunctions;

  /**
   * NCAg[hœ삷ꍇtrue
   */
  bool                      clientMode;


  /**
   * o͑Ă邩
   */
  bool                      outboundDone;

  /**
   * ͑Ă邩
   */
  bool                      inboundDone;

  /**
   * nhVF[NXe[^X
   */
  SSLEngineHandshakeStatus  handshakeStatus;
  
  /**
   * ŌɎ{nhVF[ÑgK[
   */
  SSLEngineHandshakeTrigger lastHandshakeTrigger;

  /**
   * OInitializeSecurityContext()/AcceptSecurityContext() Ăяo
   */
  HRESULT                   securityStatus;

  /**
   * InitializeSecurityContext()/AcceptSecurityContext()̏o̓obt@
   */
  SecBuffer securityContextOutput[1];
  
  /**
   * wrap()̌
   */
  SSLEngineResult wrapResult;

  /**
   * unwrap()̌
   */
  SSLEngineResult unwrapResult;

  /**
   * LɂȂĂvgR
   */
  DWORD enabledProtocols;

  /**
   * gpĂvgR̃o[W
   */
  unsigned short protocolVersion;

  /**
   * ZbVID̒
   */
  unsigned char sessionIDLength;

  /**
   * ZbVID
   */
  unsigned char* sessionID;

  /**
   * gpĂ Cipher suite킷R[h
   */
  unsigned short cipherSuiteCode;

  /**
   * g̏ؖ
   */
  PCCERT_CONTEXT pLocalCertificate;

  /**
   * g̏ؖ`F[
   */
  PCCERT_CHAIN_CONTEXT pLocalCertificateChain;

  /**
   * 葤̏ؖ
   */
  PCCERT_CONTEXT pPeerCertificate;

  /**
   * 葤̏ؖ`F[
   */
  PCCERT_CHAIN_CONTEXT pPeerCertificateChain;

  /**
   * NCAgF؃x
   *
   * 0 - F؂Ȃ
   * 1 - F؂v (wantClientAuth)
   * 2 - F؂K{     (needClientAuth)
   */
  int clientAuthLevel;

  /**
   * Xg[obt@TCY
   */
  SecPkgContext_StreamSizes streamSizes;

  /**
   * ǂ̏ؖ擾
   */
  bool peerCertificateChainLoaded;

  /**
   * Ϗ^XN
   */
  int delegatedTaskCount;

  //-----------------------------------------------------------
  // o֐
  //-----------------------------------------------------------

  /**
   * bN
   */
  inline void acquireLock()
  {
    EnterCriticalSection(&this->criticalSection);
  }

  /**
   * bN
   */
  inline void releaseLock()
  {
    LeaveCriticalSection(&this->criticalSection);
  }

  /**
   * close_notifyA[g𐶐
   */
  void wrapCloseNotify();

  inline void setHandshakeStatus(SSLEngineHandshakeStatus stat)
  {
    this->handshakeStatus = stat;
  }

  /**
   * InitializeSecurityContext()̏Ăяos
   */
  void firstInitializeSecurityContext(void* destbuff, int destlen);

  /**
   * AcceptSecurityContext()̏Ăяos
   */
  void firstAcceptSecurityContext(void* srcbuff, int srclen);

  /**
   * 2ڈȍ~InitializeSecurityContext()/AcceptSecurityContext()̌Ăяo
   */
  void wrapSecurityContext(void* destbuff, int destlen);
  void unwrapSecurityContext(void* srcbuff, int srclen);

  /**
   * nhVF[NԂXV
   */
  bool updateHandshakeStatus();

  /**
   * nhVFCN{Ԃ
   */
  inline SSLEngineHandshakeTrigger getLastHandshakeTrigger() const
  {
    return this->lastHandshakeTrigger;
  }

  inline void setLastHandshakeTrigger(SSLEngineHandshakeTrigger trigger)
  {
    this->lastHandshakeTrigger = trigger;
  }

  inline void setOutboundDone(bool done)
  {
    this->outboundDone = done;
  }

  inline void setInboundDone(bool done)
  {
    this->inboundDone = done;
  }

  /**
   * AvP[Vf[^Í
   */
  void encryptApplicationData(void* srcbuff, int srcdata, void* destbuff, int destlen);

  /**
   * lbg[Nf[^𕜍
   */
  void decryptPacketData(void* srcbuff, int srcdata, void* destbuff, int destlen);

  /**
   * w肳ꂽf[^Server hellołꍇAprotocol versioncipher suite擾
   */
  bool parseServerHello(void* buff, int len);

  /**
   * 葤M close_notify 
   */
  void onCloseNotifyReceived(int bytesConsumed);

public:
  SSLEngine(const _TCHAR* host);
  ~SSLEngine();

  inline const _TCHAR* getPeerHost() const
  {
    return this->peerHost;
  }

  /**
   * AvP[Vf[^Í
   */
  const SSLEngineResult& wrap(void* srcbuff, int srclen, void* destbuff, int destlen);

  /**
   * Íꂽf[^𕜍B
   */
  const SSLEngineResult& unwrap(void* srcbuff, int srclen, void* destbuff, int destlen);

  inline SSLEngineHandshakeStatus getHandshakeStatus()
  {
    SSLEngineHandshakeStatus hs = this->handshakeStatus;
    if (FINISHED == hs)
    {
      // FINISHEDNOT_HANDSHAKINGɈڍs
      this->handshakeStatus = NOT_HANDSHAKING;
    }
    return hs;
  }

  /**
   * NCAgƂănhVF[Nsǂݒ肷
   *
   * @param  cm
   */
  inline void setUseClientMode(bool cm)
  {
    this->clientMode = cm;
  }

  /**
   * NCAgƂănhVF[NsǂԂ
   *
   * @return
   */
  inline bool getUseClientMode()
  {
    return this->clientMode;
  }

  /**
   * nhVF[NJn
   */
  HRESULT beginHandshake();

  /**
   * Ϗ^XN݂邩Ԃ
   */
  bool hasDelegatedTask();

  /**
   * Ϗ^XN݂ꍇÃ^XNs
   */
  SSLEngineTaskStatus executeDelegatedTask();

  /**
   * 葤ؖ擾
   */
  bool loadPeerCertificateChain();

  /**
   * 葤̏ؖ`F[Ԃ
   */
  inline PCCERT_CHAIN_CONTEXT getPeerCertificateChain() const
  {
    return this->pPeerCertificateChain;
  }

  /**
   * ̏ؖԂ
   */
  inline PCCERT_CONTEXT  getLocalCertificate() const
  {
    return this->pLocalCertificate;
  }

  /**
   * ̏ؖ`F[Ԃ
   */
  PCCERT_CHAIN_CONTEXT getLocalCertificateChain();

  /**
   * o͂N[Y
   */
  void closeOutbound();

  /**
   * ͂N[Y
   */
  bool closeInbound();

  inline bool isOutboundDone() const
  {
    return this->outboundDone;
  }

  inline bool isInboundDone() const
  {
    return this->inboundDone;
  }

  /**
   * p\ȃvgRݒ肷
   */
  inline void setEnabledProtocols(DWORD protocols)
  {
    this->enabledProtocols = protocols;
  }

  inline DWORD getEnabledProtocols() const
  {
    return this->enabledProtocols;
  }

  inline unsigned short getProtocolVersion() const
  {
    return this->protocolVersion;
  }

  inline unsigned short getCipherSuiteCode() const
  {
    return this->cipherSuiteCode;
  }

  inline unsigned char getSessionIDLength() const
  {
    return this->sessionIDLength;
  }

  inline unsigned char* getSessionID() const
  {
    return this->sessionID;
  }

  /**
   * Ǒ̏ؖݒ肷
   */
  void setLocalCertificate(PCCERT_CONTEXT localCert);

  inline void setWantClientAuth(bool auth)
  {
    this->clientAuthLevel = auth ? 1 : 0;
  }

  inline bool getWantClientAuth() const
  {
    return this->clientAuthLevel == 1;
  }

  inline void setNeedClientAuth(bool auth)
  {
    this->clientAuthLevel = auth ? 2 : 0;
  }

  inline bool getNeedClientAuth() const
  {
    return this->clientAuthLevel == 2;
  }

  /**
   * pPbgf[^TCYԂB
   * sȏꍇ 0 ԂB
   */
  inline unsigned long getPacketBufferSize() const
  {
    unsigned long message = this->streamSizes.cbMaximumMessage;
    if (0 < message && message < 16384)
    {
      // http://support.microsoft.com/?scid=kb%3Ben-us%3B300562&x=13&y=13
      // ɂƁAcbMaximumMessage̒ĺASSL̎dlł163845oCgȒl16379
      // ԂĂvbgtH[B
      message = 16384;
    }
    return this->streamSizes.cbHeader + message + this->streamSizes.cbTrailer;
  }

  /**
   * AvP[Vf[^obt@TCYԂ
   */
  inline unsigned long getApplicationBufferSize() const
  {
    unsigned long message = this->streamSizes.cbMaximumMessage;
    if (0 < message && message < 16384)
    {
      // http://support.microsoft.com/?scid=kb%3Ben-us%3B300562&x=13&y=13
      // ɂƁAcbMaximumMessage̒ĺASSL̎dlł163845oCgȒl16379
      // ԂĂvbgtH[B
      message = 16384;
    }
    return message;
  }
};