package org.maachang.crypto ;

import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Map;

/**
 * RSA暗号、復号用キー情報.
 *
 * @version 2012/08/09
 * @author  masahito suzuki
 * @since   Crypto 1.00
 */
public class CryptoRsaKey {
    private static final String RANDOM_TYPE = "SHA1PRNG" ;
    private static final ThreadLocal<SecureRandom> RAND = new ThreadLocal<SecureRandom>() ;
    
    /** スレッド毎に乱数を確保. **/
    public static final SecureRandom getRand()
        throws Exception {
        SecureRandom ret = RAND.get() ;
        if( ret == null ) {
            ret = SecureRandom.getInstance( RANDOM_TYPE ) ;
            RAND.set( ret ) ;
        }
        return ret ;
    }
    
    protected BigInteger n = null ;
    protected int e = 0 ;
    protected BigInteger d = null ;
    protected BigInteger p = null ;
    protected BigInteger q = null ;
    protected BigInteger dmp1 = null ;
    protected BigInteger dmq1 = null ;
    protected BigInteger coeff = null ;
    
    /**
     * コンストラクタ.
     */
    public CryptoRsaKey() {
        n = null ;
        e = 0 ;
        d = null ;
        p = null ;
        q = null ;
        dmp1 = null ;
        dmq1 = null ;
        coeff = null ;
    }
    
    protected static final BigInteger pkcs1pad2( String s,int n )
        throws Exception {
        if (n < s.length() + 11) {
            throw new CryptoRsaException( "Message too long for RSA (n=" + n + ", l=" + s.length() + ")" ) ;
        }
        byte[] ba = new byte[ n ] ;
        int i = s.length() - 1;
        int c ;
        while (i >= 0 && n > 0) {
            c = ( int )s.charAt( i-- ) ;
            if( c < 128 ) {
                ba[--n] = (byte)c ;
            }
            else if( c > 127 && c < 2048 ) {
                ba[--n] = (byte)( (c & 63) | 128 );
                ba[--n] = (byte)( (c >> 6) | 192 ) ;
            }
            else {
                ba[--n] = (byte)( (c & 63) | 128 ) ;
                ba[--n] = (byte)( ((c >> 6) & 63) | 128 ) ;
                ba[--n] = (byte)( (c >> 12) | 224 ) ;
            }
        }
        ba[--n] = 0;
        SecureRandom rng = getRand() ;
        byte[] x = new byte[ 1 ] ;
        while (n > 2) {
            x[0] = 0;
            while (x[0] == 0) {
                rng.nextBytes( x );
            }
            ba[--n] = x[0];
        }
        ba[--n] = 2 ;
        ba[--n] = 0 ;
        return new BigInteger(1,ba);
    }
    
    protected static final String pkcs1unpad2( BigInteger d,int n )
        throws Exception {
        byte[] b = d.toByteArray() ;
        int bLen = b.length ;
        int i = 0 ;
        while( i < bLen && b[ i ] == 0 ) {
            i = i + 1 ;
        }
        if( bLen - i != n - 1 || b[ i ] != 2 ) {
            return null ;
        }
        i = i + 1 ;
        while( b[ i ] != 0 ) {
            if( ( i = i + 1 ) >= bLen ) {
                return null ;
            }
        }
        int c ;
        StringBuilder buf = new StringBuilder( bLen ) ;
        while( ( i = i + 1 ) < bLen ) {
            c = ( int )( b[ i ] & 0x000000ff ) ;
            if( c < 128 ) {
                buf.append( (char)c ) ;
            }
            else if( c > 191 && c < 224 ) {
                buf.append( (char)( ((c & 31) << 6) | (b[i + 1] & 63) ) ) ;
                i = i + 1 ;
            }
            else {
                buf.append( (char)( ((c & 15) << 12) | ((b[i + 1] & 63) << 6) | (b[i + 2] & 63) ) ) ;
                i = i + 2 ;
            }
        }
        return buf.toString() ;
    }
    
    /**
     * パブリック情報を設定.
     * @param n N条件を設定.
     * @param e E条件を16進数表現で設定.
     * @exception Exception 例外.
     */
    public void setPublic( byte[] n,String e )
        throws Exception {
        if( n == null || n.length <= 0 || e == null || e.length() <= 0 ) {
            throw new IllegalArgumentException( "引数は不正です" ) ;
        }
        this.n = new BigInteger( 1,n ) ;
        this.e = Integer.parseInt( e,16 ) ;
    }
    
    /**
     * パブリック情報を設定.
     * @param n N条件を16進数表現で設定.
     * @param e E条件を16進数表現で設定.
     * @exception Exception 例外.
     */
    public void setPublic( String n,String e )
        throws Exception {
        if( n == null || n.length() <= 0 || e == null || e.length() <= 0 ) {
            throw new IllegalArgumentException( "引数は不正です" ) ;
        }
        this.n = new BigInteger( n,16 ) ;
        this.e = Integer.parseInt( e,16 ) ;
    }
    
