diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 1f98a95ff..bf6abcc8c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -953,8 +953,7 @@ impl NormRule { } pub fn resugar_actions(&self, subst: &mut HashMap) -> Vec { - // TODO doesn't work because re-ordering actions can be bad - /*let mut used = HashSet::::default(); + let mut used = HashSet::::default(); let mut head = Vec::::default(); for a in &self.head { match a { @@ -1046,8 +1045,8 @@ impl NormRule { Expr::Call(..) => head.push(Action::Expr(expr.clone())), }; } - }*/ - self.head.iter().map(|a| a.to_action()).collect() + } + head } pub fn resugar(&self) -> Rule { diff --git a/src/gj.rs b/src/gj.rs index 3b2cdc562..dec544b9a 100644 --- a/src/gj.rs +++ b/src/gj.rs @@ -5,7 +5,7 @@ use smallvec::SmallVec; use crate::{ function::index::Offset, - typecheck::{Atom, AtomTerm, Filter, Query}, + typecheck::{Atom, AtomTerm, Query}, *, }; use std::{ @@ -14,14 +14,12 @@ use std::{ ops::Range, }; -#[derive(Clone)] enum Instr<'a> { Intersect { value_idx: usize, variable_name: Symbol, info: VarInfo2, trie_accesses: Vec<(usize, TrieAccess<'a>)>, - check: bool, // check or assign to output variable }, ConstrainConstant { index: usize, @@ -29,7 +27,7 @@ enum Instr<'a> { trie_access: TrieAccess<'a>, }, Call { - prim: Filter, + prim: Primitive, args: Vec, check: bool, // check or assign to output variable }, @@ -70,46 +68,37 @@ type Result = std::result::Result<(), ()>; struct Program<'a>(Vec>); -impl<'a> std::fmt::Display for Instr<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Instr::Intersect { - value_idx, - trie_accesses, - variable_name, - info, - check, - } => { - let name = if *check { "Check " } else { "Intersect" }; - write!( - f, - " {name} @ {value_idx} sg={sg:6} {variable_name:15}", - sg = info.size_guess - )?; - for (trie_idx, a) in trie_accesses { - write!(f, " {}: {}", trie_idx, a)?; - } - writeln!(f)? - } - Instr::ConstrainConstant { - index, - val, - trie_access, - } => { - writeln!(f, " ConstrainConstant {index} {trie_access} = {val:?}")?; - } - Instr::Call { prim, args, check } => { - writeln!(f, " Call {:?} {:?} {:?}", prim, args, check)?; - } - } - Ok(()) - } -} - impl<'a> std::fmt::Display for Program<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for instr in &self.0 { - write!(f, "{}", instr)?; + match instr { + Instr::Intersect { + value_idx, + trie_accesses, + variable_name, + info, + } => { + write!( + f, + " Intersect @ {value_idx} sg={sg:6} {variable_name:15}", + sg = info.size_guess + )?; + for (trie_idx, a) in trie_accesses { + write!(f, " {}: {}", trie_idx, a)?; + } + writeln!(f)? + } + Instr::ConstrainConstant { + index, + val, + trie_access, + } => { + writeln!(f, " ConstrainConstant {index} {trie_access} = {val:?}")?; + } + Instr::Call { prim, args, check } => { + writeln!(f, " Call {:?} {:?} {:?}", prim, args, check)?; + } + } } Ok(()) } @@ -175,7 +164,6 @@ impl<'b> Context<'b> { Instr::Intersect { value_idx, trie_accesses, - check, .. } => { if let Some(x) = trie_accesses @@ -185,20 +173,6 @@ impl<'b> Context<'b> { { stage.add_measurement(x); } - - if *check { - let mut new_tries = tries.to_vec(); - let value = self.tuple[*value_idx]; - for (j, access) in trie_accesses { - let Some(t) = tries[*j].get(access, value) else { - return Ok(()); - }; - new_tries[*j] = t; - } - return self.eval(&mut new_tries, program, stage.next(), f); - } - - assert!(!*check); match trie_accesses.as_slice() { [(j, access)] => tries[*j].for_each(access, |value, trie| { let old_trie = std::mem::replace(&mut tries[*j], trie); @@ -265,27 +239,13 @@ impl<'b> Context<'b> { }) } - let result = match prim { - Filter::Primitive(p) => p.apply(&values, self.egraph), - Filter::Function(f) => { - let func_table = self.egraph.functions.get(f).unwrap(); - // TODO check the timestamp - func_table.nodes.get(&values).map(|output| output.value) - } - }; - - if let Some(res) = result { + if let Some(res) = prim.apply(&values, self.egraph) { match out { AtomTerm::Var(v) => { let i = self.query.vars.get_index_of(v).unwrap(); - - if *check { - assert_ne!(self.tuple[i], Value::fake()); - if self.tuple[i] != res { - return Ok(()); - } + if *check && self.tuple[i] != res { + return Ok(()); } - self.tuple[i] = res; } AtomTerm::Value(val) => { @@ -304,7 +264,6 @@ impl<'b> Context<'b> { } } -#[derive(Clone)] enum Constraint { Eq(usize, usize), Const(usize, Value), @@ -353,31 +312,19 @@ impl EGraph { vars.entry(*var).or_default(); } - let mut atom_filters = vec![]; for (i, atom) in query.atoms.iter().enumerate() { - if atom.head.as_str().contains("_Parent_") { - atom_filters.push(atom.clone()); - } else { - for v in atom.vars() { - // only count grounded occurrences - vars.entry(v).or_default().occurences.push(i) - } + for v in atom.vars() { + // only count grounded occurrences + vars.entry(v).or_default().occurences.push(i) } } - // make sure everyone has an entry in the vars table for prim in &query.filters { for v in prim.vars() { vars.entry(v).or_default(); } } - for atom_filter in atom_filters { - for v in atom_filter.vars() { - vars.entry(v).or_default(); - } - } - CompiledQuery { query, vars } } @@ -539,13 +486,12 @@ impl EGraph { (atom_idx, access) }) .collect(), - check: false, } }); program.extend(var_instrs); // now we can try to add primitives - let mut calls = vec![]; + // TODO this is very inefficient, since primitives all at the end let mut extra = query.query.filters.clone(); while !extra.is_empty() { let next = extra.iter().position(|p| { @@ -568,137 +514,13 @@ impl EGraph { }, AtomTerm::Value(_) => true, }; - calls.push(Instr::Call { + program.push(Instr::Call { prim: p.head.clone(), args: p.args.clone(), check, }); } else { - panic!("cycle {:#?}", query) - } - } - - if true { - // now we have to actually place them in the program, as high as they can go. - // note, the calls should be topo sorted already at this point - 'call_loop: for mut call in calls { - let mut bound_symbols: HashSet = Default::default(); - let arg_symbols: HashSet; - let last_arg: Option; - let Instr::Call { - args, - check, - .. - } = &mut call else { - panic!("Should be a call at this point"); - }; - - if let Some((last, args)) = args.split_last() { - arg_symbols = args - .iter() - .filter_map(|a| match a { - AtomTerm::Var(v) => Some(*v), - AtomTerm::Value(_) => None, - }) - .collect(); - if let AtomTerm::Var(v) = last { - last_arg = Some(*v); - } else { - last_arg = None - } - } else { - panic!("Zero-arg primitive not supported"); - } - - for (position, instr) in program.iter().enumerate() { - match instr { - Instr::Intersect { variable_name, .. } => { - bound_symbols.insert(*variable_name); - } - Instr::Call { args, .. } => { - if let Some(AtomTerm::Var(v)) = args.last() { - bound_symbols.insert(*v); - } - } - _ => (), - } - if arg_symbols.is_subset(&bound_symbols) { - program.insert(position + 1, call); - continue 'call_loop; - } - } - *check = if let Some(last_var) = last_arg { - bound_symbols.contains(&last_var) - } else { - true - }; - program.push(call); - } - } else { - program.extend(calls); - } - - // now some intersections might already have been bound - // we need to replace these with constrain constant - let mut bound_symbols = HashSet::default(); - for instr in &mut program { - match instr { - Instr::Intersect { - variable_name, - check, - .. - } => { - *check = !bound_symbols.insert(*variable_name); - } - Instr::Call { args, check, .. } => { - if let Some(AtomTerm::Var(variable_name)) = args.last() { - *check = !bound_symbols.insert(*variable_name); - } else { - *check = true - } - } - _ => (), - } - } - - // sanity check the program - let mut tuple_valid = vec![false; query.vars.len()]; - for instr in &program { - match instr { - Instr::Intersect { - value_idx, check, .. - } => { - assert_eq!(*check, tuple_valid[*value_idx]); - if !*check { - tuple_valid[*value_idx] = true; - } - } - Instr::ConstrainConstant { .. } => {} - Instr::Call { check, args, .. } => { - let Some((last, args)) = args.split_last() else { - continue - }; - - for a in args { - if let AtomTerm::Var(v) = a { - let i = query.vars.get_index_of(v).unwrap(); - assert!(tuple_valid[i]); - } - } - - match last { - AtomTerm::Var(v) => { - let i = query.vars.get_index_of(v).unwrap(); - assert_eq!(*check, tuple_valid[i], "{instr}"); - if !*check { - tuple_valid[i] = true; - } - } - AtomTerm::Value(_) => { - assert!(*check); - } - } - } + panic!("cycle") } } @@ -716,7 +538,7 @@ impl EGraph { let has_atoms = !cq.query.atoms.is_empty(); if has_atoms { - let do_seminaive = false; + let do_seminaive = self.seminaive; // for the later atoms, we consider everything let mut timestamp_ranges = vec![0..u32::MAX; cq.query.atoms.len()]; for (atom_i, atom) in cq.query.atoms.iter().enumerate() { @@ -919,7 +741,6 @@ impl LazyTrie { } } -#[derive(Clone)] struct TrieAccess<'a> { function: &'a Function, timestamp_range: Range, diff --git a/src/lib.rs b/src/lib.rs index 39f6786d2..00dbdaa70 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -633,6 +633,7 @@ impl EGraph { let _ = self.run_actions(stack, &[], &rule.program, true); } else { for values in all_values.chunks(num_vars) { + eprintln!("running actions for rule {} with values {:?}", name, values); rule.matches += 1; // we can ignore results here stack.clear(); @@ -646,6 +647,7 @@ impl EGraph { let apply_elapsed = apply_start.elapsed(); report.apply_time += apply_elapsed; report.updated |= self.did_change(); + eprintln!("updated: {}", report.updated); report } diff --git a/src/typecheck.rs b/src/typecheck.rs index f104ad032..7e64105af 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -49,13 +49,7 @@ impl std::fmt::Display for Atom { #[derive(Default, Debug, Clone)] pub struct Query { pub atoms: Vec>, - pub filters: Vec>, -} - -#[derive(Debug, Clone)] -pub enum Filter { - Primitive(Primitive), - Function(Symbol), + pub filters: Vec>, } impl std::fmt::Display for Query { @@ -66,14 +60,12 @@ impl std::fmt::Display for Query { if !self.filters.is_empty() { writeln!(f, "where ")?; for filter in &self.filters { - match &filter.head { - Filter::Primitive(p) => { - writeln!(f, "({} {})", p.name(), ListDisplay(&filter.args, " "))?; - } - Filter::Function(fun) => { - writeln!(f, "({fun} {})", ListDisplay(&filter.args, " "))?; - } - } + writeln!( + f, + "({} {})", + filter.head.name(), + ListDisplay(&filter.args, " ") + )?; } } Ok(()) @@ -219,25 +211,6 @@ impl<'a> Context<'a> { // Now we can fill in the nodes with the canonical leaves for (node, id) in &self.nodes { match node { - // ENode::Func(f, ids) if f.as_str().contains("_Parent_") => { - // let mut args = vec![]; - // for child in ids { - // let leaf = get_leaf(child); - // if let AtomTerm::Var(v) = leaf { - // if self.egraph.global_bindings.contains_key(&v) { - // args.push(AtomTerm::Value(self.egraph.global_bindings[&v].1)); - // continue; - // } - // } - // args.push(get_leaf(child)); - // query_eclasses.insert(*child); - // } - // args.push(get_leaf(id)); - // query.filters.push(Atom { - // head: Filter::Function(*f), - // args, - // }); - // } ENode::Func(f, ids) => { let args = ids.iter().chain([id]).map(get_leaf).collect(); for id in ids { @@ -260,7 +233,7 @@ impl<'a> Context<'a> { } args.push(get_leaf(id)); query.filters.push(Atom { - head: Filter::Primitive(p.clone()), + head: p.clone(), args, }); } @@ -282,7 +255,7 @@ impl<'a> Context<'a> { // we actually know the query won't fire if canon_value != *value { query.filters.push(Atom { - head: Filter::Primitive(Primitive(Arc::new(ValueEq {}))), + head: Primitive(Arc::new(ValueEq {})), args: vec![ AtomTerm::Value(canon_value), AtomTerm::Value(*value), @@ -294,22 +267,6 @@ impl<'a> Context<'a> { } } - if query.atoms.len() > 2 { - // move the parent "atoms" to the filters - query.atoms.retain(|atom| { - let f = atom.head; - if f.as_str().contains("_Parent_") { - query.filters.push(Atom { - head: Filter::Function(f), - args: atom.args.clone(), - }); - false - } else { - true - } - }); - } - if self.errors.is_empty() { Ok((query, res_actions)) } else { @@ -865,6 +822,7 @@ impl EGraph { let old_value = function.get(args); if let Some(old_value) = old_value { + eprintln!("old value: {:?}, new value: {:?}", old_value, new_value); if new_value != old_value { let merged: Value = match function.merge.merge_vals.clone() { MergeFn::AssertEq => { @@ -883,20 +841,20 @@ impl EGraph { } }; if merged != old_value { - let args = &stack[new_len..]; - let function = self.functions.get_mut(f).unwrap(); - function.insert(args, merged, self.timestamp); - } - // re-borrow + let args = &stack[new_len..]; let function = self.functions.get_mut(f).unwrap(); - if let Some(prog) = function.merge.on_merge.clone() { - let values = [old_value, new_value]; - // XXX: we get an error if we pass the current - // stack and then truncate it to the old length. - // Why? - self.run_actions(&mut Vec::new(), &values, &prog, true)?; + function.insert(args, merged, self.timestamp); + } + // re-borrow + let function = self.functions.get_mut(f).unwrap(); + if let Some(prog) = function.merge.on_merge.clone() { + let values = [old_value, new_value]; + // XXX: we get an error if we pass the current + // stack and then truncate it to the old length. + // Why? + self.run_actions(&mut Vec::new(), &values, &prog, true)?; + } } - } } else { function.insert(args, new_value, self.timestamp); } diff --git a/tests/files.rs b/tests/files.rs index 3ccc8d26e..35e0ecceb 100644 --- a/tests/files.rs +++ b/tests/files.rs @@ -41,7 +41,10 @@ impl Run { self.test_program( &desugared_str, - "ERROR after parse, to_string, and parse again.", + &format!( + "Program:\n{}\n ERROR after parse, to_string, and parse again.", + desugared_str + ), ); } } diff --git a/tests/levenshtein-distance.egg b/tests/levenshtein-distance.egg deleted file mode 100644 index cd45a8bbc..000000000 --- a/tests/levenshtein-distance.egg +++ /dev/null @@ -1,70 +0,0 @@ -; Datatypes - -(datatype expr - (Num i64) - (Add expr expr) - (Min expr expr expr)) -(rewrite (Add (Num a) (Num b)) (Num (+ a b))) -(rewrite (Min (Num a) (Num b) (Num c)) (Num (min (min a b) c))) - -; `String` supports limited operations, let's just use it as a char type -(datatype str - (Cons String str)) -(declare Empty str) - -; Length function - -(function Length (str) expr) - -(rule ((= f (Length Empty))) - ((set (Length Empty) (Num 0)))) - -(rule ((= f (Length (Cons c cs)))) - ((set (Length (Cons c cs)) (Add (Num 1) (Length cs))))) - -; EditDist function - -(function EditDist (str str) expr) - -(rule ((= f (EditDist Empty s))) - ((set (EditDist Empty s) (Length s)))) - -(rule ((= f (EditDist s Empty))) - ((set (EditDist s Empty) (Length s)))) - -(rule ((= f (EditDist (Cons head rest1) (Cons head rest2)))) - ((set (EditDist (Cons head rest1) (Cons head rest2)) - (EditDist rest1 rest2)))) - -(rule ((= f (EditDist (Cons head1 rest1) (Cons head2 rest2))) (!= head1 head2)) - ((set (EditDist (Cons head1 rest1) (Cons head2 rest2)) - (Add (Num 1) - (Min (EditDist rest1 rest2) - (EditDist (Cons head1 rest1) rest2) - (EditDist rest1 (Cons head2 rest2))))))) - -; Unwrap function - turn a (Num n) into n - -(function Unwrap (expr) i64) -(rule ((= x (Num n))) ((set (Unwrap (Num n)) n))) - -; Tests -(let HorseStr (Cons "h" (Cons "o" (Cons "r" (Cons "s" (Cons "e" Empty)))))) -(let RosStr (Cons "r" (Cons "o" (Cons "s" Empty)))) -(let IntentionStr (Cons "i" (Cons "n" (Cons "t" (Cons "e" (Cons "n" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) -(let ExecutionStr (Cons "e" (Cons "x" (Cons "e" (Cons "c" (Cons "u" (Cons "t" (Cons "i" (Cons "o" (Cons "n" Empty)))))))))) - -(let Test1 (EditDist HorseStr RosStr)) -(let Test2 (EditDist IntentionStr ExecutionStr)) -(let Test3 (EditDist HorseStr Empty)) - -(run 100) - -(extract (Unwrap Test1)) -(check (= Test1 (Num 3))) - -(extract (Unwrap Test2)) -(check (= Test2 (Num 5))) - -(extract (Unwrap Test3)) -(check (= Test3 (Num 5))) \ No newline at end of file