from cammylib.arrows import Given, buildUnary, buildCompound

BASIS_ATOMS = (
    "id", "ignore", "fst", "snd", "left", "right",
    "zero", "succ", "nil", "cons", "t", "f", "not", "conj", "disj", "either",
    "f-zero", "f-one", "f-pi", "f-sign", "f-floor", "f-negate", "f-recip",
    "f-lt", "f-add", "f-mul", "f-sqrt", "f-sin", "f-cos", "f-atan2"
)

BASIS_FUNCTORS = (
    "comp", "pair", "case", "curry", "uncurry", "pr", "fold",
)


class SExp(object):
    "An S-expression."


class Atom(SExp):
    "An S-expression atom."

    _immutable_fields_ = "symbol",

    def __init__(self, symbol):
        self.symbol = symbol

    def asStr(self):
        return self.symbol

    def substitute(self, args):
        return self

    def canonicalize(self, hive):
        if self.symbol in BASIS_ATOMS:
            return self
        else:
            return hive.load(self.symbol)

    def occurs(self, index):
        return False

    def extractType(self, extractor):
        return self.symbol

    def buildArrow(self):
        return buildUnary(self.symbol)


class Functor(SExp):
    "A list of S-expressions with a distinguished head."

    _immutable_fields_ = "constructor", "arguments[:]"

    def __init__(self, constructor, arguments):
        self.constructor = constructor
        self.arguments = arguments

    def asStr(self):
        args = " ".join([arg.asStr() for arg in self.arguments])
        return "(%s %s)" % (self.constructor, args)

    def substitute(self, args):
        return Functor(self.constructor,
                [arg.substitute(args) for arg in self.arguments])

    def canonicalize(self, hive):
        args = [arg.canonicalize(hive) for arg in self.arguments]
        if self.constructor in BASIS_FUNCTORS:
            return Functor(self.constructor, args)
        else:
            functor = hive.load(self.constructor)
            return functor.substitute(args)

    def occurs(self, index):
        for arg in self.arguments:
            if arg.occurs(index):
                return True
        return False

    def extractType(self, extractor):
        args = [extractor.extract(unhole(arg)) for arg in self.arguments]
        if self.constructor == "hom":
            return "[%s, %s]" % (args[0], args[1])
        elif self.constructor == "pair":
            return "(%s x %s)" % (args[0], args[1])
        elif self.constructor == "sum":
            return "(%s + %s)" % (args[0], args[1])
        elif self.constructor == "list":
            return "[%s]" % args[0]
        else:
            assert False, "whoops"

    def buildArrow(self):
        args = [arg.buildArrow() for arg in self.arguments]
        return buildCompound(self.constructor, args)


class Hole(SExp):
    "A hole where an S-expression could be."

    _immutable_fields_ = "index",

    def __init__(self, index):
        self.index = index

    def asStr(self):
        return "@%d" % self.index

    def substitute(self, args):
        return args[self.index]

    def canonicalize(self, hive):
        return self

    def occurs(self, index):
        return self.index == index

    def extractType(self, extractor):
        return extractor.findTypeAlias(self.index)

    def buildArrow(self):
        return Given(self.index)

def unhole(sexp):
    assert isinstance(sexp, Hole), "implementation error"
    return sexp.index
