diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..a0ed54b7 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @egraphs-good/egglog-reviewers \ No newline at end of file diff --git a/README.md b/README.md index 9637c54a..6c0ff993 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # egglog - + Web Demo Main Branch Documentation diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 5e0b6d9d..7f7f11ba 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -125,7 +125,7 @@ fn normalize_expr( panic!("handled above"); } Expr::Call(f, children) => { - let is_compute = TypeInfo::default().is_primitive(*f); + let is_compute = desugar.type_info.is_primitive(*f); let mut new_children = vec![]; for child in children { match child { @@ -418,6 +418,7 @@ pub struct Desugar { // TODO fix getting fresh names using modules pub(crate) number_underscores: usize, pub(crate) global_variables: HashSet, + pub(crate) type_info: TypeInfo, } impl Default for Desugar { @@ -429,6 +430,7 @@ impl Default for Desugar { parser: ast::parse::ProgramParser::new(), number_underscores: 3, global_variables: Default::default(), + type_info: TypeInfo::default(), } } } @@ -689,6 +691,7 @@ impl Clone for Desugar { parser: ast::parse::ProgramParser::new(), number_underscores: self.number_underscores, global_variables: self.global_variables.clone(), + type_info: self.type_info.clone(), } } } diff --git a/src/ast/expr.rs b/src/ast/expr.rs index dc7ea582..36129091 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -193,3 +193,24 @@ impl Display for Expr { write!(f, "{}", self.to_sexp()) } } + +// currently only used for testing, but no reason it couldn't be used elsewhere later +#[cfg(test)] +pub(crate) fn parse_expr(s: &str) -> Result> { + let parser = ast::parse::ExprParser::new(); + parser + .parse(s) + .map_err(|e| e.map_token(|tok| tok.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parser_display_roundtrip() { + let s = r#"(f (g a 3) 4.0 (H "hello"))"#; + let e = parse_expr(s).unwrap(); + assert_eq!(format!("{}", e), s); + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8612590c..0461dd6d 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -395,7 +395,9 @@ impl ToSexp for Command { } => rule.to_sexp(*ruleset, *name), Command::RunSchedule(sched) => list!("run-schedule", sched), Command::Calc(args, exprs) => list!("calc", list!(++ args), ++ exprs), - Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact), + Command::Extract { variants, fact } => { + list!("query-extract", ":variants", variants, fact) + } Command::Check(facts) => list!("check", ++ facts), Command::CheckProof => list!("check-proof"), Command::Push(n) => list!("push", n), diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index e6ecb115..f2d1b533 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -133,7 +133,7 @@ Schema: Schema = { > => Schema { input: types, output } } -Expr: Expr = { +pub Expr: Expr = { => Expr::Lit(<>), => Expr::Var(<>), => <>, diff --git a/src/extract.rs b/src/extract.rs index 5d6a9298..d98f2328 100644 --- a/src/extract.rs +++ b/src/extract.rs @@ -120,7 +120,7 @@ impl<'a> Extractor<'a> { children.push(self.find_best(*value, termdag, arcsort)?.1) } - Some(termdag.make(node.sym, children)) + Some(termdag.app(node.sym, children)) } pub fn find_best( @@ -172,7 +172,7 @@ impl<'a> Extractor<'a> { if let Some((term_inputs, new_cost)) = self.node_total_cost(func, inputs, termdag) { - let make_new_pair = || (new_cost, termdag.make(sym, term_inputs)); + let make_new_pair = || (new_cost, termdag.app(sym, term_inputs)); let id = self.find(&output.value); match self.costs.entry(id) { diff --git a/src/function/mod.rs b/src/function/mod.rs index 1e5c9f76..93c32c8c 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -69,13 +69,13 @@ impl Function { pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result { let mut input = Vec::with_capacity(decl.schema.input.len()); for s in &decl.schema.input { - input.push(match egraph.proof_state.type_info.sorts.get(s) { + input.push(match egraph.type_info().sorts.get(s) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(*s))), }) } - let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) { + let output = match egraph.type_info().sorts.get(&decl.schema.output) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))), }; diff --git a/src/lib.rs b/src/lib.rs index 6ea8e453..e4c23dc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ pub mod ast; mod extract; mod function; mod gj; -mod proofs; mod serialize; pub mod sort; mod termdag; @@ -12,17 +11,16 @@ mod unionfind; pub mod util; mod value; +use ast::desugar::Desugar; use extract::Extractor; use hashbrown::hash_map::Entry; use index::ColumnIndex; use instant::{Duration, Instant}; pub use serialize::SerializeConfig; use sort::*; -pub use termdag::{Term, TermDag}; +pub use termdag::{Term, TermDag, TermId}; use thiserror::Error; -use proofs::ProofState; - use symbolic_expressions::Sexp; use ast::*; @@ -73,11 +71,11 @@ pub enum ExtractReport { Best { termdag: TermDag, cost: usize, - expr: Term, + term: Term, }, Variants { termdag: TermDag, - variants: Vec, + terms: Vec, }, } @@ -207,7 +205,7 @@ impl FromStr for CompilerPassStop { pub struct EGraph { egraphs: Vec, unionfind: UnionFind, - pub(crate) proof_state: ProofState, + pub(crate) desugar: Desugar, functions: HashMap, rulesets: HashMap>, ruleset_iteration: HashMap, @@ -246,7 +244,7 @@ impl Default for EGraph { functions: Default::default(), rulesets: Default::default(), ruleset_iteration: Default::default(), - proof_state: ProofState::default(), + desugar: Desugar::default(), global_bindings: Default::default(), match_limit: usize::MAX, node_limit: usize::MAX, @@ -486,10 +484,10 @@ impl EGraph { pub fn eval_lit(&self, lit: &Literal) -> Value { match lit { - Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(), + Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(), + Literal::F64(f) => f.store(&self.type_info().get_sort()).unwrap(), + Literal::String(s) => s.store(&self.type_info().get_sort()).unwrap(), + Literal::Unit => ().store(&self.type_info().get_sort()).unwrap(), } } @@ -528,7 +526,7 @@ impl EGraph { } else { termdag.expr_to_term(&schema.output.make_expr(self, out.value).1) }; - terms.push((termdag.make(sym, children), out)); + terms.push((termdag.app(sym, children), out)); } drop(extractor); @@ -998,7 +996,7 @@ impl EGraph { } NormAction::LetLit(var, lit) => { let value = self.eval_lit(lit); - let etype = self.proof_state.type_info.infer_literal(lit); + let etype = self.type_info().infer_literal(lit); let present = self .global_bindings .insert(*var, (etype, value, self.timestamp)); @@ -1105,9 +1103,9 @@ impl EGraph { let mut termdag = TermDag::default(); for expr in exprs { let (t, value) = self.eval_expr(&expr, None, true)?; - let expr = self.extract(value, &mut termdag, &t).1; + let term = self.extract(value, &mut termdag, &t).1; use std::io::Write; - writeln!(f, "{}", termdag.to_string(&expr)) + writeln!(f, "{}", termdag.to_string(&term)) .map_err(|e| Error::IoError(filename.clone(), e))?; } @@ -1151,7 +1149,7 @@ impl EGraph { } pub fn set_underscores_for_desugaring(&mut self, underscores: usize) { - self.proof_state.desugar.number_underscores = underscores; + self.desugar.number_underscores = underscores; } fn process_command( @@ -1159,25 +1157,23 @@ impl EGraph { command: Command, stop: CompilerPassStop, ) -> Result, Error> { - let program = self.proof_state.desugar.desugar_program( - vec![command], - self.test_proofs, - self.seminaive, - )?; + let program = + self.desugar + .desugar_program(vec![command], self.test_proofs, self.seminaive)?; if stop == CompilerPassStop::Desugar { return Ok(program); } - let type_info_before = self.proof_state.type_info.clone(); + let type_info_before = self.type_info().clone(); - self.proof_state.type_info.typecheck_program(&program)?; + self.desugar.type_info.typecheck_program(&program)?; if stop == CompilerPassStop::TypecheckDesugared { return Ok(program); } // reset type info - self.proof_state.type_info = type_info_before; - self.proof_state.type_info.typecheck_program(&program)?; + self.desugar.type_info = type_info_before; + self.desugar.type_info.typecheck_program(&program)?; if stop == CompilerPassStop::TypecheckTermEncoding { return Ok(program); } @@ -1211,11 +1207,11 @@ impl EGraph { } pub fn parse_program(&self, input: &str) -> Result, Error> { - self.proof_state.desugar.parse_program(input) + self.desugar.parse_program(input) } pub fn parse_and_run_program(&mut self, input: &str) -> Result, Error> { - let parsed = self.proof_state.desugar.parse_program(input)?; + let parsed = self.desugar.parse_program(input)?; self.run_program(parsed) } @@ -1224,11 +1220,11 @@ impl EGraph { } pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> { - self.proof_state.type_info.sorts.get(&value.tag) + self.type_info().sorts.get(&value.tag) } pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> { - self.proof_state.type_info.add_arcsort(arcsort) + self.desugar.type_info.add_arcsort(arcsort) } /// Gets the last extract report and returns it, if the last command saved it. @@ -1256,6 +1252,10 @@ impl EGraph { self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty()); std::mem::take(&mut self.msgs) } + + pub(crate) fn type_info(&self) -> &TypeInfo { + &self.desugar.type_info + } } #[derive(Debug, Error)] diff --git a/src/proofs.rs b/src/proofs.rs deleted file mode 100644 index beb2b13a..00000000 --- a/src/proofs.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::*; - -use crate::ast::desugar::Desugar; - -pub const RULE_PROOF_KEYWORD: &str = "rule-proof"; - -#[derive(Default, Clone)] -pub(crate) struct ProofState { - pub(crate) desugar: Desugar, - pub(crate) type_info: TypeInfo, -} diff --git a/src/serialize.rs b/src/serialize.rs index b4e9f886..fa75f583 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -184,7 +184,7 @@ impl EGraph { /// /// Checks for pattern created by Desugar.get_fresh fn is_temp_name(&self, name: String) -> bool { - let number_underscores = self.proof_state.desugar.number_underscores; + let number_underscores = self.desugar.number_underscores; let res = name.starts_with('v') && name.ends_with("_".repeat(number_underscores).as_str()) && name[1..name.len() - number_underscores] diff --git a/src/termdag.rs b/src/termdag.rs index 47594d13..59cf791a 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -4,14 +4,17 @@ use crate::{ Symbol, }; +pub type TermId = usize; + /// Like [`Expr`]s but with sharing and deduplication. /// -/// Terms refer to their children indirectly as indexes into an ambient [`TermDag`]. +/// Terms refer to their children indirectly via opaque [TermId]s (internally +/// these are just `usize`s) that map into an ambient [`TermDag`]. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum Term { Lit(Literal), Var(Symbol), - App(Symbol, Vec), + App(Symbol, Vec), } /// A hashconsing arena for [`Term`]s. @@ -24,7 +27,7 @@ pub struct TermDag { // - every element of node is a key in hashcons // - every key of hashcons is in nodes pub nodes: Vec, - pub hashcons: HashMap, + pub hashcons: HashMap, } #[macro_export] @@ -56,25 +59,25 @@ impl TermDag { self.nodes.len() } - /// Convert the given term to its index. + /// Convert the given term to its id. /// /// Panics if the term does not already exist in this [TermDag]. - pub fn lookup(&self, node: &Term) -> usize { + pub fn lookup(&self, node: &Term) -> TermId { *self.hashcons.get(node).unwrap() } - /// Convert the given index to the corresponding term. + /// Convert the given id to the corresponding term. /// - /// Panics if the index is not valid. - pub fn get(&self, idx: usize) -> Term { - self.nodes[idx].clone() + /// Panics if the id is not valid. + pub fn get(&self, id: TermId) -> Term { + self.nodes[id].clone() } /// Make and return a [`Term::App`] with the given head symbol and children, /// and insert into the DAG if it is not already present. /// /// Panics if any of the children are not already in the DAG. - pub fn make(&mut self, sym: Symbol, children: Vec) -> Term { + pub fn app(&mut self, sym: Symbol, children: Vec) -> Term { let node = Term::App(sym, children.iter().map(|c| self.lookup(c)).collect()); self.add_node(&node); @@ -82,6 +85,26 @@ impl TermDag { node } + /// Make and return a [`Term::Lit`] with the given literal, and insert into + /// the DAG if it is not already present. + pub fn lit(&mut self, lit: Literal) -> Term { + let node = Term::Lit(lit); + + self.add_node(&node); + + node + } + + /// Make and return a [`Term::Var`] with the given symbol, and insert into + /// the DAG if it is not already present. + pub fn var(&mut self, sym: Symbol) -> Term { + let node = Term::Var(sym); + + self.add_node(&node); + + node + } + fn add_node(&mut self, node: &Term) { if self.hashcons.get(node).is_none() { let idx = self.nodes.len(); @@ -138,8 +161,8 @@ impl TermDag { /// /// Panics if the term or any of its subterms are not in the DAG. pub fn to_string(&self, term: &Term) -> String { - let mut stored = HashMap::::default(); - let mut seen = HashSet::::default(); + let mut stored = HashMap::::default(); + let mut seen = HashSet::::default(); let id = self.lookup(term); // use a stack to avoid stack overflow let mut stack = vec![id]; @@ -174,3 +197,78 @@ impl TermDag { stored.get(&id).unwrap().clone() } } + +#[cfg(test)] +mod tests { + use crate::ast; + + use super::*; + + fn parse_term(s: &str) -> (TermDag, Term) { + let e = crate::ast::parse_expr(s).unwrap(); + let mut td = TermDag::default(); + let t = td.expr_to_term(&e); + (td, t) + } + + #[test] + fn test_to_from_expr() { + let s = r#"(f (g x y) x y (g x y))"#; + let e = crate::ast::parse_expr(s).unwrap(); + let mut td = TermDag::default(); + assert_eq!(td.size(), 0); + let t = td.expr_to_term(&e); + assert_eq!(td.size(), 4); + // the expression above has 4 distinct subterms. + // in left-to-right, depth-first order, they are: + // x, y, (g x y), and the root call to f + // so we can compute expected answer by hand: + assert_eq!( + td.nodes, + vec![ + Term::Var("x".into()), + Term::Var("y".into()), + Term::App("g".into(), vec![0, 1]), + Term::App("f".into(), vec![2, 0, 1, 2]), + ] + ); + let e2 = td.term_to_expr(&t); + assert_eq!(e, e2); // roundtrip + } + + #[test] + fn test_match_term_app() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + match_term_app!(t; { + ("f", [_, x, _, _]) => + assert_eq!(td.term_to_expr(&td.get(*x)), ast::Expr::Var(Symbol::new("x"))) + }) + } + + #[test] + fn test_to_string() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + assert_eq!(td.to_string(&t), s); + } + + #[test] + fn test_lookup() { + let s = r#"(f (g x y) x y (g x y))"#; + let (td, t) = parse_term(s); + assert_eq!(td.lookup(&t), td.size() - 1); + } + + #[test] + fn test_app_var_lit() { + let s = r#"(f (g x y) x 7 (g x y))"#; + let (mut td, t) = parse_term(s); + let x = td.var("x".into()); + let y = td.var("y".into()); + let seven = td.lit(7.into()); + let g = td.app("g".into(), vec![x.clone(), y.clone()]); + let t2 = td.app("f".into(), vec![g.clone(), x, seven, g]); + assert_eq!(t, t2); + } +} diff --git a/src/typecheck.rs b/src/typecheck.rs index 5e4d91c9..9b294ea9 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -112,7 +112,7 @@ impl<'a> Context<'a> { pub fn new(egraph: &'a EGraph) -> Self { Self { egraph, - unit: egraph.proof_state.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(), + unit: egraph.type_info().sorts[&Symbol::from(UNIT_SYM)].clone(), types: Default::default(), errors: Vec::default(), unionfind: UnionFind::default(), @@ -396,7 +396,7 @@ impl<'a> Context<'a> { (self.add_node(ENode::Var(*sym)), ty) } Expr::Lit(lit) => { - let t = self.egraph.proof_state.type_info.infer_literal(lit); + let t = self.egraph.type_info().infer_literal(lit); (self.add_node(ENode::Literal(lit.clone())), Some(t)) } Expr::Call(sym, args) => { @@ -415,7 +415,7 @@ impl<'a> Context<'a> { .collect(); let t = f.schema.output.clone(); (self.add_node(ENode::Func(*sym, ids)), Some(t)) - } else if let Some(prims) = self.egraph.proof_state.type_info.primitives.get(sym) { + } else if let Some(prims) = self.egraph.type_info().primitives.get(sym) { let (ids, arg_tys): (Vec, Vec>) = args.iter().map(|arg| self.infer_query_expr(arg)).unzip(); @@ -533,13 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> { } fn do_function(&mut self, f: Symbol, _args: Vec) -> Self::T { - let func_type = self - .egraph - .proof_state - .type_info - .func_types - .get(&f) - .unwrap(); + let func_type = self.egraph.type_info().func_types.get(&f).unwrap(); self.instructions.push(Instruction::CallFunction( f, func_type.has_default || !func_type.has_merge, @@ -601,11 +595,11 @@ trait ExprChecker<'a> { match expr { Expr::Lit(lit) => { let t = self.do_lit(lit); - Ok((t, self.egraph().proof_state.type_info.infer_literal(lit))) + Ok((t, self.egraph().type_info().infer_literal(lit))) } Expr::Var(sym) => self.infer_var(*sym), Expr::Call(sym, args) => { - if let Some(functype) = self.egraph().proof_state.type_info.func_types.get(sym) { + if let Some(functype) = self.egraph().type_info().func_types.get(sym) { assert!(functype.input.len() == args.len()); let mut ts = vec![]; @@ -615,8 +609,7 @@ trait ExprChecker<'a> { let t = self.do_function(*sym, ts); Ok((t, functype.output.clone())) - } else if let Some(prims) = self.egraph().proof_state.type_info.primitives.get(sym) - { + } else if let Some(prims) = self.egraph().type_info().primitives.get(sym) { let mut ts = Vec::with_capacity(args.len()); let mut tys = Vec::with_capacity(args.len()); for arg in args { @@ -877,44 +870,37 @@ impl EGraph { let variants = values[1].bits as i64; if variants == 0 { - let (cost, expr) = self.extract( + let (cost, term) = self.extract( values[0], &mut termdag, - self.proof_state - .type_info - .sorts - .get(&values[0].tag) - .unwrap(), + self.type_info().sorts.get(&values[0].tag).unwrap(), ); - let extracted = termdag.to_string(&expr); + let extracted = termdag.to_string(&term); log::info!("extracted with cost {cost}: {}", extracted); self.print_msg(extracted); self.extract_report = Some(ExtractReport::Best { termdag, cost, - expr, + term, }); } else { if variants < 0 { panic!("Cannot extract negative number of variants"); } - let extracted = + let terms = self.extract_variants(values[0], variants as usize, &mut termdag); log::info!("extracted variants:"); let mut msg = String::default(); msg += "(\n"; - assert!(!extracted.is_empty()); - for expr in &extracted { + assert!(!terms.is_empty()); + for expr in &terms { let str = termdag.to_string(expr); log::info!(" {}", str); msg += &format!(" {}\n", str); } msg += ")"; self.print_msg(msg); - self.extract_report = Some(ExtractReport::Variants { - termdag, - variants: extracted, - }); + self.extract_report = Some(ExtractReport::Variants { termdag, terms }); } stack.truncate(new_len); diff --git a/src/typechecking.rs b/src/typechecking.rs index c8f048b6..0e1fa00c 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -1,4 +1,6 @@ -use crate::{proofs::RULE_PROOF_KEYWORD, *}; +use crate::*; + +pub const RULE_PROOF_KEYWORD: &str = "rule-proof"; #[derive(Clone, Debug)] pub struct FuncType { @@ -630,7 +632,7 @@ pub enum TypeError { #[error("Arity mismatch, expected {expected} args: {expr}")] Arity { expr: Expr, expected: usize }, #[error( - "Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}", + "Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}", .expected.name(), .actual.name(), )] Mismatch { diff --git a/tests/matrix.egg b/tests/matrix.egg index 4dc038e6..2131a1f4 100644 --- a/tests/matrix.egg +++ b/tests/matrix.egg @@ -6,9 +6,9 @@ (rewrite (Times (Lit i) (Lit j)) (Lit (* i j))) (rewrite (Times a b) (Times b a)) -(datatype MExpr - (MMul MExpr MExpr) - (Kron MExpr MExpr) +(datatype MExpr + (MMul MExpr MExpr) + (Kron MExpr MExpr) (NamedMat String) (Id Dim) ; DSum @@ -17,7 +17,7 @@ ; Transpose ; Inverse ; Zero Math Math - ; ScalarMul + ; ScalarMul ) ; alternative encoding (type A) = (Matrix n m) may be more useful for "large story example" @@ -45,26 +45,26 @@ (rewrite (Kron (MMul A C) (MMul B D)) (MMul (Kron A B) (Kron C D))) -(rewrite (MMul (Kron A B) (Kron C D)) +(rewrite (MMul (Kron A B) (Kron C D)) (Kron (MMul A C) (MMul B D)) - :when + :when ((= (ncols A) (nrows C)) (= (ncols B) (nrows D))) ) ; demand (rule ((= e (MMul A B))) -((let demand1 (ncols A)) -(let demand2 (nrows A)) -(let demand3 (ncols B)) -(let demand4 (nrows B))) +((ncols A) +(nrows A) +(ncols B) +(nrows B)) ) (rule ((= e (Kron A B))) -((let demand1 (ncols A)) -(let demand2 (nrows A)) -(let demand3 (ncols B)) -(let demand4 (nrows B))) +((ncols A) +(nrows A) +(ncols B) +(nrows B)) ) diff --git a/tests/terms.rs b/tests/terms.rs new file mode 100644 index 00000000..dc810a24 --- /dev/null +++ b/tests/terms.rs @@ -0,0 +1,33 @@ +use egglog::*; + +// This file tests the public API to terms. + +#[test] +fn test_termdag_public() { + let mut td = TermDag::default(); + let x = td.var("x".into()); + let seven = td.lit(7.into()); + let f = td.app("f".into(), vec![x, seven]); + assert_eq!(td.to_string(&f), "(f x 7)"); +} + +#[test] +#[should_panic] +fn test_termdag_malicious_client() { + // here is an example of how TermIds can be misused by passing + // them into the wrong DAG. + + let mut td = TermDag::default(); + let x = td.var("x".into()); + // at this point, td = [0 |-> x] + // snapshot the current td + let td2 = td.clone(); + let y = td.var("y".into()); + // now td = [0 |-> x, 1 |-> y] + let f = td.app("f".into(), vec![x.clone(), y.clone()]); + // f is Term::App("f", [0, 1]) + assert_eq!(td.to_string(&f), "(f x y)"); + // recall that td2 = [0 |-> x] + // notice that f refers to index 1, so this crashes: + td2.to_string(&f); +}