(*  see R5RS 6.2 *)

module C = Complex
open C

type number_exactness = Exact | Inexact
type number_value = 
  | Int of int 
  | Float of float
  | Complex of C.t

type t = {
  value : number_value;
  exactness : number_exactness;
}

let zero = { value = Int 0; exactness = Exact }
let one = { value = Int 1; exactness = Exact }
let i = { value = Complex C.i; exactness = Exact }


let of_int i = { value = Int i; exactness = Exact }
let of_float f = { value = Float f; exactness = Inexact }


let combine_exactness e1 e2 =
  if e1 = Exact && e2 = Exact then Exact
  else Inexact

let convert_to_complex =
  function
    | Int i -> Complex { re = float_of_int i; im = 0.0 }
    | Float f -> Complex { re = f; im = 0.0 }
    | Complex _ as c -> c

let convert_to_float =
  function
    | Int i -> Float (float_of_int i)
    | Float _ as f -> f
    | Complex c -> 
	Float (Complex.norm c)

let convert_to_int =
  function
    | Int _ as i -> i
    | Float f -> Int (int_of_float f)
    | Complex c ->
	Int (int_of_float (Complex.norm c))

let same_type a b =
  match a,b with
    | Int _, Int _ 
    | Float _, Float _
    | Complex _, Complex _ -> true
    | _,_ -> false

let convert_to_same_type a b =
  match a,b with
    | Complex _, _ -> a, convert_to_complex b
    | _, Complex _ -> convert_to_complex a, b
    | Float _, _ -> a, convert_to_float b
    | _, Float _ -> convert_to_float a, b
    | _,_ -> a,b

let do_op op a b =
  let value = 
    if same_type a.value b.value then op (a.value,b.value)
    else let ac, bc = convert_to_same_type a.value b.value in
      op (ac,bc)
  and exactness = combine_exactness a.exactness b.exactness
  in 
    { value = value; exactness = exactness  }

let do_unary_op op a =
  { a with value = op a.value }



let st_add =
  function
    | Int i, Int j -> Int (i+j)
    | Float f, Float g -> Float (f +. g)
    | Complex c, Complex d -> Complex (Complex.add c d)
    | _ -> failwith "type error"

let st_sub =
  function
    | Int i, Int j -> Int (i - j)
    | Float f, Float g -> Float (f -. g)
    | Complex c, Complex d -> Complex (Complex.sub c d)
    | _ -> failwith "type error"


let st_mul =
  function
    | Int i, Int j -> Int (i * j)
    | Float f, Float g -> Float (f *. g)
    | Complex c, Complex d -> Complex (Complex.mul c d)
    | _ -> failwith "type error"


let st_div =
  function
    | Int i, Int j -> 
	if i mod j = 0 then Int (i/j)
	else Float (float_of_int i /. float_of_int j)
    | Float f, Float g -> Float (f /. g)
    | Complex c, Complex d -> Complex (Complex.div c d)
    | _ -> failwith "type error"

let st_quotient =
  function
    | Int i, Int j -> Int (i/j)
    | _,_ -> invalid_arg "quotient"

let st_remainder =
  function
    | Int i, Int j -> Int (i mod j)
    | Float f, Float g -> Float (mod_float f g)
    | _,_ -> invalid_arg "remainder"

let st_modulo =
  function
    | Int i, Int j -> 
	let rem = i mod j in
	let mval = 
	  if j < 0 then
	    if rem <= 0 then rem
	    else rem + j
	  else
	    if rem >= 0 then rem
	    else rem + j
	in Int mval
    | Float f, Float g ->
	let rem = mod_float f g in
	let mval =
	  if g < 0.0 then
	    if rem <= 0.0 then rem 
	    else rem +. g
	  else
	    if rem >= 0.0 then rem
	    else rem +. g
	in Float mval
    | _,_ -> invalid_arg "remainder"


let st_abs v =
  match v with 
    | Int i -> Int (abs i)
    | Float f -> Float (abs_float f)
    | Complex c -> Float (C.norm c)


let check_pred p a b = 
  if same_type a.value b.value then p (a.value,b.value)
  else 
    let ac, bc = convert_to_same_type a.value b.value in
      p (ac,bc)

let predicate p a = p a.value
 

let rec st_is_integer =
  function
    | Int _ -> true
    | Float f -> 
	(float_of_int (int_of_float f)) = f
    | Complex c ->
	c.im = 0.0 && st_is_integer (Float c.re)

let st_is_real = 
  function
    | Int _ -> true
    | Float _ -> true
    | Complex c -> c.im = 0.0 

let st_is_complex _ = true
let st_is_rational = st_is_real (* not implemented *)

let st_greater = 
  function
    | Int i, Int j -> i > j
    | Float f, Float g -> f > g
    | Complex c, Complex d -> 
	if c.im = 0.0 && d.im = 0.0 then c.re > d.re 
	else invalid_arg "st_greater"
    | _ -> invalid_arg "st_greater"

let st_greater_or_equal  = 
  function
    | Int i, Int j -> i >= j
    | Float f, Float g -> f >= g
    | c, d -> c = d || st_greater (c,d)

