package net.y3n20u.rfc2898;

import java.util.Arrays;

import net.y3n20u.util.ByteHelper;

/**
 * defined in <a href="http://www.ietf.org/rfc/rfc2898.txt">RFC2898</a>.
 * 
 * @author y3n20u@gmail.com
 * 
 */
public class Pbkdf2 {

	// messages
	private static final String MESSAGE_DKLEN_TOO_LONG = "derived key too long";

	/**
	 * underlying pseudorandom function. hLen denotes the length in octets of
	 * the pseudorandom function output.
	 */
	private final PseudorandomFunction _prf;

	public Pbkdf2() {
		this(new PrfHmacSha1());
	}
	
	public Pbkdf2(PseudorandomFunction prf) {
		_prf = prf;
	}

	/**
	 * 
	 * @param p
	 *            password, an octet string.
	 * @param s
	 *            salt, an octet string.
	 * @param c
	 *            iteration count, a positive integer.
	 * @param dkLen
	 *            intended length in octets of the derived key, a positive
	 *            integer, at most (2^32-1) * hLen.
	 * @return
	 */
	public byte[] deriveKey(byte[] p, byte[] s, int c, int dkLen) {
		// === step1: check the length of the derived key. ===
		int hLen = _prf.getLengthOfOutput();
		if (dkLen > (2 ^ 32 - 1) * hLen) {
			String errorMessage = String.format(MESSAGE_DKLEN_TOO_LONG);
			throw new IllegalArgumentException(errorMessage);
		}
		
		// === step2: calculate the number of blocks. ===
		/* the number of hLen-octet blocks in the derived key. */
		int l = (int) (Math.ceil(((double)dkLen)/((double)hLen)));
		// /* the number of octets in the last block. */
		//int r = dkLen - (l - 1) * hLen;
		
		// === step3: calculate each blocks. ===
		byte[][] t = new byte[l][]; 
		for (int i = 1; i <= l; i++) {
			t[i - 1] = _calculateF(p, s, c, i);
		}
		
		// === step4, 5: concatenate the blocks and output the derived key. ===
		return ByteHelper.concatByteArrays(dkLen, t);
	}

	private byte[] _calculateF(byte[] p, byte[] s, int c, int i) {
		byte[] u = new byte[_prf.getLengthOfOutput()];
		Arrays.fill(u, (byte)0);
		byte[] ui = ByteHelper.concatByteArrays(s, _getFourOctetEncoding(i));
		for (int counter = 0; counter < c; counter++) {
			ui = _prf.getPseudorandomBytes(p, ui);
			u = ByteHelper.xorTwoByteArrays(u, ui);
		}
		return u;
	}

	/**
	 * 
	 * @param intValue
	 * @return four octet encoding (most significant octet first)
	 */
	private byte[] _getFourOctetEncoding(int intValue) {
		byte[] r = new byte[4];
		r[3] = (byte)intValue;
		r[2] = (byte)(intValue >> 8);
		r[1] = (byte)(intValue >> 16);
		r[0] = (byte)(intValue >> 24);
		return r;
	}
}
