diff --git a/Cargo.lock b/Cargo.lock index 0c6b8e65..f223bb64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,6 +441,7 @@ dependencies = [ "getrandom", "glob", "hashbrown 0.15.0", + "im", "im-rc", "indexmap", "instant", @@ -636,6 +637,20 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "im" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0acd33ff0285af998aaf9b57342af478078f53492322fafc47450e09397e0e9" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "im-rc" version = "15.1.0" diff --git a/Cargo.toml b/Cargo.toml index a8f3f0bb..60f64fd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ ordered-float = { version = "3.7" } getrandom = { version = "0.2.10", features = ["js"], optional = true } im-rc = "15.1.0" +im = "15.1.0" [build-dependencies] diff --git a/src/core.rs b/src/core.rs index 5d3345ac..432d9680 100644 --- a/src/core.rs +++ b/src/core.rs @@ -107,8 +107,11 @@ impl ResolvedCall { } } } - - assert!(resolved_call.len() == 1); + assert!( + resolved_call.len() == 1, + "Ambiguous resolution for {:?}", + head, + ); resolved_call.pop().unwrap() } } diff --git a/src/sort/fn.rs b/src/sort/fn.rs index 05f42a92..7f84588f 100644 --- a/src/sort/fn.rs +++ b/src/sort/fn.rs @@ -48,8 +48,9 @@ impl Eq for ValueFunction {} #[derive(Debug)] pub struct FunctionSort { name: Symbol, - inputs: Vec, - output: ArcSort, + // Public so that other primitive sorts (external or internal) can find a function sort by the sorts of its inputs/output + pub inputs: Vec, + pub output: ArcSort, functions: Mutex>, } @@ -58,6 +59,25 @@ impl FunctionSort { let functions = self.functions.lock().unwrap(); functions.get_index(value.bits as usize).unwrap().clone() } + + /// Apply the function to the values + /// + /// Public so that other primitive sorts (external or internal) can use this to apply functions + pub fn apply(&self, fn_value: &Value, arg_values: &[Value], egraph: &mut EGraph) -> Value { + let ValueFunction(name, args) = self.get_value(fn_value); + let types: Vec<_> = args + .iter() + .map(|(sort, _)| sort.clone()) + .chain(self.inputs.clone()) + .chain(once(self.output.clone())) + .collect(); + let values = args + .iter() + .map(|(_, v)| *v) + .chain(arg_values.iter().cloned()) + .collect(); + call_fn(egraph, &name, types, values) + } } impl Presort for FunctionSort { @@ -368,21 +388,7 @@ impl PrimitiveLike for Apply { fn apply(&self, values: &[Value], egraph: Option<&mut EGraph>) -> Option { let egraph = egraph.expect("`unstable-app` is not supported yet in facts."); - let ValueFunction(name, args) = ValueFunction::load(&self.function, &values[0]); - let types: Vec<_> = args - .iter() - // get the sorts of partially applied args - .map(|(sort, _)| sort.clone()) - // combine with the args for the function call and then the output - .chain(self.function.inputs.clone()) - .chain(once(self.function.output.clone())) - .collect(); - let values = args - .iter() - .map(|(_, v)| *v) - .chain(values[1..].iter().copied()) - .collect(); - Some(call_fn(egraph, &name, types, values)) + Some(self.function.apply(&values[0], &values[1..], egraph)) } } diff --git a/src/sort/mod.rs b/src/sort/mod.rs index 9f361cc4..e43a1f43 100644 --- a/src/sort/mod.rs +++ b/src/sort/mod.rs @@ -24,6 +24,8 @@ mod vec; pub use vec::*; mod r#fn; pub use r#fn::*; +mod multiset; +pub use multiset::*; use crate::constraint::AllEqualTypeConstraint; use crate::extract::{Cost, Extractor}; diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs new file mode 100644 index 00000000..f1c276e0 --- /dev/null +++ b/src/sort/multiset.rs @@ -0,0 +1,533 @@ +use std::sync::Mutex; + +use inner::MultiSet; + +use super::*; +use crate::constraint::{AllEqualTypeConstraint, SimpleTypeConstraint}; + +// Place multiset in its own module to keep implementation details private from sort +mod inner { + use im::OrdMap; + use std::hash::Hash; + /// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order. + /// + /// All methods that return a new multiset take ownership of the old multiset. + #[derive(Debug, Hash, Eq, PartialEq, Clone)] + pub(crate) struct MultiSet( + /// All values should be > 0 + OrdMap, + /// cached length + usize, + ); + + impl MultiSet { + /// Create a new empty multiset. + pub(crate) fn new() -> Self { + MultiSet(OrdMap::new(), 0) + } + + /// Check if the multiset contains a key. + pub(crate) fn contains(&self, value: &T) -> bool { + self.0.contains_key(value) + } + + /// Return the total number of elements in the multiset. + pub(crate) fn len(&self) -> usize { + self.1 + } + + /// Return an iterator over all elements in the multiset. + pub(crate) fn iter(&self) -> impl Iterator { + self.0 + .iter() + .flat_map(|(k, v)| std::iter::repeat(k).take(*v)) + } + + /// Return an arbitrary element from the multiset. + pub(crate) fn pick(&self) -> Option<&T> { + self.0.keys().next() + } + + /// Map a function over all elements in the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn map(self, mut f: impl FnMut(&T) -> T) -> MultiSet { + let mut new = MultiSet::new(); + for (k, v) in self.0.into_iter() { + new.insert_multiple_mut(f(&k), v); + } + new + } + + /// Insert a value into the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn insert(mut self, value: T) -> MultiSet { + self.insert_multiple_mut(value, 1); + self + } + + /// Remove a value from the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn remove(mut self, value: &T) -> Option> { + if let Some(v) = self.0.get(value) { + self.1 -= 1; + if *v == 1 { + self.0.remove(value); + } else { + self.0.insert(value.clone(), v - 1); + } + Some(self) + } else { + None + } + } + + fn insert_multiple_mut(&mut self, value: T, n: usize) { + self.1 += n; + if let Some(v) = self.0.get(&value) { + self.0.insert(value, v + n); + } else { + self.0.insert(value, n); + } + } + + /// Create a multiset from an iterator. + pub(crate) fn from_iter(iter: impl IntoIterator) -> Self { + let mut multiset = MultiSet::new(); + for value in iter { + multiset.insert_multiple_mut(value, 1); + } + multiset + } + } +} + +type ValueMultiSet = MultiSet; + +#[derive(Debug)] +pub struct MultiSetSort { + name: Symbol, + element: ArcSort, + multisets: Mutex>, +} + +impl MultiSetSort { + pub fn element(&self) -> ArcSort { + self.element.clone() + } + + pub fn element_name(&self) -> Symbol { + self.element.name() + } +} + +impl Presort for MultiSetSort { + fn presort_name() -> Symbol { + "MultiSet".into() + } + + fn reserved_primitives() -> Vec { + vec![ + "multiset-of".into(), + "multiset-insert".into(), + "multiset-contains".into(), + "multiset-not-contains".into(), + "multiset-remove".into(), + "multiset-length".into(), + "unstable-multiset-map".into(), + ] + } + + fn make_sort( + typeinfo: &mut TypeInfo, + name: Symbol, + args: &[Expr], + ) -> Result { + if let [Expr::Var(span, e)] = args { + let e = typeinfo + .sorts + .get(e) + .ok_or(TypeError::UndefinedSort(*e, span.clone()))?; + + if e.is_eq_container_sort() { + return Err(TypeError::DisallowedSort( + name, + "Multisets nested with other EqSort containers are not allowed".into(), + span.clone(), + )); + } + + Ok(Arc::new(Self { + name, + element: e.clone(), + multisets: Default::default(), + })) + } else { + panic!() + } + } +} + +impl Sort for MultiSetSort { + fn name(&self) -> Symbol { + self.name + } + + fn as_arc_any(self: Arc) -> Arc { + self + } + + fn is_container_sort(&self) -> bool { + true + } + + fn is_eq_container_sort(&self) -> bool { + self.element.is_eq_sort() + } + + fn inner_values(&self, value: &Value) -> Vec<(ArcSort, Value)> { + let multisets = self.multisets.lock().unwrap(); + let multiset = multisets.get_index(value.bits as usize).unwrap(); + multiset + .iter() + .map(|k| (self.element.clone(), *k)) + .collect() + } + + fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { + let multisets = self.multisets.lock().unwrap(); + let multiset = multisets.get_index(value.bits as usize).unwrap().clone(); + let mut changed = false; + let new_multiset = multiset.map(|e| { + let mut e = *e; + changed |= self.element.canonicalize(&mut e, unionfind); + e + }); + drop(multisets); + *value = new_multiset.store(self).unwrap(); + changed + } + + fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { + typeinfo.add_primitive(MultiSetOf { + name: "multiset-of".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Insert { + name: "multiset-insert".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Contains { + name: "multiset-contains".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(NotContains { + name: "multiset-not-contains".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Remove { + name: "multiset-remove".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Length { + name: "multiset-length".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Pick { + name: "multiset-pick".into(), + multiset: self.clone(), + }); + let inner_name = self.element.name(); + let fn_sort = typeinfo.get_sort_by(|s: &Arc| { + (s.output.name() == inner_name) + && s.inputs.len() == 1 + && (s.inputs[0].name() == inner_name) + }); + // Only include map function if we already declared a function sort with the correct signature + if let Some(fn_sort) = fn_sort { + typeinfo.add_primitive(Map { + name: "unstable-multiset-map".into(), + multiset: self.clone(), + fn_: fn_sort, + }); + } + } + + fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) { + let mut termdag = TermDag::default(); + let extractor = Extractor::new(egraph, &mut termdag); + self.extract_expr(egraph, value, &extractor, &mut termdag) + .expect("Extraction should be successful since extractor has been fully initialized") + } + + fn extract_expr( + &self, + _egraph: &EGraph, + value: Value, + extractor: &Extractor, + termdag: &mut TermDag, + ) -> Option<(Cost, Expr)> { + let multiset = ValueMultiSet::load(self, &value); + let mut children = vec![]; + let mut cost = 0usize; + for e in multiset.iter() { + let (child_cost, child_term) = extractor.find_best(*e, termdag, &self.element)?; + cost = cost.saturating_add(child_cost); + children.push(termdag.term_to_expr(&child_term)); + } + let expr = Expr::call_no_span("multiset-of", children); + Some((cost, expr)) + } + + fn serialized_name(&self, _value: &Value) -> Symbol { + "multiset-of".into() + } +} + +impl IntoSort for ValueMultiSet { + type Sort = MultiSetSort; + fn store(self, sort: &Self::Sort) -> Option { + let mut multisets = sort.multisets.lock().unwrap(); + let (i, _) = multisets.insert_full(self); + Some(Value { + tag: sort.name, + bits: i as u64, + }) + } +} + +impl FromSort for ValueMultiSet { + type Sort = MultiSetSort; + fn load(sort: &Self::Sort, value: &Value) -> Self { + let sets = sort.multisets.lock().unwrap(); + sets.get_index(value.bits as usize).unwrap().clone() + } +} + +struct MultiSetOf { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for MultiSetOf { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + AllEqualTypeConstraint::new(self.name(), span.clone()) + .with_all_arguments_sort(self.multiset.element()) + .with_output_sort(self.multiset.clone()) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = MultiSet::from_iter(values.iter().copied()); + Some(multiset.store(&self.multiset).unwrap()) + } +} + +struct Insert { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Insert { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + let multiset = multiset.insert(values[1]); + multiset.store(&self.multiset) + } +} + +struct Contains { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Contains { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + Arc::new(UnitSort), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + if multiset.contains(&values[1]) { + Some(Value::unit()) + } else { + None + } + } +} + +struct NotContains { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for NotContains { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + Arc::new(UnitSort), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + if !multiset.contains(&values[1]) { + Some(Value::unit()) + } else { + None + } + } +} + +struct Length { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Length { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![self.multiset.clone(), Arc::new(I64Sort)], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + Some(Value::from(multiset.len() as i64)) + } +} + +struct Remove { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Remove { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + let multiset = multiset.remove(&values[1]); + multiset.store(&self.multiset) + } +} + +struct Pick { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Pick { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![self.multiset.clone(), self.multiset.element()], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + Some(*multiset.pick().expect("Cannot pick from an empty multiset")) + } +} + +struct Map { + name: Symbol, + multiset: Arc, + fn_: Arc, +} + +impl PrimitiveLike for Map { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.fn_.clone(), + self.multiset.clone(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], egraph: Option<&mut EGraph>) -> Option { + let egraph = + egraph.unwrap_or_else(|| panic!("`{}` is not supported yet in facts.", self.name)); + let multiset = ValueMultiSet::load(&self.multiset, &values[1]); + let new_multiset = multiset.map(|e| self.fn_.apply(&values[0], &[*e], egraph)); + new_multiset.store(&self.multiset) + } +} diff --git a/src/typechecking.rs b/src/typechecking.rs index 80f0e464..3696abb0 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -47,6 +47,7 @@ impl Default for TypeInfo { res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); + res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_primitive(ValueEq); diff --git a/tests/eqsat-basic-multiset.egg b/tests/eqsat-basic-multiset.egg new file mode 100644 index 00000000..e4a9fa0c --- /dev/null +++ b/tests/eqsat-basic-multiset.egg @@ -0,0 +1,124 @@ +;; Example showing how to use multisets to hold associative & commutative operations + +(datatype* + (Math + (Num i64) + (Var String) + (Add Math Math) + (Mul Math Math) + (Product MathMultiSet) + (Sum MathMultiSet)) + (sort MathToMath (UnstableFn (Math) Math)) + (sort MathMultiSet (MultiSet Math))) + +;; expr1 = 2 * (x + 3) +(let expr1 (Mul (Num 2) (Add (Var "x") (Num 3)))) +;; expr2 = 6 + 2 * x +(let expr2 (Add (Num 6) (Mul (Num 2) (Var "x")))) + +(rewrite (Add a b) (Sum (multiset-of a b))) +(rewrite (Mul a b) (Product (multiset-of a b))) + +;; 0 or 1 elements sums/products also can be extracted back to numbers +(rule + ( + (= sum (Sum sum-inner)) + (= 0 (multiset-length sum-inner)) + ) + ((union sum (Num 0))) +) +(rule + ( + (= sum (Sum sum-inner)) + (= 1 (multiset-length sum-inner)) + ) + ((union sum (multiset-pick sum-inner))) +) + +(rule + ( + (= product (Product product-inner)) + (= 0 (multiset-length product-inner)) + ) + ((union product (Num 1))) +) +(rule + ( + (= product (Product product-inner)) + (= 1 (multiset-length product-inner)) + ) + ((union product (multiset-pick product-inner))) +) + +; (rewrite (Mul a (Add b c)) +; (Add (Mul a b) (Mul a c))) + +; -> we would like to write it like this, but cannot (yet) bc we can't match on the inner structure of the multisets +; and we don't support anonymous functions + +; (rewrite (Product (multiset-insert a (Sum bc))) +; (Sum (multiset-map (lambda (x) (Product (multiset-insert a x))) bc))) + + +;; so instead we can define a function and partially apply it to get the same function as the lambda +(function tmp-fn (MathMultiSet Math) Math) +(rewrite (tmp-fn xs x) (Product (multiset-insert xs x))) + +(rule + ( + ;; and we can do a cross product search of all possible pairs of products/sums to find one we want + (= sum (Sum bc)) + (= product (Product product-inner)) + (multiset-contains product-inner sum) + (> (multiset-length product-inner) 1) + (= a (multiset-remove product-inner sum)) + ) + ( + (union product (Sum + (unstable-multiset-map + (unstable-fn "tmp-fn" a) + bc) + )) + ) +) + +; (rewrite (Add (Num a) (Num b)) +; (Num (+ a b))) + +(rule + ( + (= sum (Sum sum-inner)) + (= num-a (Num a)) + (multiset-contains sum-inner num-a) + (= without-a (multiset-remove sum-inner num-a)) + (= num-b (Num b)) + (multiset-contains without-a num-b) + ) + ( + (union sum + (Sum (multiset-insert (multiset-remove without-a num-b) (Num (+ a b)))) + ) + ) +) + +; (rewrite (Mul (Num a) (Num b)) +; (Num (* a b))) + +(rule + ( + (= product (Product product-inner)) + (= num-a (Num a)) + (multiset-contains product-inner num-a) + (= without-a (multiset-remove product-inner num-a)) + (= num-b (Num b)) + (multiset-contains without-a num-b) + ) + ( + (union product + (Product (multiset-insert (multiset-remove without-a num-b) (Num (* a b)))) + ) + ) +) + +(run 100) +(check (= expr1 expr2)) diff --git a/tests/multiset.egg b/tests/multiset.egg new file mode 100644 index 00000000..795281cd --- /dev/null +++ b/tests/multiset.egg @@ -0,0 +1,58 @@ +(datatype Math (Num i64)) +(sort MathToMath (UnstableFn (Math) Math)) +(sort Maths (MultiSet Math)) + +(let xs (multiset-of (Num 1) (Num 2) (Num 3))) + +;; verify equal to other ordering +(check (= + (multiset-of (Num 3) (Num 2) (Num 1)) + xs +)) + +;; verify not equal to different counts +(check (!= + (multiset-of (Num 3) (Num 2) (Num 1) (Num 1)) + xs +)) + +;; Unclear why check won't work if this is defined inline +(let inserted (multiset-insert xs (Num 4))) +;; insert +(check (= + (multiset-of (Num 1) (Num 2) (Num 3) (Num 4)) + inserted +)) + + +;; contains and not contains +(check (multiset-contains xs (Num 1))) +(check (multiset-not-contains xs (Num 4))) + +;; remove last +(check (= + (multiset-of (Num 1) (Num 3)) + (multiset-remove xs (Num 2)) +)) +;; remove one of +(check (= (multiset-of (Num 1)) (multiset-remove (multiset-of (Num 1) (Num 1)) (Num 1)))) + + +;; length +(check (= 3 (multiset-length xs))) +;; length repeated +(check (= 3 (multiset-length (multiset-of (Num 1) (Num 1) (Num 1))))) + +;; pick +(check (= (Num 1) (multiset-pick (multiset-of (Num 1))))) + +;; map +(function square (Math) Math) +(rewrite (square (Num x)) (Num (* x x))) + +(let squared-xs (unstable-multiset-map (unstable-fn "square") xs)) +(run 1) +(check (= + (multiset-of (Num 1) (Num 4) (Num 9)) + squared-xs +))