"""
Some simple examples of using FuPy based on the factorial function.

Also see: https://willamette.edu/~fruehr/haskell/evolution.html

Copyright (c) 2024 - Eindhoven University of Technology, The Netherlands

This software is made available under the terms of the MIT License.
"""
from FuPy.core import *
from FuPy.prelude import *
import sys

_ = x_


inc = (_ + 1)  # increment by one
dbl = (2 * _)  # double
sqr = la('x: x * x')  # square


@func
def factorial(n: int) -> int:
    """Iterative factorial function: factorial(n) = n!.
    """
    result = 1

    for i in range(1, n + 1):
        result *= i

    return result


@func
def fac_while(n: int) -> int:
    """Iterative factorial function: fac_while(n) = factorial(n).
    """
    result = 1

    while n != 0:
        result, n = result * n, n - 1

    return result


@func
def fac(n: int) -> int:
    """Recursive factorial function: fac(n) = factorial(n).
    """
    # if n == 0:
    #     return 1
    # else:
    #     return n * fac(n - 1)
    # traceable one-liner
    return 1 if n == 0 else op.mul(n, fac(n - 1))


pre_fac = la('x: (const(1) | op.mul @ (id_ & x @ (_ - 1))) @ guard(_ == 0)').define_as("pre_fac").doc(
    """Pre-factorial function using function combinators.
    Generalizes fac with function argument x for recursive call."""
)

# @func
# def pre_fac(x: Func[int, int]) -> Func[int, int]:
#     """Pre-factorial function.
#     """
#     return (const(1) | op.mul @ (id_ & x @ (_ - 1))) @ guard(_ == 0)


@func
def fac_pf(n: int) -> int:
    """Factorial function through pre_fac.
    Python does not allow: fac_pf = pre_fac(fac_pf)
    fac_pf = (const(1) | op.mul @ (id_ & fac_pf @ (_ - 1))) @ guard(_ == 0)
    """
    return pre_fac(fac_pf)(n)


fac_f = fix(pre_fac).define_as("fac_f").doc(
    """Factorial function as fixpoint of pre_fac.
    """
)


@func
def gfac(a: int, n: int) -> int:
    """Generalized factorial function: gfac(a, n) = a * fac(n).
    fac(n) = gfac(1, n)
    Tail recursive
    """
    if n == 0:
        return a
    else:
        return gfac(a * n, n - 1)


@func
def gfac_lazy(a: int, n: int) -> IterLazy[int]:
    """Lazy generalized tail-recursive factorial: evaluate(gfac_lazy(a, n)) = a * fac(n).
    fac(n) = evaluate(gfac_lazy(1, n))
    """
    if n == 0:
        return a
    else:
        # return Lazy(lambda: gfac_lazy(a * n, n - 1), name=f"gfac_lazy({a} * {n}, {n} - 1)")
        return Lazy(la(': gfac_lazy(a * n, n - 1)'))


@func
def gfac_ctr(a: int) -> Func[int, int]:
    """Curried tail recursive generalized factorial: gfac_ctr(a, n) = a * fac(n).
    fac(n) = gfac_ctr(1)(n)
    """
    return la('n: a if n == 0 else gfac_ctr(a * n)(n - 1)')


@func
def gfac_cps(g: Func[int, int], n: int) -> int:
    """Continuation-Passing Style generalized factorial: gfac_cps(g) = g @ fac.
    fac = gfac_cps(id_)
    Tail recursive
    """
    if n == 0:
        return g(1)
    else:
        return gfac_cps(g @ (n * _), n - 1)


@func
def gfac_fcps(n: int) -> Func[Func[int, int], int]:
    """Flipped Continuation-Passing Style generalized factorial: gfac_fcps(n)(g) = g(fac(n)).
    fac(n) = gfac_fcps(n)(id_)
    """
    if n == 0:
        return la('g: g(1)')
    else:
        return gfac_fcps(n - 1) @ (_ @ (n * _))


@func
def gfac_fc(n: int) -> Func[int, int]:
    """Flipped curried generalized factorial: gfac_fc(n)(a) = a * fac(n).
    fac(n) = gfac_fc(n)(1)
    """
    return id_ if n == 0 else gfac_fc(n - 1) @ (_ * n)


@func
def tfac(n: int) -> int:
    """Tupled factorial: tfac(n) = (factorial(n), n + 1).
    """
    return ((op.mul & ((_ + 1) @ second)) ** n)(1, 1)


@func
def tfac_left(n: int) -> int:
    """Tupled factorial: tfac(n) = (factorial(n), n + 1).
    Using fpower_left.
    """
    return fpower_left(op.mul & ((_ + 1) @ second), n)(1, 1)


fac_zygo = first @ cata_nat(const(1, 1) | op.mul & (_ + 1) @ second)

product_alg = const(1) | op.mul
product = cata_list(product_alg)
down_from_coalg = (const(unit) + (id_ & (_ - 1))) @ guard(_ == 0)
down_from = ana_list(down_from_coalg)
fac_hylo = product @ down_from
fac_hylo_fused = hylo(listf)(product_alg, down_from_coalg)

