From 8fc012f784cc810f79e8d6907f362955e46b559f Mon Sep 17 00:00:00 2001 From: Oliver Flatt Date: Wed, 2 Aug 2023 15:17:35 -0700 Subject: [PATCH] Cleanup from terms PR (#176) * add age counter * working on it * it builds lol * trying out new terms proposal * not working due to cycles lol * fix up terms * implement function extraction * oops * significant proof checker progress * some progress * manual congruence * Revert "manual congruence" This reverts commit c578a77306ade0ef370d1b91ec1695560ffc0f35. * proof caching * fix global variable desugaring * fixed up primitive computation * small cleanup * working on resugar * resugar in files * resugar in main * weird perf behavior * add resugaring * fix perf problems with proofs * fix more prim stuff * fix proof bug, big refactor * fix another bug * magic iteration action * rebuilding fixes * simplify terms.rkt * oops * working on terms * builds * fix desugaring bug * terms example much simpler * saturate parent again * progress on terms desugaring * more progress * running into terrible assert failure * fix another desugar bug * better term desugaring * revert rebuilding * remove unecessary thing * Fix subtleties in rebuilding * fix another bug * bug in function extract * desugar simplify * remove old print * add compiler passes * a bunch of cleanup * we need to fix extraction * tweaks for tutorial example * oh no bad bug * fix up fibonacci-demand for terms * another bug * let bindings correctly * add optimization but without canonicalization" * actually i already did that * get rid of define, fix bug in terms * revert let binding encoding * Revert "revert let binding encoding" This reverts commit b56cc58eb5c8562254cc1822bf6b8528dd3a3c2f. * filter equality with global variables * fix saturation bug * extract based on parent relationship * don't find value in debug * extract as an action or query * desugar set when it should be union * desugar set to union * this doesn't work anymore * some desugaring fixes * remove a print * use print-table * fix another benchmark * more bug fixes * clean up * fix up primitive queries * another normal form bug * small test * fix desugar global variables * fix bug in vec * vec tests * fix bug in filtering for primitives * rebuild non-eq tables too * add cache to fact desugaring * working on rebuild for custom containers" * fix vec rebuild! * rebuild for set * more cleanup * better error for unbound func * fix bug in query-extract * working on fixing extract variants * fix extract variants actually * fix bug in term encoding with initialization * remove some prints * fix defaults * better resugaring * remove prints again * horrible bug with fail * fix another global var bug ugh * another desugar bug * eq-able containers work better * working on vec again * fix find * unecessary dot * more desugar fixes for unbound vars * new desugar tests * proofs tests no longer relevant * no longer need presort table * canonical vars don't need to be updated * gj intersection sizes from Eli * resugar checks * trying to fix parent table * fix another parent table bug * working on computable functions in queries * fix bug in computed funcs * remove prints * some cleanup * revert all the computable stuff * delete slow test for now * all tests pass * remove rebuild and cleanup * delete random file * another random file * remove iteration * no longer using racket script * fix minimize script XD * add yihong's microbenchmark * Add some query compilation logging Co-authored-by: Oliver Flatt * Fix naive saturation Co-authored-by: Oliver Flatt * Convert parent tuples to filters Also push them up in queries. Co-authored-by: Oliver Flatt * Fix bug in check GJ instructions Co-authored-by: Oliver Flatt * Further refine the query compiler Co-authored-by: Oliver Flatt * disable query compiler changes * more cleanup * tell query compiler to do top-down * try more heuristics * less crazy heuristic * math micro change to not using match limit for consistency * simplify * generate queries as if parent table wasn't there * remove panic * print variable costs * clean up a bit * cleanup * re-add rebuilding * fix a bug with globals * globals broke seminaive * timestamps for globals * fix seminaive for globals * oops * some fixes from @ezrosent * backing out query compiler changes * more gj back out * refactor fact desugaring * remove unneeded arg to primitive apply --------- Co-authored-by: Eli Rosenthal Co-authored-by: Max Willsey --- Makefile | 2 +- README.md | 6 +- scripts/minimize.rkt | 17 +- src/ast/desugar.rs | 397 ++++---- src/ast/expr.rs | 14 +- src/ast/mod.rs | 361 +++++-- src/ast/parse.lalrpop | 29 +- src/extract.rs | 120 ++- src/function/mod.rs | 6 +- src/gj.rs | 302 ++++-- src/lib.rs | 510 +++++----- src/main.rs | 55 +- src/proofheader.egg | 1 - src/proofs.rs | 889 ------------------ src/sort/i64.rs | 43 +- src/sort/macros.rs | 2 +- src/sort/map.rs | 89 +- src/sort/set.rs | 40 +- src/sort/vec.rs | 21 +- src/termdag.rs | 148 +++ src/typecheck.rs | 252 +++-- src/typechecking.rs | 129 ++- tests/antiunify.egg | 8 +- tests/array.egg | 14 +- tests/bdd.egg | 22 +- tests/before-proofs.egg | 13 +- tests/birewrite.egg | 16 +- tests/bitwise.egg | 6 +- tests/calc.egg | 6 +- tests/combinators.egg | 8 +- tests/cyk.egg | 15 +- tests/cykjson.egg | 2 +- tests/eqsat-basic.egg | 4 +- tests/eqsolve.egg | 6 +- tests/extraction-cost.egg | 8 - tests/f64.egg | 2 + .../repro-containers-disallowed.egg | 4 + tests/fail-typecheck/repro-duplicated-var.egg | 3 + tests/fibonacci-demand.egg | 11 +- tests/files.rs | 57 +- tests/fusion.egg | 6 +- tests/herbie-tutorial.egg | 18 +- tests/herbie.egg | 54 +- tests/intersection.egg | 16 +- tests/interval.egg | 6 +- tests/knapsack.egg | 17 +- tests/lambda.egg | 20 +- tests/levenshtein-distance.egg | 16 +- tests/map.egg | 6 +- tests/math-microbenchmark.egg | 69 ++ tests/math.egg | 46 +- tests/matrix.egg | 22 +- tests/merge-during-rebuild.egg | 8 +- tests/merge-saturates.egg | 3 + tests/name-resolution.egg | 8 +- tests/path.egg | 2 +- tests/pathproof.egg | 4 +- tests/points-to.egg | 18 +- tests/prims.egg | 6 +- tests/proofs.egg | 371 -------- tests/repro-constraineq.egg | 2 + tests/repro-constraineq2.egg | 2 + tests/repro-define.egg | 2 +- tests/repro-desugar-143.egg | 40 + tests/repro-primitive-query.egg | 12 + tests/repro-querybug.egg | 2 +- tests/repro-should-saturate.egg | 2 +- tests/repro-silly-panic.egg | 2 +- tests/repro-unsound-htutorial.egg | 16 + tests/repro-unsound.egg | 14 +- tests/repro-vec-unequal.egg | 17 + tests/resolution.egg | 12 +- tests/semi_naive_set_function.egg | 6 +- tests/typecheck.egg | 28 +- tests/typeinfer.egg | 51 +- tests/unification-points-to.egg | 4 +- tests/until.egg | 6 +- 77 files changed, 2225 insertions(+), 2347 deletions(-) create mode 100644 src/termdag.rs delete mode 100644 tests/extraction-cost.egg create mode 100644 tests/fail-typecheck/repro-containers-disallowed.egg create mode 100644 tests/fail-typecheck/repro-duplicated-var.egg create mode 100644 tests/math-microbenchmark.egg delete mode 100644 tests/proofs.egg create mode 100644 tests/repro-constraineq.egg create mode 100644 tests/repro-constraineq2.egg create mode 100644 tests/repro-desugar-143.egg create mode 100644 tests/repro-primitive-query.egg create mode 100644 tests/repro-unsound-htutorial.egg create mode 100644 tests/repro-vec-unequal.egg diff --git a/Makefile b/Makefile index 241741f6..be765806 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: all web test nits docs serve -RUST_SRC=$(shell find -type f -wholename '*/src/*.rs' -or -name 'Cargo.toml') +RUST_SRC=$(shell find . -type f -wholename '*/src/*.rs' -or -name 'Cargo.toml') TESTS=$(shell find tests/ -type f -name '*.egg' -not -name '*repro-*') WWW=${PWD}/target/www/ diff --git a/README.md b/README.md index b1de9522..0014fed7 100644 --- a/README.md +++ b/README.md @@ -180,9 +180,9 @@ defines a named value. This is the same as a 0-arity function with a given, sing Example: ``` -(define one 1) -(define two 2) -(define three (+ one two)) +(let one 1) +(let two 2) +(let three (+ one two)) (extract three); extracts 3 as a i64 ``` diff --git a/scripts/minimize.rkt b/scripts/minimize.rkt index c933e399..6493270d 100644 --- a/scripts/minimize.rkt +++ b/scripts/minimize.rkt @@ -26,7 +26,7 @@ (define ITERATIONS 1) (define RANDOM-SAMPLE-FACTOR 1) (define MUST-NOT-STRINGS `()) -(define TARGET-STRINGS `("invalid default for")) +(define TARGET-STRINGS `("src/lib.rs:250")) (define (desugar line) (match line @@ -40,10 +40,8 @@ (define-values (egglog-process egglog-output egglog-in err) (subprocess (current-output-port) #f #f egglog-binary)) - (displayln "(" egglog-in) (for ([line program]) (writeln (desugar line) egglog-in)) - (displayln ")" egglog-in) (close-output-port egglog-in) (when (not (sync/timeout TIMEOUT egglog-process)) @@ -51,7 +49,7 @@ (subprocess-kill egglog-process #t) (displayln "checking output") (flush-output) - (define err-str (read-string 800 err)) + (define err-str (read-string 10000 err)) (close-input-port err) (define still-unsound (and (string? err-str) (for/and ([must-not-string MUST-NOT-STRINGS]) @@ -117,7 +115,18 @@ (random-and-sequential program))) (first (sort programs (lambda (a b) (< (length a) (length b)))))) + + (define (minimize port-in port-out) + #;((define-values (process out in err) (subprocess #f #f #f "cargo")) + (define err-str (read-string 800 err)) + (when (not (string=? err-str "")) + (error err-str)) + (close-input-port out) + (close-output-port in) + (close-input-port err) + (subprocess-wait process)) + (define egglog (read-lines port-in)) (pretty-print egglog) diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 05607b08..54fffdf9 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -1,4 +1,4 @@ -use crate::{proofs::RULE_PROOF_KEYWORD, *}; +use crate::*; fn desugar_datatype(name: Symbol, variants: Vec) -> Vec { vec![NCommand::Sort(name, None)] @@ -14,6 +14,7 @@ fn desugar_datatype(name: Symbol, variants: Vec) -> Vec { merge_action: vec![], default: None, cost: variant.cost, + unextractable: false, }) })) .collect() @@ -67,38 +68,123 @@ fn desugar_birewrite( .collect() } -fn expr_to_ssa(lhs: Symbol, expr: &Expr, desugar: &mut Desugar, res: &mut Vec) { - match expr { - Expr::Lit(l) => { - res.push(NormFact::AssignLit(lhs, l.clone())); +fn normalize_expr( + lhs_in: Symbol, + expr: &Expr, + desugar: &mut Desugar, + res: &mut Vec, + constraints: &mut Vec<(Symbol, Symbol)>, + bound: &mut HashSet, + cache: &mut HashMap, +) { + let is_bound = |var, desugar: &Desugar, bound_variables: &HashSet| { + desugar.global_variables.contains(&var) || bound_variables.contains(&var) + }; + if let Some(var) = cache.get(expr) { + if is_bound(lhs_in, desugar, bound) { + constraints.push((lhs_in, *var)); + } else { + bound.insert(lhs_in); + res.push(NormFact::AssignVar(lhs_in, *var)); + } + return; + } + + if let Expr::Var(v) = expr { + if *v == lhs_in { + return; + } + if is_bound(lhs_in, desugar, bound) && is_bound(*v, desugar, bound) { + constraints.push((lhs_in, *v)); + } else if is_bound(lhs_in, desugar, bound) { + bound.insert(*v); + res.push(NormFact::AssignVar(*v, lhs_in)); + } else if is_bound(*v, desugar, bound) { + bound.insert(lhs_in); + res.push(NormFact::AssignVar(lhs_in, *v)); + } else { + // TODO give proper error message and handle + // a wider variety of queries + panic!("Unbound variable {v}"); } - Expr::Var(v) => { - res.push(NormFact::ConstrainEq(lhs, *v)); + return; + } + + let lhs = if !is_bound(lhs_in, desugar, bound) { + bound.insert(lhs_in); + lhs_in + } else { + let fresh = desugar.get_fresh(); + constraints.push((fresh, lhs_in)); + fresh + }; + + match expr { + Expr::Lit(l) => res.push(NormFact::AssignLit(lhs, l.clone())), + Expr::Var(_v) => { + panic!("handled above"); } Expr::Call(f, children) => { + let is_compute = TypeInfo::default().is_primitive(*f); let mut new_children = vec![]; for child in children { match child { Expr::Var(v) => { - new_children.push(*v); + if is_compute { + if !is_bound(*v, desugar, bound) { + panic!("Unbound variable {v} in primitive computation"); + } + new_children.push(*v); + } else if is_bound(*v, desugar, bound) { + let fresh = desugar.get_fresh(); + new_children.push(fresh); + constraints.push((fresh, *v)); + } else { + bound.insert(*v); + new_children.push(*v); + } } _ => { let fresh = desugar.get_fresh(); - expr_to_ssa(fresh, child, desugar, res); + if !is_compute { + bound.insert(fresh); + } + normalize_expr(fresh, child, desugar, res, constraints, bound, cache); new_children.push(fresh); } } } - res.push(NormFact::Assign(lhs, NormExpr::Call(*f, new_children))); + + if is_compute { + res.push(NormFact::Compute(lhs, NormExpr::Call(*f, new_children))); + } else { + res.push(NormFact::Assign(lhs, NormExpr::Call(*f, new_children))); + } } - } + }; + cache.insert(expr.clone(), lhs); } fn flatten_equalities(equalities: Vec<(Symbol, Expr)>, desugar: &mut Desugar) -> Vec { let mut res = vec![]; + let mut bound_variables: HashSet = Default::default(); + let mut constraints: Vec<(Symbol, Symbol)> = Default::default(); + let mut cache = Default::default(); for (lhs, rhs) in equalities { - expr_to_ssa(lhs, &rhs, desugar, &mut res); + normalize_expr( + lhs, + &rhs, + desugar, + &mut res, + &mut constraints, + &mut bound_variables, + &mut cache, + ); + } + + for (lhs, rhs) in constraints { + res.push(NormFact::ConstrainEq(lhs, rhs)); } res @@ -123,7 +209,12 @@ fn flatten_facts(facts: &Vec, desugar: &mut Desugar) -> Vec { } } Fact::Fact(expr) => { - equalities.push((desugar.get_fresh(), expr.clone())); + // we can drop facts that are + // just a variable + if let Expr::Var(_v) = expr { + } else { + equalities.push((desugar.get_fresh(), expr.clone())); + } } } } @@ -146,7 +237,7 @@ fn flatten_actions(actions: &Vec, desugar: &mut Desugar) -> Vec { + Action::Set(symbol, exprs, rhs) => { let set = NormAction::Set( NormExpr::Call( *symbol, @@ -160,6 +251,11 @@ fn flatten_actions(actions: &Vec, desugar: &mut Desugar) -> Vec { + let added = add_expr(expr.clone(), &mut res); + let added_variants = add_expr(variants.clone(), &mut res); + res.push(NormAction::Extract(added, added_variants)); + } Action::Delete(symbol, exprs) => { let del = NormAction::Delete(NormExpr::Call( *symbol, @@ -260,14 +356,9 @@ fn desugar_schedule(desugar: &mut Desugar, schedule: &Schedule) -> NormSchedule } fn desugar_run_config(desugar: &mut Desugar, run_config: &RunConfig) -> NormRunConfig { - let RunConfig { - ruleset, - limit, - until, - } = run_config; + let RunConfig { ruleset, until } = run_config; NormRunConfig { ruleset: *ruleset, - limit: *limit, until: until.clone().map(|facts| flatten_facts(&facts, desugar)), } } @@ -318,9 +409,9 @@ pub struct Desugar { next_fresh: usize, next_command_id: usize, pub(crate) parser: ast::parse::ProgramParser, - pub(crate) action_parser: ast::parse::ActionParser, // TODO fix getting fresh names using modules pub(crate) number_underscores: usize, + pub(crate) global_variables: HashSet, } impl Default for Desugar { @@ -330,60 +421,84 @@ impl Default for Desugar { next_command_id: Default::default(), // these come from lalrpop and don't have default impls parser: ast::parse::ProgramParser::new(), - action_parser: ast::parse::ActionParser::new(), number_underscores: 3, + global_variables: Default::default(), } } } +pub(crate) fn desugar_simplify( + desugar: &mut Desugar, + expr: &Expr, + schedule: &Schedule, +) -> Vec { + let mut res = vec![NCommand::Push(1)]; + let lhs = desugar.get_fresh(); + res.extend( + flatten_actions(&vec![Action::Let(lhs, expr.clone())], desugar) + .into_iter() + .map(NCommand::NormAction), + ); + res.push(NCommand::RunSchedule(desugar_schedule(desugar, schedule))); + res.extend( + desugar_command( + Command::Extract { + variants: 0, + fact: Fact::Fact(Expr::Var(lhs)), + }, + desugar, + false, + false, + ) + .unwrap() + .into_iter() + .map(|c| c.command), + ); + + res.push(NCommand::Pop(1)); + res +} + pub(crate) fn desugar_calc( desugar: &mut Desugar, idents: Vec, exprs: Vec, - seminaive: bool, -) -> Vec { + seminaive_transform: bool, +) -> Result, Error> { let mut res = vec![]; // first, push all the idents for IdentSort { ident, sort } in idents { - res.extend(desugar.declare(ident, sort)); + res.push(Command::Declare { name: ident, sort }); } // now, for every pair of exprs we need to prove them equal for expr1and2 in exprs.windows(2) { let expr1 = &expr1and2[0]; let expr2 = &expr1and2[1]; - res.push(NCommand::Push(1)); - let mut new_memo = Default::default(); + res.push(Command::Push(1)); // add the two exprs - let mut actions = vec![]; - let v1 = desugar.expr_to_flat_actions(expr1, &mut actions, &mut new_memo); - let v2 = desugar.expr_to_flat_actions(expr2, &mut actions, &mut new_memo); - res.extend(actions.into_iter().map(NCommand::NormAction)); - - res.extend( - desugar_command( - Command::Run(RunConfig { - ruleset: "".into(), - limit: 1000000, - until: Some(vec![Fact::Eq(vec![expr1.clone(), expr2.clone()])]), - }), - desugar, - false, - seminaive, - ) - .unwrap() - .into_iter() - .map(|c| c.command), - ); + res.push(Command::Action(Action::Expr(expr1.clone()))); + res.push(Command::Action(Action::Expr(expr2.clone()))); + + res.push(Command::RunSchedule(Schedule::Saturate(Box::new( + Schedule::Run(RunConfig { + ruleset: "".into(), + until: Some(vec![Fact::Eq(vec![expr1.clone(), expr2.clone()])]), + }), + )))); - res.push(NCommand::Check(vec![NormFact::ConstrainEq(v1, v2)])); + res.push(Command::Check(vec![Fact::Eq(vec![ + expr1.clone(), + expr2.clone(), + ])])); - res.push(NCommand::Pop(1)); + res.push(Command::Pop(1)); } - res + desugar_commands(res, desugar, false, seminaive_transform) + .map(|cmds| cmds.into_iter().map(|cmd| cmd.command).collect()) } pub(crate) fn rewrite_name(rewrite: &Rewrite) -> String { @@ -394,7 +509,7 @@ pub(crate) fn desugar_command( command: Command, desugar: &mut Desugar, get_all_proofs: bool, - seminaive: bool, + seminaive_transform: bool, ) -> Result, Error> { let res = match command { Command::SetOption { name, value } => { @@ -418,7 +533,7 @@ pub(crate) fn desugar_command( desugar.parse_program(&s)?, desugar, get_all_proofs, - seminaive, + seminaive_transform, ); } Command::Rule { @@ -429,13 +544,14 @@ pub(crate) fn desugar_command( if name == "".into() { name = rule.to_string().replace('\"', "'").into(); } + let mut result = vec![NCommand::NormRule { ruleset, name, rule: flatten_rule(rule.clone(), desugar), }]; - if seminaive { + if seminaive_transform { if let Some(new_rule) = add_semi_naive_rule(desugar, rule) { result.push(NCommand::NormRule { ruleset, @@ -449,124 +565,57 @@ pub(crate) fn desugar_command( } Command::Sort(sort, option) => vec![NCommand::Sort(sort, option)], // TODO ignoring cost for now - Command::Define { - name, - expr, - cost: _cost, - } => { - let mut commands = vec![]; - - let mut actions = vec![]; - let mut temp = Default::default(); - let fresh = desugar.expr_to_flat_actions(&expr, &mut actions, &mut temp); - actions.push(NormAction::LetVar(name, fresh)); - for action in actions { - commands.push(NCommand::NormAction(action)); - } - commands - } Command::AddRuleset(name) => vec![NCommand::AddRuleset(name)], Command::Action(action) => flatten_actions(&vec![action], desugar) .into_iter() .map(NCommand::NormAction) .collect(), - Command::Run(config) => { - vec![NCommand::RunSchedule(NormSchedule::Run( - desugar_run_config(desugar, &config), - ))] - } - Command::Simplify { expr, config } => { - let fresh = desugar.get_fresh(); - flatten_actions(&vec![Action::Let(fresh, expr)], desugar) - .into_iter() - .map(NCommand::NormAction) - .chain( - vec![NCommand::Simplify { - var: fresh, - config: desugar_run_config(desugar, &config), - }] - .into_iter(), - ) - .collect() - } - Command::Calc(idents, exprs) => desugar_calc(desugar, idents, exprs, seminaive), + Command::Simplify { expr, schedule } => desugar_simplify(desugar, &expr, &schedule), + Command::Calc(idents, exprs) => desugar_calc(desugar, idents, exprs, seminaive_transform)?, Command::RunSchedule(sched) => { vec![NCommand::RunSchedule(desugar_schedule(desugar, &sched))] } - Command::Extract { variants, e } => { + // TODO add variants to extract action + Command::Extract { + variants: _variants, + fact, + } => { let fresh = desugar.get_fresh(); - flatten_actions(&vec![Action::Let(fresh, e)], desugar) - .into_iter() - .map(NCommand::NormAction) - .chain( - vec![NCommand::Extract { - variants, - var: fresh, - }] - .into_iter(), + let fresh_ruleset = desugar.get_fresh(); + let desugaring = if let Fact::Fact(Expr::Var(v)) = fact { + format!("(extract {v})") + } else { + format!( + "(check {fact}) + (ruleset {fresh_ruleset}) + (rule ((= {fresh} {fact})) + ((extract {fresh})) + :ruleset {fresh_ruleset}) + (run {fresh_ruleset} 1)" ) + }; + + desugar + .desugar_program( + desugar.parse_program(&desugaring).unwrap(), + get_all_proofs, + seminaive_transform, + )? + .into_iter() + .map(|cmd| cmd.command) .collect() } Command::Check(facts) => { - let mut res = vec![NCommand::Check(flatten_facts(&facts, desugar))]; + let res = vec![NCommand::Check(flatten_facts(&facts, desugar))]; if get_all_proofs { - let proofvar = desugar.get_fresh(); - // declare a variable for the resulting proof - // TODO using constant high cost - res.extend(desugar.declare(proofvar, "Proof__".into())); - - // make a dummy rule so that we get a proof for this check - let dummyrule = Rule { - body: facts.clone(), - head: vec![Action::Union( - Expr::Var(proofvar), - Expr::Var(RULE_PROOF_KEYWORD.into()), - )], - }; - let ruleset = desugar.get_fresh(); - res.push(NCommand::AddRuleset(ruleset)); - res.extend( - desugar_command( - Command::Rule { - ruleset, - name: "".into(), - rule: dummyrule, - }, - desugar, - get_all_proofs, - seminaive, - )? - .into_iter() - .map(|cmd| cmd.command), - ); - - // now run the dummy rule and get the proof - res.push(NCommand::RunSchedule(NormSchedule::Run(NormRunConfig { - ruleset, - limit: 1, - until: None, - }))); - - // we need to run proof extraction rules again - res.push(NCommand::RunSchedule(NormSchedule::Saturate(Box::new( - NormSchedule::Run(NormRunConfig { - ruleset: "proof-extract__".into(), - limit: 1, - until: None, - }), - )))); - - // extract the proof - res.push(NCommand::Extract { - variants: 0, - var: proofvar, - }); + // TODO check proofs } res } - Command::Print(symbol, size) => vec![NCommand::Print(symbol, size)], + Command::CheckProof => vec![NCommand::CheckProof], + Command::PrintTable(symbol, size) => vec![NCommand::PrintTable(symbol, size)], Command::PrintSize(symbol) => vec![NCommand::PrintSize(symbol)], Command::Output { file, exprs } => vec![NCommand::Output { file, exprs }], Command::Push(num) => { @@ -576,7 +625,7 @@ pub(crate) fn desugar_command( vec![NCommand::Pop(num)] } Command::Fail(cmd) => { - let mut desugared = desugar_command(*cmd, desugar, false, seminaive)?; + let mut desugared = desugar_command(*cmd, desugar, false, seminaive_transform)?; let last = desugared.pop().unwrap(); desugared.push(NormCommand { @@ -590,6 +639,17 @@ pub(crate) fn desugar_command( } }; + for cmd in &res { + if let NCommand::NormAction(action) = cmd { + action.map_def_use(&mut |var, is_def| { + if is_def { + desugar.global_variables.insert(var); + } + var + }); + } + } + Ok(res .into_iter() .map(|c| NormCommand { @@ -605,11 +665,11 @@ pub(crate) fn desugar_commands( program: Vec, desugar: &mut Desugar, get_all_proofs: bool, - seminaive: bool, + seminaive_transform: bool, ) -> Result, Error> { let mut res = vec![]; for command in program { - let desugared = desugar_command(command, desugar, get_all_proofs, seminaive)?; + let desugared = desugar_command(command, desugar, get_all_proofs, seminaive_transform)?; res.extend(desugared); } Ok(res) @@ -621,13 +681,20 @@ impl Clone for Desugar { next_fresh: self.next_fresh, next_command_id: self.next_command_id, parser: ast::parse::ProgramParser::new(), - action_parser: ast::parse::ActionParser::new(), number_underscores: self.number_underscores, + global_variables: self.global_variables.clone(), } } } impl Desugar { + pub fn merge_ruleset_name(&self) -> Symbol { + Symbol::from(format!( + "merge_ruleset{}", + "_".repeat(self.number_underscores) + )) + } + pub fn get_fresh(&mut self) -> Symbol { self.next_fresh += 1; format!( @@ -648,9 +715,9 @@ impl Desugar { &mut self, program: Vec, get_all_proofs: bool, - seminaive: bool, + seminaive_transform: bool, ) -> Result, Error> { - let res = desugar_commands(program, self, get_all_proofs, seminaive)?; + let res = desugar_commands(program, self, get_all_proofs, seminaive_transform)?; Ok(res) } @@ -684,8 +751,15 @@ impl Desugar { } } } - res.push(NormAction::Let(assign, NormExpr::Call(*f, new_children))); - assign + let result = NormExpr::Call(*f, new_children); + let result_expr = result.to_expr(); + if let Some(existing) = memo.get(&result_expr) { + *existing + } else { + memo.insert(result_expr.clone(), assign); + res.push(NormAction::Let(assign, result)); + assign + } } }; memo.insert(expr.clone(), res); @@ -711,7 +785,8 @@ impl Desugar { default: None, merge: None, merge_action: vec![], - cost: Some(HIGH_COST), + cost: None, + unextractable: false, }), NCommand::NormAction(NormAction::Let(name, NormExpr::Call(fresh, vec![]))), ] diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 38dd1bee..63e45c20 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -91,6 +91,10 @@ impl NormExpr { } impl Expr { + pub fn is_var(&self) -> bool { + matches!(self, Expr::Var(_)) + } + pub fn call(op: impl Into, children: impl IntoIterator) -> Self { Self::Call(op.into(), children.into_iter().collect()) } @@ -113,6 +117,12 @@ impl Expr { } } + pub fn ast_size(&self) -> usize { + let mut size = 0; + self.walk(&mut |_e| size += 1, &mut |_| {}); + size + } + pub fn walk(&self, pre: &mut impl FnMut(&Self), post: &mut impl FnMut(&Self)) { pre(self); self.children() @@ -151,12 +161,12 @@ impl Expr { res } - pub fn replace_canon(&self, canon: &HashMap) -> Self { + pub fn subst(&self, canon: &HashMap) -> Self { match self { Expr::Lit(_lit) => self.clone(), Expr::Var(v) => canon.get(v).cloned().unwrap_or_else(|| self.clone()), Expr::Call(op, children) => { - let children = children.iter().map(|c| c.replace_canon(canon)).collect(); + let children = children.iter().map(|c| c.subst(canon)).collect(); Expr::Call(*op, children) } } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index ac3acdb0..8612590c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -64,6 +64,24 @@ impl NormCommand { }) .collect() } + + pub fn resugar(&self) -> Command { + match &self.command { + NCommand::NormRule { + name, + ruleset, + rule, + } => Command::Rule { + name: *name, + ruleset: *ruleset, + rule: rule.resugar(), + }, + NCommand::Check(facts) => { + Command::Check(NormRule::resugar_facts(facts, &mut Default::default())) + } + _ => self.command.to_command(), + } + } } #[derive(Debug, Clone, Eq, PartialEq, Hash)] @@ -82,16 +100,9 @@ pub enum NCommand { }, NormAction(NormAction), RunSchedule(NormSchedule), - Simplify { - var: Symbol, - config: NormRunConfig, - }, - Extract { - variants: usize, - var: Symbol, - }, Check(Vec), - Print(Symbol, usize), + CheckProof, + PrintTable(Symbol, usize), PrintSize(Symbol), Output { file: String, @@ -134,18 +145,11 @@ impl NCommand { }, NCommand::RunSchedule(schedule) => Command::RunSchedule(schedule.to_schedule()), NCommand::NormAction(action) => Command::Action(action.to_action()), - NCommand::Simplify { var, config } => Command::Simplify { - expr: Expr::Var(*var), - config: config.to_run_config(), - }, - NCommand::Extract { variants, var } => Command::Extract { - variants: *variants, - e: Expr::Var(*var), - }, NCommand::Check(facts) => { Command::Check(facts.iter().map(|fact| fact.to_fact()).collect()) } - NCommand::Print(name, n) => Command::Print(*name, *n), + NCommand::CheckProof => Command::CheckProof, + NCommand::PrintTable(name, n) => Command::PrintTable(*name, *n), NCommand::PrintSize(name) => Command::PrintSize(*name), NCommand::Output { file, exprs } => Command::Output { file: file.to_string(), @@ -182,15 +186,11 @@ impl NCommand { rule: rule.map_exprs(f), }, NCommand::NormAction(action) => NCommand::NormAction(action.map_exprs(f)), - NCommand::Simplify { .. } => self.clone(), - NCommand::Extract { variants, var } => NCommand::Extract { - variants: *variants, - var: *var, - }, NCommand::Check(facts) => { NCommand::Check(facts.iter().map(|fact| fact.map_exprs(f)).collect()) } - NCommand::Print(name, n) => NCommand::Print(*name, *n), + NCommand::CheckProof => NCommand::CheckProof, + NCommand::PrintTable(name, n) => NCommand::PrintTable(*name, *n), NCommand::PrintSize(name) => NCommand::PrintSize(*name), NCommand::Output { file, exprs } => NCommand::Output { file: file.to_string(), @@ -236,6 +236,24 @@ impl NormSchedule { } } } + + pub fn map_run_commands(&self, f: &mut impl FnMut(&NormRunConfig) -> Schedule) -> Schedule { + match self { + NormSchedule::Run(config) => f(config), + NormSchedule::Saturate(sched) => { + Schedule::Saturate(Box::new(sched.map_run_commands(f))) + } + NormSchedule::Repeat(size, sched) => { + Schedule::Repeat(*size, Box::new(sched.map_run_commands(f))) + } + NormSchedule::Sequence(scheds) => Schedule::Sequence( + scheds + .iter() + .map(|sched| sched.map_run_commands(f)) + .collect(), + ), + } + } } trait ToSexp { @@ -318,11 +336,6 @@ pub enum Command { }, Sort(Symbol, Option<(Symbol, Vec)>), Function(FunctionDecl), - Define { - name: Symbol, - expr: Expr, - cost: Option, - }, AddRuleset(Symbol), Rule { name: Symbol, @@ -332,20 +345,20 @@ pub enum Command { Rewrite(Symbol, Rewrite), BiRewrite(Symbol, Rewrite), Action(Action), - Run(RunConfig), RunSchedule(Schedule), Simplify { expr: Expr, - config: RunConfig, + schedule: Schedule, }, Calc(Vec, Vec), Extract { variants: usize, - e: Expr, + fact: Fact, }, // TODO: this could just become an empty query Check(Vec), - Print(Symbol, usize), + CheckProof, + PrintTable(Symbol, usize), PrintSize(Symbol), Input { name: Symbol, @@ -380,27 +393,20 @@ impl ToSexp for Command { ruleset, rule, } => rule.to_sexp(*ruleset, *name), - Command::Define { name, expr, cost } => match cost { - None => list!("define", name, expr), - Some(cost) => list!("define", name, expr, ":cost", cost), - }, - Command::Run(config) => config.to_sexp(), Command::RunSchedule(sched) => list!("run-schedule", sched), Command::Calc(args, exprs) => list!("calc", list!(++ args), ++ exprs), - Command::Extract { variants, e } => list!("extract", ":variants", variants, e), + Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact), Command::Check(facts) => list!("check", ++ facts), + Command::CheckProof => list!("check-proof"), Command::Push(n) => list!("push", n), Command::Pop(n) => list!("pop", n), - Command::Print(name, n) => list!("print", name, n), + Command::PrintTable(name, n) => list!("print-table", name, n), Command::PrintSize(name) => list!("print-size", name), Command::Input { name, file } => list!("input", name, format!("\"{}\"", file)), Command::Output { file, exprs } => list!("output", format!("\"{}\"", file), ++ exprs), Command::Fail(cmd) => list!("fail", cmd), Command::Include(file) => list!("include", format!("\"{}\"", file)), - Command::Simplify { expr, config } => match &config.until { - Some(until) => list!("simplify", config.limit, expr, ":until", ++ until), - None => list!("simplify", config.limit, expr), - }, + Command::Simplify { expr, schedule } => list!("simplify", schedule, expr), } } } @@ -425,6 +431,9 @@ impl Display for Command { name, rule, } => rule.fmt_with_ruleset(f, *ruleset, *name), + Command::Check(facts) => { + write!(f, "(check {})", ListDisplay(facts, "\n")) + } _ => write!(f, "{}", self.to_sexp()), } } @@ -451,7 +460,6 @@ impl Display for IdentSort { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct RunConfig { pub ruleset: Symbol, - pub limit: usize, pub until: Option>, } @@ -461,7 +469,6 @@ impl ToSexp for RunConfig { if self.ruleset != "".into() { res.push(Sexp::String(self.ruleset.to_string())); } - res.push(Sexp::String(self.limit.to_string())); if let Some(until) = &self.until { res.push(Sexp::String(":until".into())); res.extend(until.iter().map(|fact| fact.to_sexp())); @@ -471,11 +478,9 @@ impl ToSexp for RunConfig { } } -// TODO get rid of limit, just use Repeat #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct NormRunConfig { pub ruleset: Symbol, - pub limit: usize, pub until: Option>, } @@ -483,7 +488,6 @@ impl NormRunConfig { pub fn to_run_config(&self) -> RunConfig { RunConfig { ruleset: self.ruleset, - limit: self.limit, until: self .until .as_ref() @@ -497,9 +501,11 @@ pub struct FunctionDecl { pub name: Symbol, pub schema: Schema, pub default: Option, + // TODO we should desugar merge and merge action pub merge: Option, pub merge_action: Vec, pub cost: Option, + pub unextractable: bool, } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -553,6 +559,7 @@ impl FunctionDecl { merge_action: vec![], default: None, cost: None, + unextractable: false, } } } @@ -577,6 +584,10 @@ impl ToSexp for FunctionDecl { ]); } + if self.unextractable { + res.push(Sexp::String(":unextractable".into())); + } + if !self.merge_action.is_empty() { res.push(Sexp::String(":on_merge".into())); res.push(Sexp::List( @@ -608,6 +619,8 @@ pub enum Fact { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum NormFact { Assign(Symbol, NormExpr), // assign symbol to a tuple + AssignVar(Symbol, Symbol), + Compute(Symbol, NormExpr), // compute a primative AssignLit(Symbol, Literal), ConstrainEq(Symbol, Symbol), } @@ -615,7 +628,10 @@ pub enum NormFact { impl NormFact { pub fn to_fact(&self) -> Fact { match self { - NormFact::Assign(symbol, expr) => Fact::Eq(vec![Expr::Var(*symbol), expr.to_expr()]), + NormFact::Assign(symbol, expr) | NormFact::Compute(symbol, expr) => { + Fact::Eq(vec![Expr::Var(*symbol), expr.to_expr()]) + } + NormFact::AssignVar(lhs, rhs) => Fact::Eq(vec![Expr::Var(*lhs), Expr::Var(*rhs)]), NormFact::ConstrainEq(lhs, rhs) => Fact::Eq(vec![Expr::Var(*lhs), Expr::Var(*rhs)]), NormFact::AssignLit(symbol, lit) => { Fact::Eq(vec![Expr::Var(*symbol), Expr::Lit(lit.clone())]) @@ -625,7 +641,10 @@ impl NormFact { pub fn map_exprs(&self, f: &mut impl FnMut(&NormExpr) -> NormExpr) -> NormFact { match self { - NormFact::Assign(symbol, expr) => NormFact::Assign(*symbol, f(expr)), + NormFact::Assign(symbol, expr) | NormFact::Compute(symbol, expr) => { + NormFact::Assign(*symbol, f(expr)) + } + NormFact::AssignVar(lhs, rhs) => NormFact::AssignVar(*lhs, *rhs), NormFact::ConstrainEq(lhs, rhs) => NormFact::ConstrainEq(*lhs, *rhs), NormFact::AssignLit(symbol, lit) => NormFact::AssignLit(*symbol, lit.clone()), } @@ -636,6 +655,12 @@ impl NormFact { NormFact::Assign(symbol, expr) => { NormFact::Assign(fvar(*symbol, true), expr.map_def_use(fvar, true)) } + NormFact::AssignVar(lhs, rhs) => { + NormFact::AssignVar(fvar(*lhs, true), fvar(*rhs, false)) + } + NormFact::Compute(symbol, expr) => { + NormFact::Compute(fvar(*symbol, true), expr.map_def_use(fvar, false)) + } NormFact::AssignLit(symbol, lit) => { NormFact::AssignLit(fvar(*symbol, true), lit.clone()) } @@ -662,6 +687,10 @@ impl Fact { Fact::Fact(expr) => Fact::Fact(f(expr)), } } + + pub fn subst(&self, subst: &HashMap) -> Fact { + self.map_exprs(&mut |e| e.subst(subst)) + } } impl Display for NormFact { @@ -680,9 +709,9 @@ impl Display for Fact { pub enum Action { Let(Symbol, Expr), Set(Symbol, Vec, Expr), - SetNoTrack(Symbol, Vec, Expr), Delete(Symbol, Vec), Union(Expr, Expr), + Extract(Expr, Expr), Panic(String), Expr(Expr), // If(Expr, Action, Action), @@ -693,6 +722,7 @@ pub enum NormAction { Let(Symbol, NormExpr), LetVar(Symbol, Symbol), LetLit(Symbol, Literal), + Extract(Symbol, Symbol), Set(NormExpr, Symbol), Delete(NormExpr), Union(Symbol, Symbol), @@ -705,11 +735,14 @@ impl NormAction { NormAction::Let(symbol, expr) => Action::Let(*symbol, expr.to_expr()), NormAction::LetVar(symbol, other) => Action::Let(*symbol, Expr::Var(*other)), NormAction::LetLit(symbol, lit) => Action::Let(*symbol, Expr::Lit(lit.clone())), - NormAction::Set(NormExpr::Call(head, body), other) => Action::SetNoTrack( + NormAction::Set(NormExpr::Call(head, body), other) => Action::Set( *head, body.iter().map(|s| Expr::Var(*s)).collect(), Expr::Var(*other), ), + NormAction::Extract(symbol, variants) => { + Action::Extract(Expr::Var(*symbol), Expr::Var(*variants)) + } NormAction::Delete(NormExpr::Call(symbol, args)) => { Action::Delete(*symbol, args.iter().map(|s| Expr::Var(*s)).collect()) } @@ -724,6 +757,7 @@ impl NormAction { NormAction::LetVar(symbol, other) => NormAction::LetVar(*symbol, *other), NormAction::LetLit(symbol, lit) => NormAction::LetLit(*symbol, lit.clone()), NormAction::Set(expr, other) => NormAction::Set(f(expr), *other), + NormAction::Extract(var, variants) => NormAction::Extract(*var, *variants), NormAction::Delete(expr) => NormAction::Delete(f(expr)), NormAction::Union(lhs, rhs) => NormAction::Union(*lhs, *rhs), NormAction::Panic(msg) => NormAction::Panic(msg.clone()), @@ -743,6 +777,9 @@ impl NormAction { NormAction::Set(expr, other) => { NormAction::Set(expr.map_def_use(fvar, false), fvar(*other, false)) } + NormAction::Extract(var, variants) => { + NormAction::Extract(fvar(*var, false), fvar(*variants, false)) + } NormAction::Delete(expr) => NormAction::Delete(expr.map_def_use(fvar, false)), NormAction::Union(lhs, rhs) => NormAction::Union(fvar(*lhs, false), fvar(*rhs, false)), NormAction::Panic(msg) => NormAction::Panic(msg.clone()), @@ -755,11 +792,9 @@ impl ToSexp for Action { match self { Action::Let(lhs, rhs) => list!("let", lhs, rhs), Action::Set(lhs, args, rhs) => list!("set", list!(lhs, ++ args), rhs), - Action::SetNoTrack(lhs, args, rhs) => { - list!("set-no-track", list!(lhs, ++ args), rhs) - } Action::Union(lhs, rhs) => list!("union", lhs, rhs), Action::Delete(lhs, args) => list!("delete", list!(lhs, ++ args)), + Action::Extract(expr, variants) => list!("extract", expr, variants), Action::Panic(msg) => list!("panic", format!("\"{}\"", msg.clone())), Action::Expr(e) => e.to_sexp(), } @@ -774,12 +809,9 @@ impl Action { let right = f(rhs); Action::Set(*lhs, args.iter().map(f).collect(), right) } - Action::SetNoTrack(lhs, args, rhs) => { - let right = f(rhs); - Action::SetNoTrack(*lhs, args.iter().map(f).collect(), right) - } Action::Delete(lhs, args) => Action::Delete(*lhs, args.iter().map(f).collect()), Action::Union(lhs, rhs) => Action::Union(f(lhs), f(rhs)), + Action::Extract(expr, variants) => Action::Extract(f(expr), f(variants)), Action::Panic(msg) => Action::Panic(msg.clone()), Action::Expr(e) => Action::Expr(f(e)), } @@ -787,25 +819,21 @@ impl Action { pub fn replace_canon(&self, canon: &HashMap) -> Self { match self { - Action::Let(lhs, rhs) => Action::Let(*lhs, rhs.replace_canon(canon)), + Action::Let(lhs, rhs) => Action::Let(*lhs, rhs.subst(canon)), Action::Set(lhs, args, rhs) => Action::Set( *lhs, - args.iter().map(|e| e.replace_canon(canon)).collect(), - rhs.replace_canon(canon), - ), - Action::SetNoTrack(lhs, args, rhs) => Action::SetNoTrack( - *lhs, - args.iter().map(|e| e.replace_canon(canon)).collect(), - rhs.replace_canon(canon), + args.iter().map(|e| e.subst(canon)).collect(), + rhs.subst(canon), ), Action::Delete(lhs, args) => { - Action::Delete(*lhs, args.iter().map(|e| e.replace_canon(canon)).collect()) + Action::Delete(*lhs, args.iter().map(|e| e.subst(canon)).collect()) } - Action::Union(lhs, rhs) => { - Action::Union(lhs.replace_canon(canon), rhs.replace_canon(canon)) + Action::Union(lhs, rhs) => Action::Union(lhs.subst(canon), rhs.subst(canon)), + Action::Extract(expr, variants) => { + Action::Extract(expr.subst(canon), variants.subst(canon)) } Action::Panic(msg) => Action::Panic(msg.clone()), - Action::Expr(e) => Action::Expr(e.replace_canon(canon)), + Action::Expr(e) => Action::Expr(e.subst(canon)), } } } @@ -844,6 +872,195 @@ impl NormRule { } } + pub fn globals_used_in_matcher(facts: &Vec) -> HashSet { + let mut bound_vars = HashSet::::default(); + for fact in facts { + fact.map_def_use(&mut |var, def| { + if def { + bound_vars.insert(var); + } + var + }); + } + + let mut unbound_vars = HashSet::::default(); + for fact in facts { + fact.map_def_use(&mut |var, def| { + if !def && !bound_vars.contains(&var) { + unbound_vars.insert(var); + } + var + }); + } + unbound_vars + } + + // just get rid of all the equality constraints for now + pub fn resugar_facts(facts: &Vec, subst: &mut HashMap) -> Vec { + let unbound = NormRule::globals_used_in_matcher(facts); + let mut unionfind = UnionFind::default(); + let mut var_to_id = HashMap::::default(); + let mut id_to_var = HashMap::::default(); + let mut get_id = |var: Symbol, uf: &mut UnionFind| -> Id { + if let Some(id) = var_to_id.get(&var) { + *id + } else { + let id = uf.make_set(); + var_to_id.insert(var, id); + id_to_var.insert(id, var); + id + } + }; + for norm_fact in facts { + if let NormFact::ConstrainEq(v1, v2) = norm_fact { + let id1 = get_id(*v1, &mut unionfind); + let id2 = get_id(*v2, &mut unionfind); + unionfind.union_raw(id1, id2); + } else if let NormFact::AssignVar(v1, v2) = norm_fact { + let id1 = get_id(*v1, &mut unionfind); + let id2 = get_id(*v2, &mut unionfind); + unionfind.union_raw(id1, id2); + } + } + + for (var, id) in &var_to_id { + let leader = id_to_var.get(&unionfind.find(*id)).unwrap(); + if leader != var { + subst.insert(*var, Expr::Var(*leader)); + } + } + + let mut res = vec![]; + for fact in facts { + match fact { + NormFact::ConstrainEq(..) => (), + NormFact::AssignVar(..) => (), + _ => res.push(fact.to_fact().subst(subst)), + } + } + + // add back contraints on unbound variables + for var in unbound { + if let Some(id) = var_to_id.get(&var) { + let leader = id_to_var.get(&unionfind.find(*id)).unwrap(); + if leader != &var { + res.push(Fact::Eq(vec![Expr::Var(var), Expr::Var(*leader)])); + } + } + } + + res + } + + pub fn resugar_actions(&self, subst: &mut HashMap) -> Vec { + let mut used = HashSet::::default(); + let mut head = Vec::::default(); + for a in &self.head { + match a { + NormAction::Let(symbol, expr) => { + let new_expr = expr.to_expr(); + new_expr.map(&mut |subexpr| { + if let Expr::Var(v) = subexpr { + used.insert(*v); + } + subexpr.clone() + }); + let substituted = new_expr.subst(subst); + + // TODO sometimes re-arranging actions is bad + if substituted.ast_size() > 1 { + head.push(Action::Let(*symbol, substituted)); + } else { + subst.insert(*symbol, substituted); + } + } + NormAction::LetVar(symbol, other) => { + let new_expr = subst.get(other).unwrap_or(&Expr::Var(*other)).clone(); + used.insert(*other); + subst.insert(*symbol, new_expr); + } + NormAction::Extract(symbol, variants) => { + let new_expr = subst.get(symbol).cloned().unwrap_or(Expr::Var(*symbol)); + used.insert(*symbol); + let new_expr2 = subst.get(variants).cloned().unwrap_or(Expr::Var(*variants)); + used.insert(*variants); + head.push(Action::Extract(new_expr, new_expr2)); + } + NormAction::LetLit(symbol, lit) => { + subst.insert(*symbol, Expr::Lit(lit.clone())); + } + NormAction::Set(expr, other) => { + let new_expr = expr.to_expr(); + new_expr.map(&mut |subexpr| { + if let Expr::Var(v) = subexpr { + used.insert(*v); + } + subexpr.clone() + }); + let other_expr = subst.get(other).unwrap_or(&Expr::Var(*other)).clone(); + used.insert(*other); + let substituted = new_expr.subst(subst); + match substituted { + Expr::Call(op, children) => { + head.push(Action::Set(op, children, other_expr)); + } + _ => panic!("Expected call in set"), + } + } + NormAction::Delete(expr) => { + let new_expr = expr.to_expr(); + new_expr.map(&mut |subexpr| { + if let Expr::Var(v) = subexpr { + used.insert(*v); + } + subexpr.clone() + }); + match new_expr.subst(subst) { + Expr::Call(op, children) => { + head.push(Action::Delete(op, children)); + } + _ => panic!("Expected call in delete"), + } + } + NormAction::Union(lhs, rhs) => { + let new_lhs = subst.get(lhs).unwrap_or(&Expr::Var(*lhs)).clone(); + let new_rhs = subst.get(rhs).unwrap_or(&Expr::Var(*rhs)).clone(); + used.insert(*lhs); + used.insert(*rhs); + head.push(Action::Union(new_lhs, new_rhs)); + } + NormAction::Panic(msg) => { + head.push(Action::Panic(msg.clone())); + } + } + } + + // unused substitutions need to be added + // to the action, since they have the side-effect + // of adding to the database + for (var, expr) in subst { + if !used.contains(var) { + match expr { + Expr::Var(..) => (), + Expr::Lit(..) => (), + Expr::Call(..) => head.push(Action::Expr(expr.clone())), + }; + } + } + head + } + + pub fn resugar(&self) -> Rule { + let mut subst = HashMap::::default(); + + let facts_resugared = NormRule::resugar_facts(&self.body, &mut subst); + + Rule { + head: self.resugar_actions(&mut subst), + body: facts_resugared, + } + } + pub fn map_exprs(&self, f: &mut impl FnMut(&NormExpr) -> NormExpr) -> Self { NormRule { head: self.head.iter().map(|a| a.map_exprs(f)).collect(), diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index 52e45caf..e6ecb115 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -50,9 +50,10 @@ Command: Command = { LParen "sort" LParen RParen RParen => Command::Sort (name, Some((head, tail))), LParen "sort" RParen => Command::Sort (name, None), LParen "function" + >)?> )?> )?> RParen => { - Command::Function(FunctionDecl { name, schema, merge, merge_action: merge_action.unwrap_or_default(), default, cost }) + Command::Function(FunctionDecl { name, schema, merge, merge_action: merge_action.unwrap_or_default(), default, cost, unextractable: unextractable.is_some() }) }, LParen "declare" RParen => Command::Declare{name, sort}, LParen "relation" > RParen => Command::Function(FunctionDecl::relation(name, types)), @@ -66,21 +67,21 @@ Command: Command = { >)?> )?> RParen => Command::BiRewrite(ruleset.unwrap_or("".into()), Rewrite { lhs, rhs, conditions: conditions.unwrap_or_default() }), - LParen "define" RParen => Command::Define { name, expr, cost }, LParen "let" RParen => Command::Action(Action::Let(name, expr)), => Command::Action(<>), - LParen "run" )?> RParen => Command::Run(RunConfig { ruleset: "".into(), limit, until }), - LParen "run" )?> RParen => Command::Run(RunConfig { ruleset, limit, until }), - LParen "simplify" )?> RParen - => Command::Simplify { expr, config : RunConfig { ruleset: "".into(), limit, until } }, + LParen "run" )?> RParen => Command::RunSchedule(Schedule::Repeat(limit, Box::new(Schedule::Run(RunConfig { ruleset : "".into(), until })))), + LParen "run" )?> RParen => Command::RunSchedule(Schedule::Repeat(limit, Box::new(Schedule::Run(RunConfig { ruleset, until })))), + LParen "simplify" RParen + => Command::Simplify { expr, schedule }, LParen "add-ruleset" RParen => Command::AddRuleset(name), LParen "calc" LParen RParen RParen => Command::Calc(idents, exprs), - LParen "extract" )?> RParen => Command::Extract { e, variants: variants.unwrap_or(0) }, + LParen "query-extract" )?> RParen => Command::Extract { fact, variants: variants.unwrap_or(0) }, LParen "check" <(Fact)*> RParen => Command::Check(<>), + LParen "check-proof" RParen => Command::CheckProof, LParen "run-schedule" RParen => Command::RunSchedule(Schedule::Sequence(<>)), LParen "push" RParen => Command::Push(<>.unwrap_or(1)), LParen "pop" RParen => Command::Pop(<>.unwrap_or(1)), - LParen "print" RParen => Command::Print(sym, n.unwrap_or(10)), + LParen "print-table" RParen => Command::PrintTable(sym, n.unwrap_or(10)), LParen "print-size" RParen => Command::PrintSize(sym), LParen "input" RParen => Command::Input { name, file }, LParen "output" RParen => Command::Output { file, exprs }, @@ -92,9 +93,10 @@ Schedule: Schedule = { LParen "saturate" RParen => Schedule::Saturate(Box::new(Schedule::Sequence(<>))), LParen "seq" RParen => Schedule::Sequence(<>), LParen "repeat" RParen => Schedule::Repeat(limit, Box::new(Schedule::Sequence(scheds))), - LParen "run" )?> RParen => Schedule::Run(RunConfig { ruleset: "".into(), limit, until }), - LParen "run" )?> RParen => Schedule::Run(RunConfig { ruleset, limit, until }), - => Schedule::Run(RunConfig { ruleset: ident, limit: 1, until: None }), + LParen "run" )?> RParen => + Schedule::Run(RunConfig { ruleset: "".into(), until }), + LParen "run" )?> RParen => Schedule::Run(RunConfig { ruleset, until }), + => Schedule::Run(RunConfig { ruleset: ident, until: None }), } Cost: Option = { @@ -104,10 +106,11 @@ Cost: Option = { NonLetAction: Action = { LParen "set" LParen RParen RParen => Action::Set ( f, args, v ), - LParen "set-no-track" LParen RParen RParen => Action::SetNoTrack ( f, args, v ), LParen "delete" LParen RParen RParen => Action::Delete ( f, args), LParen "union" RParen => Action::Union(<>), LParen "panic" RParen => Action::Panic(msg), + LParen "extract" RParen => Action::Extract(expr, Expr::Lit(Literal::Int(0))), + LParen "extract" RParen => Action::Extract(expr, variants), => Action::Expr(e), } @@ -118,7 +121,7 @@ pub Action: Action = { Name: Symbol = { "[" "]" => <> } -Fact: Fact = { +pub Fact: Fact = { LParen "=" RParen => { es.push(e); Fact::Eq(es) diff --git a/src/extract.rs b/src/extract.rs index fa055eba..baba88c5 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -1,19 +1,20 @@ use hashbrown::hash_map::Entry; use crate::ast::Symbol; +use crate::termdag::{Term, TermDag}; use crate::util::HashMap; -use crate::{ArcSort, EGraph, Expr, Function, Id, Value}; +use crate::{ArcSort, EGraph, Function, Id, Value}; type Cost = usize; #[derive(Debug)] -struct Node<'a> { +pub(crate) struct Node<'a> { sym: Symbol, inputs: &'a [Value], } -struct Extractor<'a> { - costs: HashMap)>, +pub(crate) struct Extractor<'a> { + costs: HashMap, ctors: Vec, egraph: &'a EGraph, } @@ -29,14 +30,19 @@ impl EGraph { None } - pub fn extract(&self, value: Value, arcsort: &ArcSort) -> (Cost, Expr) { - Extractor::new(self).find_best(value, arcsort) + pub fn extract(&self, value: Value, termdag: &mut TermDag, arcsort: &ArcSort) -> (Cost, Term) { + Extractor::new(self, termdag).find_best(value, termdag, arcsort) } - pub fn extract_variants(&mut self, value: Value, limit: usize) -> Vec { + pub fn extract_variants( + &mut self, + value: Value, + limit: usize, + termdag: &mut TermDag, + ) -> Vec { let (tag, id) = self.value_to_id(value).unwrap(); let output_value = &Value::from_id(tag, id); - let ext = &Extractor::new(self); + let ext = &Extractor::new(self, termdag); ext.ctors .iter() .flat_map(|&sym| { @@ -45,12 +51,13 @@ impl EGraph { return vec![]; } assert!(func.schema.output.is_eq_sort()); + func.nodes .iter() - .filter_map(move |(inputs, output)| { + .filter_map(|(inputs, output)| { (&output.value == output_value).then(|| { let node = Node { sym, inputs }; - ext.expr_from_node(&node) + ext.expr_from_node(&node, termdag) }) }) .collect() @@ -61,71 +68,114 @@ impl EGraph { } impl<'a> Extractor<'a> { - fn new(egraph: &'a EGraph) -> Self { + pub fn new(egraph: &'a EGraph, termdag: &mut TermDag) -> Self { let mut extractor = Extractor { costs: HashMap::default(), egraph, ctors: vec![], }; - // HACK - // just consider all functions constructors for now... - extractor.ctors.extend(egraph.functions.keys().cloned()); + // only consider "extractable" functions + extractor.ctors.extend( + egraph + .functions + .keys() + .filter(|func| !egraph.functions.get(*func).unwrap().decl.unextractable) + .cloned(), + ); log::debug!("Extracting from ctors: {:?}", extractor.ctors); - extractor.find_costs(); + extractor.find_costs(termdag); extractor } - fn expr_from_node(&self, node: &Node) -> Expr { - let children = node.inputs.iter().map(|&value| { - let arcsort = self.egraph.get_sort(&value).unwrap(); - self.find_best(value, arcsort).1 - }); - Expr::call(node.sym, children) + fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Term { + let mut children = vec![]; + for value in node.inputs { + let arcsort = self.egraph.get_sort(value).unwrap(); + children.push(self.find_best(*value, termdag, arcsort).1) + } + + termdag.make(node.sym, children) } - fn find_best(&self, value: Value, sort: &ArcSort) -> (Cost, Expr) { + pub fn find_best(&self, value: Value, termdag: &mut TermDag, sort: &ArcSort) -> (Cost, Term) { if sort.is_eq_sort() { - let id = self.egraph.find(Id::from(value.bits as usize)); - let (cost, node) = &self + let id = self.find(&value); + let (cost, node) = self .costs .get(&id) - .unwrap_or_else(|| panic!("No cost for {:?}", value)); - (*cost, self.expr_from_node(node)) + .unwrap_or_else(|| { + log::error!("No cost for {:?}", value); + for func in self.egraph.functions.values() { + for (inputs, output) in func.nodes.iter() { + if output.value == value { + log::error!("Found unextractable function: {:?}", func.decl.name); + log::error!("Inputs: {:?}", inputs); + log::error!( + "{:?}", + inputs + .iter() + .map(|input| self.costs.get(&self.find(input))) + .collect::>() + ); + } + } + } + + panic!("No cost for {:?}", value) + }) + .clone(); + (cost, node) } else { - (0, sort.make_expr(self.egraph, value)) + (0, termdag.expr_to_term(&sort.make_expr(self.egraph, value))) } } - fn node_total_cost(&self, function: &Function, children: &[Value]) -> Option { + fn node_total_cost( + &mut self, + function: &Function, + children: &[Value], + termdag: &mut TermDag, + ) -> Option<(Vec, Cost)> { let mut cost = function.decl.cost.unwrap_or(1); let types = &function.schema.input; + let mut terms: Vec = vec![]; for (ty, value) in types.iter().zip(children) { cost = cost.saturating_add(if ty.is_eq_sort() { let id = self.egraph.find(Id::from(value.bits as usize)); // TODO costs should probably map values? - self.costs.get(&id)?.0 + let (cost, term) = self.costs.get(&id)?; + terms.push(term.clone()); + *cost } else { + let term = termdag.expr_to_term(&ty.make_expr(self.egraph, *value)); + terms.push(term); 1 }); } - Some(cost) + Some((terms, cost)) + } + + fn find(&self, value: &Value) -> Id { + self.egraph.find(Id::from(value.bits as usize)) } - fn find_costs(&mut self) { + fn find_costs(&mut self, termdag: &mut TermDag) { let mut did_something = true; while did_something { did_something = false; - for &sym in &self.ctors { + for sym in self.ctors.clone() { let func = &self.egraph.functions[&sym]; if func.schema.output.is_eq_sort() { for (inputs, output) in func.nodes.iter() { - if let Some(new_cost) = self.node_total_cost(func, inputs) { - let make_new_pair = || (new_cost, Node { sym, inputs }); + if let Some((term_inputs, new_cost)) = + self.node_total_cost(func, inputs, termdag) + { + let make_new_pair = || (new_cost, termdag.make(sym, term_inputs)); - let id = self.egraph.find(Id::from(output.value.bits as usize)); + let id = self.find(&output.value); match self.costs.entry(id) { Entry::Vacant(e) => { did_something = true; diff --git a/src/function/mod.rs b/src/function/mod.rs index a91c7109..b4809e81 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,3 +1,5 @@ +use std::mem; + use crate::*; use index::*; use smallvec::SmallVec; @@ -12,7 +14,6 @@ pub type ValueVec = SmallVec<[Value; 3]>; pub struct Function { pub decl: FunctionDecl, pub schema: ResolvedSchema, - pub(crate) is_variable: bool, pub merge: MergeAction, pub(crate) nodes: table::Table, sorts: HashSet, @@ -65,7 +66,7 @@ impl ResolvedSchema { pub(crate) type DeferredMerge = (ValueVec, Value, Value); impl Function { - pub fn new(egraph: &EGraph, decl: &FunctionDecl, is_variable: bool) -> Result { + pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result { let mut input = Vec::with_capacity(decl.schema.input.len()); for s in &decl.schema.input { input.push(match egraph.proof_state.type_info.sorts.get(s) { @@ -129,7 +130,6 @@ impl Function { Ok(Function { decl: decl.clone(), schema: ResolvedSchema { input, output }, - is_variable, nodes: Default::default(), scratch: Default::default(), sorts, diff --git a/src/gj.rs b/src/gj.rs index cccd66b6..bdc69f63 100644 --- a/src/gj.rs +++ b/src/gj.rs @@ -1,5 +1,6 @@ use hashbrown::hash_map::Entry as HEntry; use indexmap::map::Entry; +use log::log_enabled; use smallvec::SmallVec; use crate::{ @@ -13,9 +14,12 @@ use std::{ ops::Range, }; +#[derive(Clone)] enum Instr<'a> { Intersect { value_idx: usize, + variable_name: Symbol, + info: VarInfo2, trie_accesses: Vec<(usize, TrieAccess<'a>)>, }, ConstrainConstant { @@ -30,44 +34,90 @@ enum Instr<'a> { }, } +// FIXME @mwillsey awful name, bad bad bad +#[derive(Default, Debug, Clone)] +struct VarInfo2 { + occurences: Vec, + intersected_on: usize, + size_guess: usize, +} + +struct InputSizes<'a> { + cur_stage: usize, + // a map from from stage to vector of costs for each stage, + // where 'cost' is the largest relation being intersected + stage_sizes: &'a mut HashMap>, +} + +impl<'a> InputSizes<'a> { + fn add_measurement(&mut self, max_size: usize) { + self.stage_sizes + .entry(self.cur_stage) + .or_default() + .push(max_size); + } + + fn next(&mut self) -> InputSizes { + InputSizes { + cur_stage: self.cur_stage + 1, + stage_sizes: self.stage_sizes, + } + } +} + type Result = std::result::Result<(), ()>; struct Program<'a>(Vec>); -impl<'a> std::fmt::Display for Program<'a> { +impl<'a> std::fmt::Display for Instr<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - for instr in &self.0 { - match instr { - Instr::Intersect { - value_idx, - trie_accesses, - } => { - write!(f, " Intersect @ {} ", value_idx)?; - for (trie_idx, a) in trie_accesses { - write!(f, " {}: {}", trie_idx, a)?; - } - writeln!(f)? - } - Instr::ConstrainConstant { - index, - val, - trie_access, - } => { - writeln!(f, " ConstrainConstant {index} {trie_access} = {val:?}")?; - } - Instr::Call { prim, args, check } => { - writeln!(f, " Call {:?} {:?} {:?}", prim, args, check)?; + match self { + Instr::Intersect { + value_idx, + trie_accesses, + variable_name, + info, + } => { + write!( + f, + " Intersect @ {value_idx} sg={sg:6} {variable_name:15}", + sg = info.size_guess + )?; + for (trie_idx, a) in trie_accesses { + write!(f, " {}: {}", trie_idx, a)?; } + writeln!(f)? + } + Instr::ConstrainConstant { + index, + val, + trie_access, + } => { + writeln!(f, " ConstrainConstant {index} {trie_access} = {val:?}")?; + } + Instr::Call { prim, args, check } => { + writeln!(f, " Call {:?} {:?} {:?}", prim, args, check)?; } } Ok(()) } } +impl<'a> std::fmt::Display for Program<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (i, instr) in self.0.iter().enumerate() { + write!(f, "{i:2}. {}", instr)?; + } + Ok(()) + } +} + struct Context<'b> { query: &'b CompiledQuery, + join_var_ordering: Vec, tuple: Vec, matches: usize, + egraph: &'b EGraph, } impl<'b> Context<'b> { @@ -76,18 +126,27 @@ impl<'b> Context<'b> { cq: &'b CompiledQuery, timestamp_ranges: &[Range], ) -> Option<(Self, Program<'b>, Vec>)> { + let (program, join_var_ordering, intersections) = + egraph.compile_program(cq, timestamp_ranges)?; + let ctx = Context { query: cq, tuple: vec![Value::fake(); cq.vars.len()], + join_var_ordering, matches: 0, + egraph, }; - let (program, _vars, intersections) = egraph.compile_program(cq, timestamp_ranges)?; - Some((ctx, program, intersections)) } - fn eval(&mut self, tries: &mut [&LazyTrie], program: &[Instr], f: &mut F) -> Result + fn eval( + &mut self, + tries: &mut [&LazyTrie], + program: &[Instr], + mut stage: InputSizes, + f: &mut F, + ) -> Result where F: FnMut(&[Value]) -> Result, { @@ -108,7 +167,7 @@ impl<'b> Context<'b> { if let Some(next) = tries[*index].get(trie_access, *val) { let old = tries[*index]; tries[*index] = next; - self.eval(tries, program, f)?; + self.eval(tries, program, stage.next(), f)?; tries[*index] = old; } Ok(()) @@ -116,12 +175,21 @@ impl<'b> Context<'b> { Instr::Intersect { value_idx, trie_accesses, + .. } => { + if let Some(x) = trie_accesses + .iter() + .map(|(atom, _)| tries[*atom].len()) + .max() + { + stage.add_measurement(x); + } + match trie_accesses.as_slice() { [(j, access)] => tries[*j].for_each(access, |value, trie| { let old_trie = std::mem::replace(&mut tries[*j], trie); self.tuple[*value_idx] = value; - self.eval(tries, program, f)?; + self.eval(tries, program, stage.next(), f)?; tries[*j] = old_trie; Ok(()) }), @@ -136,7 +204,7 @@ impl<'b> Context<'b> { let old_ta = std::mem::replace(&mut tries[a.0], ta); let old_tb = std::mem::replace(&mut tries[b.0], tb); self.tuple[*value_idx] = value; - self.eval(tries, program, f)?; + self.eval(tries, program, stage.next(), f)?; tries[a.0] = old_ta; tries[b.0] = old_tb; } @@ -165,7 +233,7 @@ impl<'b> Context<'b> { // at this point, new_tries is ready to go self.tuple[*value_idx] = value; - self.eval(&mut new_tries, program, f) + self.eval(&mut new_tries, program, stage.next(), f) }) } } @@ -180,6 +248,7 @@ impl<'b> Context<'b> { self.tuple[i] } AtomTerm::Value(val) => *val, + AtomTerm::Global(g) => self.egraph.global_bindings.get(g).unwrap().1, }) } @@ -187,9 +256,14 @@ impl<'b> Context<'b> { match out { AtomTerm::Var(v) => { let i = self.query.vars.get_index_of(v).unwrap(); - if *check && self.tuple[i] != res { - return Ok(()); + + if *check { + assert_ne!(self.tuple[i], Value::fake()); + if self.tuple[i] != res { + return Ok(()); + } } + self.tuple[i] = res; } AtomTerm::Value(val) => { @@ -198,8 +272,16 @@ impl<'b> Context<'b> { return Ok(()); } } + AtomTerm::Global(g) => { + assert!(check); + let (sort, val, _ts) = self.egraph.global_bindings.get(g).unwrap(); + assert!(sort.name() == res.tag); + if val.bits != res.bits { + return Ok(()); + } + } } - self.eval(tries, program, f)?; + self.eval(tries, program, stage.next(), f)?; } Ok(()) @@ -208,6 +290,7 @@ impl<'b> Context<'b> { } } +#[derive(Clone)] enum Constraint { Eq(usize, usize), Const(usize, Value), @@ -239,6 +322,8 @@ pub struct VarInfo { #[derive(Debug, Clone)] pub struct CompiledQuery { query: Query, + // Ordering is used for the tuple + // The GJ variable ordering is stored in the context pub vars: IndexMap, } @@ -248,7 +333,7 @@ impl EGraph { query: Query, types: &IndexMap, ) -> CompiledQuery { - // NOTE: this vars order only used for ordering the tuple, + // NOTE: this vars order only used for ordering the tuple storing the resulting match // It is not the GJ variable order. let mut vars: IndexMap = Default::default(); @@ -263,6 +348,7 @@ impl EGraph { } } + // make sure everyone has an entry in the vars table for prim in &query.filters { for v in prim.vars() { vars.entry(v).or_default(); @@ -284,6 +370,9 @@ impl EGraph { for (i, t) in atom.args.iter().enumerate() { match t { AtomTerm::Value(val) => constraints.push(Constraint::Const(i, *val)), + AtomTerm::Global(g) => { + constraints.push(Constraint::Const(i, self.global_bindings.get(g).unwrap().1)) + } AtomTerm::Var(_v) => { if let Some(j) = atom.args[..i].iter().position(|t2| t == t2) { constraints.push(Constraint::Eq(j, i)); @@ -323,13 +412,6 @@ impl EGraph { Vec, /* variable ordering */ Vec>, /* the first column accessed per-atom */ )> { - #[derive(Default)] - struct VarInfo2 { - occurences: Vec, - intersected_on: usize, - size_guess: usize, - } - let atoms = &query.query.atoms; let mut vars: IndexMap = Default::default(); let mut constants = @@ -342,6 +424,10 @@ impl EGraph { AtomTerm::Value(val) => { constants.entry(i).or_default().push((col, *val)); } + AtomTerm::Global(g) => { + let val = self.global_bindings.get(g).unwrap().1; + constants.entry(i).or_default().push((col, val)); + } } } } @@ -372,16 +458,23 @@ impl EGraph { // info.size_guess >>= info.occurences.len() - 1; } + // here we are picking the variable ordering let mut ordered_vars = IndexMap::default(); while !vars.is_empty() { - let (&var, _info) = vars + let mut var_cost = vars .iter() - .max_by_key(|(_v, info)| { + .map(|(v, info)| { let size = info.size_guess as isize; - (info.occurences.len(), info.intersected_on, -size) + let cost = (info.occurences.len(), info.intersected_on, -size); + (cost, v) }) - .unwrap(); + .collect::>(); + var_cost.sort(); + var_cost.reverse(); + + log::debug!("Variable costs: {:?}", ListDebug(&var_cost, "\n")); + let var = *var_cost[0].1; let info = vars.remove(&var).unwrap(); for &i in &info.occurences { for v in atoms[i].vars() { @@ -420,6 +513,8 @@ impl EGraph { }); Instr::Intersect { value_idx, + variable_name: v, + info: info.clone(), trie_accesses: info .occurences .iter() @@ -439,7 +534,6 @@ impl EGraph { program.extend(var_instrs); // now we can try to add primitives - // TODO this is very inefficient, since primitives all at the end let mut extra = query.query.filters.clone(); while !extra.is_empty() { let next = extra.iter().position(|p| { @@ -447,6 +541,7 @@ impl EGraph { p.args[..p.args.len() - 1].iter().all(|a| match a { AtomTerm::Var(v) => vars.contains_key(v), AtomTerm::Value(_) => true, + AtomTerm::Global(_) => true, }) }); @@ -461,6 +556,7 @@ impl EGraph { } }, AtomTerm::Value(_) => true, + AtomTerm::Global(_) => true, }; program.push(Instr::Call { prim: p.head.clone(), @@ -468,17 +564,62 @@ impl EGraph { check, }); } else { - panic!("cycle") + panic!("cycle {:#?}", query) } } + let resulting_program = Program(program); + self.sanity_check_program(&resulting_program, query); + Some(( - Program(program), + resulting_program, vars.into_keys().collect(), initial_columns, )) } + fn sanity_check_program(&self, program: &Program, query: &CompiledQuery) { + // sanity check the program + let mut tuple_valid = vec![false; query.vars.len()]; + for instr in &program.0 { + match instr { + Instr::Intersect { value_idx, .. } => { + assert!(!tuple_valid[*value_idx]); + tuple_valid[*value_idx] = true; + } + Instr::ConstrainConstant { .. } => {} + Instr::Call { check, args, .. } => { + let Some((last, args)) = args.split_last() else { + continue + }; + + for a in args { + if let AtomTerm::Var(v) = a { + let i = query.vars.get_index_of(v).unwrap(); + assert!(tuple_valid[i]); + } + } + + match last { + AtomTerm::Var(v) => { + let i = query.vars.get_index_of(v).unwrap(); + assert_eq!(*check, tuple_valid[i], "{instr}"); + if !*check { + tuple_valid[i] = true; + } + } + AtomTerm::Value(_) => { + assert!(*check); + } + AtomTerm::Global(_) => { + assert!(*check); + } + } + } + } + } + } + pub(crate) fn run_query(&self, cq: &CompiledQuery, timestamp: u32, mut f: F) where F: FnMut(&[Value]) -> Result, @@ -486,7 +627,19 @@ impl EGraph { let has_atoms = !cq.query.atoms.is_empty(); if has_atoms { - let do_seminaive = self.seminaive; + // check if any globals updated + let mut global_updated = false; + for atom in &cq.query.atoms { + for arg in &atom.args { + if let AtomTerm::Global(g) = arg { + if self.global_bindings.get(g).unwrap().2 > timestamp { + global_updated = true; + } + } + } + } + + let do_seminaive = self.seminaive && !global_updated; // for the later atoms, we consider everything let mut timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()]; for (atom_i, atom) in cq.query.atoms.iter().enumerate() { @@ -496,14 +649,14 @@ impl EGraph { } // do the gj + if let Some((mut ctx, program, cols)) = Context::new(self, cq, ×tamp_ranges) { let start = Instant::now(); log::debug!( - "Query: {}\nNew atom: {}\nVars: {}\nProgram\n{}", - cq.query, - atom, - ListDisplay(cq.vars.keys(), " "), - program + "Query:\n{q}\nNew atom: {atom}\nTuple: {tuple}\nJoin order: {order}\nProgram\n{program}", + q = cq.query, + order = ListDisplay(&ctx.join_var_ordering, " "), + tuple = ListDisplay(cq.vars.keys(), " "), ); let mut tries = Vec::with_capacity(cq.query.atoms.len()); for ((atom, ts), col) in cq @@ -526,12 +679,36 @@ impl EGraph { } } let mut trie_refs = tries.iter().collect::>(); - ctx.eval(&mut trie_refs, &program.0, &mut f).unwrap_or(()); - log::debug!( - "Matched {} times (took {:?})", - ctx.matches, - Instant::now().duration_since(start) + let mut meausrements = HashMap::>::default(); + let stages = InputSizes { + stage_sizes: &mut meausrements, + cur_stage: 0, + }; + ctx.eval(&mut trie_refs, &program.0, stages, &mut f) + .unwrap_or(()); + let mut sums = Vec::from_iter( + meausrements + .iter() + .map(|(x, y)| (*x, y.iter().copied().sum::())), ); + sums.sort_by_key(|(i, _sum)| *i); + if log_enabled!(log::Level::Debug) { + for (i, sum) in sums { + log::debug!("stage {i} total cost {sum}"); + } + } + let duration = start.elapsed(); + log::debug!("Matched {} times (took {:?})", ctx.matches, duration,); + let iteration = self + .ruleset_iteration + .get::(&"".into()) + .unwrap_or(&0); + if duration.as_millis() > 1000 { + log::warn!( + "Query took a long time at iter {iteration} : {:?}", + duration + ); + } } if !do_seminaive { @@ -543,9 +720,15 @@ impl EGraph { timestamp_ranges[atom_i] = 0..timestamp; } } else if let Some((mut ctx, program, _)) = Context::new(self, cq, &[]) { + let mut meausrements = HashMap::>::default(); + let stages = InputSizes { + stage_sizes: &mut meausrements, + cur_stage: 0, + }; let tries = LazyTrie::make_initial_vec(cq.query.atoms.len()); let mut trie_refs = tries.iter().collect::>(); - ctx.eval(&mut trie_refs, &program.0, &mut f).unwrap_or(()); + ctx.eval(&mut trie_refs, &program.0, stages, &mut f) + .unwrap_or(()); } } } @@ -665,6 +848,7 @@ impl LazyTrie { } } +#[derive(Clone)] struct TrieAccess<'a> { function: &'a Function, timestamp_range: Range, diff --git a/src/lib.rs b/src/lib.rs index c20e31be..ae806f8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,17 +5,20 @@ mod gj; mod proofs; mod serialize; pub mod sort; +mod termdag; mod typecheck; mod typechecking; mod unionfind; pub mod util; mod value; +use extract::Extractor; use hashbrown::hash_map::Entry; use index::ColumnIndex; use instant::{Duration, Instant}; pub use serialize::SerializeConfig; use sort::*; +use termdag::{Term, TermDag}; use thiserror::Error; use proofs::ProofState; @@ -25,15 +28,15 @@ use symbolic_expressions::Sexp; use ast::*; pub use typechecking::{TypeInfo, UNIT_SYM}; -use std::fmt::{Formatter, Write}; +use std::fmt::{Display, Formatter, Write}; use std::fs::File; use std::hash::Hash; use std::io::Read; use std::iter::once; -use std::mem; use std::ops::{Deref, Range}; use std::path::PathBuf; use std::rc::Rc; +use std::str::FromStr; use std::{fmt::Debug, sync::Arc}; use typecheck::Program; @@ -67,8 +70,9 @@ pub struct RunReport { #[derive(Debug, Clone)] pub struct ExtractReport { pub cost: usize, - pub expr: Expr, - pub variants: Vec, + pub expr: Term, + pub variants: Vec, + pub termdag: TermDag, } impl RunReport { @@ -151,6 +155,48 @@ impl PrimitiveLike for SimplePrimitive { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] +pub enum CompilerPassStop { + Desugar, + TypecheckDesugared, + TermEncoding, + TypecheckTermEncoding, + Proofs, + TypecheckProofs, + All, +} + +impl Display for CompilerPassStop { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CompilerPassStop::Desugar => write!(f, "desugar"), + CompilerPassStop::TypecheckDesugared => write!(f, "typecheck_desugared"), + CompilerPassStop::TermEncoding => write!(f, "term_encoding"), + CompilerPassStop::TypecheckTermEncoding => write!(f, "typecheck_term_encoding"), + CompilerPassStop::Proofs => write!(f, "proofs"), + CompilerPassStop::TypecheckProofs => write!(f, "typecheck_proofs"), + CompilerPassStop::All => write!(f, "all"), + } + } +} + +impl FromStr for CompilerPassStop { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "desugar" => Ok(CompilerPassStop::Desugar), + "typecheck_desugared" => Ok(CompilerPassStop::TypecheckDesugared), + "term_encoding" => Ok(CompilerPassStop::TermEncoding), + "typecheck_term_encoding" => Ok(CompilerPassStop::TypecheckTermEncoding), + "proofs" => Ok(CompilerPassStop::Proofs), + "typecheck_proofs" => Ok(CompilerPassStop::TypecheckProofs), + "all" => Ok(CompilerPassStop::All), + _ => Err(format!("Unknown compiler pass stop: {}", s)), + } + } +} + #[derive(Clone)] pub struct EGraph { egraphs: Vec, @@ -158,6 +204,7 @@ pub struct EGraph { pub(crate) proof_state: ProofState, functions: HashMap, rulesets: HashMap>, + ruleset_iteration: HashMap, proofs_enabled: bool, interactive_mode: bool, timestamp: u32, @@ -166,6 +213,8 @@ pub struct EGraph { pub node_limit: usize, pub fact_directory: Option, pub seminaive: bool, + // sort, value, and timestamp + pub global_bindings: HashMap, extract_report: Option, run_report: Option, } @@ -189,7 +238,9 @@ impl Default for EGraph { unionfind: Default::default(), functions: Default::default(), rulesets: Default::default(), + ruleset_iteration: Default::default(), proof_state: ProofState::default(), + global_bindings: Default::default(), match_limit: usize::MAX, node_limit: usize::MAX, timestamp: 0, @@ -289,6 +340,8 @@ impl EGraph { } } + // find the leader term for this term + // in the corresponding table pub fn find(&self, id: Id) -> Id { self.unionfind.find(id) } @@ -314,6 +367,15 @@ impl EGraph { break; } } + + // now update global bindings + let mut new_global_bindings = self.global_bindings.clone(); + for (_sym, (_sort, value, ts)) in new_global_bindings.iter_mut() { + *value = self.bad_find_value(*value); + *ts = self.timestamp; + } + self.global_bindings = new_global_bindings; + self.debug_assert_invariants(); Ok(updates) } @@ -331,6 +393,7 @@ impl EGraph { for (func, merges) in deferred_merges { new_unions += self.apply_merges(func, &merges); } + Ok(new_unions) } @@ -363,8 +426,8 @@ impl EGraph { self.unionfind.n_unions() - n_unions + function.clear_updates() } - pub fn declare_function(&mut self, decl: &FunctionDecl, is_var: bool) -> Result<(), Error> { - let function = Function::new(self, decl, is_var)?; + pub fn declare_function(&mut self, decl: &FunctionDecl) -> Result<(), Error> { + let function = Function::new(self, decl)?; let old = self.functions.insert(decl.name, function); if old.is_some() { panic!( @@ -383,20 +446,18 @@ impl EGraph { ) -> Result<(), Error> { let name = variant.name; let sort = sort.into(); - self.declare_function( - &FunctionDecl { - name, - schema: Schema { - input: variant.types, - output: sort, - }, - merge: None, - merge_action: vec![], - default: None, - cost: variant.cost, + self.declare_function(&FunctionDecl { + name, + schema: Schema { + input: variant.types, + output: sort, }, - false, - )?; + merge: None, + merge_action: vec![], + default: None, + cost: variant.cost, + unextractable: false, + })?; // if let Some(ctors) = self.sorts.get_mut(&sort) { // ctors.push(name); // } @@ -412,7 +473,11 @@ impl EGraph { } } - pub fn print_function(&mut self, sym: Symbol, n: usize) -> Result { + pub fn function_to_dag( + &mut self, + sym: Symbol, + n: usize, + ) -> Result<(Vec<(Term, Term)>, TermDag), Error> { let f = self.functions.get(&sym).ok_or(TypeError::Unbound(sym))?; let schema = f.schema.clone(); let nodes = f @@ -422,34 +487,49 @@ impl EGraph { .map(|(k, v)| (ValueVec::from(k), v.clone())) .collect::>(); - let out_is_unit = f.schema.output.name() == UNIT_SYM.into(); - - let mut buf = String::new(); - let s = &mut buf; + let mut termdag = TermDag::default(); + let extractor = Extractor::new(self, &mut termdag); + let mut terms = Vec::new(); for (ins, out) in nodes { - write!(s, "({}", sym).unwrap(); - for (a, t) in ins.iter().copied().zip(&schema.input) { - s.push(' '); - let e = self.extract(a, t).1; - write!(s, "{}", e).unwrap(); + let mut children = Vec::new(); + for (a, a_type) in ins.iter().copied().zip(&schema.input) { + if a_type.is_eq_sort() { + children.push(extractor.find_best(a, &mut termdag, a_type).1); + } else { + children.push(termdag.expr_to_term(&a_type.make_expr(self, a))); + }; } - if out_is_unit { - s.push(')'); + let out = if schema.output.is_eq_sort() { + extractor + .find_best(out.value, &mut termdag, &schema.output) + .1 } else { - let e = self.extract(out.value, &schema.output).1; - write!(s, ") -> {}", e).unwrap(); + termdag.expr_to_term(&schema.output.make_expr(self, out.value)) + }; + terms.push((termdag.make(sym, children), out)); + } + drop(extractor); + + Ok((terms, termdag)) + } + + pub fn print_function(&mut self, sym: Symbol, n: usize) -> Result { + let (terms_with_outputs, termdag) = self.function_to_dag(sym, n)?; + let f = self + .functions + .get(&sym) + .ok_or(TypeError::UnboundFunction(sym))?; + let out_is_unit = f.schema.output.name() == UNIT_SYM.into(); + + let mut buf = String::new(); + let s = &mut buf; + for (term, output) in terms_with_outputs { + write!(s, "{}", termdag.to_string(&term)).unwrap(); + if !out_is_unit { + write!(s, " -> {}", termdag.to_string(&output)).unwrap(); } s.push('\n'); - // write!(s, "{}(", self.decl.name)?; - // for (i, arg) in args.iter().enumerate() { - // if i > 0 { - // write!(s, ", ")?; - // } - // write!(s, "{}", arg)?; - // } - // write!(s, ") = {}", value)?; - // println!("{}", s); } Ok(buf) @@ -462,13 +542,12 @@ impl EGraph { // returns whether the egraph was updated pub fn run_schedule(&mut self, sched: &NormSchedule) -> RunReport { - log::info!("Running {}", sched); match sched { NormSchedule::Run(config) => self.run_rules(config), NormSchedule::Repeat(limit, sched) => { let mut report = RunReport::default(); for _i in 0..*limit { - let rec = report.union(&self.run_schedule(sched)); + let rec = self.run_schedule(sched); report = report.union(&rec); if !rec.updated { break; @@ -497,50 +576,42 @@ impl EGraph { } } - pub fn run_rules(&mut self, config: &NormRunConfig) -> RunReport { - let NormRunConfig { - ruleset, - limit, - until, - } = config; - let mut report: RunReport = Default::default(); - - // we rebuild on every command so we are in a valid state at this point - for i in 0..*limit { - if let Some(facts) = until { - if self.check_facts(facts).is_ok() { - log::info!( - "Breaking early at iteration {} because of facts:\n {}!", - i, - ListDisplay(facts, "\n") - ); - break; - } + pub fn run_rules_once(&mut self, config: &NormRunConfig, report: &mut RunReport) { + // first rebuild + let rebuild_start = Instant::now(); + let updates = self.rebuild_nofail(); + log::debug!("database size: {}", self.num_tuples()); + log::debug!("Made {updates} updates"); + report.rebuild_time += rebuild_start.elapsed(); + self.timestamp += 1; + + let NormRunConfig { ruleset, until } = config; + + if let Some(facts) = until { + if self.check_facts(facts).is_ok() { + log::info!( + "Breaking early because of facts:\n {}!", + ListDisplay(facts, "\n") + ); + return; } + } - let subreport = self.step_rules(i, *ruleset); - report = report.union(&subreport); - - let rebuild_start = Instant::now(); - let updates = self.rebuild_nofail(); - log::debug!("database size: {}", self.num_tuples()); - log::debug!("Made {updates} updates (iteration {i})"); - report.rebuild_time += rebuild_start.elapsed(); - self.timestamp += 1; - if !subreport.updated { - log::info!("Breaking early at iteration {}!", i); - break; - } + let subreport = self.step_rules(*ruleset); + *report = report.union(&subreport); - if self.num_tuples() > self.node_limit { - log::warn!( - "Node limit reached at iteration {}, {} nodes. Stopping!", - i, - self.num_tuples() - ); - break; - } + log::debug!("database size: {}", self.num_tuples()); + self.timestamp += 1; + + if self.num_tuples() > self.node_limit { + log::warn!("Node limit reached, {} nodes. Stopping!", self.num_tuples()); } + } + + pub fn run_rules(&mut self, config: &NormRunConfig) -> RunReport { + let mut report: RunReport = Default::default(); + + self.run_rules_once(config, &mut report); // Report the worst offenders log::debug!("Slowest rules:\n{}", { @@ -573,7 +644,15 @@ impl EGraph { report } - fn step_rules(&mut self, iteration: usize, ruleset: Symbol) -> RunReport { + fn step_rules(&mut self, ruleset: Symbol) -> RunReport { + let n_unions_before = self.unionfind.n_unions(); + // don't ban parent or rebuilding + let match_limit = + if ruleset.as_str().contains("parent_") || ruleset.as_str().contains("rebuilding_") { + usize::MAX + } else { + self.match_limit + }; let mut report = RunReport::default(); let ban_length = 5; @@ -583,6 +662,8 @@ impl EGraph { } let mut rules: HashMap = std::mem::take(self.rulesets.get_mut(&ruleset).unwrap()); + let iteration = *self.ruleset_iteration.entry(ruleset).or_default(); + self.ruleset_iteration.insert(ruleset, iteration + 1); // TODO why did I have to copy the rules here for the first for loop? let copy_rules = rules.clone(); let search_start = Instant::now(); @@ -590,7 +671,7 @@ impl EGraph { for (name, rule) in copy_rules.iter() { let mut all_values = vec![]; if rule.banned_until <= iteration { - let mut fuel = safe_shl(self.match_limit, rule.times_banned); + let mut fuel = safe_shl(match_limit, rule.times_banned); let rule_search_start = Instant::now(); self.run_query(&rule.query, rule.todo_timestamp, |values| { assert_eq!(values.len(), rule.query.vars.len()); @@ -608,10 +689,7 @@ impl EGraph { rule_search_time.as_secs_f64(), all_values.len() ); - report.updated |= !all_values.is_empty(); searched.push((name, all_values, rule_search_time)); - } else { - report.updated = true; } } @@ -628,7 +706,7 @@ impl EGraph { if num_vars != 0 { // backoff logic let len = all_values.len() / num_vars; - let threshold = safe_shl(self.match_limit, rule.times_banned); + let threshold = safe_shl(match_limit, rule.times_banned); if len > threshold { let ban_length = safe_shl(ban_length, rule.times_banned); rule.times_banned = rule.times_banned.saturating_add(1); @@ -663,9 +741,21 @@ impl EGraph { self.rulesets.insert(ruleset, rules); let apply_elapsed = apply_start.elapsed(); report.apply_time += apply_elapsed; + report.updated |= self.did_change_tables() || n_unions_before != self.unionfind.n_unions(); + report } + fn did_change_tables(&self) -> bool { + for (_name, function) in &self.functions { + if function.nodes.max_ts() >= self.timestamp { + return true; + } + } + + false + } + fn add_rule_with_name( &mut self, name: String, @@ -747,7 +837,7 @@ impl EGraph { pub fn set_option(&mut self, name: &str, value: Expr) { match name { "enable_proofs" => { - panic!("enable_proofs must be set as the first line of the file"); + self.proofs_enabled = true; } "interactive_mode" => { if let Expr::Lit(Literal::Int(i)) = value { @@ -809,6 +899,11 @@ impl EGraph { pre_rebuild.elapsed().as_millis() ); } + + self.debug_assert_invariants(); + + self.extract_report = None; + self.run_report = None; let res = Ok(match command { NCommand::SetOption { name, value } => { let str = format!("Set option {} to {}", name, value); @@ -818,7 +913,7 @@ impl EGraph { // Sorts are already declared during typechecking NCommand::Sort(name, _presort_and_args) => format!("Declared sort {}.", name), NCommand::Function(fdecl) => { - self.declare_function(&fdecl, false)?; + self.declare_function(&fdecl)?; format!("Declared function {}.", fdecl.name) } NCommand::AddRuleset(name) => { @@ -841,23 +936,6 @@ impl EGraph { "Skipping schedule.".to_string() } } - NCommand::Extract { var, variants } => { - let expr = Expr::Var(var); - if should_run { - // TODO typecheck - let report = self.extract_expr(expr, variants)?; - let mut msg = format!("Extracted with cost {}: {}", report.cost, report.expr); - if variants > 0 { - let line = "\n "; - let v_exprs = ListDisplay(&report.variants, line); - write!(msg, "\nVariants of {}:{line}{v_exprs}", report.expr).unwrap(); - } - self.extract_report = Some(report); - msg - } else { - "Skipping extraction.".into() - } - } NCommand::Check(facts) => { if should_run { self.check_facts(&facts)?; @@ -866,28 +944,36 @@ impl EGraph { "Skipping check.".into() } } - NCommand::Simplify { var, config } => { - if should_run { - let report = self.simplify(Expr::Var(var), &config)?; - let res = format!("Simplified with cost {} to {}", report.cost, report.expr); - self.extract_report = Some(report); - res - } else { - "Skipping simplify.".into() - } - } + NCommand::CheckProof => "TODO implement proofs".into(), NCommand::NormAction(action) => { if should_run { match &action { NormAction::Let(name, contents) => { - // define with high cost - self.define(*name, &contents.to_expr(), Some(HIGH_COST))?; + let (etype, value) = self.eval_expr(&contents.to_expr(), None, true)?; + let present = self + .global_bindings + .insert(*name, (etype, value, self.timestamp)); + if present.is_some() { + panic!("Variable {name} was already present in global bindings"); + } } NormAction::LetVar(var1, var2) => { - self.define(*var1, &Expr::Var(*var2), Some(HIGH_COST))?; + let value = self.global_bindings.get(var2).unwrap(); + let present = self.global_bindings.insert(*var1, value.clone()); + if present.is_some() { + panic!("Variable {var1} was already present in global bindings"); + } } NormAction::LetLit(var, lit) => { - self.define(*var, &Expr::Lit(lit.clone()), Some(HIGH_COST))?; + let value = self.eval_lit(lit); + let etype = self.proof_state.type_info.infer_literal(lit); + let present = self + .global_bindings + .insert(*var, (etype, value, self.timestamp)); + + if present.is_some() { + panic!("Variable {var} was already present in global bindings"); + } } _ => { self.eval_actions(std::slice::from_ref(&action.to_action()))?; @@ -908,7 +994,7 @@ impl EGraph { } format!("Popped {n} levels.") } - NCommand::Print(f, n) => { + NCommand::PrintTable(f, n) => { let msg = self.print_function(f, n)?; println!("{}", msg); msg @@ -919,10 +1005,13 @@ impl EGraph { msg } NCommand::Fail(c) => { - if self.run_command(*c, should_run).is_ok() { + let result = self.run_command(*c, should_run); + if let Err(e) = result { + eprintln!("Expect failure: {}", e); + "Command failed as expected".into() + } else { return Err(Error::ExpectFail); } - "Command failed as expected.".into() } NCommand::Input { name, file } => { let func = self.functions.get_mut(&name).unwrap(); @@ -991,7 +1080,8 @@ impl EGraph { for expr in exprs { use std::io::Write; let res = self.extract_expr(expr, 1)?; - writeln!(f, "{}", res.expr).map_err(|e| Error::IoError(filename.clone(), e))?; + writeln!(f, "{}", res.termdag.to_string(&res.expr)) + .map_err(|e| Error::IoError(filename.clone(), e))?; } format!("Output to '{filename:?}'.") @@ -1007,29 +1097,18 @@ impl EGraph { } } - fn simplify(&mut self, expr: Expr, config: &NormRunConfig) -> Result { - self.push(); - let (t, value) = self.eval_expr(&expr, None, true).unwrap(); - self.run_report = Some(self.run_rules(config)); - let (cost, expr) = self.extract(value, &t); - self.pop().unwrap(); - Ok(ExtractReport { - cost, - expr, - variants: vec![], - }) - } // Extract an expression from the current state, returning the cost, the extracted expression and some number // of other variants, if variants is not zero. - pub fn extract_expr(&mut self, e: Expr, variants: usize) -> Result { + pub fn extract_expr(&mut self, e: Expr, num_variants: usize) -> Result { let (t, value) = self.eval_expr(&e, None, true)?; - let (cost, expr) = self.extract(value, &t); - let exprs = match variants { + let mut termdag = TermDag::default(); + let (cost, expr) = self.extract(value, &mut termdag, &t); + let variants = match num_variants { 0 => vec![], 1 => vec![expr.clone()], _ => { if self.get_sort(&value).is_some_and(|sort| sort.is_eq_sort()) { - self.extract_variants(value, variants) + self.extract_variants(value, num_variants, &mut termdag) } else { vec![expr.clone()] } @@ -1038,77 +1117,17 @@ impl EGraph { Ok(ExtractReport { cost, expr, - variants: exprs, + variants, + termdag, }) } - pub fn declare_const(&mut self, name: Symbol, sort: &ArcSort) -> Result<(), Error> { - assert!(sort.is_eq_sort()); - self.declare_function( - &FunctionDecl { - name, - schema: Schema { - input: vec![], - output: sort.name(), - }, - default: None, - merge: None, - merge_action: vec![], - cost: None, - }, - true, - )?; - let f = self.functions.get_mut(&name).unwrap(); - let id = self.unionfind.make_set(); - let value = Value::from_id(sort.name(), id); - f.insert(&[], value, self.timestamp); - Ok(()) - } - pub fn define( - &mut self, - name: Symbol, - expr: &Expr, - cost: Option, - ) -> Result { - let (sort, value) = self.eval_expr(expr, None, true)?; - self.declare_function( - &FunctionDecl { - name, - schema: Schema { - input: vec![], - output: value.tag, - }, - default: None, - merge: None, - merge_action: vec![], - cost, - }, - true, - )?; - let f = self.functions.get_mut(&name).unwrap(); - f.insert(&[], value, self.timestamp); - Ok(sort) - } - - // process the commands but don't run them pub fn process_commands( &mut self, - mut program: Vec, + program: Vec, + stop: CompilerPassStop, ) -> Result, Error> { let mut result = vec![]; - if let Some(Command::SetOption { - name, - value: Expr::Lit(Literal::Int(1)), - }) = program.first() - { - if name == &"enable_proofs".into() { - program = program.split_off(1); - for step in self.proof_state.proof_header() { - result.extend(self.process_command(step)?); - } - self.proofs_enabled = true; - } - } for command in program { match command { @@ -1125,7 +1144,7 @@ impl EGraph { } _ => {} } - result.extend(self.process_command(command)?); + result.extend(self.process_command(command, stop)?); } Ok(result) } @@ -1134,71 +1153,45 @@ impl EGraph { self.proof_state.desugar.number_underscores = underscores; } - fn process_command(&mut self, command: Command) -> Result, Error> { - let program_desugared = self.proof_state.desugar.desugar_program( + fn process_command( + &mut self, + command: Command, + stop: CompilerPassStop, + ) -> Result, Error> { + let program = self.proof_state.desugar.desugar_program( vec![command], self.test_proofs, self.seminaive, )?; + if stop == CompilerPassStop::Desugar { + return Ok(program); + } let type_info_before = self.proof_state.type_info.clone(); - self.proof_state - .type_info - .typecheck_program(&program_desugared)?; - - let program = if self.proofs_enabled { - // proofs require type info, so - // we need to pass in the desugar - let proofs = self.proof_state.add_proofs(program_desugared); - - let final_desugared = - self.proof_state - .desugar - .desugar_program(proofs, false, self.seminaive)?; - - // revert back to the type info before - // proofs were added, typecheck again - self.proof_state.type_info = type_info_before; - self.proof_state - .type_info - .typecheck_program(&final_desugared)?; - final_desugared - } else { - program_desugared - }; - Ok(program) - } + self.proof_state.type_info.typecheck_program(&program)?; + if stop == CompilerPassStop::TypecheckDesugared { + return Ok(program); + } - pub fn enable_proofs(&mut self) { - let proofs_already_enabled = self.proofs_enabled; - self.proofs_enabled = true; - if !proofs_already_enabled && self.proofs_enabled { - self.proofs_enabled = false; - self.run_program(self.proof_state.proof_header()).unwrap(); - self.proofs_enabled = true; + // reset type info + self.proof_state.type_info = type_info_before; + self.proof_state.type_info.typecheck_program(&program)?; + if stop == CompilerPassStop::TypecheckTermEncoding { + return Ok(program); } + + Ok(program) } - pub fn run_program(&mut self, mut program: Vec) -> Result, Error> { + pub fn run_program(&mut self, program: Vec) -> Result, Error> { let mut msgs = vec![]; let should_run = true; - if let Some(Command::SetOption { - name, - value: Expr::Lit(Literal::Int(1)), - }) = program.first() - { - if name == &"enable_proofs".into() { - self.enable_proofs(); - program = program.split_off(1); - } - } - for command in program { // Important to process each command individually // because push and pop create new scopes - for processed in self.process_command(command)? { + for processed in self.process_command(command, CompilerPassStop::All)? { let msg = self.run_command(processed.command, should_run)?; if !msg.is_empty() { log::info!("{}", msg); @@ -1214,7 +1207,6 @@ impl EGraph { } // this is bad because we shouldn't inspect values like this, we should use type information - #[cfg(debug_assertions)] fn bad_find_value(&self, value: Value) -> Value { if let Some((tag, id)) = self.value_to_id(value) { Value::from_id(tag, self.find(id)) @@ -1224,11 +1216,11 @@ impl EGraph { } pub fn parse_program(&self, input: &str) -> Result, Error> { - self.proof_state.parse_program(input) + self.proof_state.desugar.parse_program(input) } pub fn parse_and_run_program(&mut self, input: &str) -> Result, Error> { - let parsed = self.proof_state.parse_program(input)?; + let parsed = self.proof_state.desugar.parse_program(input)?; self.run_program(parsed) } diff --git a/src/main.rs b/src/main.rs index 429a2b07..01c28154 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use clap::Parser; -use egglog::{EGraph, SerializeConfig}; +use egglog::{CompilerPassStop, EGraph, SerializeConfig}; use std::io::{self, BufRead, BufReader}; use std::path::PathBuf; @@ -9,9 +9,18 @@ struct Args { fact_directory: Option, #[clap(long)] naive: bool, - inputs: Vec, + #[clap(long)] + desugar: bool, + #[clap(long)] + resugar: bool, #[clap(long)] proofs: bool, + #[clap(long, default_value_t = CompilerPassStop::All)] + stop: CompilerPassStop, + // TODO remove this evil hack + #[clap(long, default_value_t = 3)] + num_underscores: usize, + inputs: Vec, #[clap(long)] to_json: bool, } @@ -28,10 +37,13 @@ fn main() { let mk_egraph = || { let mut egraph = EGraph::default(); + egraph.set_underscores_for_desugaring(args.num_underscores); egraph.fact_directory = args.fact_directory.clone(); egraph.seminaive = !args.naive; if args.proofs { - egraph.enable_proofs(); + egraph + .parse_and_run_program("(set-option enable_proofs 1)") + .unwrap(); } egraph }; @@ -64,16 +76,41 @@ fn main() { } for (idx, input) in args.inputs.iter().enumerate() { - let s = std::fs::read_to_string(input).unwrap_or_else(|_| { + let program_read = std::fs::read_to_string(input).unwrap_or_else(|_| { let arg = input.to_string_lossy(); panic!("Failed to read file {arg}") }); let mut egraph = mk_egraph(); - match egraph.parse_and_run_program(&s) { - Ok(_msgs) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1) + let already_enables = program_read.starts_with("(set-option enable_proofs 1)"); + let program = if args.proofs && !already_enables { + format!("(set-option enable_proofs 1)\n{}", program_read) + } else { + program_read + }; + + if args.desugar || args.resugar { + let parsed = egraph.parse_program(&program).unwrap(); + let desugared_str = egraph + .process_commands(parsed, args.stop) + .unwrap() + .into_iter() + .map(|x| { + if args.resugar { + x.resugar().to_string() + } else { + x.to_string() + } + }) + .collect::>() + .join("\n"); + println!("{}", desugared_str); + } else { + match egraph.parse_and_run_program(&program) { + Ok(_msgs) => {} + Err(err) => { + log::error!("{}", err); + std::process::exit(1) + } } } diff --git a/src/proofheader.egg b/src/proofheader.egg index 968cbf51..6c0b8842 100644 --- a/src/proofheader.egg +++ b/src/proofheader.egg @@ -156,4 +156,3 @@ :ruleset proof-extract__) - diff --git a/src/proofs.rs b/src/proofs.rs index ff661c15..beb2b13a 100644 --- a/src/proofs.rs +++ b/src/proofs.rs @@ -1,900 +1,11 @@ use crate::*; use crate::ast::desugar::Desugar; -use crate::typechecking::FuncType; - -use symbolic_expressions::Sexp; pub const RULE_PROOF_KEYWORD: &str = "rule-proof"; -// primitives don't need type info -fn make_ast_version_prim(name: Symbol) -> Symbol { - Symbol::from(format!("Ast{}__", name)) -} - -fn make_ast_version(proof_state: &mut ProofState, expr: &NormExpr) -> Symbol { - let NormExpr::Call(name, _) = expr; - let types = proof_state - .type_info - .typecheck_expr(proof_state.current_ctx, expr, true) - .unwrap(); - Symbol::from(format!( - "Ast{}_{}__", - name, - ListDisplay(types.input.iter().map(|sort| sort.name()), "_"), - )) -} - -fn make_rep_version(proof_state: &mut ProofState, expr: &NormExpr) -> Symbol { - let NormExpr::Call(name, _) = expr; - let types = proof_state - .type_info - .typecheck_expr(proof_state.current_ctx, expr, true) - .unwrap(); - Symbol::from(format!( - "Rep{}_{}__", - name, - ListDisplay(types.input.iter().map(|sort| sort.name()), "_"), - )) -} - -// representatives for primitive values -fn make_rep_version_prim(name: &Symbol) -> Symbol { - Symbol::from(format!("Rep{}__", name)) -} - -fn setup_primitives() -> Vec { - let mut commands = vec![]; - let fresh_types = TypeInfo::default(); - commands.extend(make_ast_primitives_sorts(&fresh_types)); - commands.extend(make_rep_primitive_sorts(&fresh_types)); - commands -} - -fn make_rep_primitive_sorts(type_info: &TypeInfo) -> Vec { - type_info - .sorts - .iter() - .map(|(name, _)| { - Command::Function(FunctionDecl { - name: make_rep_version_prim(name), - schema: Schema { - input: vec![*name], - output: "TrmPrf__".into(), - }, - // Right now we just union every proof of some primitive. - merge: None, - merge_action: vec![], - default: None, - cost: None, - }) - }) - .collect() -} - -fn make_ast_primitives_sorts(type_info: &TypeInfo) -> Vec { - type_info - .sorts - .iter() - .map(|(name, _)| { - Command::Function(FunctionDecl { - name: make_ast_version_prim(*name), - schema: Schema { - input: vec![*name], - output: "Ast__".into(), - }, - merge: None, - merge_action: vec![], - default: None, - cost: None, - }) - }) - .collect() -} -fn make_ast_function(proof_state: &mut ProofState, expr: &NormExpr) -> FunctionDecl { - let NormExpr::Call(_head, body) = expr; - FunctionDecl { - name: make_ast_version(proof_state, expr), - schema: Schema { - input: body.iter().map(|_sort| "Ast__".into()).collect(), - output: "Ast__".into(), - }, - merge: None, - merge_action: vec![], - default: None, - cost: None, - } -} - -fn merge_action(proof_state: &mut ProofState, types: FuncType) -> Vec { - let child1 = |i| Symbol::from(format!("c1_{}__", i)); - let child2 = |i| Symbol::from(format!("c2_{}__", i)); - - let mut congr_prf = Sexp::String("Null__".to_string()); - for i in 0..types.input.len() { - let current = types.input.len() - i - 1; - congr_prf = Sexp::List(vec![ - Sexp::String("Cons__".to_string()), - Sexp::List(vec![ - Sexp::String("DemandEq__".to_string()), - Sexp::String(child1(current).to_string()), - Sexp::String(child2(current).to_string()), - ]), - congr_prf, - ]); - } - let t1 = proof_state.get_fresh(); - let t2 = proof_state.get_fresh(); - let p1 = proof_state.get_fresh(); - - vec![ - format!("(let {t1} (TrmOf__ old))"), - format!("(let {t2} (TrmOf__ new))"), - format!("(let {p1} (PrfOf__ old))"), - ] - .into_iter() - .chain(types.input.iter().enumerate().flat_map(|(i, _sort)| { - vec![ - format!("(let {} (GetChild__ {t1} {}))", child1(i), i), - format!("(let {} (GetChild__ {t2} {}))", child2(i), i), - ] - })) - .chain(vec![ - format!("(let congr_prf__ (Congruence__ {p1} {}))", congr_prf), - format!("(let age__ (currentAge))"), - format!("(set (currentAge) (+ age__ 1))"), - format!("(set (EqGraph__ {t1} {t2}) (MakeProofWithAge__ congr_prf__ age__))"), - format!("(set (EqGraph__ {t2} {t1}) (MakeProofWithAge__ (Flip__ congr_prf__) age__))"), - ]) - .map(|s| proof_state.desugar.action_parser.parse(&s).unwrap()) - .collect() -} - -#[derive(Clone, Debug)] -pub(crate) struct ProofInfo { - // proof for each variable bound in an assignment (lhs or rhs) - pub var_term: HashMap, - // proofs for each variable - pub var_proof: HashMap, - pub rule_proof: Option, - pub rule_proof_ast: Option, -} - -// This function makes use of the property that the body is Norm -// variables appear at most once (including the rhs of assignments) -// besides when they appear in constraints -fn instrument_facts( - body: &Vec, - proof_state: &mut ProofState, - actions: &mut Vec, -) -> ProofInfo { - let mut info: ProofInfo = ProofInfo { - var_term: Default::default(), - var_proof: Default::default(), - rule_proof: None, - rule_proof_ast: None, - }; - - for fact in body { - match fact { - NormFact::AssignLit(lhs, rhs) => { - let literal_name = proof_state.literal_name(rhs); - let rep_trm = proof_state.get_fresh(); - let rep_prf = proof_state.get_fresh(); - actions.push(NormAction::Let( - rep_trm, - NormExpr::Call(make_ast_version_prim(literal_name), vec![*lhs]), - )); - actions.push(NormAction::Let( - rep_prf, - NormExpr::Call("ComputePrim__".into(), vec![rep_trm]), - )); - - info.var_term.insert(*lhs, rep_trm); - assert!(info.var_proof.insert(*lhs, rep_prf).is_none()); - } - NormFact::Assign(lhs, NormExpr::Call(head, body)) - if proof_state.type_info.is_primitive(*head) => - { - // child terms should already exist if we are computing something - let rep_trm = proof_state.get_fresh(); - actions.push(NormAction::Let( - rep_trm, - NormExpr::Call( - make_ast_version(proof_state, &NormExpr::Call(*head, body.clone())), - body.iter() - .map(|v| get_var_term(*v, proof_state, &info)) - .collect(), - ), - )); - - let rep_prf = proof_state.get_fresh(); - - actions.push(NormAction::Let( - rep_prf, - NormExpr::Call("ComputePrim__".into(), vec![rep_trm]), - )); - info.var_term.insert(*lhs, rep_trm); - info.var_proof.insert(*lhs, rep_prf); - } - NormFact::Assign(lhs, NormExpr::Call(head, body)) => { - let rep = proof_state.get_fresh(); - let rep_trm = proof_state.get_fresh(); - let rep_prf = proof_state.get_fresh(); - - actions.push(NormAction::Let( - rep, - NormExpr::Call( - make_rep_version(proof_state, &NormExpr::Call(*head, body.clone())), - body.clone(), - ), - )); - actions.push(NormAction::Let( - rep_trm, - NormExpr::Call("TrmOf__".into(), vec![rep]), - )); - actions.push(NormAction::Let( - rep_prf, - NormExpr::Call("PrfOf__".into(), vec![rep]), - )); - - info.var_term.insert(*lhs, rep_trm); - assert!(info.var_proof.insert(*lhs, rep_prf).is_none()); - - for (i, child) in body.iter().enumerate() { - let child_trm = proof_state.get_fresh(); - let const_var = proof_state.get_fresh(); - actions.push(NormAction::LetLit(const_var, Literal::Int(i as i64))); - actions.push(NormAction::Let( - child_trm, - NormExpr::Call("GetChild__".into(), vec![rep_trm, const_var]), - )); - info.var_term.insert(*child, child_trm); - } - } - NormFact::ConstrainEq(lhs, rhs) => { - // variables need to have an ast so they can be computed on - // but if they are used in a binding instead of a computation, this proof and term is overridden in the hashmap - if let Some(term) = get_var_term_option(*rhs, proof_state, &info) { - if get_var_term_option(*lhs, proof_state, &info).is_none() { - assert!(info.var_term.insert(*lhs, term).is_none()); - } - } else if let Some(term) = get_var_term_option(*lhs, proof_state, &info) { - if get_var_term_option(*rhs, proof_state, &info).is_none() { - assert!(info.var_term.insert(*rhs, term).is_none()); - } - } else { - panic!( - "Contraint without representative term for at least one side {} = {}", - lhs, rhs - ); - } - } - } - } - - // now fill in representitive terms for any aliases - for fact in body { - if let NormFact::ConstrainEq(lhs, rhs) = fact { - let lhsterm = get_var_term_option(*lhs, proof_state, &info); - let rhsterm = get_var_term_option(*rhs, proof_state, &info); - if let Some(rep_term) = lhsterm { - if rhsterm.is_none() { - info.var_term.insert(*rhs, rep_term); - } - } else if let Some(rep_term) = rhsterm { - info.var_term.insert(*lhs, rep_term); - } else { - panic!( - "Contraint without representative term for at least one side {} = {}", - lhs, rhs - ); - } - } - } - - info -} - -fn get_var_term_option( - var: Symbol, - proof_state: &ProofState, - proof_info: &ProofInfo, -) -> Option { - if var == RULE_PROOF_KEYWORD.into() { - return Some(proof_info.rule_proof_ast.unwrap()); - } - proof_info - .var_term - .get(&var) - .or_else(|| proof_state.global_var_ast.get(&var)) - .cloned() -} - -fn get_var_term(var: Symbol, proof_state: &ProofState, proof_info: &ProofInfo) -> Symbol { - get_var_term_option(var, proof_state, proof_info).unwrap() -} - -fn add_eqgraph_equality( - proof_state: &mut ProofState, - proof_info: &mut ProofInfo, - astvar1: Symbol, - astvar2: Symbol, - res: &mut Vec, -) { - let rule_proof = proof_info.rule_proof.unwrap(); - let proof_with_age = proof_state.get_fresh(); - res.push(NormAction::Let( - proof_with_age, - NormExpr::Call( - "MakeProofWithAge__".into(), - vec![rule_proof, "age__".into()], - ), - )); - res.push(NormAction::Set( - NormExpr::Call("EqGraph__".into(), vec![astvar1, astvar2]), - proof_with_age, - )); - res.push(NormAction::Set( - NormExpr::Call("EqGraph__".into(), vec![astvar2, astvar1]), - proof_with_age, - )); -} - -fn make_expr_ast( - proof_state: &mut ProofState, - proof_info: &ProofInfo, - expr: &NormExpr, - res: &mut Vec, -) -> Symbol { - let NormExpr::Call(head, body) = expr; - let newterm = proof_state.get_fresh(); - // make the term for this variable - res.push(NormAction::Let( - newterm, - NormExpr::Call( - make_ast_version(proof_state, &NormExpr::Call(*head, body.clone())), - body.iter() - .map(|v| get_var_term(*v, proof_state, proof_info)) - .collect(), - ), - )); - - newterm -} - -fn make_expr_rep( - proof_state: &mut ProofState, - proof_info: &ProofInfo, - expr: &NormExpr, - res: &mut Vec, -) -> Symbol { - let NormExpr::Call(head, body) = expr; - let newterm = make_expr_ast(proof_state, proof_info, expr, res); - - let ruletrm = proof_state.get_fresh(); - res.push(NormAction::Let( - ruletrm, - NormExpr::Call( - "RuleTerm__".into(), - vec![proof_info.rule_proof.unwrap(), newterm], - ), - )); - - let trmprf = proof_state.get_fresh(); - res.push(NormAction::Let( - trmprf, - NormExpr::Call("MakeTrmPrf__".into(), vec![newterm, ruletrm]), - )); - - res.push(NormAction::Set( - NormExpr::Call( - make_rep_version(proof_state, &NormExpr::Call(*head, body.clone())), - body.clone(), - ), - trmprf, - )); - newterm -} - -fn add_action_proof( - proof_info: &mut ProofInfo, - action: &NormAction, - res: &mut Vec, - proof_state: &mut ProofState, -) { - match action { - NormAction::LetVar(var1, var2) => { - // update var1's term - proof_info - .var_term - .insert(*var1, get_var_term(*var2, proof_state, proof_info)); - } - NormAction::Delete(..) | NormAction::Panic(..) => (), - NormAction::Union(var1, var2) => { - add_eqgraph_equality( - proof_state, - proof_info, - get_var_term(*var1, proof_state, proof_info), - get_var_term(*var2, proof_state, proof_info), - res, - ); - } - NormAction::Set(expr, rhs) => { - let new_term = make_expr_rep(proof_state, proof_info, expr, res); - // add to the equality graph when we set things equal to each other - add_eqgraph_equality( - proof_state, - proof_info, - new_term, - get_var_term(*rhs, proof_state, proof_info), - res, - ) - } - NormAction::Let(lhs, expr) => { - let ast = make_expr_rep(proof_state, proof_info, expr, res); - proof_info.var_term.insert(*lhs, ast); - } - // very similar to let case - NormAction::LetLit(lhs, lit) => { - let newterm = proof_state.get_fresh(); - // make the term for this variable - res.push(NormAction::Let( - newterm, - NormExpr::Call( - make_ast_version_prim(proof_state.literal_name(lit)), - vec![*lhs], - ), - )); - proof_info.var_term.insert(*lhs, newterm); - - let ruletrm = proof_state.get_fresh(); - res.push(NormAction::Let( - ruletrm, - NormExpr::Call( - "RuleTerm__".into(), - vec![proof_info.rule_proof.unwrap(), newterm], - ), - )); - - let trmprf = proof_state.get_fresh(); - res.push(NormAction::Let( - trmprf, - NormExpr::Call("MakeTrmPrf__".into(), vec![newterm, ruletrm]), - )); - - res.push(NormAction::Set( - NormExpr::Call( - make_rep_version_prim(&proof_state.literal_name(lit)), - vec![*lhs], - ), - trmprf, - )); - } - } -} - -fn add_rule_proof( - rule_name: Symbol, - proof_info: &ProofInfo, - facts: &Vec, - res: &mut Vec, - proof_state: &mut ProofState, -) -> Symbol { - let mut current_proof = proof_state.get_fresh(); - res.push(NormAction::LetVar(current_proof, "Null__".into())); - - for fact in facts { - match fact { - NormFact::Assign(lhs, _rhs) => { - let fresh = proof_state.get_fresh(); - res.push(NormAction::Let( - fresh, - NormExpr::Call( - "Cons__".into(), - vec![proof_info.var_proof[lhs], current_proof], - ), - )); - current_proof = fresh; - } - // same as Assign case - NormFact::AssignLit(lhs, _rhs) => { - let fresh = proof_state.get_fresh(); - res.push(NormAction::Let( - fresh, - NormExpr::Call( - "Cons__".into(), - vec![proof_info.var_proof[lhs], current_proof], - ), - )); - current_proof = fresh; - } - NormFact::ConstrainEq(lhs, rhs) => { - let pfresh = proof_state.get_fresh(); - res.push(NormAction::Let( - pfresh, - NormExpr::Call( - "DemandEq__".into(), - vec![ - get_var_term(*lhs, proof_state, proof_info), - get_var_term(*rhs, proof_state, proof_info), - ], - ), - )); - - let fresh = proof_state.get_fresh(); - res.push(NormAction::Let( - fresh, - NormExpr::Call("Cons__".into(), vec![pfresh, current_proof]), - )); - current_proof = fresh; - } - } - } - - let name_const = proof_state.get_fresh(); - res.push(NormAction::LetLit(name_const, Literal::String(rule_name))); - let rule_proof = proof_state.get_fresh(); - res.push(NormAction::Let( - rule_proof, - NormExpr::Call("Rule__".into(), vec![current_proof, name_const]), - )); - rule_proof -} - -// replace the rule-proof keyword with the proof of the rule -fn replace_rule_proof(actions: &[NormAction], rule_proof: Symbol) -> Vec { - actions - .iter() - .map(|action| { - action.map_def_use(&mut |var, _isdef| { - if var == RULE_PROOF_KEYWORD.into() { - rule_proof - } else { - var - } - }) - }) - .collect() -} - -fn add_age_variable(proof_state: &mut ProofState) -> Vec { - let mut res = vec![]; - let age_var = Symbol::from("age__"); - res.push(NormAction::Let( - age_var, - NormExpr::Call("currentAge".into(), vec![]), - )); - let one = proof_state.get_fresh(); - res.push(NormAction::LetLit(one, Literal::Int(1))); - let next_age = proof_state.get_fresh(); - res.push(NormAction::Let( - next_age, - NormExpr::Call("+".into(), vec![age_var, one]), - )); - res.push(NormAction::Set( - NormExpr::Call("currentAge".into(), vec![]), - next_age, - )); - - res -} - -fn instrument_rule(rule: &NormRule, rule_name: Symbol, proof_state: &mut ProofState) -> Rule { - let mut actions = vec![]; - actions.extend(add_age_variable(proof_state)); - let info = instrument_facts(&rule.body, proof_state, &mut actions); - let rule_proof = add_rule_proof(rule_name, &info, &rule.body, &mut actions, proof_state); - - let rule_proof_ast = proof_state.get_fresh(); - actions.push(NormAction::Let( - rule_proof_ast, - NormExpr::Call("AstProof__".into(), vec![rule_proof]), - )); - - actions.extend(replace_rule_proof(&rule.head, rule_proof)); - - // make a new proofinfo with the rule_proof symbol added - let mut proof_info = ProofInfo { - var_term: info.var_term, - var_proof: info.var_proof, - rule_proof: Some(rule_proof), - rule_proof_ast: Some(rule_proof_ast), - }; - - for action in &rule.head { - add_action_proof(&mut proof_info, action, &mut actions, proof_state); - } - - NormRule { - head: actions, - body: rule.body.clone(), - } - .to_rule() -} - -fn make_rep_function(proof_state: &mut ProofState, expr: &NormExpr) -> FunctionDecl { - let types = proof_state - .type_info - .typecheck_expr(proof_state.current_ctx, expr, true) - .unwrap(); - FunctionDecl { - name: make_rep_version(proof_state, expr), - schema: Schema { - input: types.input.iter().map(|sort| sort.name()).collect(), - output: "TrmPrf__".into(), - }, - merge: Some(Expr::Var("old".into())), - // Merge action is only needed if the merge function is union - merge_action: if types.has_merge { - vec![] - } else { - merge_action(proof_state, types) - }, - default: None, - cost: None, - } -} - -fn make_getchild_rule(proof_state: &mut ProofState, expr: &NormExpr) -> Command { - let NormExpr::Call(_name, body) = expr; - let getchild = |i| Symbol::from(format!("c{}__", i)); - Command::Rule { - ruleset: "proofrules__".into(), - name: "".into(), - rule: Rule { - body: vec![Fact::Eq(vec![ - Expr::Var("ast__".into()), - Expr::Call( - make_ast_version(proof_state, expr), - body.iter() - .enumerate() - .map(|(i, _)| Expr::Var(getchild(i))) - .collect(), - ), - ])], - head: body - .iter() - .enumerate() - .map(|(i, _s)| { - Action::Set( - "GetChild__".into(), - vec![Expr::Var("ast__".into()), Expr::Lit(Literal::Int(i as i64))], - Expr::Var(getchild(i)), - ) - }) - .collect(), - }, - } -} - #[derive(Default, Clone)] pub(crate) struct ProofState { - pub(crate) global_var_ast: HashMap, - pub(crate) ast_funcs_created: HashSet, - pub(crate) current_ctx: CommandId, pub(crate) desugar: Desugar, pub(crate) type_info: TypeInfo, } - -fn make_rep_command(proof_state: &mut ProofState, lhs: Symbol, expr: &NormExpr) -> Vec { - let NormExpr::Call(head, body) = expr; - let ast_var = proof_state.get_fresh(); - let ast_action = format!( - "(let {} ({} {}))", - ast_var, - make_ast_version(proof_state, &NormExpr::Call(*head, body.clone())), - ListDisplay(body.iter().map(|e| { proof_state.global_var_ast[e] }), " ") - ); - proof_state.global_var_ast.insert(lhs, ast_var); - let rep = make_rep_version(proof_state, expr); - vec![ - Command::Action( - proof_state - .desugar - .action_parser - .parse(&ast_action) - .unwrap(), - ), - Command::Action( - proof_state - .desugar - .action_parser - .parse(&format!( - "(set ({} {}) - (MakeTrmPrf__ {} (Original__ {})))", - rep, - ListDisplay(body, " "), - ast_var, - ast_var - )) - .unwrap(), - ), - ] -} - -fn proof_original_action(action: &NormAction, proof_state: &mut ProofState) -> Vec { - match action { - NormAction::Let(lhs, expr) => make_rep_command(proof_state, *lhs, expr), - NormAction::LetVar(var1, var2) => { - proof_state - .global_var_ast - .insert(*var1, proof_state.global_var_ast[var2]); - vec![] - } - NormAction::LetLit(lhs, literal) => { - let ast_var = proof_state.get_fresh(); - proof_state.global_var_ast.insert(*lhs, ast_var); - vec![ - Command::Action( - proof_state - .desugar - .action_parser - .parse(&format!( - "(let {} ({} {}))", - ast_var, - make_ast_version_prim(proof_state.literal_name(literal)), - literal - )) - .unwrap(), - ), - Command::Action( - proof_state - .desugar - .action_parser - .parse(&format!( - "(set ({} {}) - (MakeTrmPrf__ {} (Original__ {})))", - make_rep_version_prim(&proof_state.literal_name(literal)), - literal, - ast_var, - ast_var - )) - .unwrap(), - ), - ] - } - NormAction::Set(expr, var) => { - let fresh = proof_state.get_fresh(); - let mut rep_commands = make_rep_command(proof_state, fresh, expr); - - rep_commands.push(Command::Action( - proof_state - .desugar - .action_parser - .parse(&format!( - "(set (EqGraph__ {} {}) (MakeProofWithAge__ (OriginalEq__ {} {}) 0))", - proof_state.global_var_ast[&fresh], - proof_state.global_var_ast[var], - proof_state.global_var_ast[&fresh], - proof_state.global_var_ast[var] - )) - .unwrap(), - )); - rep_commands - } - NormAction::Union(var1, var2) => { - vec![Command::Action( - proof_state - .desugar - .action_parser - .parse(&format!( - "(set (EqGraph__ {} {}) (MakeProofWithAge__ (OriginalEq__ {} {}) 0))", - proof_state.global_var_ast[var1], - proof_state.global_var_ast[var2], - proof_state.global_var_ast[var1], - proof_state.global_var_ast[var2] - )) - .unwrap(), - )] - } - NormAction::Delete(..) | NormAction::Panic(..) => vec![], - } -} - -fn instrument_schedule(schedule: &NormSchedule) -> Schedule { - match schedule { - NormSchedule::Saturate(schedule) => { - Schedule::Saturate(Box::new(instrument_schedule(schedule))) - } - NormSchedule::Repeat(times, schedule) => { - Schedule::Repeat(*times, Box::new(instrument_schedule(schedule))) - } - // We only do anything in the run case - NormSchedule::Run(run_config) => Schedule::Sequence(vec![ - Schedule::Saturate(Box::new(Schedule::Run(RunConfig { - ruleset: "proofrules__".into(), - until: None, - limit: 1, - }))), - Schedule::Run(run_config.to_run_config()), - ]), - NormSchedule::Sequence(schedules) => { - Schedule::Sequence(schedules.iter().map(instrument_schedule).collect()) - } - } -} - -impl ProofState { - pub fn parse_program(&self, input: &str) -> Result, Error> { - self.desugar.parse_program(input) - } - - // TODO we need to also instrument merge actions and merge because they can add new terms that need representatives - // the egraph is the initial egraph with only default sorts - pub(crate) fn add_proofs(&mut self, program: Vec) -> Vec { - let mut res = vec![]; - - for command in program { - self.current_ctx = command.metadata.id; - - // first, set up any rep functions that we need - command.command.map_exprs(&mut |expr| { - let ast_name = make_ast_version(self, expr); - if self.ast_funcs_created.insert(ast_name) { - let commands = vec![ - Command::Function(make_ast_function(self, expr)), - Command::Function(make_rep_function(self, expr)), - make_getchild_rule(self, expr), - ]; - res.extend(commands); - } - expr.clone() - }); - - match &command.command { - NCommand::Push(_num) => { - res.push(command.to_command()); - } - NCommand::Sort(_name, _presort_and_args) => { - res.push(command.to_command()); - } - NCommand::Function(_fdecl) => { - res.push(command.to_command()); - } - NCommand::NormRule { - ruleset, - name, - rule, - } => { - res.push(Command::Rule { - ruleset: *ruleset, - name: *name, - rule: instrument_rule(rule, *name, self), - }); - } - NCommand::NormAction(action) => { - res.push(Command::Action(action.to_action())); - res.extend(proof_original_action(action, self)); - } - NCommand::Check(_facts) => { - res.push(command.to_command()); - } - NCommand::RunSchedule(schedule) => { - res.push(Command::RunSchedule(instrument_schedule(schedule))); - } - _ => res.push(command.to_command()), - } - } - - res - } - - pub(crate) fn get_fresh(&mut self) -> Symbol { - self.desugar.get_fresh() - } - - pub(crate) fn proof_header(&self) -> Vec { - let str = include_str!("proofheader.egg"); - let rest_of_header = setup_primitives(); - self.parse_program(str) - .unwrap() - .into_iter() - .chain(rest_of_header) - .collect() - } - - pub(crate) fn literal_name(&self, lit: &Literal) -> Symbol { - self.type_info.infer_literal(lit).name() - } -} diff --git a/src/sort/i64.rs b/src/sort/i64.rs index 314bbd03..9754f50b 100644 --- a/src/sort/i64.rs +++ b/src/sort/i64.rs @@ -26,31 +26,36 @@ impl Sort for I64Sort { // We need the closure for division and mod operations, as they can panic. // cf https://github.com/rust-lang/rust-clippy/issues/9422 #[allow(clippy::unnecessary_lazy_evaluations)] - fn register_primitives(self: Arc, eg: &mut TypeInfo) { + fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { + typeinfo.add_primitive(TermOrderingMin { + }); + typeinfo.add_primitive(TermOrderingMax { + }); + type Opt = Option; - add_primitives!(eg, "+" = |a: i64, b: i64| -> i64 { a + b }); - add_primitives!(eg, "-" = |a: i64, b: i64| -> i64 { a - b }); - add_primitives!(eg, "*" = |a: i64, b: i64| -> i64 { a * b }); - add_primitives!(eg, "/" = |a: i64, b: i64| -> Opt { (b != 0).then(|| a / b) }); - add_primitives!(eg, "%" = |a: i64, b: i64| -> Opt { (b != 0).then(|| a % b) }); + add_primitives!(typeinfo, "+" = |a: i64, b: i64| -> i64 { a + b }); + add_primitives!(typeinfo, "-" = |a: i64, b: i64| -> i64 { a - b }); + add_primitives!(typeinfo, "*" = |a: i64, b: i64| -> i64 { a * b }); + add_primitives!(typeinfo, "/" = |a: i64, b: i64| -> Opt { (b != 0).then(|| a / b) }); + add_primitives!(typeinfo, "%" = |a: i64, b: i64| -> Opt { (b != 0).then(|| a % b) }); - add_primitives!(eg, "&" = |a: i64, b: i64| -> i64 { a & b }); - add_primitives!(eg, "|" = |a: i64, b: i64| -> i64 { a | b }); - add_primitives!(eg, "^" = |a: i64, b: i64| -> i64 { a ^ b }); - add_primitives!(eg, "<<" = |a: i64, b: i64| -> Opt { b.try_into().ok().and_then(|b| a.checked_shl(b)) }); - add_primitives!(eg, ">>" = |a: i64, b: i64| -> Opt { b.try_into().ok().and_then(|b| a.checked_shr(b)) }); - add_primitives!(eg, "not-i64" = |a: i64| -> i64 { !a }); + add_primitives!(typeinfo, "&" = |a: i64, b: i64| -> i64 { a & b }); + add_primitives!(typeinfo, "|" = |a: i64, b: i64| -> i64 { a | b }); + add_primitives!(typeinfo, "^" = |a: i64, b: i64| -> i64 { a ^ b }); + add_primitives!(typeinfo, "<<" = |a: i64, b: i64| -> Opt { b.try_into().ok().and_then(|b| a.checked_shl(b)) }); + add_primitives!(typeinfo, ">>" = |a: i64, b: i64| -> Opt { b.try_into().ok().and_then(|b| a.checked_shr(b)) }); + add_primitives!(typeinfo, "not-i64" = |a: i64| -> i64 { !a }); - add_primitives!(eg, "log2" = |a: i64| -> i64 { (a as i64).ilog2() as i64 }); + add_primitives!(typeinfo, "log2" = |a: i64| -> i64 { (a as i64).ilog2() as i64 }); - add_primitives!(eg, "<" = |a: i64, b: i64| -> Opt { (a < b).then(|| ()) }); - add_primitives!(eg, ">" = |a: i64, b: i64| -> Opt { (a > b).then(|| ()) }); - add_primitives!(eg, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then(|| ()) }); - add_primitives!(eg, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then(|| ()) }); + add_primitives!(typeinfo, "<" = |a: i64, b: i64| -> Opt { (a < b).then(|| ()) }); + add_primitives!(typeinfo, ">" = |a: i64, b: i64| -> Opt { (a > b).then(|| ()) }); + add_primitives!(typeinfo, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then(|| ()) }); + add_primitives!(typeinfo, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then(|| ()) }); - add_primitives!(eg, "min" = |a: i64, b: i64| -> i64 { a.min(b) }); - add_primitives!(eg, "max" = |a: i64, b: i64| -> i64 { a.max(b) }); + add_primitives!(typeinfo, "min" = |a: i64, b: i64| -> i64 { a.min(b) }); + add_primitives!(typeinfo, "max" = |a: i64, b: i64| -> i64 { a.max(b) }); } fn make_expr(&self, _egraph: &EGraph, value: Value) -> Expr { diff --git a/src/sort/macros.rs b/src/sort/macros.rs index c6cc5ca5..a892264c 100644 --- a/src/sort/macros.rs +++ b/src/sort/macros.rs @@ -63,7 +63,7 @@ macro_rules! add_primitives { // println!(") = {result:?}"); result.store(&self.__out) } else { - panic!() + panic!("wrong number of arguments") } } } diff --git a/src/sort/map.rs b/src/sort/map.rs index d85ec3c0..b9c86206 100644 --- a/src/sort/map.rs +++ b/src/sort/map.rs @@ -45,6 +45,19 @@ impl MapSort { } } +impl MapSort { + pub fn presort_names() -> Vec { + vec![ + "map-empty".into(), + "map-insert".into(), + "map-get".into(), + "map-not-contains".into(), + "map-contains".into(), + "map-remove".into(), + ] + } +} + impl Sort for MapSort { fn name(&self) -> Symbol { self.name @@ -73,22 +86,8 @@ impl Sort for MapSort { result } - fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { - let maps = self.maps.lock().unwrap(); - let map = maps.get_index(value.bits as usize).unwrap(); - let mut changed = false; - let new_map: ValueMap = map - .iter() - .map(|(k, v)| { - let (mut k, mut v) = (*k, *v); - changed |= self.key.canonicalize(&mut k, unionfind); - changed |= self.value.canonicalize(&mut v, unionfind); - (k, v) - }) - .collect(); - drop(maps); - *value = new_map.store(self).unwrap(); - changed + fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool { + false } fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { @@ -123,10 +122,14 @@ impl Sort for MapSort { fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr { let map = ValueMap::load(self, &value); let mut expr = Expr::call("map-empty", []); + let mut termdag = TermDag::default(); for (k, v) in map.iter().rev() { - let k = egraph.extract(*k, &self.key).1; - let v = egraph.extract(*v, &self.value).1; - expr = Expr::call("map-insert", [expr, k, v]) + let k = egraph.extract(*k, &mut termdag, &self.key).1; + let v = egraph.extract(*v, &mut termdag, &self.value).1; + expr = Expr::call( + "map-insert", + [expr, termdag.term_to_expr(&k), termdag.term_to_expr(&v)], + ) } expr } @@ -157,6 +160,54 @@ struct Ctor { map: Arc, } +pub(crate) struct TermOrderingMin {} + +impl PrimitiveLike for TermOrderingMin { + fn name(&self) -> Symbol { + "ordering-min".into() + } + + fn accept(&self, types: &[ArcSort]) -> Option { + match types { + [a, b] if a.name() == b.name() => Some(a.clone()), + _ => None, + } + } + + fn apply(&self, values: &[Value]) -> Option { + assert_eq!(values.len(), 2); + if values[0] < values[1] { + Some(values[0]) + } else { + Some(values[1]) + } + } +} + +pub(crate) struct TermOrderingMax {} + +impl PrimitiveLike for TermOrderingMax { + fn name(&self) -> Symbol { + "ordering-max".into() + } + + fn accept(&self, types: &[ArcSort]) -> Option { + match types { + [a, b] if a.name() == b.name() => Some(a.clone()), + _ => None, + } + } + + fn apply(&self, values: &[Value]) -> Option { + assert_eq!(values.len(), 2); + if values[0] > values[1] { + Some(values[0]) + } else { + Some(values[1]) + } + } +} + impl PrimitiveLike for Ctor { fn name(&self) -> Symbol { self.name diff --git a/src/sort/set.rs b/src/sort/set.rs index 2126acb1..7f82c8e9 100644 --- a/src/sort/set.rs +++ b/src/sort/set.rs @@ -42,6 +42,22 @@ impl SetSort { } } +impl SetSort { + pub fn presort_names() -> Vec { + vec![ + "set-of".into(), + "set-empty".into(), + "set-insert".into(), + "set-not-contains".into(), + "set-contains".into(), + "set-remove".into(), + "set-union".into(), + "set-diff".into(), + "set-intersect".into(), + ] + } +} + impl Sort for SetSort { fn name(&self) -> Symbol { self.name @@ -70,21 +86,8 @@ impl Sort for SetSort { result } - fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { - let sets = self.sets.lock().unwrap(); - let set = sets.get_index(value.bits as usize).unwrap(); - let mut changed = false; - let new_set: ValueSet = set - .iter() - .map(|e| { - let mut e = *e; - changed |= self.element.canonicalize(&mut e, unionfind); - e - }) - .collect(); - drop(sets); - *value = new_set.store(self).unwrap(); - changed + fn canonicalize(&self, _value: &mut Value, _unionfind: &UnionFind) -> bool { + false } fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { @@ -131,9 +134,10 @@ impl Sort for SetSort { fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr { let set = ValueSet::load(self, &value); let mut expr = Expr::call("set-empty", []); + let mut termdag = TermDag::default(); for e in set.iter().rev() { - let e = egraph.extract(*e, &self.element).1; - expr = Expr::call("set-insert", [expr, e]) + let e = egraph.extract(*e, &mut termdag, &self.element).1; + expr = Expr::call("set-insert", [expr, termdag.term_to_expr(&e)]) } expr } @@ -179,7 +183,7 @@ impl PrimitiveLike for SetOf { fn apply(&self, values: &[Value]) -> Option { let set = ValueSet::from_iter(values.iter().copied()); - set.store(&self.set) + Some(set.store(&self.set).unwrap()) } } diff --git a/src/sort/vec.rs b/src/sort/vec.rs index ce1baf00..60c3ec3c 100644 --- a/src/sort/vec.rs +++ b/src/sort/vec.rs @@ -16,6 +16,20 @@ impl VecSort { self.element.name() } + pub fn presort_names() -> Vec { + vec![ + "vec-of".into(), + "vec-append".into(), + "vec-empty".into(), + "vec-push".into(), + "vec-pop".into(), + "vec-not-contains".into(), + "vec-contains".into(), + "vec-length".into(), + "vec-get".into(), + ] + } + pub fn make_sort( typeinfo: &mut TypeInfo, name: Symbol, @@ -36,7 +50,7 @@ impl VecSort { vecs: Default::default(), })) } else { - panic!() + panic!("Vec sort must have sort as argument. Got {:?}", args) } } } @@ -132,9 +146,10 @@ impl Sort for VecSort { fn make_expr(&self, egraph: &EGraph, value: Value) -> Expr { let vec = ValueVec::load(self, &value); let mut expr = Expr::call("vec-empty", []); + let mut termdag = TermDag::default(); for e in vec.iter().rev() { - let e = egraph.extract(*e, &self.element).1; - expr = Expr::call("vec-push", [expr, e]) + let e = egraph.extract(*e, &mut termdag, &self.element).1; + expr = Expr::call("vec-push", [expr, termdag.term_to_expr(&e)]) } expr } diff --git a/src/termdag.rs b/src/termdag.rs new file mode 100644 index 00000000..ab49d654 --- /dev/null +++ b/src/termdag.rs @@ -0,0 +1,148 @@ +use crate::{ + ast::{Expr, Literal}, + util::{HashMap, HashSet}, + Symbol, +}; + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum Term { + Lit(Literal), + Var(Symbol), + App(Symbol, Vec), +} + +#[derive(Clone, PartialEq, Eq, Debug, Default)] +pub struct TermDag { + nodes: Vec, + hashcons: HashMap, +} + +#[macro_export] +macro_rules! match_term_app { + ($e:expr; { $( + ($head:expr, $args:pat) => $body:expr $(,)? + ),*}) => { + match $e { + Term::App(head, args) => { + $( + if head.as_str() == $head { + match args.as_slice() { + $args => $body, + _ => panic!("arg mismatch"), + } + } else + )* { + panic!("Failed to match any of the heads of the patterns. Got: {}", head); + } + } + _ => panic!("not an app") + } + } +} + +impl TermDag { + pub fn size(&self) -> usize { + self.nodes.len() + } + + // users can't construct a termnode, so just + // look it up + pub fn lookup(&self, node: &Term) -> usize { + *self.hashcons.get(node).unwrap() + } + + pub fn get(&self, idx: usize) -> Term { + self.nodes[idx].clone() + } + + pub fn make(&mut self, sym: Symbol, children: Vec) -> Term { + let node = Term::App(sym, children.iter().map(|c| self.lookup(c)).collect()); + + self.add_node(&node); + + node + } + + fn add_node(&mut self, node: &Term) { + if self.hashcons.get(node).is_none() { + let idx = self.nodes.len(); + self.nodes.push(node.clone()); + self.hashcons.insert(node.clone(), idx); + } + } + + pub fn expr_to_term(&mut self, expr: &Expr) -> Term { + let res = match expr { + Expr::Lit(lit) => Term::Lit(lit.clone()), + Expr::Var(v) => Term::Var(*v), + Expr::Call(op, args) => { + let args = args + .iter() + .map(|a| { + let term = self.expr_to_term(a); + self.lookup(&term) + }) + .collect(); + Term::App(*op, args) + } + }; + self.add_node(&res); + res + } + + pub fn term_to_expr(&mut self, term: &Term) -> Expr { + match term { + Term::Lit(lit) => Expr::Lit(lit.clone()), + Term::Var(v) => Expr::Var(*v), + Term::App(op, args) => { + let args = args + .iter() + .map(|a| { + let term = self.get(*a); + self.term_to_expr(&term) + }) + .collect(); + Expr::Call(*op, args) + } + } + } + + pub fn to_string(&self, term: &Term) -> String { + let mut stored = HashMap::::default(); + let mut seen = HashSet::::default(); + let id = self.lookup(term); + // use a stack to avoid stack overflow + let mut stack = vec![id]; + while !stack.is_empty() { + let next = stack.pop().unwrap(); + + match self.nodes[next].clone() { + Term::App(name, children) => { + if seen.contains(&next) { + let mut str = String::new(); + str.push_str(&format!("({}", name)); + for c in children.iter() { + str.push_str(&format!(" {}", stored[c])); + } + str.push(')'); + stored.insert(next, str); + } else { + seen.insert(next); + stack.push(next); + for c in children.iter().rev() { + stack.push(*c); + } + } + } + Term::Lit(lit) => { + stored.insert(next, format!("{}", lit)); + } + Term::Var(v) => { + stored.insert(next, format!("{}", v)); + } + } + } + + stored.get(&id).unwrap().clone() + } +} diff --git a/src/typecheck.rs b/src/typecheck.rs index 25533cb7..21ad7173 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -11,7 +11,7 @@ pub struct Context<'a> { nodes: HashMap, } -#[derive(Hash, Eq, PartialEq)] +#[derive(Hash, Eq, PartialEq, Clone)] enum ENode { Func(Symbol, Vec), Prim(Primitive, Vec), @@ -23,6 +23,7 @@ enum ENode { pub enum AtomTerm { Var(Symbol), Value(Value), + Global(Symbol), } impl std::fmt::Display for AtomTerm { @@ -30,6 +31,7 @@ impl std::fmt::Display for AtomTerm { match self { AtomTerm::Var(v) => write!(f, "{}", v), AtomTerm::Value(_) => write!(f, ""), + AtomTerm::Global(g) => write!(f, "{}", g), } } } @@ -55,14 +57,14 @@ pub struct Query { impl std::fmt::Display for Query { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for atom in &self.atoms { - write!(f, "{atom}")?; + writeln!(f, "{atom}")?; } if !self.filters.is_empty() { - write!(f, "where ")?; + writeln!(f, "where ")?; for filter in &self.filters { - write!( + writeln!( f, - "({} {}) ", + "({} {})", filter.head.name(), ListDisplay(&filter.args, " ") )?; @@ -77,10 +79,35 @@ impl Atom { self.args.iter().filter_map(|t| match t { AtomTerm::Var(v) => Some(*v), AtomTerm::Value(_) => None, + AtomTerm::Global(_) => None, }) } } +pub(crate) struct ValueEq {} + +impl PrimitiveLike for ValueEq { + fn name(&self) -> Symbol { + "value-eq".into() + } + + fn accept(&self, types: &[ArcSort]) -> Option { + match types { + [a, b] if a.name() == b.name() => Some(a.clone()), + _ => None, + } + } + + fn apply(&self, values: &[Value]) -> Option { + assert_eq!(values.len(), 2); + if values[0] == values[1] { + Some(values[0]) + } else { + None + } + } +} + impl<'a> Context<'a> { pub fn new(egraph: &'a EGraph) -> Self { Self { @@ -126,6 +153,25 @@ impl<'a> Context<'a> { _ => continue, } } + // Globally bound variables first + for (node, &id) in &self.nodes { + match node { + ENode::Var(var) => { + if self.egraph.global_bindings.get(var).is_some() { + match leaves.entry(id) { + Entry::Occupied(existing) => { + canon.insert(*var, existing.get().clone()); + } + Entry::Vacant(v) => { + v.insert(Expr::Var(*var)); + } + } + } + } + _ => continue, + } + } + // Now do variables for (node, &id) in &self.nodes { debug_assert_eq!(id, self.unionfind.find(id)); @@ -151,22 +197,44 @@ impl<'a> Context<'a> { let get_leaf = |id: &Id| -> AtomTerm { let mk = || AtomTerm::Var(Symbol::from(format!("?__{}", id))); match leaves.get(id) { - Some(Expr::Var(v)) => AtomTerm::Var(*v), + Some(Expr::Var(v)) => { + if let Some((_ty, _value, _ts)) = self.egraph.global_bindings.get(v) { + AtomTerm::Global(*v) + } else { + AtomTerm::Var(*v) + } + } Some(Expr::Lit(l)) => AtomTerm::Value(self.egraph.eval_lit(l)), _ => mk(), } }; let mut query = Query::default(); + let mut query_eclasses = HashSet::::default(); // Now we can fill in the nodes with the canonical leaves for (node, id) in &self.nodes { match node { ENode::Func(f, ids) => { let args = ids.iter().chain([id]).map(get_leaf).collect(); + for id in ids { + query_eclasses.insert(*id); + } query.atoms.push(Atom { head: *f, args }); } ENode::Prim(p, ids) => { - let args = ids.iter().chain([id]).map(get_leaf).collect(); + let mut args = vec![]; + for child in ids { + let leaf = get_leaf(child); + if let AtomTerm::Var(v) = leaf { + if self.egraph.global_bindings.contains_key(&v) { + args.push(AtomTerm::Value(self.egraph.global_bindings[&v].1)); + continue; + } + } + args.push(get_leaf(child)); + query_eclasses.insert(*child); + } + args.push(get_leaf(id)); query.filters.push(Atom { head: p.clone(), args, @@ -176,6 +244,22 @@ impl<'a> Context<'a> { } } + // filter for global variables + for node in &self.nodes { + if let ENode::Var(var) = node.0 { + if self.egraph.global_bindings.contains_key(var) { + let canon = get_leaf(node.1); + if canon != AtomTerm::Global(*var) { + // compare global to canon + query.filters.push(Atom { + head: Primitive(Arc::new(ValueEq {})), + args: vec![canon.clone(), AtomTerm::Global(*var), canon], + }); + } + } + } + } + if self.errors.is_empty() { Ok((query, res_actions)) } else { @@ -198,7 +282,7 @@ impl<'a> Context<'a> { } // reinsert and handle hit - if let Some(old) = self.nodes.insert(node, id) { + if let Some(old) = self.nodes.insert(node.clone(), id) { keep_going = true; self.unionfind.union_raw(old, id); } @@ -209,6 +293,7 @@ impl<'a> Context<'a> { fn typecheck_fact(&mut self, fact: &Fact) { match fact { Fact::Eq(exprs) => { + assert!(exprs.len() == 2); let mut later = vec![]; let mut ty: Option = None; let mut ids = Vec::with_capacity(exprs.len()); @@ -221,7 +306,7 @@ impl<'a> Context<'a> { // so we'll try again later when we can check its type (Expr::Var(v), None) if !self.types.contains_key(v) - && !self.egraph.functions.contains_key(v) => + && !self.egraph.global_bindings.contains_key(v) => { later.push(expr) } @@ -256,7 +341,7 @@ impl<'a> Context<'a> { fn check_query_expr(&mut self, expr: &Expr, expected: ArcSort) -> Id { match expr { - Expr::Var(sym) if !self.egraph.functions.contains_key(sym) => { + Expr::Var(sym) => { match self.types.entry(*sym) { IEntry::Occupied(ty) => { // TODO name comparison?? @@ -299,8 +384,11 @@ impl<'a> Context<'a> { if self.egraph.functions.contains_key(sym) { return self.infer_query_expr(&Expr::call(*sym, [])); } + let ty = if let Some(ty) = self.types.get(sym) { Some(ty.clone()) + } else if let Some(ty) = self.egraph.global_bindings.get(sym) { + Some(ty.0.clone()) } else { self.errors.push(TypeError::Unbound(*sym)); None @@ -373,7 +461,7 @@ impl<'a> ActionChecker<'a> { self.locals.insert(*v, ty); Ok(()) } - Action::Set(f, args, val) | Action::SetNoTrack(f, args, val) => { + Action::Set(f, args, val) => { let fake_call = Expr::Call(*f, args.clone()); let (_, ty) = self.infer_expr(&fake_call)?; let fake_instr = self.instructions.pop().unwrap(); @@ -382,6 +470,12 @@ impl<'a> ActionChecker<'a> { self.instructions.push(Instruction::Set(*f)); Ok(()) } + Action::Extract(variable, variants) => { + let (_, _ty) = self.infer_expr(variable)?; + let (_, _ty2) = self.infer_expr(variants)?; + self.instructions.push(Instruction::Extract(2)); + Ok(()) + } Action::Delete(f, args) => { let fake_call = Expr::Call(*f, args.clone()); let (_, _ty) = self.infer_expr(&fake_call)?; @@ -424,7 +518,10 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> { } fn infer_var(&mut self, sym: Symbol) -> Result<(Self::T, ArcSort), TypeError> { - if let Some((i, _, ty)) = self.locals.get_full(&sym) { + if let Some((sort, _v, _ts)) = self.egraph().global_bindings.get(&sym) { + self.instructions.push(Instruction::Global(sym)); + Ok(((), sort.clone())) + } else if let Some((i, _, ty)) = self.locals.get_full(&sym) { self.instructions.push(Instruction::Load(Load::Stack(i))); Ok(((), ty.clone())) } else if let Some((i, _, ty)) = self.types.get_full(&sym) { @@ -436,7 +533,17 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> { } fn do_function(&mut self, f: Symbol, _args: Vec) -> Self::T { - self.instructions.push(Instruction::CallFunction(f)); + let func_type = self + .egraph + .proof_state + .type_info + .func_types + .get(&f) + .unwrap(); + self.instructions.push(Instruction::CallFunction( + f, + func_type.has_default || !func_type.has_merge, + )); } fn do_prim(&mut self, prim: Primitive, args: Vec) -> Self::T { @@ -467,17 +574,9 @@ trait ExprChecker<'a> { } } - fn variable_function(&self, var: Symbol) -> bool { - if let Some(func) = self.egraph().functions.get(&var) { - func.is_variable - } else { - false - } - } - fn check_expr(&mut self, expr: &Expr, ty: ArcSort) -> Result { match expr { - Expr::Var(v) if !self.variable_function(*v) => self.check_var(*v, ty), + Expr::Var(v) if !self.is_variable(*v) => self.check_var(*v, ty), _ => { let (t, actual) = self.infer_expr(expr)?; if actual.name() != ty.name() { @@ -494,34 +593,28 @@ trait ExprChecker<'a> { } } + fn is_variable(&self, sym: Symbol) -> bool { + self.egraph().global_bindings.contains_key(&sym) + } + fn infer_expr(&mut self, expr: &Expr) -> Result<(Self::T, ArcSort), TypeError> { match expr { Expr::Lit(lit) => { let t = self.do_lit(lit); Ok((t, self.egraph().proof_state.type_info.infer_literal(lit))) } - Expr::Var(sym) => { - if self.variable_function(*sym) { - return self.infer_expr(&Expr::call(*sym, [])); - } - self.infer_var(*sym) - } + Expr::Var(sym) => self.infer_var(*sym), Expr::Call(sym, args) => { - if let Some(f) = self.egraph().functions.get(sym) { - if f.schema.input.len() != args.len() { - return Err(TypeError::Arity { - expr: expr.clone(), - expected: f.schema.input.len(), - }); - } + if let Some(functype) = self.egraph().proof_state.type_info.func_types.get(sym) { + assert!(functype.input.len() == args.len()); let mut ts = vec![]; - for (expected, arg) in f.schema.input.iter().zip(args) { + for (expected, arg) in functype.input.iter().zip(args) { ts.push(self.check_expr(arg, expected.clone())?); } let t = self.do_function(*sym, ts); - Ok((t, f.schema.output.clone())) + Ok((t, functype.output.clone())) } else if let Some(prims) = self.egraph().proof_state.type_info.primitives.get(sym) { let mut ts = Vec::with_capacity(args.len()); @@ -544,7 +637,7 @@ trait ExprChecker<'a> { inputs: tys.into_iter().map(|t| t.name()).collect(), }) } else { - Err(TypeError::Unbound(*sym)) + panic!("Unbound function {}", sym); } } } @@ -561,11 +654,14 @@ enum Load { enum Instruction { Literal(Literal), Load(Load), - CallFunction(Symbol), + Global(Symbol), + // function to call, and whether to make defaults + CallFunction(Symbol, bool), CallPrimitive(Primitive, usize), DeleteRow(Symbol), Set(Symbol), Union(usize), + Extract(usize), Panic(String), Pop, } @@ -634,11 +730,16 @@ impl EGraph { ) -> Result<(), Error> { for instr in &program.0 { match instr { + Instruction::Global(sym) => { + let (_ty, value, _ts) = self.global_bindings.get(sym).unwrap(); + stack.push(*value); + } Instruction::Load(load) => match load { Load::Stack(idx) => stack.push(stack[*idx]), Load::Subst(idx) => stack.push(subst[*idx]), }, - Instruction::CallFunction(f) => { + Instruction::CallFunction(f, make_defaults_func) => { + let make_defaults = make_defaults && *make_defaults_func; let function = self.functions.get_mut(f).unwrap(); let output_tag = function.schema.output.name(); let new_len = stack.len() - function.schema.input.len(); @@ -653,6 +754,9 @@ impl EGraph { let value = if let Some(out) = function.nodes.get(values) { out.value } else if make_defaults { + if function.merge.on_merge.is_some() { + panic!("No value found for function {} with values {:?}", f, values); + } let ts = self.timestamp; let out = &function.schema.output; match function.decl.default.as_ref() { @@ -676,13 +780,13 @@ impl EGraph { } _ => { return Err(Error::NotFoundError(NotFoundError(Expr::Var( - format!("fake expression {f} {:?}", values).into(), + format!("No value found for {f} {:?}", values).into(), )))) } } } else { return Err(Error::NotFoundError(NotFoundError(Expr::Var( - format!("fake expression {f} {:?}", values).into(), + format!("No value found for {f} {:?}", values).into(), )))); }; @@ -703,6 +807,9 @@ impl EGraph { Instruction::Set(f) => { assert!(make_defaults); let function = self.functions.get_mut(f).unwrap(); + // desugaring should have desugared + // set to union + // except for setting the parent relation let new_value = stack.pop().unwrap(); let new_len = stack.len() - function.schema.input.len(); let args = &stack[new_len..]; @@ -712,22 +819,13 @@ impl EGraph { if let Some(old_value) = old_value { if new_value != old_value { - let tag = old_value.tag; - if let Some(prog) = function.merge.on_merge.clone() { - let values = [old_value, new_value]; - // XXX: we get an error if we pass the current - // stack and then truncate it to the old length. - // Why? - self.run_actions(&mut Vec::new(), &values, &prog, true)?; - } - // re-borrow - let function = self.functions.get_mut(f).unwrap(); let merged: Value = match function.merge.merge_vals.clone() { MergeFn::AssertEq => { return Err(Error::MergeError(*f, new_value, old_value)); } MergeFn::Union => { - self.unionfind.union_values(old_value, new_value, tag) + self.unionfind + .union_values(old_value, new_value, old_value.tag) } MergeFn::Expr(merge_prog) => { let values = [old_value, new_value]; @@ -738,10 +836,20 @@ impl EGraph { result } }; + if merged != old_value { + let args = &stack[new_len..]; + let function = self.functions.get_mut(f).unwrap(); + function.insert(args, merged, self.timestamp); + } // re-borrow - let args = &stack[new_len..]; let function = self.functions.get_mut(f).unwrap(); - function.insert(args, merged, self.timestamp); + if let Some(prog) = function.merge.on_merge.clone() { + let values = [old_value, new_value]; + // XXX: we get an error if we pass the current + // stack and then truncate it to the old length. + // Why? + self.run_actions(&mut Vec::new(), &values, &prog, true)?; + } } } else { function.insert(args, new_value, self.timestamp); @@ -759,6 +867,40 @@ impl EGraph { }); stack.truncate(new_len); } + Instruction::Extract(arity) => { + let new_len = stack.len() - arity; + let values = &stack[new_len..]; + let new_len = stack.len() - arity; + let mut termdag = TermDag::default(); + let num_sort = values[1].tag; + assert!(num_sort.to_string() == "i64"); + + let variants = values[1].bits as i64; + if variants == 0 { + let (cost, expr) = self.extract( + values[0], + &mut termdag, + self.proof_state + .type_info + .sorts + .get(&values[0].tag) + .unwrap(), + ); + log::info!("extracted with cost {cost}: {}", termdag.to_string(&expr)); + } else { + if variants < 0 { + panic!("Cannot extract negative number of variants"); + } + let extracted = + self.extract_variants(values[0], variants as usize, &mut termdag); + log::info!("extracted variants:"); + for expr in extracted { + log::info!(" {}", termdag.to_string(&expr)); + } + } + + stack.truncate(new_len); + } Instruction::Panic(msg) => panic!("Panic: {}", msg), Instruction::Literal(lit) => match lit { Literal::Int(i) => stack.push(Value::from(*i)), diff --git a/src/typechecking.rs b/src/typechecking.rs index 40f89422..c8f048b6 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -5,14 +5,16 @@ pub struct FuncType { pub input: Vec, pub output: ArcSort, pub has_merge: bool, + pub has_default: bool, } impl FuncType { - pub fn new(input: Vec, output: ArcSort, has_merge: bool) -> Self { + pub fn new(input: Vec, output: ArcSort, has_merge: bool, has_default: bool) -> Self { Self { input, output, has_merge, + has_default, } } } @@ -21,6 +23,7 @@ impl FuncType { pub struct TypeInfo { // get the sort from the sorts name() pub presorts: HashMap, + pub presort_names: HashSet, pub sorts: HashMap>, pub primitives: HashMap>, pub func_types: HashMap, @@ -32,6 +35,7 @@ impl Default for TypeInfo { fn default() -> Self { let mut res = Self { presorts: Default::default(), + presort_names: Default::default(), sorts: Default::default(), primitives: Default::default(), func_types: Default::default(), @@ -44,9 +48,15 @@ impl Default for TypeInfo { res.add_sort(I64Sort::new("i64".into())); res.add_sort(F64Sort::new("f64".into())); res.add_sort(RationalSort::new("Rational".into())); + + res.presort_names.extend(MapSort::presort_names()); + res.presort_names.extend(SetSort::presort_names()); + res.presort_names.extend(VecSort::presort_names()); + res.presorts.insert("Map".into(), MapSort::make_sort); res.presorts.insert("Set".into(), SetSort::make_sort); res.presorts.insert("Vec".into(), VecSort::make_sort); + res } } @@ -129,7 +139,12 @@ impl TypeInfo { } else { Err(TypeError::Unbound(func.schema.output)) }?; - Ok(FuncType::new(input, output, func.merge.is_some())) + Ok(FuncType::new( + input, + output, + func.merge.is_some(), + func.default.is_some(), + )) } fn typecheck_ncommand(&mut self, command: &NCommand, id: CommandId) -> Result<(), TypeError> { @@ -138,7 +153,7 @@ impl TypeInfo { if self.sorts.contains_key(&fdecl.name) { return Err(TypeError::SortAlreadyBound(fdecl.name)); } - if self.primitives.contains_key(&fdecl.name) { + if self.is_primitive(fdecl.name) { return Err(TypeError::PrimitiveAlreadyBound(fdecl.name)); } let ftype = self.function_to_functype(fdecl)?; @@ -161,10 +176,14 @@ impl TypeInfo { } NCommand::Check(facts) => { self.typecheck_facts(id, facts)?; + self.verify_normal_form_facts(facts); } NCommand::Fail(cmd) => { self.typecheck_ncommand(cmd, id)?; } + NCommand::RunSchedule(schedule) => { + self.typecheck_schedule(id, schedule)?; + } // TODO cover all cases in typechecking _ => (), @@ -172,6 +191,34 @@ impl TypeInfo { Ok(()) } + fn typecheck_schedule( + &mut self, + ctx: CommandId, + schedule: &NormSchedule, + ) -> Result<(), TypeError> { + match schedule { + NormSchedule::Repeat(_times, schedule) => { + self.typecheck_schedule(ctx, schedule)?; + } + NormSchedule::Sequence(schedules) => { + for schedule in schedules { + self.typecheck_schedule(ctx, schedule)?; + } + } + NormSchedule::Saturate(schedule) => { + self.typecheck_schedule(ctx, schedule)?; + } + NormSchedule::Run(run_config) => { + if let Some(facts) = &run_config.until { + self.typecheck_facts(ctx, facts)?; + self.verify_normal_form_facts(facts); + } + } + } + + Result::Ok(()) + } + pub(crate) fn typecheck_command(&mut self, command: &NormCommand) -> Result<(), TypeError> { assert!(self .local_types @@ -232,18 +279,30 @@ impl TypeInfo { fn verify_normal_form_facts(&self, facts: &Vec) -> HashSet { let mut let_bound: HashSet = Default::default(); - let mut bound_in_constraint = vec![]; for fact in facts { match fact { - NormFact::Assign(var, NormExpr::Call(_head, body)) => { + NormFact::Compute(var, NormExpr::Call(_head, body)) => { + assert!(!self.global_types.contains_key(var)); assert!(let_bound.insert(*var)); body.iter().for_each(|bvar| { if !self.global_types.contains_key(bvar) { - assert!(let_bound.insert(*bvar)); + assert!(let_bound.contains(bvar)); } }); } + NormFact::Assign(var, NormExpr::Call(_head, body)) => { + assert!(!self.global_types.contains_key(var)); + assert!(let_bound.insert(*var)); + body.iter().for_each(|bvar| { + assert!(!self.global_types.contains_key(bvar)); + assert!(let_bound.insert(*bvar), "Expected {} to be bound", bvar); + }); + } + NormFact::AssignVar(lhs, _rhs) => { + assert!(!self.global_types.contains_key(lhs)); + assert!(let_bound.insert(*lhs)); + } NormFact::AssignLit(var, _lit) => { assert!(let_bound.insert(*var)); } @@ -255,12 +314,9 @@ impl TypeInfo { { panic!("ConstrainEq on unbound variables"); } - bound_in_constraint.push(*var1); - bound_in_constraint.push(*var2); } } } - let_bound.extend(bound_in_constraint); let_bound } @@ -273,7 +329,8 @@ impl TypeInfo { assert!( let_bound.contains(var) || self.global_types.contains_key(var) - || self.reserved_type(*var).is_some() + || self.reserved_type(*var).is_some(), + "Expected {var} to be let bound in body of rule", ) }; @@ -303,6 +360,10 @@ impl TypeInfo { }); assert_bound(var, let_bound); } + NormAction::Extract(var, variants) => { + assert_bound(var, let_bound); + assert_bound(variants, let_bound); + } NormAction::Union(v1, v2) => { assert_bound(v1, let_bound); assert_bound(v2, let_bound); @@ -368,6 +429,7 @@ impl TypeInfo { return Err(TypeError::TypeMismatch(var1_type, var2_type)); } } + NormAction::Extract(_var, _variants) => {} NormAction::LetVar(var1, var2) => { let var2_type = self.lookup(ctx, *var2)?; self.introduce_binding(ctx, *var1, var2_type, is_global)?; @@ -379,17 +441,37 @@ impl TypeInfo { fn typecheck_fact(&mut self, ctx: CommandId, fact: &NormFact) -> Result<(), TypeError> { match fact { + NormFact::Compute(var, expr) => { + let expr_type = self.typecheck_expr(ctx, expr, true)?; + if let Some(_existing) = self + .local_types + .get_mut(&ctx) + .unwrap() + .insert(*var, expr_type.output.clone()) + { + return Err(TypeError::AlreadyDefined(*var)); + } + } NormFact::Assign(var, expr) => { let expr_type = self.typecheck_expr(ctx, expr, false)?; - if let Some(existing) = self + if let Some(_existing) = self .local_types .get_mut(&ctx) .unwrap() .insert(*var, expr_type.output.clone()) { - if expr_type.output.name() != existing.name() { - return Err(TypeError::TypeMismatch(expr_type.output, existing)); - } + return Err(TypeError::AlreadyDefined(*var)); + } + } + NormFact::AssignVar(lhs, rhs) => { + let rhs_type = self.lookup(ctx, *rhs)?; + if let Some(_existing) = self + .local_types + .get_mut(&ctx) + .unwrap() + .insert(*lhs, rhs_type.clone()) + { + return Err(TypeError::AlreadyDefined(*lhs)); } } NormFact::AssignLit(var, lit) => { @@ -478,7 +560,7 @@ impl TypeInfo { } pub(crate) fn is_primitive(&self, sym: Symbol) -> bool { - self.primitives.contains_key(&sym) + self.primitives.contains_key(&sym) || self.presort_names.contains(&sym) } fn lookup_func( @@ -493,7 +575,7 @@ impl TypeInfo { if let Some(prims) = self.primitives.get(&sym) { for prim in prims { if let Some(return_type) = prim.accept(&input_types) { - return Ok(FuncType::new(input_types, return_type, false)); + return Ok(FuncType::new(input_types, return_type, false, true)); } } } @@ -516,9 +598,18 @@ impl TypeInfo { let child_types = if let Some(found) = self.func_types.get(head) { found.input.clone() } else { - body.iter() + let types = body + .iter() .map(|var| self.lookup(ctx, *var)) - .collect::, _>>()? + .collect::, _>>(); + if let Ok(types) = types { + types + } else if expect_lookup { + // return the error + types? + } else { + return Err(TypeError::UnboundFunction(*head)); + } }; for (child_type, var) in child_types.iter().zip(body.iter()) { if expect_lookup { @@ -554,6 +645,8 @@ pub enum TypeError { Unbound(Symbol), #[error("Undefined sort {0}")] UndefinedSort(Symbol), + #[error("Unbound function {0}")] + UnboundFunction(Symbol), #[error("Function already bound {0}")] FunctionAlreadyBound(Symbol), #[error("Function declarations are not allowed after a push.")] diff --git a/tests/antiunify.egg b/tests/antiunify.egg index 2891f3b5..b9e5554b 100644 --- a/tests/antiunify.egg +++ b/tests/antiunify.egg @@ -16,11 +16,11 @@ (AU (Add a b) (Add c d)) (Add (AU a c) (AU b d))) -(define e1 (Add (Var "x") (Add (Num 1) (Num 2)))) -(define e2 (Add (Num 3) (Var "y"))) +(let e1 (Add (Var "x") (Add (Num 1) (Num 2)))) +(let e2 (Add (Num 3) (Var "y"))) -(define au12 (AU e1 e2)) +(let au12 (AU e1 e2)) (run 4) (check (= au12 (Add (Num 3) (AU (Var "x") (Var "y"))))) -(extract au12) +(query-extract au12) diff --git a/tests/array.egg b/tests/array.egg index a84a84f2..713568eb 100644 --- a/tests/array.egg +++ b/tests/array.egg @@ -52,17 +52,17 @@ (rewrite (add (add x y) z) (add x (add y z))) (rewrite (add (Num x) (Num y)) (Num (+ x y))) -(define r1 (Var "r1")) -(define r2 (Var "r2")) -(define r3 (Var "r3")) -(define mem1 (AVar "mem1")) +(let r1 (Var "r1")) +(let r2 (Var "r2")) +(let r3 (Var "r3")) +(let mem1 (AVar "mem1")) (neq r1 r2) (neq r2 r3) (neq r1 r3) -(define test1 (select (store mem1 r1 (Num 42)) r1)) -(define test2 (select (store mem1 r1 (Num 42)) (add r1 (Num 17)))) -(define test3 (select (store (store mem1 (add r1 r2) (Num 1)) (add r2 r1) (Num 2)) (add r1 r3))) +(let test1 (select (store mem1 r1 (Num 42)) r1)) +(let test2 (select (store mem1 r1 (Num 42)) (add r1 (Num 17)))) +(let test3 (select (store (store mem1 (add r1 r2) (Num 1)) (add r2 r1) (Num 2)) (add r1 r3))) (run 4) (check (= test1 (Num 42))) diff --git a/tests/bdd.egg b/tests/bdd.egg index aca2a34a..dea5d783 100644 --- a/tests/bdd.egg +++ b/tests/bdd.egg @@ -34,17 +34,17 @@ (ITE n (and a1 b1) (and a2 b2)) ) -(define b0 (ITE 0 True False)) -(define b1 (ITE 1 True False)) -(define b2 (ITE 2 True False)) +(let b0 (ITE 0 True False)) +(let b1 (ITE 1 True False)) +(let b2 (ITE 2 True False)) -(define b123 (and b2 (and b0 b1))) -(define b11 (and b1 b1)) -(define b12 (and b1 b2)) +(let b123 (and b2 (and b0 b1))) +(let b11 (and b1 b1)) +(let b12 (and b1 b2)) (run 5) -(extract b11) -(extract b12) -(extract b123) +(query-extract b11) +(query-extract b12) +(query-extract b123) (check (= (and (ITE 1 True False) (ITE 2 True False)) (ITE 1 (ITE 2 True False) False)) ) @@ -67,9 +67,9 @@ (ITE n (or a1 b1) (or a2 b2)) ) -(define or121 (or b1 (or b2 b1))) +(let or121 (or b1 (or b2 b1))) (run 5) -(extract or121) +(query-extract or121) (function not (BDD) BDD) (rewrite (not True) False) diff --git a/tests/before-proofs.egg b/tests/before-proofs.egg index 52641a42..ab38a52c 100644 --- a/tests/before-proofs.egg +++ b/tests/before-proofs.egg @@ -14,8 +14,8 @@ (rewrite (Add a (Add b c)) (Add (Add a b) c)) -(define two (rational 2 1)) -(define start1 (Add (Var "x") (Const two))) +(let two (rational 2 1)) +(let start1 (Add (Var "x") (Const two))) ;; add original proofs (run 3) @@ -32,12 +32,3 @@ (check (= addx2 addx20)) - -(function p1 () Proof__ :cost 100000000) - -(rule ((= addx2 addzerofront)) - ((union rule-proof (p1)))) -(run 1) - -(run proof-extract__ 10000 :until (< (ProofCost__ (p1)) 1000000)) -(extract (p1)) diff --git a/tests/birewrite.egg b/tests/birewrite.egg index 01be8f3f..a97bb637 100644 --- a/tests/birewrite.egg +++ b/tests/birewrite.egg @@ -2,16 +2,16 @@ (birewrite (Add (Add x y) z) (Add x (Add y z))) -(define a (Lit 1)) -(define b (Lit 2)) -(define c (Lit 3)) +(let a (Lit 1)) +(let b (Lit 2)) +(let c (Lit 3)) -(define d (Lit 4)) -(define e (Lit 5)) -(define f (Lit 6)) +(let d (Lit 4)) +(let e (Lit 5)) +(let f (Lit 6)) -(define ex1 (Add (Add a b) c)) -(define ex2 (Add d (Add e f))) +(let ex1 (Add (Add a b) c)) +(let ex2 (Add d (Add e f))) (run 10) (check (= ex1 (Add a (Add b c)))) diff --git a/tests/bitwise.egg b/tests/bitwise.egg index d0abf9b0..e84fc1aa 100644 --- a/tests/bitwise.egg +++ b/tests/bitwise.egg @@ -28,10 +28,10 @@ ;(function bs-diff (i64 i64) i64) ;(rewrite (bs-diff a b) (^ a (bs-inter a b)) -;(define bs-empty 0) +;(let bs-empty 0) -;(define bs-subset (i64 i64) bool) +;(let bs-subset (i64 i64) bool) ;(rewrite (bs-subset x y) (is-zero (bs-diff x y))) -;(define bs-is-elem (i64 i64) bool) +;(let bs-is-elem (i64 i64) bool) ;(rewrite (bs-is-elem s x) (not (is-zero (bs-inter s (sing x))))) diff --git a/tests/calc.egg b/tests/calc.egg index 66e4d880..1d84cc0d 100644 --- a/tests/calc.egg +++ b/tests/calc.egg @@ -13,9 +13,9 @@ ; A is cyclic of period 4 (rewrite (g* A (g* A (g* A A))) I) -(define A2 (g* A A)) -(define A4 (g* A2 A2)) -(define A8 (g* A4 A4)) +(let A2 (g* A A)) +(let A4 (g* A2 A2)) +(let A8 (g* A4 A4)) (calc () diff --git a/tests/combinators.egg b/tests/combinators.egg index 630071de..06919284 100644 --- a/tests/combinators.egg +++ b/tests/combinators.egg @@ -17,7 +17,7 @@ ; (\x. (if x then 0 else 1) + 2) false -(define test +(let test (App (Abs "x" (Add (If (Var "x") (N 0) (N 1)) (N 2))) F)) @@ -77,17 +77,17 @@ ; May be needed for multiple nested variables (rewrite (CAbs v (CApp K (CVar v))) K) -;;;; Primitive Evaluation rules (defined on "surface syntax") +;;;; Primitive Evaluation rules (letd on "surface syntax") (rewrite (If T t f) t) (rewrite (If F t f) f) (rewrite (Add (N n) (N m)) (N (+ n m))) -;;;; Substitution Rules (defined on the combinator representation) +;;;; Substitution Rules (letd on the combinator representation) (rewrite (CApp I cx) cx) (rewrite (CApp (CApp K cx) cy) cx) ; Without demand, this can cause an explosion in DB size. (rewrite (CApp (CApp (CApp S cx) cy) cz) (CApp (CApp cx cz) (CApp cy cz))) (run 11) -(extract (Comb test)) +(query-extract (Comb test)) (check (= test (N 3))) \ No newline at end of file diff --git a/tests/cyk.egg b/tests/cyk.egg index 196efc1d..cab8d093 100644 --- a/tests/cyk.egg +++ b/tests/cyk.egg @@ -56,7 +56,7 @@ (End (NonTerm "DET") "a") -(define test1 (B 7 1 (NonTerm "S")):cost 1000) +(let test1 (B 7 1 (NonTerm "S"))) (run 100) @@ -64,8 +64,7 @@ (fail (check (P 7 1 (NonTerm "VP")))) (fail (check (P 7 1 (NonTerm "")))) -(extract test1) -(print test1) +(query-extract test1) (pop) @@ -91,9 +90,8 @@ (run 100) (check (P 5 1 (NonTerm "S"))) (fail (check (P 5 1 (NonTerm "B")))) -(define test2 (B 5 1 (NonTerm "S")):cost 1000) -(extract:variants 10 test2) -(print test2) +(let test2 (B 5 1 (NonTerm "S"))) +(query-extract:variants 10 test2) (pop) @@ -111,8 +109,7 @@ (fail (check (P 5 1 (NonTerm "B")))) (fail (check (P 5 1 (NonTerm "")))) (fail (check (P 5 1 (NonTerm "unrelated")))) -(define test3 (B 5 1 (NonTerm "S")):cost 1000) -(extract:variants 10 test3) -(print test3) +(let test3 (B 5 1 (NonTerm "S"))) +(query-extract :variants 10 test3) (pop) \ No newline at end of file diff --git a/tests/cykjson.egg b/tests/cykjson.egg index 696a1b07..599097e5 100644 --- a/tests/cykjson.egg +++ b/tests/cykjson.egg @@ -37,7 +37,7 @@ ; medium size 7821 but runs for 2 min. ;(input getString "./tests/cykjson_medium_token.csv") -(define test1 (B 801 1 "VAL"):cost 100000) +(let test1 (B 801 1 "VAL")) (run 10000) diff --git a/tests/eqsat-basic.egg b/tests/eqsat-basic.egg index 0f88e015..f6bead84 100644 --- a/tests/eqsat-basic.egg +++ b/tests/eqsat-basic.egg @@ -5,9 +5,9 @@ (Mul Math Math)) ;; expr1 = 2 * (x + 3) -(define expr1 (Mul (Num 2) (Add (Var "x") (Num 3)))) +(let expr1 (Mul (Num 2) (Add (Var "x") (Num 3)))) ;; expr2 = 6 + 2 * x -(define expr2 (Add (Num 6) (Mul (Num 2) (Var "x")))) +(let expr2 (Add (Num 6) (Mul (Num 2) (Var "x")))) ;; (rule ((= __root (Add a b))) diff --git a/tests/eqsolve.egg b/tests/eqsolve.egg index 1314e0d0..46c7a937 100644 --- a/tests/eqsolve.egg +++ b/tests/eqsolve.egg @@ -29,9 +29,9 @@ (set (Add (Var "z") (Var "z")) (Var "y")) (run 5) -(extract (Var "x")) -(extract (Var "y")) -(extract (Var "z")) +(query-extract (Var "x")) +(query-extract (Var "y")) +(query-extract (Var "z")) (check (= (Var "z") (Add (Num 6) (Neg (Var "y"))))) (check (= (Var "y") (Add (Add (Num 6) (Neg (Var "y"))) (Add (Num 6) (Neg (Var "y")))))) (check (= (Var "y") (Add (Add (Num 12) (Neg (Var "y"))) (Neg (Var "y"))))) diff --git a/tests/extraction-cost.egg b/tests/extraction-cost.egg deleted file mode 100644 index d413dfe3..00000000 --- a/tests/extraction-cost.egg +++ /dev/null @@ -1,8 +0,0 @@ -(datatype Expr - (Num i64 :cost 5)) - -(define x (Num 1) :cost 10) -(define y (Num 2) :cost 1) - -(extract x) ;; (Num 1) -(extract y) ;; (y) \ No newline at end of file diff --git a/tests/f64.egg b/tests/f64.egg index b3ed0dd4..cb28bda3 100644 --- a/tests/f64.egg +++ b/tests/f64.egg @@ -3,5 +3,7 @@ (check (= (/ 12.5 2.0) 6.25)) (check (< 1.5 9.2)) (check (>= 9.2 1.5)) +(fail (check (< 9.2 1.5))) +(fail (check (= (+ 1.5 9.2) 10.6))) (check (= (to-f64 1) 1.0)) (check (= (to-i64 1.0) 1)) diff --git a/tests/fail-typecheck/repro-containers-disallowed.egg b/tests/fail-typecheck/repro-containers-disallowed.egg new file mode 100644 index 00000000..ed76d3a1 --- /dev/null +++ b/tests/fail-typecheck/repro-containers-disallowed.egg @@ -0,0 +1,4 @@ +(sort IVec (Vec i64)) + +; Test vec-of +(fail (check (= (vec-of 1 2) (vec-push (vec-push (vec-empty) 1) 2)))) diff --git a/tests/fail-typecheck/repro-duplicated-var.egg b/tests/fail-typecheck/repro-duplicated-var.egg new file mode 100644 index 00000000..2340aa7f --- /dev/null +++ b/tests/fail-typecheck/repro-duplicated-var.egg @@ -0,0 +1,3 @@ +(function f (i64) i64) +;; the let's y should fail checking +(rule ((= x 1) (= y x)) ((let y (f 1)) (set (f 0) 0))) \ No newline at end of file diff --git a/tests/fibonacci-demand.egg b/tests/fibonacci-demand.egg index a03999de..a6b00b02 100644 --- a/tests/fibonacci-demand.egg +++ b/tests/fibonacci-demand.egg @@ -1,8 +1,8 @@ (datatype Expr - (Num i64) - (Add Expr Expr)) + (Num i64 :cost 1) + (Add Expr Expr :cost 5)) -(function Fib (i64) Expr) +(function Fib (i64) Expr :cost 10) (rewrite (Add (Num a) (Num b)) (Num (+ a b))) (rewrite (Fib x) (Add (Fib (- x 1)) (Fib (- x 2))) @@ -10,9 +10,10 @@ (rewrite (Fib x) (Num x) :when ((<= x 1))) -(define f7 (Fib 7)) +(let f7 (Fib 7)) (run 1000) -(extract f7) +(print-table Fib) +(query-extract f7) (check (= f7 (Num 13))) \ No newline at end of file diff --git a/tests/files.rs b/tests/files.rs index 004fe207..b5d64996 100644 --- a/tests/files.rs +++ b/tests/files.rs @@ -7,45 +7,51 @@ use libtest_mimic::Trial; struct Run { path: PathBuf, test_proofs: bool, + resugar: bool, } impl Run { fn run(&self) { let _ = env_logger::builder().is_test(true).try_init(); - let program = std::fs::read_to_string(&self.path) + let program_read = std::fs::read_to_string(&self.path) .unwrap_or_else(|err| panic!("Couldn't read {:?}: {:?}", self.path, err)); - self.test_program(&program, "Top level error"); - if !self.should_fail() { + let already_enables = program_read.starts_with("(set-option enable_proofs 1)"); + let program = if self.test_proofs && !already_enables { + format!("(set-option enable_proofs 1)\n{}", program_read) + } else { + program_read + }; + + if !self.resugar { + self.test_program(&program, "Top level error"); + } else if self.resugar { let mut egraph = EGraph::default(); - egraph.set_underscores_for_desugaring(4); + egraph.set_underscores_for_desugaring(3); let parsed = egraph.parse_program(&program).unwrap(); + // TODO can we test after term encoding instead? + // last time I tried it spun out becuase + // it adds term encoding to term encoding let desugared_str = egraph - .process_commands(parsed) + .process_commands(parsed, CompilerPassStop::TypecheckDesugared) .unwrap() .into_iter() - .map(|x| x.to_string()) + .map(|x| x.resugar().to_string()) .collect::>() .join("\n"); - println!("{}", desugared_str); - self.test_program( &desugared_str, - &format!( - "Program:\n{}\n ERROR after parse, to_string, and parse again.", - desugared_str - ), + "ERROR after parse, to_string, and parse again.", ); } } fn test_program(&self, program: &str, message: &str) { let mut egraph = EGraph::default(); - egraph.set_underscores_for_desugaring(5); if self.test_proofs { - egraph.enable_proofs(); egraph.test_proofs = true; } + egraph.set_underscores_for_desugaring(5); match egraph.parse_and_run_program(program) { Ok(msgs) => { if self.should_fail() { @@ -82,8 +88,8 @@ impl Run { let stem = self.0.path.file_stem().unwrap(); let stem_str = stem.to_string_lossy().replace(['.', '-', ' '], "_"); write!(f, "{stem_str}")?; - if self.0.test_proofs { - write!(f, "_with_proofs")?; + if self.0.resugar { + write!(f, "_resugar")?; } Ok(()) } @@ -104,24 +110,15 @@ fn generate_tests(glob: &str) -> Vec { let run = Run { path: entry.unwrap().clone(), test_proofs: false, + resugar: false, }; - let name = run.name().to_string(); + let should_fail = run.should_fail(); push_trial(run.clone()); - - // make a test with proofs enabled - // TODO: re-enable herbie, unsound, and eqsolve when proof extraction is faster - let banned = [ - "herbie", - "repro_unsound", - "eqsolve", - "before_proofs", - "lambda", - ]; - if !banned.contains(&name.as_str()) { + if !should_fail { push_trial(Run { - test_proofs: true, - ..run + resugar: true, + ..run.clone() }); } } diff --git a/tests/fusion.egg b/tests/fusion.egg index e652d79f..a53949eb 100644 --- a/tests/fusion.egg +++ b/tests/fusion.egg @@ -131,14 +131,14 @@ (set (freer (sum)) (set-empty)) (set (freer (mapf)) (set-empty)) -(define expr (App (sum) (App (mapf) (TVar (V "expr"))))) +(let expr (App (sum) (App (mapf) (TVar (V "expr"))))) (run 100) -(extract (freer expr)) +(query-extract (freer expr)) -(define my-output +(let my-output (CaseSplit (TVar (V "expr")) (Num 0) (Lam (V "x") (Lam (V "xs'") (Add (Add (TVar (V "x")) (Num 1)) diff --git a/tests/herbie-tutorial.egg b/tests/herbie-tutorial.egg index 8fd7235e..8163d74c 100644 --- a/tests/herbie-tutorial.egg +++ b/tests/herbie-tutorial.egg @@ -14,7 +14,7 @@ (rewrite (Add (Num r1) (Num r2)) (Num (+ r1 r2))) -(define one-two (Add one two)) +(let one-two (Add one two)) (push) (run 1) @@ -72,8 +72,8 @@ (run 3) -(extract (lower-bound x1)) -(extract (upper-bound x1)) +(query-extract (lower-bound x1)) +(query-extract (upper-bound x1)) (check (= one (Div x1 x1))) (pop) @@ -85,8 +85,8 @@ (run 3) -(extract (lower-bound x1)) -(extract (upper-bound x1)) +(query-extract (lower-bound x1)) +(query-extract (upper-bound x1)) (function true-value (Math) f64) @@ -96,7 +96,7 @@ (to-f64 (lower-bound e))))) (run 1) -(extract (true-value x1)) +(query-extract (true-value x1)) (function best-error (Math) f64 :merge new :default (to-f64 (rational 10000 1))) @@ -125,13 +125,15 @@ (run 1) +;; set a default +(best-error target) ;; error is bad, constant folding hasn't fired enough -(extract (best-error target)) +(query-extract (best-error target)) (run 1) ;; error is good, constant folding has fired enough -(extract (best-error target)) +(query-extract (best-error target)) (pop) \ No newline at end of file diff --git a/tests/herbie.egg b/tests/herbie.egg index 24d95491..6b19c319 100644 --- a/tests/herbie.egg +++ b/tests/herbie.egg @@ -30,14 +30,14 @@ (Round Math) (Log Math)) -(define r-zero (rational 0 1)) -(define r-one (rational 1 1)) -(define r-two (rational 2 1)) -(define zero (Num r-zero)) -(define one (Num r-one)) -(define two (Num r-two)) -(define three (Num (rational 3 1))) -(define neg-one (Neg one)) +(let r-zero (rational 0 1)) +(let r-one (rational 1 1)) +(let r-two (rational 2 1)) +(let zero (Num r-zero)) +(let one (Num r-one)) +(let two (Num r-two)) +(let three (Num (rational 3 1))) +(let neg-one (Neg one)) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; @@ -471,64 +471,64 @@ ;; (src/core/simplify.rkt) (push) -(define e (Add one zero)) +(let e (Add one zero)) (run 1) (check (= e one)) (pop) (push) -(define five (Num (rational 5 1))) -(define six (Num (rational 6 1))) -(define e2 (Add one five)) +(let five (Num (rational 5 1))) +(let six (Num (rational 6 1))) +(let e2 (Add one five)) (run 1) (check (= e2 six)) (pop) -(define x (Var "x")) +(let x (Var "x")) (push) -(define e3 (Add x zero)) +(let e3 (Add x zero)) (run 1) (check (= e3 x)) (pop) (push) -(define e4 (Sub x zero)) +(let e4 (Sub x zero)) (run 1) (check (= e4 x)) (pop) (push) -(define e5 (Mul x one)) +(let e5 (Mul x one)) (run 1) (check (= e5 x)) (pop) (push) -(define e6 (Div x one)) +(let e6 (Div x one)) (run 1) (check (= e6 x)) (pop) (push) -(define e7 (Sub (Mul one x) (Mul (Add x one) one))) +(let e7 (Sub (Mul one x) (Mul (Add x one) one))) (run 3) (check (= e7 (Num (rational -1 1)))) (pop) (push) -(define e8 (Sub (Add x one) x)) +(let e8 (Sub (Add x one) x)) (run 4) (check (= e8 one)) (pop) (push) -(define e9 (Sub (Add x one) one)) +(let e9 (Sub (Add x one) one)) (run 4) (check (= e9 x)) (pop) @@ -536,35 +536,35 @@ (push) (set (lo x) r-one) -(define e10 (Div (Mul x three) x)) +(let e10 (Div (Mul x three) x)) (run 3) (check (= e10 three)) (pop) (push) -(define e11 (Sub (Mul (Sqrt (Add x one)) (Sqrt (Add x one))) (Mul (Sqrt x) (Sqrt x)))) +(let e11 (Sub (Mul (Sqrt (Add x one)) (Sqrt (Add x one))) (Mul (Sqrt x) (Sqrt x)))) (run 5) (check (= one e11)) (pop) (push) -(define e12 (Add (Num (rational 1 5)) (Num (rational 3 10)))) +(let e12 (Add (Num (rational 1 5)) (Num (rational 3 10)))) (run 1) (check (= e12 (Num (rational 1 2)))) (pop) (push) -(define e13 (Unary "cos" (Const "PI"))) +(let e13 (Unary "cos" (Const "PI"))) (run 1) (check (= e13 (Num (rational -1 1)))) (pop) (push) -(define sqrt5 (Sqrt (Num (rational 5 1)))) -(define e14 +(let sqrt5 (Sqrt (Num (rational 5 1)))) +(let e14 (Div one (Sub (Div (Add one sqrt5) two) (Div (Sub one sqrt5) two)))) -(define tgt (Div one sqrt5)) +(let tgt (Div one sqrt5)) (run 6) (check (= e14 tgt)) (pop) \ No newline at end of file diff --git a/tests/intersection.egg b/tests/intersection.egg index 6e8eb800..49a5d7ba 100644 --- a/tests/intersection.egg +++ b/tests/intersection.egg @@ -14,17 +14,17 @@ (set (intersect f1 f2) (f x3)) )) -(define a1 (Var "a1")) (define a2 (Var "a2")) (define a3 (Var "a3")) -(define b1 (Var "b1")) (define b2 (Var "b2")) (define b3 (Var "b3")) +(let a1 (Var "a1")) (let a2 (Var "a2")) (let a3 (Var "a3")) +(let b1 (Var "b1")) (let b2 (Var "b2")) (let b3 (Var "b3")) ;; e-graph 1: f(a) = f(b), f(f(a)) -(define t1 (f (f a1))) -(define fb1 (f b1)) +(let t1 (f (f a1))) +(let fb1 (f b1)) (union (f a1) fb1) ;; e-graph 2: f(f(a)) = f(f(b)) -(define t2 (f (f a2))) -(define t2p (f (f b2))) +(let t2 (f (f a2))) +(let t2p (f (f b2))) (union t2 t2p) (set (intersect a1 a2) a3) @@ -32,8 +32,8 @@ (run 100) -(define t3 (f (f a3))) -(extract :variants 5 t3) +(let t3 (f (f a3))) +(query-extract :variants 5 t3) ;; f(f(a)) = f(f(b)) is preserved (check (= (f (f a3)) (f (f b3)))) diff --git a/tests/interval.egg b/tests/interval.egg index ef5c428a..2c82611d 100644 --- a/tests/interval.egg +++ b/tests/interval.egg @@ -11,8 +11,8 @@ (min (min (* (lo a) (lo b)) (* (lo a) (hi b))) (min (* (hi a) (lo b)) (* (hi a) (hi b))))))) -(define x (Var "x")) -(define e (Mul x x)) +(let x (Var "x")) +(let e (Mul x x)) (set (lo x) (rational -10 1)) (set (hi x) (rational 10 1)) @@ -28,4 +28,4 @@ (check (= (lo e) (rational 100 1))) ;; testing extraction of rationals -(extract (lo e)) +(query-extract (lo e)) diff --git a/tests/knapsack.egg b/tests/knapsack.egg index 4a9a30e4..40122e4e 100644 --- a/tests/knapsack.egg +++ b/tests/knapsack.egg @@ -27,13 +27,13 @@ (rule ((= f (Knap capacity Nil))) ((set (Knap capacity Nil) (Num 0)))) -(define test1 (Knap 13 (Cons 5 5 (Cons 3 3 (Cons 12 12 (Cons 5 5 Nil)))))) +(let test1 (Knap 13 (Cons 5 5 (Cons 3 3 (Cons 12 12 (Cons 5 5 Nil)))))) -(define test2 (Knap 5 (Cons 6 6 Nil))) +(let test2 (Knap 5 (Cons 6 6 Nil))) -(define test3 (Knap 5 (Cons 1 1 (Cons 1 1 (Cons 1 1 Nil))))) +(let test3 (Knap 5 (Cons 1 1 (Cons 1 1 (Cons 1 1 Nil))))) -(define test4 (Knap 15 (Cons 12 40 (Cons 2 20 (Cons 1 20 (Cons 1 10 (Cons 4 100 Nil))))))) +(let test4 (Knap 15 (Cons 12 40 (Cons 2 20 (Cons 1 20 (Cons 1 10 (Cons 4 100 Nil))))))) ; turn a (Num n) into n (function Unwrap (expr) i64) @@ -41,14 +41,5 @@ (run 100) -(extract (Unwrap test1)) (check (= test1 (Num 13))) -(extract (Unwrap test2)) -(check (= test2 (Num 0))) - -(extract (Unwrap test3)) -(check (= test3 (Num 3))) - -(extract (Unwrap test4)) -(check (= test4 (Num 150))) diff --git a/tests/lambda.egg b/tests/lambda.egg index d16fe630..9d93dab0 100644 --- a/tests/lambda.egg +++ b/tests/lambda.egg @@ -132,7 +132,7 @@ ;; lambda_under (push) -(define e +(let e (Lam (V "x") (Add (Val (Num 4)) (App (Lam (V "y") (Var (V "y"))) (Val (Num 4)))))) @@ -142,7 +142,7 @@ ;; lambda_if_elim (push) -(define e2 (If (Eq (Var (V "a")) (Var (V "b"))) +(let e2 (If (Eq (Var (V "a")) (Var (V "b"))) (Add (Var (V "a")) (Var (V "a"))) (Add (Var (V "a")) (Var (V "b"))))) (run 10) @@ -151,7 +151,7 @@ ;; lambda_let_simple (push) -(define e3 (Let (V "x") (Val (Num 0)) +(let e3 (Let (V "x") (Val (Num 0)) (Let (V "y") (Val (Num 1)) (Add (Var (V "x")) (Var (V "y")))))) (run 10) @@ -160,7 +160,7 @@ ;; lambda_capture (push) -(define e4 (Let (V "x") (Val (Num 1)) +(let e4 (Let (V "x") (Val (Num 1)) (Lam (V "x") (Var (V "x"))))) (run 10) (fail (check (= e4 (Lam (V "x") (Val (Num 1)))))) @@ -168,7 +168,7 @@ ;; lambda_capture_free (push) -(define e5 (Let (V "y") (Add (Var (V "x")) (Var (V "x"))) +(let e5 (Let (V "y") (Add (Var (V "x")) (Var (V "x"))) (Lam (V "x") (Var (V "y"))))) (run 10) (check (set-contains (freer (Lam (V "x") (Var (V "y")))) (V "y"))) @@ -177,7 +177,7 @@ ;; lambda_closure_not_seven (push) -(define e6 +(let e6 (Let (V "five") (Val (Num 5)) (Let (V "add-five") (Lam (V "x") (Add (Var (V "x")) (Var (V "five")))) (Let (V "five") (Val (Num 6)) @@ -190,7 +190,7 @@ ;; lambda_compose (push) -(define e7 +(let e7 (Let (V "compose") (Lam (V "f") (Lam (V "g") (Lam (V "x") (App (Var (V "f")) @@ -214,14 +214,14 @@ ;; lambda_if_simple (push) -(define e10 (If (Eq (Val (Num 1)) (Val (Num 1))) (Val (Num 7)) (Val (Num 9)))) +(let e10 (If (Eq (Val (Num 1)) (Val (Num 1))) (Val (Num 7)) (Val (Num 9)))) (run 4) (check (= e10 (Val (Num 7)))) (pop) ;; lambda_compose_many (push) -(define e11 +(let e11 (Let (V "compose") (Lam (V "f") (Lam (V "g") (Lam (V "x") (App (Var (V "f")) (App (Var (V "g")) (Var (V "x"))))))) (Let (V "add1") (Lam (V "y") (Add (Var (V "y")) (Val (Num 1)))) @@ -245,7 +245,7 @@ ;; lambda_if (push) -(define e8 +(let e8 (Let (V "zeroone") (Lam (V "x") (If (Eq (Var (V "x")) (Val (Num 0))) (Val (Num 0)) diff --git a/tests/levenshtein-distance.egg b/tests/levenshtein-distance.egg index 41345cc2..cd45a8bb 100644 --- a/tests/levenshtein-distance.egg +++ b/tests/levenshtein-distance.egg @@ -49,14 +49,14 @@ (rule ((= x (Num n))) ((set (Unwrap (Num n)) n))) ; Tests -(define HorseStr (Cons "h" (Cons "o" (Cons "r" (Cons "s" (Cons "e" Empty)))))) -(define RosStr (Cons "r" (Cons "o" (Cons "s" Empty)))) -(define IntentionStr (Cons "i" (Cons "n" (Cons "t" (Cons "e" (Cons "n" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) -(define ExecutionStr (Cons "e" (Cons "x" (Cons "e" (Cons "c" (Cons "u" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) +(let HorseStr (Cons "h" (Cons "o" (Cons "r" (Cons "s" (Cons "e" Empty)))))) +(let RosStr (Cons "r" (Cons "o" (Cons "s" Empty)))) +(let IntentionStr (Cons "i" (Cons "n" (Cons "t" (Cons "e" (Cons "n" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) +(let ExecutionStr (Cons "e" (Cons "x" (Cons "e" (Cons "c" (Cons "u" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) -(define Test1 (EditDist HorseStr RosStr)) -(define Test2 (EditDist IntentionStr ExecutionStr)) -(define Test3 (EditDist HorseStr Empty)) +(let Test1 (EditDist HorseStr RosStr)) +(let Test2 (EditDist IntentionStr ExecutionStr)) +(let Test3 (EditDist HorseStr Empty)) (run 100) @@ -67,4 +67,4 @@ (check (= Test2 (Num 5))) (extract (Unwrap Test3)) -(check (= Test3 (Num 5))) +(check (= Test3 (Num 5))) \ No newline at end of file diff --git a/tests/map.egg b/tests/map.egg index 04026f35..51e9897f 100644 --- a/tests/map.egg +++ b/tests/map.egg @@ -1,7 +1,7 @@ (sort MyMap (Map i64 String)) -(define my_map1 (map-insert (map-empty) 1 "one")) -(define my_map2 (map-insert my_map1 2 "two")) +(let my_map1 (map-insert (map-empty) 1 "one")) +(let my_map2 (map-insert my_map1 2 "two")) (check (= "one" (map-get my_map1 1))) -(extract my_map2) \ No newline at end of file +(query-extract my_map2) \ No newline at end of file diff --git a/tests/math-microbenchmark.egg b/tests/math-microbenchmark.egg new file mode 100644 index 00000000..8b3ef92b --- /dev/null +++ b/tests/math-microbenchmark.egg @@ -0,0 +1,69 @@ +(datatype Math + (Diff Math Math) + (Integral Math Math) + + (Add Math Math) + (Sub Math Math) + (Mul Math Math) + (Div Math Math) + (Pow Math Math) + (Ln Math) + (Sqrt Math) + + (Sin Math) + (Cos Math) + + (Const Rational) + (Var String)) + +(rewrite (Add a b) (Add b a)) +(rewrite (Mul a b) (Mul b a)) +(rewrite (Add a (Add b c)) (Add (Add a b) c)) +(rewrite (Mul a (Mul b c)) (Mul (Mul a b) c)) + +(rewrite (Sub a b) (Add a (Mul (Const (rational -1 1)) b))) +;; (rewrite (Div a b) (Mul a (Pow b (Const (rational -1 1)))) :when ((is-not-zero b))) + +(rewrite (Add a (Const (rational 0 1))) a) +(rewrite (Mul a (Const (rational 0 1))) (Const (rational 0 1))) +(rewrite (Mul a (Const (rational 1 1))) a) + +(rewrite (Sub a a) (Const (rational 0 1))) + +(rewrite (Mul a (Add b c)) (Add (Mul a b) (Mul a c))) +(rewrite (Add (Mul a b) (Mul a c)) (Mul a (Add b c))) + +(rewrite (Mul (Pow a b) (Pow a c)) (Pow a (Add b c))) +(rewrite (Pow x (Const (rational 1 1))) x) +(rewrite (Pow x (Const (rational 2 1))) (Mul x x)) + +(rewrite (Diff x (Add a b)) (Add (Diff x a) (Diff x b))) +(rewrite (Diff x (Mul a b)) (Add (Mul a (Diff x b)) (Mul b (Diff x a)))) + +(rewrite (Diff x (Sin x)) (Cos x)) +(rewrite (Diff x (Cos x)) (Mul (Const (rational -1 1)) (Sin x))) + +(rewrite (Integral (Const (rational 1 1)) x) x) +(rewrite (Integral (Cos x) x) (Sin x)) +(rewrite (Integral (Sin x) x) (Mul (Const (rational -1 1)) (Cos x))) +(rewrite (Integral (Add f g) x) (Add (Integral f x) (Integral g x))) +(rewrite (Integral (Sub f g) x) (Sub (Integral f x) (Integral g x))) +(rewrite (Integral (Mul a b) x) +(Sub (Mul a (Integral b x)) + (Integral (Mul (Diff x a) (Integral b x)) x))) +(Integral (Ln (Var "x")) (Var "x")) +(Integral (Add (Var "x") (Cos (Var "x"))) (Var "x")) +(Integral (Mul (Cos (Var "x")) (Var "x")) (Var "x")) +(Diff (Var "x") (Add (Const (rational 1 1)) (Mul (Const (rational 2 1)) (Var "x")))) +(Diff (Var "x") (Sub (Pow (Var "x") (Const (rational 3 1))) (Mul (Const (rational 7 1)) (Pow (Var "x") (Const (rational 2 1)))))) +(Add (Mul (Var "y") (Add (Var "x") (Var "y"))) (Sub (Add (Var "x") (Const (rational 2 1))) (Add (Var "x") (Var "x")))) +(Div (Const (rational 1 1)) + (Sub (Div (Add (Const (rational 1 1)) + (Sqrt (Var "five"))) + (Const (rational 2 1))) + (Div (Sub (Const (rational 1 1)) + (Sqrt (Var "five"))) + (Const (rational 2 1))))) +(run 10) +(print-size Add) +(print-size Mul) diff --git a/tests/math.egg b/tests/math.egg index 7190d235..127b0549 100644 --- a/tests/math.egg +++ b/tests/math.egg @@ -128,27 +128,8 @@ (Sub (Mul a (Integral b x)) (Integral (Mul (Diff x a) (Integral b x)) x))) -;; math_simplify_root -(push) -(define start-expr - (Div (Const (rational 1 1)) - (Sub (Div (Add (Const (rational 1 1)) - (Sqrt (Var "five"))) - (Const (rational 2 1))) - (Div (Sub (Const (rational 1 1)) - (Sqrt (Var "five"))) - (Const (rational 2 1)))))) -(run 11) -(define end-expr - (Div (Const (rational 1 1)) - (Sqrt (Var "five")))) -(check (= start-expr end-expr)) - -(pop) - -;; math_simplify_const -(push) -(define start-expr2 (Add (Const (rational 1 1)) + +(let start-expr2 (Add (Const (rational 1 1)) (Sub (Var "a") (Mul (Sub (Const (rational 2 1)) (Const (rational 1 1))) @@ -156,23 +137,8 @@ (run 6) -(define end-expr2 (Const (rational 1 1))) +(let end-expr2 (Const (rational 1 1))) + (check (= start-expr2 end-expr2)) -(pop) - -;; math_simplify_factor -(push) -(define start-expr3 (Mul (Add (Var "x") (Const (rational 3 1))) - (Add (Var "x") (Const (rational 1 1))))) -(run 8) -(define end-expr3 (Add (Add (Mul (Var "x") (Var "x")) - (Mul (Const (rational 4 1)) (Var "x"))) - (Const (rational 3 1)))) -(check (= start-expr3 end-expr3)) -(pop) - -(simplify 5 (Add (Const (rational 1 1)) - (Sub (Var "a") - (Mul (Sub (Const (rational 2 1)) - (Const (rational 1 1))) - (Var "a"))))) + +(query-extract start-expr2) \ No newline at end of file diff --git a/tests/matrix.egg b/tests/matrix.egg index a583df41..4dc038e6 100644 --- a/tests/matrix.egg +++ b/tests/matrix.egg @@ -68,13 +68,13 @@ ) -(define n (NamedDim "n")) -(define m (NamedDim "m")) -(define p (NamedDim "p")) +(let n (NamedDim "n")) +(let m (NamedDim "m")) +(let p (NamedDim "p")) -(define A (NamedMat "A")) -(define B (NamedMat "B")) -(define C (NamedMat "C")) +(let A (NamedMat "A")) +(let B (NamedMat "B")) +(let C (NamedMat "C")) (set (nrows A) n) (set (ncols A) n) @@ -82,17 +82,17 @@ (set (ncols B) m) (set (nrows C) p) (set (ncols C) p) -(define ex1 (MMul (Kron (Id n) B) (Kron A (Id m)))) -(define rows1 (nrows ex1)) -(define cols1 (ncols ex1)) +(let ex1 (MMul (Kron (Id n) B) (Kron A (Id m)))) +(let rows1 (nrows ex1)) +(let cols1 (ncols ex1)) (run 20) (check (= (nrows B) m)) (check (= (nrows (Kron (Id n) B)) (Times n m))) -(define simple_ex1 (Kron A B)) +(let simple_ex1 (Kron A B)) (check (= ex1 simple_ex1)) -(define ex2 (MMul (Kron (Id p) C) (Kron A (Id m)))) +(let ex2 (MMul (Kron (Id p) C) (Kron A (Id m)))) (run 10) (fail (check (= ex2 (Kron A C)))) diff --git a/tests/merge-during-rebuild.egg b/tests/merge-during-rebuild.egg index 9ff661e2..326705b8 100644 --- a/tests/merge-during-rebuild.egg +++ b/tests/merge-during-rebuild.egg @@ -4,10 +4,10 @@ (datatype N (Node i64)) (function distance (N N) i64 :merge (min old new)) -(define a (Node 0)) -(define b (Node 1)) -(define x (Node 2)) -(define y (Node 3)) +(let a (Node 0)) +(let b (Node 1)) +(let x (Node 2)) +(let y (Node 3)) (set (distance x y) 1) (set (distance a b) 2) diff --git a/tests/merge-saturates.egg b/tests/merge-saturates.egg index 46f0d10d..1a5dff88 100644 --- a/tests/merge-saturates.egg +++ b/tests/merge-saturates.egg @@ -1,3 +1,6 @@ +;; SKIP_PROOFS +;; doesn't work with proofs because of the side effect in +;; the merge function (function foo () i64 :merge (min old new)) (set (foo) 0) diff --git a/tests/name-resolution.egg b/tests/name-resolution.egg index adf1db3f..2ef292ce 100644 --- a/tests/name-resolution.egg +++ b/tests/name-resolution.egg @@ -2,15 +2,15 @@ (Add Math Math) (Num i64)) -(define zero (Num 0)) +(let zero (Num 0)) ;; zero here refers to the function/constant zero, not a free variable (rewrite (Add zero x) x) -(define a (Add (Num 0) (Num 3))) -(define b (Add (Num 7) (Num 9))) -(define c (Num 16)) +(let a (Add (Num 0) (Num 3))) +(let b (Add (Num 7) (Num 9))) +(let c (Num 16)) (union b c) ;; crash if we merge two numbers diff --git a/tests/path.egg b/tests/path.egg index 812564f5..ffcd069e 100644 --- a/tests/path.egg +++ b/tests/path.egg @@ -14,6 +14,6 @@ (fail (check (path 1 2))) (run 3) -(print path) +(print-table path) (check (path 1 4)) (fail (check (path 4 1))) diff --git a/tests/pathproof.egg b/tests/pathproof.egg index 9a50558e..fafeb062 100644 --- a/tests/pathproof.egg +++ b/tests/pathproof.egg @@ -26,5 +26,5 @@ ; Would prefer being able to check ;(check (path 1 2 _)) ; or extract -;(extract (path 1 4 ?p)) -(print path) \ No newline at end of file +;(query-extract (path 1 4 ?p)) +(print-table path) \ No newline at end of file diff --git a/tests/points-to.egg b/tests/points-to.egg index a61a294c..a4401f3a 100644 --- a/tests/points-to.egg +++ b/tests/points-to.egg @@ -43,15 +43,15 @@ ; l4: o2.f = o1; ; l5: Object r = o3.f; -(define A (Class "A")) -(define B (Class "B")) -(define f (Field "f")) - -(define l1 (New "o1" A)) -(define l2 (New "o2" B)) -(define l3 (Assign "o3" "o2")) -(define l4 (Store "o2" f "o1")) -(define l5 (Load "r" "o3" f)) +(let A (Class "A")) +(let B (Class "B")) +(let f (Field "f")) + +(let l1 (New "o1" A)) +(let l2 (New "o2" B)) +(let l3 (Assign "o3" "o2")) +(let l4 (Store "o2" f "o1")) +(let l5 (Load "r" "o3" f)) (run 3) diff --git a/tests/prims.egg b/tests/prims.egg index fdc5fb2d..f7331056 100644 --- a/tests/prims.egg +++ b/tests/prims.egg @@ -7,7 +7,7 @@ (relation true ()) (true) -(define infinity 99999999) ; close enough +(let infinity 99999999) ; close enough ; ==== PROBLEM INSTANCES ==== @@ -110,7 +110,7 @@ ; === PRINT RESULTS === -; (print edge-in-mst) ; this is not very helpful +; (print-table edge-in-mst) ; this is not very helpful ; Just copy canonical edges to solution (relation solution (i64 i64 i64)) @@ -121,4 +121,4 @@ :ruleset finalize) (run-schedule (saturate finalize)) -(print solution) ; this is better +(print-table solution) ; this is better diff --git a/tests/proofs.egg b/tests/proofs.egg deleted file mode 100644 index 0929de20..00000000 --- a/tests/proofs.egg +++ /dev/null @@ -1,371 +0,0 @@ -(datatype Math - (Add Math Math) - (Sub Math Math) - (Const Rational) - (Var String)) - -(datatype AstMath - (AstAdd AstMath AstMath) - (AstSub AstMath AstMath) - (AstConst Rational) - (AstVar String)) - -(datatype ProofList) - -;; There are two types of proofs: -;; 1) Provenance proofs justify the existance of a term -;; 2) Equality proofs prove two terms are equal -;; Equality proofs `a = b` also double as -;; provenance proofs `a` and `b` -;; When a proof `a = b` is used as a provenance proof, it is a proof of `b` -(datatype Proof - ;; proves that a term exists in the database - (Original AstMath) - ;; proves two terms were set equal in the database - (OriginalEq AstMath AstMath) - ;; justifies the fact that a rule fired - ;; the proof list justifies each of the premises in order - ;; it then justifies all of the equality constraints in order - (Rule ProofList String) - ;; using a rule justification, - ;; proves a term exists - (RuleTerm Proof AstMath) - ;; using a rule justification, proves two terms are equal - (RuleEquality Proof AstMath AstMath) - - ;; given proofs for x1 = x2, x2 = x3, ..., xn = xn+1 - ;; proves that x1 = xn+1 and x1 and xn+1 exist - (Transitivity ProofList) - ;; given x1 = x2, proves x2 = x1 - (Flip Proof) - ;; given a proof for a term t1, proves that t1 is equal - ;; to another term t2 via equality proofs on their children - (Congruence Proof ProofList) - ;; a placeholder for a proof of equality between two - ;; terms which can be proven equal using the graph - ;; stored in the `Eq` relation - (DemandEq AstMath AstMath :cost 10000000)) - -;; ProofList definitions -(function Cons (Proof ProofList) ProofList) -(declare Null ProofList) - -;; prove two terms equal -(function EqGraph (AstMath AstMath) Proof :cost 100000 :merge old) - -(datatype TrmPrf - (MakeTrmPrf AstMath Proof)) - -;; get child terms, proofs, and child terms -(function TrmOf (TrmPrf) AstMath :cost 10000) -(function PrfOf (TrmPrf) Proof :cost 100000) -(function Child1 (AstMath) AstMath :cost 100000) -(function Child2 (AstMath) AstMath :cost 100000) - - -;; For every Add in the database, -;; store an AstAdd representative -;; and a proof of that representative -(function AddRep (Math Math) TrmPrf - :on_merge ( - (let t1 (TrmOf old)) - (let t2 (TrmOf new)) - (let p1 (PrfOf old)) - (let x1 (Child1 t1)) - (let y1 (Child2 t1)) - (let x2 (Child1 t2)) - (let y2 (Child2 t2)) - (let cong-prf - (Congruence p1 - (Cons (DemandEq x1 x2) - (Cons (DemandEq y1 y2) - Null)))) - (set (EqGraph t1 t2) cong-prf) - (set (EqGraph t2 t1) (Flip cong-prf)) -) - :merge old) - -(function SubRep (Math Math) TrmPrf - :on_merge ( - (let t1 (TrmOf old)) - (let t2 (TrmOf new)) - (let p1 (PrfOf old)) - (let x1 (Child1 t1)) - (let y1 (Child2 t1)) - (let x2 (Child1 t2)) - (let y2 (Child2 t2)) - (let cong-prf - (Congruence p1 - (Cons (DemandEq x1 x2) - (Cons (DemandEq y1 y2) - Null)))) - (set (EqGraph t1 t2) cong-prf) - (set (EqGraph t2 t1) (Flip cong-prf)) -) :merge old) -(function ConstRep (Rational) TrmPrf :merge old) -(function VarRep (String) TrmPrf :merge old) - - -;; ############################## NORMAL RULES - -(rule ((= t (Add a b)) - (= tp (AddRep a b)) - (= term (TrmOf tp)) - (= proof (PrfOf tp))) - ((union (Add a b) (Add (Add a b) (Const (rational 0 1)))) - (let rhs (AstAdd term (AstConst (rational 0 1)))) - (let ruleprf (Rule (Cons proof - Null) - "add-identity")) - (let trmprf (MakeTrmPrf rhs (RuleTerm ruleprf rhs))) - (set (AddRep (Add a b) (Const (rational 0 1))) - trmprf) - (set (TrmOf trmprf) rhs) - (set (PrfOf trmprf) (RuleTerm ruleprf rhs)) - - (set (EqGraph term rhs) - (RuleEquality ruleprf term rhs)) - (set (EqGraph rhs term) - (Flip (RuleEquality ruleprf term rhs))))) - -(rule ((= t (Add a b)) ;; (rewrite (Add a b) (Add b c)) - (= tp (AddRep a b)) - (= lhs (TrmOf tp)) - (= proof (PrfOf tp)) - (= c1 (Child1 lhs)) - (= c2 (Child2 lhs))) - ((union (Add a b) (Add b a)) ;; normal rhs - (let ruleprf (Rule (Cons proof - Null) - "add-commute")) - (let rhs (AstAdd c2 c1)) - (let trmprf (MakeTrmPrf rhs (RuleTerm ruleprf rhs))) - (set (AddRep b a) - trmprf) - (set (TrmOf trmprf) rhs) - (set (PrfOf trmprf) (RuleTerm ruleprf rhs)) - - (set (EqGraph lhs rhs) ;; equality edge -> - (RuleEquality ruleprf lhs rhs)) - (set (EqGraph rhs lhs) ;; equality edge -> - (Flip (RuleEquality ruleprf lhs rhs))))) - -; (rewrite (Add a (Add b c)) (Add (Add a b) c)) -(rule ((= t (Add a (Add b c))) - (= tp (AddRep a (Add b c))) - (= term (TrmOf tp)) - (= proof (PrfOf tp)) - (= c1 (Child1 term)) - (= c2 (Child2 term)) - (= tp2 (AddRep b c)) - (= c2term (TrmOf tp2)) - (= proof-right (PrfOf tp2)) - (= c2termc1 (Child1 c2term)) - (= c2termc2 (Child2 c2term))) - ((union (Add a (Add b c)) (Add (Add a b) c)) - (let newrep (AstAdd c1 c2term)) - (let newrepproof - (Congruence - proof - (Cons - (DemandEq c1 c1) - (Cons - (DemandEq c2 c2term) - Null)))) - (let ruleproof - (Rule - (Cons - newrepproof - Null) - "add-assoc")) - ;; first, add our new representative to the graph via congruence - (set (EqGraph term newrep) newrepproof) - (set (EqGraph newrep term) newrepproof) - (let rhs (AstAdd (AstAdd c1 c2termc1) - c2termc2)) - (let trmprf - (MakeTrmPrf rhs (RuleTerm ruleproof rhs))) - ;; Add a proof for our RHS - (set (AddRep (Add a b) c) - trmprf) - (set (TrmOf trmprf) rhs) - (set (PrfOf trmprf) (RuleTerm ruleproof rhs)) - - ;; prove equality between new representative and the RHS - (set (EqGraph newrep rhs) - (RuleEquality ruleproof - newrep rhs)) - (set (EqGraph newrep rhs) - (Flip (RuleEquality ruleproof - newrep rhs))))) - - -;; ########################## PROOF RULES - -(ruleset proof-rules) -;; children -(rule ((= a (AstAdd c1 c2))) - ((set (Child1 a) c1) - (set (Child2 a) c2)) - :ruleset proof-rules) -(rule ((= a (AstSub c1 c2))) - ((set (Child1 a) c1) - (set (Child2 a) c2)) - :ruleset proof-rules) - -(rule ((= trmprf (MakeTrmPrf t p))) - ((set (TrmOf trmprf) t) - (set (PrfOf trmprf) p)) - :ruleset proof-rules) - - -(ruleset proof-extract) - -;; Silly function to get the proof -(function GetProof (AstMath) Proof :cost 1000000000) -(rule ((= tp (AddRep a b)) - (= term (TrmOf tp)) - (= proof (PrfOf tp))) - ((set (GetProof term) proof)) - :ruleset proof-extract) -(rule ((= tp (SubRep a b)) - (= term (TrmOf tp)) - (= proof (PrfOf tp))) - ((set (GetProof term) proof)) - :ruleset proof-extract) -(rule ((= tp (ConstRep a)) - (= term (TrmOf tp)) - (= proof (PrfOf tp))) - ((set (GetProof term) proof)) - :ruleset proof-extract) -(rule ((= tp (VarRep a)) - (= term (TrmOf tp)) - (= proof (PrfOf tp))) - ((set (GetProof term) proof)) - :ruleset proof-extract) - -;; start term, end term, current progress term, proof -(function ProofBetween (AstMath AstMath AstMath) ProofList :cost 100000 :merge old) - -;; start proof search for equalities -(rule ((= e (DemandEq t1 t2))) - ((set (ProofBetween t1 t2 t1) Null)) - :ruleset proof-extract) -;; do one step of proof search to find a path -(rule ((= proof (ProofBetween t1 t2 tmid)) - (= eproof (EqGraph tmid next))) - ((set (ProofBetween t1 t2 next) - (Cons eproof proof))) - :ruleset proof-extract) -;; when you find a path, union it with the equality proof -(rule ((= e (DemandEq t1 t2)) - (= prooflist (ProofBetween t1 t2 t2))) - ((set (DemandEq t1 t2) (Transitivity prooflist))) - :ruleset proof-extract) - - - -(define two (rational 2 1)) -(define start1 (Add (Var "x") (Const two))) -;; add original proofs -(set (VarRep "x") - (MakeTrmPrf (AstVar "x") (Original (AstVar "x")))) -(set (ConstRep two) - (MakeTrmPrf (AstConst two) (Original (AstConst two)))) -(define addx2 (AstAdd (AstVar "x") (AstConst two))) -(define add2x (AstAdd (AstConst two) (AstVar "x"))) -(set (AddRep (Var "x") (Const two)) - (MakeTrmPrf - addx2 - (Original (AstAdd (AstVar "x") (AstConst two))))) - -(run proof-rules 1000) -(run 1) -(run proof-rules 1000) -(run 1) -(run proof-rules 1000) -(run 1) -(run proof-rules 1000) - -(define zero (AstConst (rational 0 1))) -(define addzero (AstAdd addx2 zero)) -(define addzerofront (AstAdd (AstAdd zero (AstVar "x")) (AstConst two))) - -(DemandEq addx2 add2x) -(DemandEq addx2 addzerofront) - -(run proof-extract 100) - - -(check (!= (Var "x") (Const two))) -(check (= (Add (Var "x") (Const two)) - (Add (Const two) (Var "x")))) -(check - (= (GetProof (AstAdd (AstVar "x") (AstConst two))) - (Original (AstAdd (AstVar "x") (AstConst two))))) - -(check - (= (GetProof add2x) - (RuleTerm - (Rule (Cons (Original addx2) Null)"add-commute") - add2x - ))) - -(check (= (DemandEq addx2 add2x) - (Transitivity - (Cons - (RuleEquality - (Rule (Cons (Original addx2) Null)"add-commute") - addx2 - add2x) - Null)))) - - -(check (= (DemandEq addx2 addzero) - (Transitivity - (Cons - (RuleEquality - (Rule - (Cons - (Original addx2) Null) "add-identity") - addx2 - addzero) - Null)))) - -(check (= (DemandEq addx2 addzerofront) - (Transitivity - (Cons - ;; 0+(x+2) -> (0+x)+2 - (RuleEquality - ;; proof that the assoc rule fires - (Rule - (Cons - (Congruence - ;; proof of 0+(x+2) - (RuleTerm - (Rule (Cons (RuleTerm (Rule (Cons (Original addx2) Null) "add-identity") addzero) Null) "add-commute") - (AstAdd zero addx2)) - - ;; children already equal - (Cons (Transitivity Null) (Cons (Transitivity Null) Null))) - Null) "add-assoc") - (AstAdd zero addx2) addzerofront) - - - ;; (x+2)+0 -> 0+(x+2) - (Cons - (RuleEquality - (Rule (Cons (RuleTerm (Rule (Cons (Original addx2) Null) "add-identity") addzero) Null) "add-commute") - - addzero - - (AstAdd zero addx2)) - - ;; x+2 -> (x+2)+0 - (Cons (RuleEquality - (Rule (Cons (Original addx2) Null) "add-identity") - addx2 - addzero) - - Null))) - ))) \ No newline at end of file diff --git a/tests/repro-constraineq.egg b/tests/repro-constraineq.egg new file mode 100644 index 00000000..6ac52883 --- /dev/null +++ b/tests/repro-constraineq.egg @@ -0,0 +1,2 @@ +(rule ((= x 1) (= y x) (= z y)) ()) +(run 1) \ No newline at end of file diff --git a/tests/repro-constraineq2.egg b/tests/repro-constraineq2.egg new file mode 100644 index 00000000..522bb43d --- /dev/null +++ b/tests/repro-constraineq2.egg @@ -0,0 +1,2 @@ +(rule ((= x 1) (= y x)) ()) +(run 1) \ No newline at end of file diff --git a/tests/repro-define.egg b/tests/repro-define.egg index f908bda6..4fce3ac5 100644 --- a/tests/repro-define.egg +++ b/tests/repro-define.egg @@ -2,7 +2,7 @@ (S Nat)) (declare Zero Nat) -(define two (S (S Zero))) +(let two (S (S Zero))) (union two (S (S (S Zero)))) (check (= two (S (S (S Zero))))) diff --git a/tests/repro-desugar-143.egg b/tests/repro-desugar-143.egg new file mode 100644 index 00000000..dde3e826 --- /dev/null +++ b/tests/repro-desugar-143.egg @@ -0,0 +1,40 @@ +;; To test on issue #143 +(rule ((= x 1) (= y x)) ()) +(rule ((= x 1) (= y x) (= z y)) ()) + +(function f (i64) i64) + +(set (f 0) 0) + +;; a funky id rule +(rule ((f x) (= x y) (= z y)) ((let a (f z)) (set (f (+ z 1)) (+ a 1)))) + +(run 20) + +(print-table f) +(check (= (f 10) 10)) + +(datatype Value (Num i64)) +(function fib (Value) Value) + +;; a funky fibonacci that test on more complex case and user defined datatype +(rule ((= (Num a) (fib (Num x))) + (= (Num b) (fib (Num y))) + (= x1 x) + (= y1 y) + (= a1 a) + (= b1 b) + (= x1 x2) + (= y1 y2) + (= a1 a2) + (= b1 b2) + (= 1 (- x2 y2))) + ((let n (+ x 1)) (let sum (+ a1 b2)) (set (fib (Num n)) (Num sum)))) + +(set (fib (Num 1)) (Num 1)) +(set (fib (Num 2)) (Num 1)) + +(run 20) + +(print-table fib) +(check (= (fib (Num 10)) (Num 55))) diff --git a/tests/repro-primitive-query.egg b/tests/repro-primitive-query.egg new file mode 100644 index 00000000..ebe44430 --- /dev/null +++ b/tests/repro-primitive-query.egg @@ -0,0 +1,12 @@ +(datatype Math + (Num i64)) + +(Num 1) +(Num 2) + +(rule ((Num ?a) + (Num ?b) + (= (+ ?a ?b) 5)) + ((panic "should not have matched"))) + +(run 100) \ No newline at end of file diff --git a/tests/repro-querybug.egg b/tests/repro-querybug.egg index 8a5b6434..f21e3582 100644 --- a/tests/repro-querybug.egg +++ b/tests/repro-querybug.egg @@ -11,6 +11,6 @@ (rule ((= x (Cons x1 rest1)) (= y (Cons x2 rest2)) (= x1 x2) (eq rest1 rest2)) ((eq (Cons x1 rest1) (Cons x2 rest2)))) -(define mylist (Cons 1 Empty)) +(let mylist (Cons 1 Empty)) (run 100) diff --git a/tests/repro-should-saturate.egg b/tests/repro-should-saturate.egg index af9901d2..74c38a14 100644 --- a/tests/repro-should-saturate.egg +++ b/tests/repro-should-saturate.egg @@ -7,4 +7,4 @@ ((set (MyMap) 1) (set (MyMap) 2))) -(run-schedule (saturate (run 1))) \ No newline at end of file +(run-schedule (saturate (run))) diff --git a/tests/repro-silly-panic.egg b/tests/repro-silly-panic.egg index be1082ab..4f4db418 100644 --- a/tests/repro-silly-panic.egg +++ b/tests/repro-silly-panic.egg @@ -7,5 +7,5 @@ (rule ((= r (par q r)) (= q (par q r))) ((union r q))) ; tests -(define q (par A A)) +(let q (par A A)) (run 10) \ No newline at end of file diff --git a/tests/repro-unsound-htutorial.egg b/tests/repro-unsound-htutorial.egg new file mode 100644 index 00000000..fcf1031d --- /dev/null +++ b/tests/repro-unsound-htutorial.egg @@ -0,0 +1,16 @@ +(datatype Math + (Num Rational) + (Var String) + (Add Math Math) + (Div Math Math) + (Mul Math Math)) + +(let z (Var "z")) + +(Add (Var "x") (Var "y")) + +(rewrite (Add a z) a) + +(run 2) + +(fail (check (= (Var "x") (Add (Var "x") (Var "y"))))) \ No newline at end of file diff --git a/tests/repro-unsound.egg b/tests/repro-unsound.egg index f50befbf..db3236d6 100644 --- a/tests/repro-unsound.egg +++ b/tests/repro-unsound.egg @@ -2,12 +2,12 @@ (datatype HerbieType (Type String)) (datatype Math (Num HerbieType Rational) (Var HerbieType String) (Fma HerbieType Math Math Math) (If HerbieType Math Math Math) (Less HerbieType Math Math) (LessEq HerbieType Math Math) (Greater HerbieType Math Math) (GreaterEq HerbieType Math Math) (Eq HerbieType Math Math) (NotEq HerbieType Math Math) (Add HerbieType Math Math) (Sub HerbieType Math Math) (Mul HerbieType Math Math) (Div HerbieType Math Math) (Pow HerbieType Math Math) (Atan2 HerbieType Math Math) (Hypot HerbieType Math Math) (And HerbieType Math Math) (Or HerbieType Math Math) (Not HerbieType Math) (Neg HerbieType Math) (Sqrt HerbieType Math) (Cbrt HerbieType Math) (Fabs HerbieType Math) (Ceil HerbieType Math) (Floor HerbieType Math) (Round HerbieType Math) (Log HerbieType Math) (Exp HerbieType Math) (Sin HerbieType Math) (Cos HerbieType Math) (Tan HerbieType Math) (Atan HerbieType Math) (Asin HerbieType Math) (Acos HerbieType Math) (Expm1 HerbieType Math) (Log1p HerbieType Math) (Sinh HerbieType Math) (Cosh HerbieType Math) (Tanh HerbieType Math) (PI HerbieType) (E HerbieType) (INFINITY HerbieType) (TRUE HerbieType) (FALSE HerbieType)) -(define r-zero (rational 0 1)) -(define r-one (rational 1 1)) -(define r-two (rational 2 1)) -(define r-three (rational 3 1)) -(define r-four (rational 4 1)) -(define r-neg-one (rational -1 1)) +(let r-zero (rational 0 1)) +(let r-one (rational 1 1)) +(let r-two (rational 2 1)) +(let r-three (rational 3 1)) +(let r-four (rational 4 1)) +(let r-neg-one (rational -1 1)) (relation universe (Math HerbieType)) (rule ((= t (Expm1 ty a))) ((universe t ty))) (rewrite (Mul ty a b) (Mul ty b a)) @@ -57,7 +57,7 @@ -(define eggvar1 (Div (Type "binary64") (Expm1 (Type "binary64") (Add (Type "binary64") (Var (Type "binary64") "h0") (Var (Type "binary64") "h0"))) (Expm1 (Type "binary64") (Var (Type "binary64") "h0"))) :cost 10000000) +(let eggvar1 (Div (Type "binary64") (Expm1 (Type "binary64") (Add (Type "binary64") (Var (Type "binary64") "h0") (Var (Type "binary64") "h0"))) (Expm1 (Type "binary64") (Var (Type "binary64") "h0")))) (run 10) diff --git a/tests/repro-vec-unequal.egg b/tests/repro-vec-unequal.egg new file mode 100644 index 00000000..1ffaab62 --- /dev/null +++ b/tests/repro-vec-unequal.egg @@ -0,0 +1,17 @@ +(datatype Math + (Num i64)) + +(sort MathVec (Vec Math)) + +(let v1 (vec-of (Num 1) (Num 2))) +(let v2 (vec-of (Num 2) (Num 2))) + +(fail (check (= v1 v2))) + + +(sort IVec (Vec i64)) + +(let v3 (vec-of 1 2)) +(let v4 (vec-of 2 2)) + +(fail (check (= v3 v4))) \ No newline at end of file diff --git a/tests/resolution.egg b/tests/resolution.egg index 2a478026..56ba5637 100644 --- a/tests/resolution.egg +++ b/tests/resolution.egg @@ -74,9 +74,9 @@ ; example predicate (function p (i64) Bool) -(define p0 (p 0)) -(define p1 (p 1)) -(define p2 (p 2)) +(let p0 (p 0)) +(let p1 (p 1)) +(let p2 (p 2)) ;(set (or p0 (or p1 (or p2 False))) True) ;(set (or (negate p0) (or p1 (or (negate p2) False))) True) (set (or p1 (or (negate p2) False)) True) @@ -85,10 +85,8 @@ (union p1 False) (set (or (negate p0) (or p1 (or p2 False))) True) (run 10) -(print or) -(print p) -(print True) -(print False) + + (check (!= True False)) (check (= p0 False)) (check (= p2 False)) diff --git a/tests/semi_naive_set_function.egg b/tests/semi_naive_set_function.egg index 01c6009a..cd6caba2 100644 --- a/tests/semi_naive_set_function.egg +++ b/tests/semi_naive_set_function.egg @@ -18,7 +18,7 @@ (rule ((= f0 (f 0))) ((set (f 0) (f 3)))) (run 100) -(print f) ;; f0 is expected to have value 3, but has 0 in reality. +(print-table f) ;; f0 is expected to have value 3, but has 0 in reality. (check (= (f 0) 3)) (check (= (f 1) 3)) @@ -84,7 +84,7 @@ (run 100) -(print f) +(print-table f) (check (!= 0 (f 0))) (check (!= 0 (f 1))) (check (!= 0 (f 2))) @@ -99,6 +99,6 @@ (rule ((= x (g x))) ((set (g (+ x 1)) (+ (g (- x 1)) 2)))) (run 100) -(print g) +(print-table g) (check (= 20 (g 20))) \ No newline at end of file diff --git a/tests/typecheck.egg b/tests/typecheck.egg index 13e5ad51..70d869c8 100644 --- a/tests/typecheck.egg +++ b/tests/typecheck.egg @@ -65,35 +65,35 @@ ; ---- ; lam x : unit, f : unit -> unit . f x -(define e +(let e (Lam "x" TUnit (Lam "f" (TArr TUnit TUnit) (App (Var "f") (Var "x"))))) ; lam x : unit . x -(define id (Lam "x" TUnit (Var "x"))) -(define t-id (typeof Nil id)) +(let id (Lam "x" TUnit (Var "x"))) +(let t-id (typeof Nil id)) ; (e () id) = () -(define app-unit-id (App (App e MyUnit) id)) -(define t-app (typeof Nil app-unit-id)) +(let app-unit-id (App (App e MyUnit) id)) +(let t-app (typeof Nil app-unit-id)) -(define free (Lam "x" TUnit (Var "y"))) -(define t-free-ill (typeof Nil free)) -(define t-free-1 (typeof (Cons "y" TUnit Nil) free)) -(define t-free-2 (typeof (Cons "y" (TArr (TArr TUnit TUnit) TUnit) Nil) free)) +(let free (Lam "x" TUnit (Var "y"))) +(let t-free-ill (typeof Nil free)) +(let t-free-1 (typeof (Cons "y" TUnit Nil) free)) +(let t-free-2 (typeof (Cons "y" (TArr (TArr TUnit TUnit) TUnit) Nil) free)) (run 15) -(extract t-id) +(query-extract t-id) (check (= t-id (TArr TUnit TUnit))) -(extract t-app) +(query-extract t-app) (check (= t-app TUnit)) -(extract t-free-1) +(query-extract t-free-1) (check (= t-free-1 (TArr TUnit TUnit))) -(extract t-free-2) +(query-extract t-free-2) (check (= t-free-2 (TArr TUnit (TArr (TArr TUnit TUnit) TUnit)))) ; this will err -; (extract t-free-ill) +; (query-extract t-free-ill) diff --git a/tests/typeinfer.egg b/tests/typeinfer.egg index f41472aa..d7d7e42c 100644 --- a/tests/typeinfer.egg +++ b/tests/typeinfer.egg @@ -226,86 +226,86 @@ ;;;;;;;;;;;;;;;;;;;;;; (push) -(define id (Abs (V "x") (Var (V "x")))) -(define t-id (typeof (Nil) id 0)) +(let id (Abs (V "x") (Var (V "x")))) +(let t-id (typeof (Nil) id 0)) (run 100) (check (= t-id (TArr (TVar (Fresh (V "x") 0)) (TVar (Fresh (V "x") 0))))) (pop) (push) -(define let-poly (Let (V "id") (Abs (V "x") (Var (V "x"))) +(let let-poly (Let (V "id") (Abs (V "x") (Var (V "x"))) (App (App (Var (V "id")) (Var (V "id"))) (App (Var (V "id")) (True))))) -(define t-let-poly (typeof (Nil) let-poly 0)) +(let t-let-poly (typeof (Nil) let-poly 0)) (run 100) (check (= t-let-poly (TBool))) (pop) (push) -(define id-id (App (Abs (V "x") (Var (V "x"))) +(let id-id (App (Abs (V "x") (Var (V "x"))) (Abs (V "y") (Var (V "y"))))) -(define t-id-id (typeof (Nil) id-id 0)) +(let t-id-id (typeof (Nil) id-id 0)) (run 100) (check (= t-id-id (TArr (TVar (Fresh (V "y") 3)) (TVar (Fresh (V "y") 3))))) (pop) (push) -(define let-true (Let (V "x") (True) (True))) -(define t-let-true (typeof (Nil) let-true 0)) +(let let-true (Let (V "x") (True) (True))) +(let t-let-true (typeof (Nil) let-true 0)) (run 100) (check (= t-let-true (TBool))) (pop) (push) -(define let-var-true (Let (V "x") (True) (Var (V "x")))) -(define t-let-var-true (typeof (Nil) let-var-true 0)) +(let let-var-true (Let (V "x") (True) (Var (V "x")))) +(let t-let-var-true (typeof (Nil) let-var-true 0)) (run 100) (check (= t-let-var-true (TBool))) (pop) (push) -(define abs-id (Abs (V "x") +(let abs-id (Abs (V "x") (Let (V "y") (Abs (V "z") (Var (V "z"))) (Var (V "y"))))) -(define t-abs-id (typeof (Nil) abs-id 0)) +(let t-abs-id (typeof (Nil) abs-id 0)) (run 100) -(define x (Fresh (V "x") 0)) -(define z (Fresh (Fresh (V "z") 2) 4)) +(let x (Fresh (V "x") 0)) +(let z (Fresh (Fresh (V "z") 2) 4)) (check (= t-abs-id (TArr (TVar x) (TArr (TVar z) (TVar z))))) (pop) (push) -(define let-env (Let (V "x") (True) +(let let-env (Let (V "x") (True) (Let (V "f") (Abs (V "a") (Var (V "a"))) (Let (V "x") (MyUnit) (App (Var (V "f")) (Var (V "x"))) )))) -(define t-let-env (typeof (Nil) let-env 0)) +(let t-let-env (typeof (Nil) let-env 0)) (run 100) (check (= t-let-env (TUnit))) (pop) (push) -(define let-env-2a (Let (V "x") (MyUnit) +(let let-env-2a (Let (V "x") (MyUnit) (Let (V "f") (Abs (V "y") (Var (V "x"))) (Let (V "x") (True) (App (Var (V "f")) (Var (V "x"))))))) -(define t-let-env-2a (typeof (Nil) let-env-2a 0)) +(let t-let-env-2a (typeof (Nil) let-env-2a 0)) (run 100) (check (= t-let-env-2a (TUnit))) (pop) (push) -(define let-env-2b (App (Abs (V "x") +(let let-env-2b (App (Abs (V "x") (Let (V "f") (Abs (V "y") (Var (V "x"))) (Let (V "x") (True) (App (Var (V "f")) (Var (V "x")))))) (MyUnit))) -(define t-let-env-2b (typeof (Nil) let-env-2b 0)) +(let t-let-env-2b (typeof (Nil) let-env-2b 0)) (run 100) (check (= t-let-env-2b (TUnit))) (pop) @@ -313,25 +313,24 @@ (push) ;; ((lambda (x) ((lambda (f) ((lambda (x) (f x)) #t)) (lambda (y) x))) 5) -(define let-env-hard (App (Abs (V "x") +(let let-env-hard (App (Abs (V "x") (App (Abs (V "f") (App (Abs (V "x") (App (Var (V "f")) (Var (V "x")))) (True))) (Abs (V "y") (Var (V "x"))))) (MyUnit))) -(define t-let-env-hard (typeof (Nil) let-env-hard 0)) +(let t-let-env-hard (typeof (Nil) let-env-hard 0)) (run 100) (check (= t-let-env-hard (TUnit))) (pop) (push) -(define let-inst (Let (V "id") (Abs (V "x") (Var (V "x"))) +(let let-inst (Let (V "id") (Abs (V "x") (Var (V "x"))) (Let (V "iid") (Abs (V "y") (Var (V "id"))) (App (Var (V "iid")) - (App (Var (V "id")) (True))))) - :cost 1000) -(define t-let-inst (typeof (Nil) let-inst 0) :cost 1000) + (App (Var (V "id")) (True))))) ) +(let t-let-inst (typeof (Nil) let-inst 0)) (run 100) (check (= t-let-inst (TArr (TVar (Fresh (Fresh (Fresh (V "x") 1) 5) 7)) (TVar (Fresh (Fresh (Fresh (V "x") 1) 5) 7))))) (pop) diff --git a/tests/unification-points-to.egg b/tests/unification-points-to.egg index d33438c0..95267354 100644 --- a/tests/unification-points-to.egg +++ b/tests/unification-points-to.egg @@ -239,5 +239,5 @@ (check (= (AllocVar (Expr "v")) (AllocVar (Expr "u")))) (check (!= (AllocVar (Expr "v")) (AllocVar (Expr "sp")))) -(extract :variants 100 (AllocVar (Expr "u"))) -(extract :variants 100 (AllocVar (Expr "sp"))) +(query-extract :variants 100 (AllocVar (Expr "u"))) +(query-extract :variants 100 (AllocVar (Expr "sp"))) diff --git a/tests/until.egg b/tests/until.egg index 6b93eb37..f94c9cfc 100644 --- a/tests/until.egg +++ b/tests/until.egg @@ -13,9 +13,9 @@ ; A is cyclic of period 4 (rewrite (g* A (g* A (g* A A))) I) -(define A2 (g* A A)) -(define A4 (g* A2 A2)) -(define A8 (g* A4 A4)) +(let A2 (g* A A)) +(let A4 (g* A2 A2)) +(let A8 (g* A4 A4)) ; non terminating rule (relation allgs (G))