    protected BigInteger doPublic( BigInteger x )
        throws Exception {
        return x.modPow( new BigInteger( String.valueOf( this.e ) ),this.n ) ;
    }
    
    /**
     * プライベート条件を設定.
     * @param n N条件を16進数表現で設定.
     * @param e E条件を16進数表現で設定.
     * @param d D条件を16進数表現で設定.
     * @exception Exception 例外.
     */
    public void setPrivate( String n,String e,String d )
        throws Exception {
        if( n == null || n.length() <= 0 || e == null || e.length() <= 0 ||
            d == null || d.length() <= 0 ) {
            throw new IllegalArgumentException( "引数は不正です" ) ;
        }
        this.n = new BigInteger( n,16 ) ;
        this.e = Integer.parseInt( e,16 ) ;
        this.d = new BigInteger( d,16 ) ;
    }
    
    /**
     * プライベート条件を設定.
     * @param n N条件を16進数表現で設定.
     * @param e E条件を16進数表現で設定.
     * @param d D条件を16進数表現で設定.
     * @param p P条件を16進数表現で設定.
     * @param q Q条件を16進数表現で設定.
     * @param dp DP条件を16進数表現で設定.
     * @param dq DQ条件を16進数表現で設定.
     * @param c C条件を16進数表現で設定.
     * @exception Exception 例外.
     */
    public void setPrivateEx( String n,String e,String d,String p,String q,String dp,String dq,String c )
        throws Exception {
        if( n == null || n.length() <= 0 || e == null || e.length() <= 0 ||
            d == null || d.length() <= 0 || p == null || p.length() <= 0 ||
            d == null || d.length() <= 0 || p == null || p.length() <= 0 ||
            q == null || q.length() <= 0 || dp == null || dp.length() <= 0 ||
            dq == null || dq.length() <= 0 || c == null || c.length() <= 0 ) {
            throw new IllegalArgumentException( "引数は不正です" ) ;
        }
        this.n = new BigInteger( n,16 ) ;
        this.e = Integer.parseInt( e,16 ) ;
        this.d = new BigInteger( d,16 ) ;
        this.p = new BigInteger( p,16 ) ;
        this.q = new BigInteger( q,16 ) ;
        this.dmp1 = new BigInteger( dp,16 ) ;
        this.dmq1 = new BigInteger( dq,16 ) ;
        this.coeff = new BigInteger( c,16 ) ;
    }
    
    protected BigInteger doPrivate( BigInteger x )
        throws Exception {
        if( this.p == null || this.q == null ) {
            return x.modPow( this.d,this.n ) ;
        }
        BigInteger xp = x.mod( this.p ).modPow( this.dmp1, this.p ) ;
        BigInteger xq = x.mod( this.q ).modPow( this.dmq1, this.q ) ;
        while ( xp.compareTo(xq) < 0 ) {
            xp = xp.add(this.p);
        }
        return xp.subtract(xq).multiply( this.coeff ).mod( this.p ).multiply( this.q ).add( xq ) ;
    }
    
    protected void generate( int b,String e )
        throws Exception {
        SecureRandom rng = getRand() ;
        int qs = b >> 1 ;
        this.e = Integer.parseInt( e,16 ) ;
        BigInteger ee = new BigInteger( e,16 ) ;
        BigInteger t,p1,q1,phi ;
        while( true ) {
            while( true ) {
                this.p = new BigInteger( b - qs,1,rng ) ;
                if( this.p.subtract( BigInteger.ONE ).gcd( ee ).compareTo( BigInteger.ONE ) == 0 && this.p.isProbablePrime( 10 ) ) {
                    break ;
                }
            }
            while( true ) {
                this.q = new BigInteger( qs,1,rng ) ;
                if( this.q.subtract( BigInteger.ONE ).gcd( ee ).compareTo( BigInteger.ONE ) == 0 && this.q.isProbablePrime( 10 ) ) {
                    break ;
                }
            }
            if( this.p.compareTo( this.q ) <= 0 ) {
                t = this.p ;
                this.p = this.q ;
                this.q = t ;
            }
            p1 = this.p.subtract( BigInteger.ONE ) ;
            q1 = this.q.subtract( BigInteger.ONE ) ;
            phi = p1.multiply( q1 ) ;
            if( phi.gcd( ee ).compareTo( BigInteger.ONE ) == 0 ) {
                this.n = this.p.multiply( this.q ) ;
                this.d = ee.modInverse( phi ) ;
                this.dmp1 = this.d.mod( p1 ) ;
                this.dmq1 = this.d.mod( q1 ) ;
                this.coeff = this.q.modInverse( this.p ) ;
                break ;
            }
        }
    }
    
