/* Copyright 2013 Akira Ohta (akohta001@gmail.com)
    This file is part of ntch.

    The ntch is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    The ntch is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with ntch.  If not, see <http://www.gnu.org/licenses/>.
    
*/
#include <string.h>
#include <stdio.h>
#include <wchar.h>
#include <assert.h>

#include "utils/nt_std_t.h"
#include "utils/base64.h"

#define B00000011 3
#define B00001111 15
#define B00110000 48
#define B00111100 60
#define B00111111 63


static char base64encode_tbl[] = {
	'A','B','C','D','E','F','G','H','I','J','K','L','M','N',
	'O','P','Q','R','S','T','U','V','W','X','Y','Z',
	'a','b','c','d','e','f','g','h','i','j','k','l','m','n',
	'o','p','q','r','s','t','u','v','w','x','y','z',
	'0','1','2','3','4','5','6','7','8','9','+','/',
};
static char base64decode_tbl[] = {
	-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
	-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
	-1,-1,-1,-1,-1,-2,-1,-1,-1,-1,-1,62,-1,-1,-1,63,
	52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-3,-1,-1,
	-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,
	15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1,
	-1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,
	41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1,
};


static BOOL s2i(char** cpp, char *cp)
{
	char c, c1, c2;
	char *p;
	c1 = (*cpp)[0];
	c2 = (*cpp)[1];
	
	if(c1 >= '0' && c1 <= '9')
		c = c1 - '0';
	else if (c1 >= 'A' && c1 <= 'F')
		c = c1 - 'A' + 10;
	else if (c1 >= 'a' && c1 <= 'f')
		c = c1 - 'a' + 10;
	else
		return FALSE;
	c <<= 4;
	if(c2 >= '0' && c2 <= '9')
		c |= c2 - '0';
	else if (c2 >= 'A' && c2 <= 'F')
		c |= c2 - 'A' + 10;
	else if (c2 >= 'a' && c2 <= 'f')
		c |= c2 - 'a' + 10;
	else
		return FALSE;
	*cp = c;
	p = *cpp;
	p += 2;
	*cpp = p;
	return TRUE;
}


int nt_base64_url_decode(
		const char* in_buf, size_t in_buf_len, 
		unsigned char* out_buf, size_t out_buf_len) 
{
	const unsigned char *cptr, *end_ptr;
	char c;
	unsigned char *outp;
	int state;
	
	outp = out_buf;
	cptr = (const unsigned char*)in_buf;
	end_ptr = (const unsigned char*)in_buf + in_buf_len;
	state = 0;
	while(cptr < end_ptr){	
		assert(*cptr < 128);
		c = base64decode_tbl[*cptr];
		cptr++;
		if(c == -1){
			return -1;
		}else if(c == -2){/* % */
			if(cptr + 2 > end_ptr)
				return -1;
			if(!s2i((char**)&cptr, &c))
				return -1;
			if(c == '='){
				state = 5;
				continue;
			}else if(c == '+'){
				c = base64decode_tbl[(unsigned char)c];
			}else if(c == '/'){
				c = base64decode_tbl[(unsigned char)c];
			}else{
				return -1;
			}
		}else if(c == -3){/* = */
			state = 5;
			continue;
		}
		switch(state){
		case 0:
			if(outp + 1 >= out_buf + out_buf_len)
				return -1;
			*outp = c << 2;
			state = 1;
			break;
		case 1:
			if(outp + 2 >= out_buf + out_buf_len)
				return -1;
			*outp |= (c & B00110000) >> 4;
			outp++;
			*outp = (c & B00001111) << 4;
			state = 2;
			break;
		case 2:
			if(outp + 2 >= out_buf + out_buf_len)
				return -1;
			*outp |= (c & B00111100) >> 2;
			outp++;
			*outp = (c & B00000011) << 6;
			state = 3;
			break;
		case 3:
			*outp |= (c & B00111111);
			outp++;
			state = 0;
			break;
		default:
			assert(0);
			break;
		}
	}
	return outp - out_buf;
}


BOOL nt_base64_url_encode(
		const unsigned char* in_buf, size_t in_buf_len,
		char* out_buf, size_t out_buf_len) 
{
	unsigned char block[4];
	const unsigned char *cptr;
	int blocks, out_idx, i, j;
	int mod;
	
	out_idx = 0;
	cptr = in_buf;
	blocks = in_buf_len / 3;
	for(i = 0; i < blocks; i++){
		block[0] = (cptr[0]>>2)&B00111111;
		block[1] = (cptr[1]>>4)&B00001111;
		block[1] |= (cptr[0]&B00000011)<<4;
		block[2] = (cptr[2]>>6)&B00000011;
		block[2] |= (cptr[1]&B00001111)<<2;
		block[3] = cptr[2]&B00111111;
		for(j = 0; j < 4; j++){
			if(block[j] < 62){
				if(out_buf_len <= out_idx+1)
					return FALSE;
				out_buf[out_idx++] = base64encode_tbl[block[j]];
			}else if(block[j] == 62){
				if(out_buf_len <= out_idx+3)
					return FALSE;
				out_buf[out_idx++] = '%';
				out_buf[out_idx++] = '2';
				out_buf[out_idx++] = 'B';
			}else if(block[j] == 63){
				if(out_buf_len <= out_idx+3)
					return FALSE;
				out_buf[out_idx++] = '%';
				out_buf[out_idx++] = '2';
				out_buf[out_idx++] = 'F';
			}else{
				assert(0);
			}
		}
		cptr += 3;
	}
	mod = in_buf_len % 3;
	if(mod == 1){
		if(out_buf_len <= out_idx+8)
			return FALSE;
		block[0] = (cptr[0]>>2)&B00111111;
		block[1] = (cptr[0]&B00000011)<<4;
		out_buf[out_idx++] = base64encode_tbl[block[0]];
		out_buf[out_idx++] = base64encode_tbl[block[1]];
		out_buf[out_idx++] = '%';
		out_buf[out_idx++] = '3';
		out_buf[out_idx++] = 'D';
		out_buf[out_idx++] = '%';
		out_buf[out_idx++] = '3';
		out_buf[out_idx++] = 'D';
	}else if(mod == 2){
		if(out_buf_len <= out_idx+6)
			return FALSE;
		block[0] = (cptr[0]>>2)&B00111111;
		block[1] = (cptr[1]>>4)&B00001111;
		block[1] |= (cptr[0]&B00000011)<<4;
		block[2] = (cptr[1]&B00001111)<<2;
		out_buf[out_idx++] = base64encode_tbl[block[0]];
		out_buf[out_idx++] = base64encode_tbl[block[1]];
		out_buf[out_idx++] = base64encode_tbl[block[2]];
		out_buf[out_idx++] = '%';
		out_buf[out_idx++] = '3';
		out_buf[out_idx++] = 'D';
	}
	out_buf[out_idx] = '\0';
	return TRUE;
}