if __name__ == '__main__':
    print(inc)
    print(dbl)
    print(sqr)

    print(inc @ dbl)
    print("tracing (inc @ dbl)(3):")
    trace(lambda: (inc @ dbl)(3), live=True, skip_steps={DefinitionStep})

    print(inc | dbl)
    print("tracing (inc | dbl)(left(3)):")
    trace(lambda: (inc | dbl)(left(3)), live=True, skip_steps={DefinitionStep})
    print("tracing (inc | dbl)(right(3)):")
    trace(lambda: (inc | dbl)(right(3)), live=True, skip_steps={DefinitionStep})

    print(inc & dbl)
    print("tracing pair = (inc & dbl)(3)")
    pair = trace(lambda: (inc & dbl)(3), live=True, skip_steps={DefinitionStep})[0]
    print("tracing first(pair):")
    trace(lambda: first(pair), live=True, skip_steps={DefinitionStep})
    print("tracing second(pair):")
    trace(lambda: second(pair), live=True, skip_steps={DefinitionStep})

    print(inc + dbl)
    print("tracing (inc + dbl)(left(3)):")
    trace(lambda: (inc + dbl)(left(3)), live=True, skip_steps={DefinitionStep})
    print("tracing (inc + dbl)(right(3)):")
    trace(lambda: (inc + dbl)(right(3)), live=True, skip_steps={DefinitionStep})

    print(inc * dbl)
    print("tracing force((inc * dbl)(1, 3)):")
    print(trace(lambda: force((inc * dbl)(1, 3)), live=True, skip_steps={DefinitionStep})[0])

    print(sqr ** 3)
    print("tracing (sqr ** 3)(2):")
    trace(lambda: (sqr ** 3)(2), live=True, skip_steps={DefinitionStep})

    print()
    print(f"Limit on recursion depth: {sys.getrecursionlimit()}")
    print(f"factorial(5) = {factorial(5)}")
    print(f"fac_while(5) = {fac_while(5)}")
    print("tracing fac(3):")
    trace(lambda: fac(3), live=True)
    print(f"fac(5) = {fac(5)}")
    print("tracing pre_fac(const(2)):")
    trace(lambda: pre_fac(const(2)), live=True, skip_steps={DefinitionStep})
    print("tracing force(pre_fac(const(2))(3)):")
    print(trace(lambda: force(pre_fac(const(2))(3)), live=True, skip_steps={DefinitionStep})[0])
    print(force(fac_pf(5)))
    print(fac_f)
    print("tracing force(fac_f(3)):")
    print(trace(lambda: force(fac_f(3)), live=True)[0])
    print(f"gfac(1, 5) = {gfac(1, 5)}")
    print(f"evaluate(gfac_lazy(1, 5)) = {evaluate(gfac_lazy(1, 5))}")
    print(f"gfac_ctr(1) = {gfac_ctr(1)}")
    print(f"gfac_ctr(1)(5) = {gfac_ctr(1)(5)}")
    print(f"gfac_fc(0) = {gfac_fc(0)}")
    print(f"gfac_fc(5) = {gfac_fc(5)}")
    print(f"gfac_fc(5)(1) = {gfac_fc(5)(1)}")
    print(f"gfac_cps(id_, 5) = {gfac_cps(id_, 5)}")
    print("tracing gfac_cps(id, 3):")
    trace(lambda: gfac_cps(id_, 3), live=True)
    print(f"gfac_fcps(0) = {gfac_fcps(0)}")
    print(f"gfac_fcps(5) = {gfac_fcps(5)}")
    print(f"gfac_fcps(5)(id_) = {gfac_fcps(5)(id_)}")
    print(f"tfac(5) = {force(tfac(5))}")
    print("tracing force(tfac(3)):")
    print(trace(lambda: force(tfac(3)), live=True, skip_steps={DefinitionStep})[0])
    print(f"tfac_left(5) = {force(tfac_left(5))}")
    print("tracing force(tfac_left(3)):")
    print(trace(lambda: force(tfac_left(3)), live=True, skip_steps={DefinitionStep})[0])
    print(f"fac_zygo = {fac_zygo}")
    print(f"fac_zygo(5) = {fac_zygo(5)}")
    print(f"fac_hylo = {fac_hylo}")
    print("tracing product([2, 3]):")
    trace(lambda: product([2, 3]), live=True)
    print(f"product([1, 2, 3, 4, 5]) = {force(product([1, 2, 3, 4, 5]))}")
    print("tracing down_from(2):")
    trace(lambda: down_from(2), live=True)
    print(f"down_from(5) = {down_from(5)}")
    print(f"fac_hylo(5) = {force(fac_hylo(5))}")
    print(f"fac_hylo_fused = {fac_hylo_fused}")
    print(f"fac_hylo_fused(5) = {force(fac_hylo_fused(5))}")

    print()
    print(f"factorial(1000) = {factorial(1000)}")
    try:
        print(f"fac(1000) = {fac(1000)}")
    except RecursionError as e:
        print(e)
    print("tracing evaluate(gfac_lazy(1, 3)):")
    trace(lambda: evaluate(gfac_lazy(1, 3)), live=True)
    print(f"evaluate(gfac_lazy(1000)) = {evaluate(gfac_lazy(1, 1000))}")