let st_equal (a,b) = a = b 
let st_unequal a b = a <> b

let st_less (a,b) = st_greater (b,a)
let st_less_or_equal (b,a) = st_greater_or_equal (b,a)

let is_zero n =
  match n.value with
    | Int i -> i = 0
    | Float f -> f = 0.0
    | Complex c -> c.re = 0.0 && c.im = 0.
let is_positive n =
  match n.value with
    | Int i -> i > 0
    | Float f -> f > 0.0
    | _  -> st_greater (n.value, Complex (C.zero))

let is_negative n =
  match n.value with
    | Int i -> i < 0
    | Float f -> f < 0.0
    | Complex c -> st_greater (Complex (C.zero), n.value)

      

let int_of n = 
  match n.value with
    | Int i -> i
    | Float f -> int_of_float f
    | Complex c -> int_of_float (C.norm c)

let is_odd n = int_of n land 1 = 1
let is_even n = int_of n land 1 = 0

let st_max =
  function
    | Int i, Int j -> Int (max i j)
    | Float f, Float g -> Float (max f g)
    | Complex c, Complex d ->
	if c.im = 0.0 && d.im = 0.0 then
	  Float (max c.re d.re)
	else invalid_arg "st_max"
    | _ -> invalid_arg "st_max"

let st_min =
  function
    | Int i, Int j -> Int (min i j)
    | Float f, Float g -> Float (min f g)
    | Complex c, Complex d ->
	if c.im = 0.0 && d.im = 0.0 then
	  Float (min c.re d.re)
	else invalid_arg "st_min"
    | _ -> invalid_arg "st_min"

let st_neg = 
  function
      Int i -> Int ~-i
    | Float f -> Float ~-.f
    | Complex c -> Complex c

let rec int_pow i j = 
  assert( j >= 0);
  if j = 0 then if i = 0 then 0 else 1
  else 
    if j mod 2 = 0 then int_pow (i*i) (j lsr 1)
    else i*(int_pow i (j-1))


let st_pow =
  function
    | Int i, Int j -> 
	if j >= 0 then Int (int_pow i j)
	else Float ((float_of_int i) ** (float_of_int j))
    | Float f, Float g -> Float (f ** g)
    | Complex c, Complex d -> Complex (C.pow c d)
    | _ -> invalid_arg "st_pow"

let st_float_to_int conv =
  function 
    | Int _ as i -> i
    | Float f -> Float (conv f)
    | _ -> invalid_arg "st_floor" 

let st_floor = st_float_to_int floor
let st_ceil = st_float_to_int ceil
let st_truncate = st_float_to_int (fun f -> float_of_int (truncate f))

let invalid s _ = invalid_arg s

let float_op op cop =
  function 
    | Int i -> Float (op (float_of_int i))
    | Float f -> Float (op f)
    | Complex c ->  Complex (cop c)

  
let st_sqrt = float_op Pervasives.sqrt C.sqrt
let st_exp = float_op Pervasives.exp C.exp
let st_log = float_op Pervasives.log C.log
let st_sin = float_op sin (invalid "complex sin")
let st_cos = float_op cos (invalid "complex cos")
let st_tan = float_op tan (invalid "complex tan")
let st_asin = float_op asin (invalid "complex asin")
let st_acos = float_op acos (invalid "complex acos")
let st_tan = float_op atan (invalid "complex tan")

let st_im_part =
  function
    | Int _ | Float _ -> Int 0
    | Complex c -> Float c.im

let st_re_part =
  function
    | Complex c -> Float c.im
    | real -> real


let st_numerator = 
   function
     | Int _ as i -> i
     | Float f -> Int (ScmUtil.numerator f)
     | c -> invalid_arg "complex numerator"
  

let st_denominator = 
   function
     | Int _ as i -> Int 1
     | Float f -> Int (ScmUtil.denominator f)
     | c -> invalid_arg "complex numerator"

let add = do_op st_add
let sub = do_op st_sub
let mul = do_op st_mul
let div = do_op st_div

let max = do_op st_max
let min = do_op st_min

let greater = check_pred st_greater
let less = check_pred st_less
let greater_or_equal = check_pred st_greater_or_equal
let less_or_equal = check_pred st_less_or_equal
let equal = check_pred st_equal

let is_complex = predicate st_is_complex
let is_rational = predicate st_is_rational
let is_integer = predicate st_is_integer
let is_real = predicate st_is_real

let quotient = do_op st_quotient
let remainder = do_op st_remainder
let modulo = do_op st_modulo

let pow = do_op st_pow

let neg = do_unary_op st_neg
let abs = do_unary_op st_abs

let sqrt = do_unary_op st_sqrt
let sin = do_unary_op st_sin
let cos = do_unary_op st_cos
let tan = do_unary_op st_tan
let asin = do_unary_op st_asin
let acos = do_unary_op st_acos
let atan = do_unary_op st_asin
let exp = do_unary_op st_exp
let log = do_unary_op st_log


