import math, sys

from rpython.rlib.jit import JitDriver, set_param, unroll_safe
from rpython.rlib.rfloat import string_to_float
from rpython.rlib.rstring import StringBuilder, split

from cammylib.arrows import F, P, TypeFail
from cammylib.hax import patch_ctypes_for_ffi
from cammylib.parser import parse
from cammylib import stb


# Borrowed from stub.scm and Typhon's colors.mt

def scale(bot, top, x):
    d = top - bot
    return bot + d * x

def linear2sRGB(u):
    try:
        return u * 25 / 323 if u <= 0.04045 else math.pow((u * 200 + 11) / 211, 12 / 5)
    except OverflowError:
        return 1.0

def finishChannel(c):
    return int(255 * max(0.0, min(1.0, linear2sRGB(c))))

# Pixel area: 4 / (w * h)
# Pixel radius: √(area / pi) = 2 / √(w * h * pi) = (2 / √pi) / √(w * h)
# This constant is the first half of that.
TWOSQRTPI = 2.0 / math.sqrt(math.pi)

class Window(object):
    _immutable_ = True

    def __init__(self, corners, width, height):
        self._corners = corners[0], corners[1], corners[2], corners[3]
        self._w = width
        self._h = height
        area = abs((corners[0] - corners[2]) * (corners[1] - corners[3]))
        self.pixelRadius = TWOSQRTPI * area / math.sqrt(width * height)

    def coordsForPixel(self, i):
        w = i % self._w
        h = i // self._w
        iw = 1.0 / self._w
        dw = 0.5 * iw
        ih = 1.0 / self._h
        dh = 0.5 * ih
        c1 = scale(self._corners[0], self._corners[2], dw + iw * w)
        c2 = scale(self._corners[1], self._corners[3], dh + ih * h)
        return c1, c2

sample_driver = JitDriver(name="sample",
        greens=["program"], reds=["x", "y"],
        is_recursive=True)

def sample(program, x, y):
    sample_driver.jit_merge_point(program=program, x=x, y=y)
    rgb = program.run(P(F(x), F(y)))
    r = rgb.first().f()
    g = rgb.second().first().f()
    b = rgb.second().second().f()
    return r, g, b

offsets = [(0.0, 0.0), (1.0, 1.0), (1.0, -1.0), (-1.0, 1.0), (-1.0, -1.0)]

@unroll_safe
def multisample(program, radius, x, y):
    r = g = b = 0.0
    for ox, oy in offsets:
        sr, sg, sb = sample(program, x + radius * ox, y + radius * oy)
        r += sr
        g += sg
        b += sb
    l = len(offsets)
    return finishChannel(r / l), finishChannel(g / l), finishChannel(b / l)

def drawPixels(size, program, window):
    sb = StringBuilder()
    i = 0
    while i < size:
        c1, c2 = window.coordsForPixel(i)
        r, g, b = multisample(program, window.pixelRadius, c1, c2)
        sb.append(chr(r) + chr(g) + chr(b))
        i += 1
    return sb.build()

def drawPNG(program, filename, corners, width, height):
    window = Window(corners, width, height)
    size = width * height
    buf = drawPixels(size, program, window)
    channels = 3
    stb.i_write_png(filename, width, height, channels, buf, width * channels)


def main(argv):
    set_param(None, "trace_limit", 50001)

    prog = argv[1]
    window = [string_to_float(s) for s in split(argv[2])]
    width = int(argv[3])
    height = int(argv[4])
    out = argv[5]
    with open(prog) as handle:
        sexp, trail = parse(handle.read())
    func = sexp.buildArrow()
    try:
        drawPNG(func, out, window, width, height)
    except TypeFail as tf:
        print "Type failure:", tf.reason
        raise
    return 0


def target(driver, *args):
    patch_ctypes_for_ffi()
    driver.exe_name = "cammy-draw"
    return main, None


if __name__ == "__main__":
    sys.exit(main(sys.argv))