    protected byte[] encrypt( String target )
        throws Exception {
        BigInteger m = CryptoRsaKey.pkcs1pad2( target,( this.n.bitLength() + 7 ) >> 3 ) ;
        BigInteger c = this.doPublic( m ) ;
        return c.toByteArray() ;
    }
    
    protected String decrypt( byte[] target )
        throws Exception {
        BigInteger c = new BigInteger( 1,target ) ;
        BigInteger m = this.doPrivate( c ) ;
        return CryptoRsaKey.pkcs1unpad2( m,( this.n.bitLength() + 7 ) >> 3 ) ;
    }
    
    
    /**
     * Signature.
     */
    
    private static final byte[] _RSASIGN_DIHEAD_SHA1 = { (byte)0x30,(byte)0x21,(byte)0x30,(byte)0x09,
        (byte)0x06,(byte)0x05,(byte)0x2b,(byte)0x0e,(byte)0x03,(byte)0x02,(byte)0x1a,(byte)0x05,(byte)0x00,(byte)0x04,(byte)0x14 } ;
    private static final byte[] _RSASIGN_DIHEAD_SHA256 = { (byte)0x30,(byte)0x31,(byte)0x30,(byte)0x0d,
        (byte)0x06,(byte)0x09,(byte)0x60,(byte)0x86,(byte)0x48,(byte)0x01,(byte)0x65,(byte)0x03,(byte)0x04,
        (byte)0x02,(byte)0x01,(byte)0x05,(byte)0x00,(byte)0x04,(byte)0x20 } ;
    private static final byte[] _S_HEAD = { (byte)0x00,(byte)0x01 } ;
    
    protected static final byte[] convertHash( String hashAlg,byte[] b )
        throws Exception {
        MessageDigest md ;
        if( "sha1".equals( hashAlg ) ) {
            md = MessageDigest.getInstance( "sha-1" ) ;
        }
        else if( "sha256".equals( hashAlg ) ) {
            md = MessageDigest.getInstance( "sha-256" ) ;
        }
        else {
            throw new CryptoRsaException( "非対応のハッシュアルゴリズムです:" + hashAlg ) ;
        }
        md.reset() ;
        md.update( b ) ;
        return md.digest() ;
    }
    
    protected static final byte[] getDihead( String hashAlg )
        throws Exception {
        if( "sha1".equals( hashAlg ) ) {
            return _RSASIGN_DIHEAD_SHA1 ;
        }
        else if( "sha256".equals( hashAlg ) ) {
            return _RSASIGN_DIHEAD_SHA256 ;
        }
        else {
            throw new CryptoRsaException( "非対応のハッシュアルゴリズムです:" + hashAlg ) ;
        }
    }
    
    protected static final byte[] getHexPaddedDigestInfo( byte[] b,int keySize,String hashAlg )
        throws Exception {
        int pmBinLen = keySize >> 3 ;
        byte[] hash = convertHash( hashAlg,b ) ;
        byte[] dihead = getDihead( hashAlg ) ;
        byte[] sHead = _S_HEAD ;
        int sTailLen = 1 + dihead.length + hash.length ;
        
        int fLen = pmBinLen - sHead.length - sTailLen ;
        byte[] ret = new byte[ sHead.length + fLen + sTailLen ] ;
        int p = 0 ;
        System.arraycopy( sHead,0,ret,p,sHead.length ) ;
        p += sHead.length ;
        for( int i = 0 ; i < fLen ; i ++ ) {
            ret[ p ] = (byte)0x000000ff ;
            p ++ ;
        }
        ret[ p ] = 0 ; p ++ ;
        System.arraycopy( dihead,0,ret,p,dihead.length ) ;
        p += dihead.length ;
        System.arraycopy( hash,0,ret,p,hash.length ) ;
        
        return ret ;
    }
    
    protected BigInteger sign( byte[] b,String hashAlg )
        throws Exception {
        byte[] hPm = getHexPaddedDigestInfo( b,this.n.bitLength(),hashAlg ) ;
        return doPrivate( new BigInteger( 1,hPm ) ) ;
    }
    
    private static final Object[] _ALG_BIN_LIST = new Object[]{
            _RSASIGN_DIHEAD_SHA1,_RSASIGN_DIHEAD_SHA256 } ;
    private static final String[] _ALG_NAME_LIST = new String[]{
            "sha1","sha256" } ;
    
    protected static final Object[] getAlgNameAndHashFromHexDisgestInfo( byte[] info )
        throws Exception {
        int len = _ALG_BIN_LIST.length ;
        byte[] alg ;
        int p ;
        for( int i = 0 ; i < len ; i ++ ) {
            alg = (byte[])_ALG_BIN_LIST[ i ] ;
            
            if( ( p = CryptoUtils.binaryIndexOf( info,alg,0 ) ) != -1 ) {
                return new Object[]{
                    _ALG_NAME_LIST[ i ],p+alg.length } ;
            }
        }
        return null ;
    }
    
