/* WCESSLSocket.java -- 
   Copyright (C) 2009 Mysaifu.com

This file is a part of GNU Classpath.

Mysaifu JVM is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or (at
your option) any later version.

Mysaifu JVM is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with Mysaifu JVM; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
USA

Linking this library statically or dynamically with other modules is
making a combined work based on this library.  Thus, the terms and
conditions of the GNU General Public License cover the whole
combination.

As a special exception, the copyright holders of this library give you
permission to link this library with independent modules to produce an
executable, regardless of the license terms of these independent
modules, and to copy and distribute the resulting executable under
terms of your choice, provided that you also meet, for each linked
independent module, the terms and conditions of the license of that
module.  An independent module is a module which is not derived from
or based on this library.  If you modify this library, you may extend
this exception to your version of the library, but you are not
obligated to do so.  If you do not wish to do so, delete this
exception statement from your version. */


package com.mysaifu.jvm.java.security.provider;

import java.io.DataInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.HashSet;
import java.util.Set;

import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;

/**
 * SSL socket implementation for Windows CE device.
 */
public class WCESSLSocket
    extends SSLSocket
{

  /**
   * OutputStream
   */
  private class SocketOutputStream
      extends OutputStream
  {
    private final ByteBuffer buffer;

    private final OutputStream out;

    SocketOutputStream() throws IOException
    {
      buffer = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);
      out = underlyingSocket.getOutputStream();
    }

    @Override
    public void write(byte[] buf, int off, int len) throws IOException
    {
      if (WCESSLSocket.this.socketStatus == SocketStatus.CREATION
          || engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING)
        {
          doHandshake();
          if (handshakeException != null)
            throw handshakeException;
        }

      int k = 0;
      while (k < len)
        {
          synchronized (engine)
            {
              int l = Math.min(len - k, getSession().getApplicationBufferSize());
              ByteBuffer in = ByteBuffer.wrap(buf, off + k, l);
              SSLEngineResult result = engine.wrap(in, buffer);
              if (result.getStatus() == Status.CLOSED)
                return;
              if (result.getStatus() != Status.OK)
                throw new SSLException("unexpected SSL state "
                                       + result.getStatus());
              buffer.flip();
              out.write(buffer.array(), 0, buffer.limit());
              k += result.bytesConsumed();
              buffer.clear();
            }
        }
    }

    @Override
    public void write(int b) throws IOException
    {
      write(new byte[] { (byte) b });
    }

    @Override
    public void close() throws IOException
    {
      if (! engine.isOutboundDone())
        {
          WCESSLSocket.this.close();
        }
    }
  }

  private class SocketInputStream
      extends InputStream
  {
    private final ByteBuffer inBuffer;

    private final ByteBuffer appBuffer;

    private final DataInputStream in;

    SocketInputStream() throws IOException
    {
      inBuffer = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);
      inBuffer.limit(0);
      appBuffer = ByteBuffer.allocate(getSession().getApplicationBufferSize());
      appBuffer.flip();
      in = new DataInputStream(underlyingSocket.getInputStream());
    }

    @Override
    public int read(byte[] buf, int off, int len) throws IOException
    {
      if (WCESSLSocket.this.socketStatus == SocketStatus.CREATION
          || engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING)
        {
          doHandshake();
          if (handshakeException != null)
            throw handshakeException;
        }

      if (! appBuffer.hasRemaining())
        {
          // ʃ\Pbg1oCgǂݍ
          int x = in.read();

          if (x == - 1)
            {
              return - 1;
            }

          // SSLȊÕvgRŐڑĂƍlꍇAO𔭐
          checkSSLType(x);

          // inBuffer.clear();
          // position()͕ύXĂ͂Ȃ...ĈŏC
          inBuffer.limit(inBuffer.capacity());
          // C܂
          inBuffer.put((byte) x);
          inBuffer.putInt(in.readInt());
          int reclen = inBuffer.getShort(3) & 0xFFFF;
          in.readFully(inBuffer.array(), 5, reclen);
          inBuffer.position(0).limit(reclen + 5);
          synchronized (engine)
            {
              appBuffer.clear();
              SSLEngineResult result = engine.unwrap(inBuffer, appBuffer);
              Status status = result.getStatus();
              if (status == Status.CLOSED && result.bytesProduced() == 0)
                return - 1;
            }
          inBuffer.compact();
          appBuffer.flip();
        }
      int l = Math.min(len, appBuffer.remaining());
      appBuffer.get(buf, off, l);
      return l;
    }

    @Override
    public int read() throws IOException
    {
      byte[] b = new byte[1];
      if (read(b) == - 1)
        return - 1;
      return b[0] & 0xFF;
    }

    @Override
    public void close() throws IOException
    {
      if (! engine.isInboundDone())
        {
          WCESSLSocket.this.close();
        }
    }
  }

  /**
   * SSLEngine.
   */
  private WCESSLEngine engine;

  /**
   * HandshakeCompletedListener s.
   */
  private Set<HandshakeCompletedListener> listeners;

  /**
   * Underlying socket object.
   */
  Socket underlyingSocket;

  /**
   * ̃\Pbg̃Xe[^X
   */
  enum SocketStatus
  {
    CREATION, INITIAL_HANDSHAKE, APPLICATION_DATA, REHANDSHAKE, CLOSING, CLOSED
  };

  /**
   * nhVFCN
   */
  private boolean isHandshaking;

  /**
   * nhVFCNɔO
   * ʃXbhŏ邽߁ÃoϐɊi[
   */
  private IOException handshakeException;

  /**
   * ̃\Pbg̏
   */
  private SocketStatus socketStatus = SocketStatus.CREATION;

  /**
   * ̃\PbgƂɁAz̃\PbgN[Y邩
   */
  private final boolean autoClose;

  public WCESSLSocket(WCESSLContextSpi contextSpi, String host, int port)
  {
    this(contextSpi, host, port, new Socket(), true);
  }

  public WCESSLSocket(WCESSLContextSpi contextSpi, String host, int port,
                      Socket underlyingSocket, boolean autoClose)
  {
    engine = new WCESSLEngine(contextSpi, this, host, port);
    engine.setUseClientMode(true); // default to client mode
    listeners = new HashSet<HandshakeCompletedListener>();

    if (underlyingSocket == null)
      {
        throw new NullPointerException("underlyingSocket==null");
      }
    this.underlyingSocket = underlyingSocket;
    this.autoClose = autoClose;
  }

  @Override
  public void addHandshakeCompletedListener(HandshakeCompletedListener listener)
  {
    listeners.add(listener);
  }

  @Override
  public boolean getEnableSessionCreation()
  {
    return engine.getEnableSessionCreation();
  }

  @Override
  public String[] getEnabledCipherSuites()
  {
    return engine.getEnabledCipherSuites();
  }

  @Override
  public String[] getEnabledProtocols()
  {
    return engine.getEnabledProtocols();
  }

  @Override
  public boolean getNeedClientAuth()
  {
    return engine.getNeedClientAuth();
  }

  @Override
  public SSLSession getSession()
  {
    return engine.getSession();
  }

  @Override
  public String[] getSupportedCipherSuites()
  {
    return engine.getSupportedCipherSuites();
  }

  @Override
  public String[] getSupportedProtocols()
  {
    return engine.getSupportedProtocols();
  }

  @Override
  public boolean getUseClientMode()
  {
    return engine.getUseClientMode();
  }

  @Override
  public boolean getWantClientAuth()
  {
    return engine.getWantClientAuth();
  }

  @Override
  public void removeHandshakeCompletedListener(
                                               HandshakeCompletedListener listener)
  {
    listeners.remove(listener);
  }

  @Override
  public void setEnableSessionCreation(boolean enable)
  {
    engine.setEnableSessionCreation(enable);
  }

  @Override
  public void setEnabledCipherSuites(String[] suites)
  {
    engine.setEnabledCipherSuites(suites);
  }

  @Override
  public void setEnabledProtocols(String[] protocols)
  {
    engine.setEnabledProtocols(protocols);
  }

  @Override
  public void setNeedClientAuth(boolean needAuth)
  {
    engine.setNeedClientAuth(needAuth);
  }

  @Override
  public void setUseClientMode(boolean clientMode)
  {
    engine.setUseClientMode(clientMode);
  }

  @Override
  public void setWantClientAuth(boolean wantAuth)
  {
    engine.setWantClientAuth(wantAuth);
  }

  @Override
  public void startHandshake() throws IOException
  {
    if (isHandshaking)
      return;

    if (handshakeException != null)
      throw handshakeException;

    Thread t = new Thread(new Runnable()
    {
      public void run()
      {
        try
          {
            doHandshake();
          }
        catch (IOException ioe)
          {
            handshakeException = ioe;
          }
      }
    }, "HandshakeThread@" + System.identityHashCode(this));
    t.start();

    // ҂
    try
      {
        t.join();
      }
    catch (InterruptedException ie)
      {
        ie.printStackTrace();
      }
  }

  /**
   * nhVFCNs
   */
  void doHandshake() throws IOException
  {
    if (handshakeException != null)
      throw handshakeException;

    synchronized (engine)
      {
        if (isHandshaking)
          {
            try
              {
                engine.wait();
              }
            catch (InterruptedException ie)
              {
              }
            return;
          }
        isHandshaking = true;
      }

    if (this.socketStatus == SocketStatus.REHANDSHAKE)
      {
        throw new SSLException("rehandshaking not supported");
      }

    long now = - System.currentTimeMillis();

    // nhVF[NJn
    if (this.socketStatus == SocketStatus.CREATION)
      {
        engine.beginHandshake();
        this.socketStatus = SocketStatus.INITIAL_HANDSHAKE;
      }

    HandshakeStatus status = engine.getHandshakeStatus();
    assert (status != HandshakeStatus.NOT_HANDSHAKING);

    // nhVFCNpobt@
    ByteBuffer inBuffer = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);
    inBuffer.position(inBuffer.limit());
    ByteBuffer outBuffer = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);
    ByteBuffer emptyBuffer = ByteBuffer.allocate(0);
    SSLEngineResult result = null;

    DataInputStream sockIn = new DataInputStream(
                                                 underlyingSocket.getInputStream());
    OutputStream sockOut = underlyingSocket.getOutputStream();

    try
      {
        while (status != HandshakeStatus.NOT_HANDSHAKING
               && status != HandshakeStatus.FINISHED)
          {
            if (inBuffer.capacity() != getSession().getPacketBufferSize())
              {
                ByteBuffer b = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);
                if (inBuffer.hasRemaining())
                  b.put(inBuffer).flip();
                inBuffer = b;
              }
            if (outBuffer.capacity() != getSession().getPacketBufferSize())
              outBuffer = ByteBuffer.wrap(new byte[getSession().getPacketBufferSize()]);

            switch (status)
              {
              case NEED_UNWRAP:
              {
                if (this.socketStatus == SocketStatus.CLOSING && this.autoClose)
                  {
                    // ڐؒfĂȂi͂j
                    status = HandshakeStatus.NOT_HANDSHAKING;
                    break;
                  }

                // Read in a single SSL record.
                inBuffer.clear();
                int i = sockIn.read();
                if (i == - 1)
                  {
                    throw new EOFException();
                  }
                if ((i & 0x80) == 0x80) // SSLv2 client hello.
                  {
                    inBuffer.put((byte) i);
                    int v2len = (i & 0x7f) << 8;
                    i = sockIn.read();
                    v2len = v2len | (i & 0xff);
                    inBuffer.put((byte) i);
                    sockIn.readFully(inBuffer.array(), 2, v2len);
                    inBuffer.position(0).limit(v2len + 2);
                  }
                else
                  {
                    checkSSLType(i);

                    inBuffer.put((byte) i);
                    inBuffer.putInt(sockIn.readInt());
                    int reclen = inBuffer.getShort(3) & 0xFFFF;
                    sockIn.readFully(inBuffer.array(), 5, reclen);
                    inBuffer.position(0).limit(reclen + 5);
                  }
                result = engine.unwrap(inBuffer, emptyBuffer);
                status = result.getHandshakeStatus();
                final Status st = result.getStatus();
                if (st != Status.OK && st != Status.CLOSED)
                  {
                    throw new SSLException("unexpected SSL status "
                                           + result.getStatus());
                  }
              }
                break;

              case NEED_WRAP:
              {
                outBuffer.clear();
                result = engine.wrap(emptyBuffer, outBuffer);
                status = result.getHandshakeStatus();
                final Status st = result.getStatus();
                if (st != Status.OK && st != Status.CLOSED)
                  {
                    throw new SSLException("unexpected SSL status "
                                           + result.getStatus());
                  }
                outBuffer.flip();
                sockOut.write(outBuffer.array(), outBuffer.position(),
                              outBuffer.limit());
              }
                break;

              case NEED_TASK:
              {
                Runnable task;
                while ((task = engine.getDelegatedTask()) != null)
                  {
                    task.run();
                  }
                status = engine.getHandshakeStatus();
              }
                break;

              case FINISHED:
              case NOT_HANDSHAKING:
                break;
              }
          }

        // ԂXV
        this.socketStatus = (this.socketStatus == SocketStatus.INITIAL_HANDSHAKE) ? SocketStatus.APPLICATION_DATA
                                                                                 : SocketStatus.CLOSED;

        if (this.socketStatus != SocketStatus.CLOSED)
          {
            HandshakeCompletedEvent hce = new HandshakeCompletedEvent(
                                                                      this,
                                                                      getSession());
            for (HandshakeCompletedListener l : listeners)
              {
                try
                  {
                    l.handshakeCompleted(hce);
                  }
                catch (ThreadDeath td)
                  {
                    throw td;
                  }
                catch (Throwable x)
                  {
                  }
              }
          }

        now += System.currentTimeMillis();
      }
    catch (SSLException ssle)
      {
        handshakeException = ssle;
        throw ssle;
      }
    finally
      {
        synchronized (engine)
          {
            isHandshaking = false;
            engine.notifyAll();
          }
      }
  }

  // Methods overriding Socket.

  @Override
  public void bind(SocketAddress bindpoint) throws IOException
  {
    underlyingSocket.bind(bindpoint);
  }

  @Override
  public void connect(SocketAddress endpoint) throws IOException
  {
    underlyingSocket.connect(endpoint);
  }

  @Override
  public void connect(SocketAddress endpoint, int timeout) throws IOException
  {
    underlyingSocket.connect(endpoint, timeout);
  }

  @Override
  public InetAddress getInetAddress()
  {
    return underlyingSocket.getInetAddress();
  }

  @Override
  public InetAddress getLocalAddress()
  {
    return underlyingSocket.getLocalAddress();
  }

  @Override
  public int getPort()
  {
    return underlyingSocket.getPort();
  }

  @Override
  public int getLocalPort()
  {
    return underlyingSocket.getLocalPort();
  }

  @Override
  public SocketAddress getRemoteSocketAddress()
  {
    return underlyingSocket.getRemoteSocketAddress();
  }

  public SocketAddress getLocalSocketAddress()
  {
    return underlyingSocket.getLocalSocketAddress();
  }

  @Override
  public SocketChannel getChannel()
  {
    throw new UnsupportedOperationException(
                                            "use javax.net.ssl.SSLEngine for NIO");
  }

  @Override
  public InputStream getInputStream() throws IOException
  {
    return new SocketInputStream();
  }

  @Override
  public OutputStream getOutputStream() throws IOException
  {
    return new SocketOutputStream();
  }

  @Override
  public void setTcpNoDelay(boolean on) throws SocketException
  {
    underlyingSocket.setTcpNoDelay(on);
  }

  @Override
  public boolean getTcpNoDelay() throws SocketException
  {
    return underlyingSocket.getTcpNoDelay();
  }

  @Override
  public void setSoLinger(boolean on, int linger) throws SocketException
  {
    underlyingSocket.setSoLinger(on, linger);
  }

  public int getSoLinger() throws SocketException
  {
    return underlyingSocket.getSoLinger();
  }

  @Override
  public void sendUrgentData(int x) throws IOException
  {
    throw new UnsupportedOperationException("not supported");
  }

  @Override
  public void setOOBInline(boolean on) throws SocketException
  {
    underlyingSocket.setOOBInline(on);
  }

  @Override
  public boolean getOOBInline() throws SocketException
  {
    return underlyingSocket.getOOBInline();
  }

  @Override
  public void setSoTimeout(int timeout) throws SocketException
  {
    underlyingSocket.setSoTimeout(timeout);
  }

  @Override
  public int getSoTimeout() throws SocketException
  {
    return underlyingSocket.getSoTimeout();
  }

  @Override
  public void setSendBufferSize(int size) throws SocketException
  {
    underlyingSocket.setSendBufferSize(size);
  }

  @Override
  public int getSendBufferSize() throws SocketException
  {
    return underlyingSocket.getSendBufferSize();
  }

  @Override
  public void setReceiveBufferSize(int size) throws SocketException
  {
    underlyingSocket.setReceiveBufferSize(size);
  }

  @Override
  public int getReceiveBufferSize() throws SocketException
  {
    return underlyingSocket.getReceiveBufferSize();
  }

  @Override
  public void setKeepAlive(boolean on) throws SocketException
  {
    underlyingSocket.setKeepAlive(on);
  }

  @Override
  public boolean getKeepAlive() throws SocketException
  {
    return underlyingSocket.getKeepAlive();
  }

  @Override
  public void setTrafficClass(int tc) throws SocketException
  {
    underlyingSocket.setTrafficClass(tc);
  }

  @Override
  public int getTrafficClass() throws SocketException
  {
    return underlyingSocket.getTrafficClass();
  }

  @Override
  public void setReuseAddress(boolean reuseAddress) throws SocketException
  {
    underlyingSocket.setReuseAddress(reuseAddress);
  }

  @Override
  public boolean getReuseAddress() throws SocketException
  {
    return underlyingSocket.getReuseAddress();
  }

  @Override
  public void close() throws IOException
  {
    // ŏɏo͑ (close_notify𑗐M邽߁j
    engine.closeOutbound();

    // nhVFCNĎs
    this.socketStatus = SocketStatus.CLOSING;

    try
      {
        doHandshake();
      }
    catch (IOException e)
      {
        // 肪Aclose_notify 𑗂炸ɒʐMIĂ܂
        if (autoClose)
          {
            // SSL/TLSIɃRlNVN[Y邾Ȃ̂ł΁Â܂܏IĂȂ
          }
        else
          {
            throw e;
          }
      }

    if (autoClose)
      {
        underlyingSocket.close();
      }

    // lCeBu\[X
    engine.close();
  }

  @Override
  public void shutdownInput() throws IOException
  {
    underlyingSocket.shutdownInput();
  }

  @Override
  public void shutdownOutput() throws IOException
  {
    underlyingSocket.shutdownOutput();
  }

  @Override
  public boolean isConnected()
  {
    return underlyingSocket.isConnected();
  }

  @Override
  public boolean isBound()
  {
    return underlyingSocket.isBound();
  }

  @Override
  public boolean isClosed()
  {
    return underlyingSocket.isClosed();
  }

  @Override
  public boolean isInputShutdown()
  {
    return underlyingSocket.isInputShutdown();
  }

  @Override
  public boolean isOutputShutdown()
  {
    return underlyingSocket.isOutputShutdown();
  }

  /**
   * w肳ꂽoCgASSL typeƂđÓǂ؂
   */
  private void checkSSLType(int b) throws SSLException
  {
    // change_cipher_spec 0x14
    // alsert 0x15
    // handshake 0x16
    // application_data 0x17
    if (! (b >= 0x14 && b <= 0x17))
      {
        throw new SSLException(
                               "Unrecognized SSL message, plaintext connection?");
      }
  }
}
