/// SPDX-License-Identifier: GPL-3.0-or-later
/// SPDX-FileCopyrightText: Copyright © 2024 Tony Garnock-Jones <tonyg@leastfixedpoint.com>

import * as Ast from './ast';
import { definedNames } from './values';
import './sets';
import { exhaustive } from './never';

export type Result = {
    freeNames: Set<string>;
    freeReferences: Set<Ast.Var>;
    mutatedNames: Set<string>;
    mutatedReferences: Set<Ast.Var>;
};

export namespace Result {
    export const EMPTY: Result = {
        freeNames: new Set(),
        freeReferences: new Set(),
        mutatedNames: new Set(),
        mutatedReferences: new Set(),
    };

    export function free(r: Ast.Var): Result {
        return {
            freeNames: new Set([r.name]),
            freeReferences: new Set([r]),
            mutatedNames: new Set(),
            mutatedReferences: new Set(),
        };
    }

    export function mutated(r: Ast.Var): Result {
        return {
            freeNames: new Set([r.name]),
            freeReferences: new Set([r]),
            mutatedNames: new Set([r.name]),
            mutatedReferences: new Set([r]),
        };
    }
}

export function combine(... rs: Result[]): Result {
    return {
        freeNames: Set.union(... rs.map(r => r.freeNames)),
        freeReferences: Set.union(... rs.map(r => r.freeReferences)),
        mutatedNames: Set.union(... rs.map(r => r.mutatedNames)),
        mutatedReferences: Set.union(... rs.map(r => r.mutatedReferences)),
    };
}

export type BoundNames = Set<string>;

export function body(bn: BoundNames, b: Ast.Body): Result {
    return scope(bn, [], b);
}

export function terms(bn: BoundNames, ts: Ast.Term[]): Result {
    return combine(... ts.map(t => term(bn, t)));
}

export function term(bn: BoundNames, t: Ast.Term): Result {
    switch (t.type) {
        case 'begin': return combine(exprs(bn, t.prelude), expr(bn, t.expr));
        case 'call': return combine(expr(bn, t.operator), exprs(bn, t.actuals));
        case 'check': return body(bn, t.body);
        case 'cond':
            return combine(
                ... t.clauses.map(c => combine(expr(bn, c.test), body(bn, c.body))),
                ... t.elseClause === void 0 ? [] : [body(bn, t.elseClause)]);
        case 'const': return Result.EMPTY;
        case 'deffun': return scope(bn, t.formals.map(v => v.name), t.body);
        case 'defvar': return expr(bn, t.expr);
        case 'defrec': return Result.EMPTY;
        case 'lambda': return scope(bn, t.formals.map(v => v.name), t.body);
        case 'let':
            return combine(
                ... t.bindings.map(b => expr(bn, b.expr)),
                scope(bn, t.bindings.map(b => b.varName.name), t.body));
        case 'matchrec':
            return combine(
                expr(bn, t.expr),
                ... t.clauses.map(c => scope(bn, c.fieldNames.map(n => n.name), c.body)),
                ... t.elseClause === void 0 ? [] : [body(bn, t.elseClause)]);
        case 'set': return combine(Result.mutated(t.varName), expr(bn, t.expr));
        case 'var': return Result.free(t);
        default: exhaustive(t);
    }
}

export function exprs(bn: BoundNames, es: Ast.Expr[]): Result {
    return combine(... es.map(e => expr(bn, e)));
}

export function expr(bn: BoundNames, e: Ast.Expr): Result {
    return term(bn, e);
}

export function program(bn: BoundNames, p: Ast.Program): Result {
    return _scope(bn, [], p.terms, null);
}

export function scope(bn: BoundNames, newNames: string[], body: Ast.Body): Result {
    return _scope(bn, newNames, body.terms, body.expr);
}

export function _scope(bn: BoundNames, newNames: string[], ts: Ast.Term[], maybeExpr: Ast.Expr | null): Result {
    const allNewNames = Set.union(new Set(newNames), new Set(definedNames(ts).map(v => v.name)));
    const innerBn = Set.union(bn, allNewNames);
    const r = combine(terms(innerBn, ts), maybeExpr === null ? Result.EMPTY : expr(innerBn, maybeExpr));
    return {
        freeNames: Set.difference(r.freeNames, allNewNames),
        freeReferences: Set.differenceMap(v => v.name, r.freeReferences, allNewNames),
        mutatedNames: Set.difference(r.mutatedNames, allNewNames),
        mutatedReferences: Set.differenceMap(v => v.name, r.mutatedReferences, allNewNames),
    };
}
