/// SPDX-License-Identifier: GPL-3.0-or-later
/// SPDX-FileCopyrightText: Copyright © 2024 Tony Garnock-Jones <tonyg@leastfixedpoint.com>

import { Span } from './span';
import * as Ast from './ast';
import { Environment as GenericEnvironment, Scope, Closure, Value as GenericValue, lookup, makeDeepClosure, Primitive, definedNames, RecType, PrimitiveFunction, isRecord } from './values';
import { defPrim, makeTop, mkPrim, N } from './topenv';
import { exhaustive } from 'never';

export type Value = GenericValue<null, {closure: Closure<Environment>}, Value[]>;
export type EnvLink = { env: Environment };
export type Environment = GenericEnvironment<Value | null, EnvLink> | null;

export class InterpreterError extends Error {
    constructor(public span: Span, message: string) {
        super(message);
    }
}

export function makeInterpreterTop(): Environment {
    function V(span: Span, v: Value): Value[] {
        if (Array.isArray(v)) return v;
        throw new InterpreterError(span, 'expected a vector');
    }
    const inj: (f: Primitive<null, {closure: Closure<Environment>}, Value[]>) => Value = f => f;
    const top = makeTop<Value, EnvLink, null, {closure: Closure<Environment>}, Value[]>(inj, { env: null });
    defPrim(top, inj, -1, { 'vec': (_ctx, _s, ...args) => args });
    defPrim(top, inj, 3, { 'vec-set!': (_ctx, s, v, i, w) => V(s, v)[N(s, i)] = w });
    defPrim(top, inj, 2, { 'vec-ref': (_ctx, s, v, i) => V(s, v)[N(s, i)] });
    return top;
}

export function interpretProgram(top: Environment, p: Ast.Program): Value | null {
    let result: Value | null = null;
    let env: Environment = {scope: {}, next: { env: top }};
    definedNames(p.terms).forEach(v => env.scope[v.name] = null);
    for (const t of p.terms) {
        result = interpretTerm(env, t);
        if (result !== null) console.log(result);
    }
    return result;
}

export function interpretTerm(env: Environment, t: Ast.Term): Value | null {
    switch (t.type) {
        case 'defvar':
            env!.scope[t.name.name] = interpretExpr(env, t.expr);
            return null;
        case 'deffun':
            env!.scope[t.name.name] = {closure: makeDeepClosure(env, t.formals, t.body)};
            return null;
        case 'defrec': {
            const ty: RecType = { name: Symbol.for(t.name.name), fields: t.fields.map(f => Symbol.for(f.name)) };
            const n = `〈${t.name.name}〉`;
            const ctor = mkPrim(t.fields.length, { [n]: (_ctx: null, _s: Span, ... args: Value[]) => ({ record: ty, fields: args }) })[n];
            env!.scope[t.name.name] = ctor;
            return null;
        }
        case 'check':
            try {
                const outcome = interpretBody(env, t.body);
                if (outcome === false) {
                    console.error('Failed check: outcome false');
                    return Symbol.for('fail');
                }
            } catch (e) {
                if (e instanceof InterpreterError) {
                    console.error('Failed check: exception: ' + e.message);
                    return Symbol.for('error');
                }
                throw e;
            }
            return Symbol.for('pass');
        default:
            return interpretExpr(env, t);
    }
}

export function checkArgCount(span: Span, actual: number, expected: number): void {
    if (expected >= 0) {
        if (actual !== expected) {
            throw new InterpreterError(span, `Expected ${expected} arguments, got ${actual}`);
        }
    } else {
        const minimum = -1 - expected;
        if (actual < minimum) {
            throw new InterpreterError(span, `Expected at least ${minimum} arguments, got ${actual}`);
        }
    }
}

export function interpretExpr(env: Environment, e: Ast.Expr): Value {
    switch (e.type) {
        case 'const': return e.literal;
        case 'var': {
            const v = lookup(l => l.env, env, e.name);
            if (v === 'missing') throw new InterpreterError(e.span, `Variable ${e.name} missing`);
            if (v.value === null) throw new InterpreterError(e.span, `Variable ${e.name} uninitialized`);
            return v.value;
        }
        case 'begin':
            e.prelude.forEach(f => interpretExpr(env, f));
            return interpretExpr(env, e.expr);
        case 'matchrec': {
            const v = interpretExpr(env, e.expr);
            if (isRecord(v)) {
                for (const c of e.clauses) {
                    if (v.record.name.description! === c.typeName.name
                        && v.fields.length === c.fieldNames.length)
                    {
                        const inner: Environment = {scope: {}, next: { env }};
                        c.fieldNames.forEach((n, i) => inner.scope[n.name] = v.fields[i]);
                        return interpretBody(inner, c.body);
                    }
                }
            }
            if (e.elseClause) {
                return interpretBody(env, e.elseClause);
            }
            throw new InterpreterError(e.span, `matchrec had no clause that matched`);
        }
        case 'set': {
            const v = lookup(l => l.env, env, e.varName.name);
            if (v === 'missing') throw new InterpreterError(e.span, `Variable ${e.varName.name} missing`);
            if (v.value === null) throw new InterpreterError(e.span, `Variable ${e.varName.name} uninitialized`);
            const w = interpretExpr(env, e.expr);
            v.value = w;
            return w;
        }
        case 'lambda': return {closure: makeDeepClosure(env, e.formals, e.body)};
        case 'let': {
            const inner: Environment = {scope: {}, next: { env }};
            e.bindings.forEach(b => inner.scope[b.varName.name] = interpretExpr(env, b.expr));
            return interpretBody(inner, e.body);
        }
        case 'cond':
            for (const c of e.clauses) {
                if (interpretExpr(env, c.test) !== false) {
                    return interpretBody(env, c.body);
                }
            }
            if (e.elseClause) {
                return interpretBody(env, e.elseClause);
            }
            throw new InterpreterError(e.span, `cond had no clause that matched`);
        case 'call': {
            const op = interpretExpr(env, e.operator);
            switch (typeof op) {
                case 'function': {
                    const args = e.actuals.map(f => interpretExpr(env, f));
                    checkArgCount(e.span, args.length, op.argv);
                    return op(null, e.span, ... args);
                }
                case 'object':
                    if ('record' in op) {
                        // fall through
                    } else if (Array.isArray(op)) {
                        // fall through
                    } else {
                        const args = e.actuals.map(f => interpretExpr(env, f));
                        checkArgCount(e.span, args.length, op.closure.formals.length);
                        const inner: Environment = {scope: {}, next: { env: op.closure.env }};
                        for (let i = 0; i < args.length; i++) inner!.scope[op.closure.formals[i].name] = args[i];
                        return interpretBody(inner, op.closure.body);
                    }
                default:
                    throw new InterpreterError(e.operator.span, 'attempted to call something not callable');
            }
        }
        default: exhaustive(e);
    }
}

export function interpretBody(env: Environment, body: Ast.Body): Value {
    const inner: Environment = {scope: {}, next: { env }};
    definedNames(body.terms).forEach(v => inner.scope[v.name] = null);
    body.terms.forEach(t => interpretTerm(inner, t));
    return interpretExpr(inner, body.expr);
}