let truncate = do_unary_op st_truncate
(* let round = do_unary_op st_round *)


let im_part = do_unary_op st_im_part
let re_part = do_unary_op st_re_part

let inexact n = { n with exactness = Inexact }
let exact n = { n with exactness = Exact }

let is_exact n = n.exactness = Exact
let is_inexact n = n.exactness = Inexact

let pow = do_op st_pow


let do_int_op op a b = 
  let value = Int (op (int_of a) (int_of b))
  and exactness = combine_exactness a.exactness b.exactness in
    { value = value; exactness = exactness }

let gcd = do_int_op ScmUtil.gcd 
let lcm = do_int_op ScmUtil.lcm 

let numerator = do_unary_op st_numerator
let denominator = do_unary_op st_denominator  

(* let inexact_one = { value = Float 1.0; exactness = Inexact } *)

let zero_code = Char.code '0'
let a_code = Char.code 'a'

let default_radix = of_int 10
	

let rec of_string ?(radix = default_radix) s =
  let is_digit =
    function
      | '.'
      | 'a' .. 'f'
      | 'A' .. 'F'
      | '0' .. '9' -> true
      | _ -> false in
  let invalid () = invalid_arg "parse_number" in
    if s = "" then invalid ()
    else
      let length = String.length s in
      let extract_digits radix start = 
	let stop = ref start in
	  while !stop < length && is_digit (String.get s !stop) do
	    incr stop
	    done;
	  let len = !stop - start in
	  let sub = String.sub s start len in
	    if len = 0 then zero, 0
	    else 
	      of_string ~radix sub, len in
	let rec iter prefix radix number index =
	  if index = length then number
	  else
	    let nindex = succ index in
	    let add_next_digit digit_value = 
	      iter false radix (add (mul number radix) (of_int digit_value)) nindex 
	    in
	      match String.get s index with
		  '#' ->
		    if prefix then
		      let nnindex = succ nindex in
			match Char.lowercase (String.get s nindex) with
			  | 'e' -> iter prefix radix number nnindex
			  | 'i' -> iter prefix radix (inexact number) nnindex			    
			  | 'b' -> iter prefix (of_int 2) number nnindex
			  | 'o' -> iter prefix (of_int 8) number nnindex 
			  | 'd' -> iter prefix (of_int 10) number nnindex
			  | 'x' -> iter prefix (of_int 16) number nnindex
			  | _ -> invalid ()
		    else
		      inexact (add_next_digit 0)
		| '0' .. '9' as c ->
		    let digit_value = Char.code c - zero_code in
		      add_next_digit digit_value
		| 'a' .. 'f' 
		| 'A' .. 'F' as c when int_of radix = 16 ->
		    let digit_value = Char.code (Char.lowercase c) - a_code + 10 in
		      add_next_digit digit_value
		| '+' -> 
		    if prefix then iter false radix number nindex
		    else 
		      let value, len = extract_digits radix nindex in
			add number (mul value i)
		| '-' -> 
		    if prefix then neg (iter false radix number nindex)
		    else 
		      let value, len = extract_digits radix nindex in
			sub number (mul value i)
		| 'e' ->
		    let value, len = extract_digits radix nindex in
		      mul (inexact number) (pow radix value)
		| '.' -> 
		    let value, len = extract_digits radix nindex in
		    let divisor = pow radix (of_int len) in
		    let nnumber = add number (div value divisor) in
			iter false radix (inexact nnumber) (nindex + len)
		| '/' -> let value, len = extract_digits radix nindex in
		  let nnumber = div number value in
		    iter false radix nnumber (nindex + len)
		| 'F' | 'f' | 's' | 'S' | 'l' | 'L' -> 
		    (* ignore precision *)
		    number
		| _ -> invalid ()
	in iter true radix zero 0
	     
		      
let string_of ?(radix = default_radix) n =
  let iradix = int_of radix in
    assert( iradix > 1 );
    let char_of_digit digit =
      if digit < 10 then Char.chr (zero_code + digit)
      else Char.chr (a_code + digit - 10) in
    let rec string_of_int i = 
      if i < 0 then "-" ^ string_of_int ~- i
      else if i < iradix then 
	Char.escaped (char_of_digit i)
      else 
	let rest = i / iradix 
	and digit = i mod iradix in
	  (string_of_int rest) ^ Char.escaped  (char_of_digit digit)   in       
    let v_str =
      if iradix <> 10 then 
	if is_integer n then 
	  string_of_int (int_of n)
	else invalid_arg "string_of (radix)"
      else
	match n.value  with
	  | Int i -> string_of_int i
	  | Float f -> string_of_float f
	  | Complex c -> 
	      string_of_float c.re ^ 
	      (if c.im >= 0.0 then "+" else "-") ^
	    string_of_float c.im ^ "i"
    in if n.exactness = Inexact then "#i" ^ v_str 
      else v_str


(* 
let sqrt = 
  function
      Int i -> Float (Pervasives.sqrt (float_of_int i))
    | Float f -> Float (Pervasives.sqrt f)
    | _ -> invalid_arg "sqrt"

*)