    protected boolean verify( byte[] sMsg,byte[] hSig )
        throws Exception {
        BigInteger biSig = new BigInteger( 1,hSig ) ;
        BigInteger biDecryptedSig = this.doPublic( biSig ) ;
        byte[] hDigestInfo = biDecryptedSig.toByteArray() ;
        if( hDigestInfo[ 0 ] == (byte)0x0000001f && hDigestInfo[ 1 ] == (byte)0x00000000 ) {
            byte[] n = new byte[ hDigestInfo.length-2 ] ;
            System.arraycopy( hDigestInfo,2,n,0,n.length ) ;
            hDigestInfo = n ;
            n = null ;
        }
        Object[] digestInfoAry = getAlgNameAndHashFromHexDisgestInfo( hDigestInfo ) ;
        if( digestInfoAry == null ) {
            return false ;
        }
        String algName = ( String )digestInfoAry[ 0 ] ;
        int p = ( Integer )digestInfoAry[ 1 ] ;
        
        byte[] hash = convertHash( algName,sMsg ) ;
        
        if( CryptoUtils.binaryIndexOf( hDigestInfo,hash,p ) != -1 ) {
            return true ;
        }
        return false ;
    }
    
    protected static final String hexByte( byte[] b ) {
        StringBuilder sb= new StringBuilder();
        int cnt= b.length;
        for(int i= 0; i< cnt; i++){
            sb.append(Integer.toHexString( (b[i]>> 4) & 0x0F ) );
            sb.append(Integer.toHexString( b[i] & 0x0F ) );
        }
        return sb.toString();
    }
    
    /**
     * 対象オブジェクト情報をJSON変換.
     * @return String JSON変換された情報が返されます.
     */
    public String encodeJSON() {
        StringBuilder buf = new StringBuilder( 1024 ) ;
        buf.append( "{n:" ) ;
        if( n == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( n.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( " e:" ).append( e ).append( "," ) ;
        
        buf.append( "d:" ) ;
        if( d == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( d.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( "p:" ) ;
        if( p == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( p.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( "q:" ) ;
        if( q == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( q.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( "dmp1:" ) ;
        if( dmp1 == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( dmp1.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( "dmq1:" ) ;
        if( dmq1 == null ) {
            buf.append( "null," ) ;
        }
        else {
            buf.append( "\"" ).append( dmq1.toString( 16 ) ).append( "\"," ) ;
        }
        
        buf.append( "coeff:" ) ;
        if( coeff == null ) {
            buf.append( "null}" ) ;
        }
        else {
            buf.append( "\"" ).append( coeff.toString( 16 ) ).append( "\"}" ) ;
        }
        
        return buf.toString() ;
    }
    
    /**
     * JSON形式の情報をオブジェクトにマージ.
     * @param json 対象のJSONオブジェクトを設定します.
     * @return CryptoRsaKey 変換されたオブジェクトが返されます.
     * @exception Exception 例外.
     */
    public static final CryptoRsaKey decodeJSON( String json )
        throws Exception {
        if( json == null || json.length() <= 0 ) {
            throw new IllegalArgumentException( "jsonパラメータが設定されていません" ) ;
        }
        Map m = (Map)CryptoUtils_DecodeJSON.execution( json ) ;
        
        CryptoRsaKey ret = new CryptoRsaKey() ;
        if( m.get( "n" ) != null ) {
            ret.n = new BigInteger( (String)m.get( "n" ),16 ) ;
        }
        if( m.get( "e" ) instanceof Integer ) {
            ret.e = ( Integer )m.get( "e" ) ;
        }
        else {
            ret.e = Integer.parseInt( (String)m.get( "e" ) ) ;
        }
        if( m.get( "d" ) != null ) {
            ret.d = new BigInteger( (String)m.get( "d" ),16 ) ;
        }
        if( m.get( "p" ) != null ) {
            ret.p = new BigInteger( (String)m.get( "p" ),16 ) ;
        }
        if( m.get( "q" ) != null ) {
            ret.q = new BigInteger( (String)m.get( "q" ),16 ) ;
        }
        if( m.get( "dmp1" ) != null ) {
            ret.dmp1 = new BigInteger( (String)m.get( "dmp1" ),16 ) ;
        }
        if( m.get( "dmq1" ) != null ) {
            ret.dmq1 = new BigInteger( (String)m.get( "dmq1" ),16 ) ;
        }
        if( m.get( "coeff" ) != null ) {
            ret.coeff = new BigInteger( (String)m.get( "coeff" ),16 ) ;
        }
        
        return ret ;
    }
    
}
