From 23cfd61b65cb47e803364056d1894c4dfbe0a91e Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Fri, 18 Oct 2024 15:50:46 -0400 Subject: [PATCH 1/6] Add multisets This PR adds a multiset sort. It is based on a data structure that implements functional sharing. Using that sort, an example is added to show how you can use it to express associative & commutative operations like addition in multiplication with multisets, so that their canonical forms don't need to re-encoded for every ordering. See [these](https://egraphs.zulipchat.com/#narrow/channel/328972-general/topic/Reducing.20node.20explosion.20through.20algebraic.20representations.3F) [threads](https://egraphs.zulipchat.com/#narrow/channel/328972-general/topic/Linear.20and.20Polynomial.20Equations) on zulip for some more background. --- Cargo.lock | 25 ++ Cargo.toml | 1 + src/core.rs | 7 +- src/lib.rs | 1 + src/multiset.rs | 86 +++++++ src/sort/fn.rs | 39 +-- src/sort/mod.rs | 2 + src/sort/multiset.rs | 440 +++++++++++++++++++++++++++++++++ src/typechecking.rs | 1 + tests/eqsat-basic-multiset.egg | 124 ++++++++++ tests/multiset.egg | 58 +++++ 11 files changed, 765 insertions(+), 19 deletions(-) create mode 100644 src/multiset.rs create mode 100644 src/sort/multiset.rs create mode 100644 tests/eqsat-basic-multiset.egg create mode 100644 tests/multiset.egg diff --git a/Cargo.lock b/Cargo.lock index bcfa1b29..a3a0d013 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,6 +94,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "archery" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae2ed21cd55021f05707a807a5fc85695dafb98832921f6cfa06db67ca5b869" +dependencies = [ + "triomphe", +] + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -466,6 +475,7 @@ dependencies = [ "num-traits", "ordered-float", "regex", + "rpds", "rustc-hash", "serde_json", "smallvec", @@ -1187,6 +1197,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "rpds" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0e15515d3ce3313324d842629ea4905c25a13f81953eadb88f85516f59290a4" +dependencies = [ + "archery", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1431,6 +1450,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "triomphe" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index c1664496..364fb25a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ ordered-float = { version = "3.7" } getrandom = { version = "0.2.10", features = ["js"], optional = true } im-rc = "15.1.0" +rpds = "1.1.0" [build-dependencies] diff --git a/src/core.rs b/src/core.rs index 5d3345ac..432d9680 100644 --- a/src/core.rs +++ b/src/core.rs @@ -107,8 +107,11 @@ impl ResolvedCall { } } } - - assert!(resolved_call.len() == 1); + assert!( + resolved_call.len() == 1, + "Ambiguous resolution for {:?}", + head, + ); resolved_call.pop().unwrap() } } diff --git a/src/lib.rs b/src/lib.rs index c5e477c6..34c7f097 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ mod core; mod extract; mod function; mod gj; +mod multiset; mod serialize; pub mod sort; mod termdag; diff --git a/src/multiset.rs b/src/multiset.rs new file mode 100644 index 00000000..9f680c5c --- /dev/null +++ b/src/multiset.rs @@ -0,0 +1,86 @@ +use rpds::RedBlackTreeMapSync; +use std::hash::Hash; + +/// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order. +/// +/// All methods that return a new multiset take ownership of the old multiset. +#[derive(Debug, Hash, Eq, PartialEq, Clone)] +pub(crate) struct MultiSet( + /// All values should be > 0 + RedBlackTreeMapSync, +); + +impl MultiSet { + /// Create a new empty multiset. + pub(crate) fn new() -> Self { + MultiSet(RedBlackTreeMapSync::new_sync()) + } + + /// Check if the multiset contains a key. + pub(crate) fn contains(&self, value: &T) -> bool { + self.0.contains_key(value) + } + + /// Return the total number of elements in the multiset. + pub(crate) fn len(&self) -> usize { + self.0.iter().map(|(_, v)| *v).sum() + } + + /// Return an iterator over all elements in the multiset. + pub(crate) fn iter(&self) -> impl Iterator { + self.0 + .iter() + .flat_map(|(k, v)| std::iter::repeat(k).take(*v)) + } + + /// Return an arbitrary element from the multiset. + pub(crate) fn pick(&self) -> Option<&T> { + self.0.first().map(|(k, _)| k) + } + + /// Map a function over all elements in the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn map(self, mut f: impl FnMut(&T) -> T) -> MultiSet { + let mut new = MultiSet::new(); + for (k, v) in self.0.into_iter() { + new.insert_multiple_mut(f(k), *v); + } + new + } + + /// Insert a value into the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn insert(mut self, value: T) -> MultiSet { + self.insert_multiple_mut(value, 1); + self + } + + /// Remove a value from the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn remove(mut self, value: &T) -> Option> { + if let Some(v) = self.0.get(value) { + if *v == 1 { + self.0.remove_mut(value); + } else { + self.0.insert_mut(value.clone(), v - 1); + } + Some(self) + } else { + None + } + } + + fn insert_multiple_mut(&mut self, value: T, n: usize) { + if let Some(v) = self.0.get(&value) { + self.0.insert_mut(value, v + n); + } else { + self.0.insert_mut(value, n); + } + } + + /// Create a multiset from an iterator. + pub(crate) fn from_iter(iter: impl IntoIterator) -> Self { + let mut multiset = MultiSet::new(); + for value in iter { + multiset.insert_multiple_mut(value, 1); + } + multiset + } +} diff --git a/src/sort/fn.rs b/src/sort/fn.rs index 05f42a92..b9c70ef0 100644 --- a/src/sort/fn.rs +++ b/src/sort/fn.rs @@ -48,8 +48,8 @@ impl Eq for ValueFunction {} #[derive(Debug)] pub struct FunctionSort { name: Symbol, - inputs: Vec, - output: ArcSort, + pub(crate) inputs: Vec, + pub(crate) output: ArcSort, functions: Mutex>, } @@ -58,6 +58,25 @@ impl FunctionSort { let functions = self.functions.lock().unwrap(); functions.get_index(value.bits as usize).unwrap().clone() } + + /// Apply the function to the values + /// + /// Public so that other primitive sorts (external or internal) can use this to apply functions + pub fn apply(&self, fn_value: &Value, arg_values: &[Value], egraph: &mut EGraph) -> Value { + let ValueFunction(name, args) = self.get_value(fn_value); + let types: Vec<_> = args + .iter() + .map(|(sort, _)| sort.clone()) + .chain(self.inputs.clone()) + .chain(once(self.output.clone())) + .collect(); + let values = args + .iter() + .map(|(_, v)| *v) + .chain(arg_values.iter().cloned()) + .collect(); + call_fn(egraph, &name, types, values) + } } impl Presort for FunctionSort { @@ -368,21 +387,7 @@ impl PrimitiveLike for Apply { fn apply(&self, values: &[Value], egraph: Option<&mut EGraph>) -> Option { let egraph = egraph.expect("`unstable-app` is not supported yet in facts."); - let ValueFunction(name, args) = ValueFunction::load(&self.function, &values[0]); - let types: Vec<_> = args - .iter() - // get the sorts of partially applied args - .map(|(sort, _)| sort.clone()) - // combine with the args for the function call and then the output - .chain(self.function.inputs.clone()) - .chain(once(self.function.output.clone())) - .collect(); - let values = args - .iter() - .map(|(_, v)| *v) - .chain(values[1..].iter().copied()) - .collect(); - Some(call_fn(egraph, &name, types, values)) + Some(self.function.apply(&values[0], &values[1..], egraph)) } } diff --git a/src/sort/mod.rs b/src/sort/mod.rs index 9f361cc4..e43a1f43 100644 --- a/src/sort/mod.rs +++ b/src/sort/mod.rs @@ -24,6 +24,8 @@ mod vec; pub use vec::*; mod r#fn; pub use r#fn::*; +mod multiset; +pub use multiset::*; use crate::constraint::AllEqualTypeConstraint; use crate::extract::{Cost, Extractor}; diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs new file mode 100644 index 00000000..5b7c8982 --- /dev/null +++ b/src/sort/multiset.rs @@ -0,0 +1,440 @@ +use std::sync::Mutex; + +use super::*; +use crate::constraint::{AllEqualTypeConstraint, SimpleTypeConstraint}; +use crate::multiset::MultiSet; + +type ValueMultiSet = MultiSet; + +#[derive(Debug)] +pub struct MultiSetSort { + name: Symbol, + element: ArcSort, + multisets: Mutex>, +} + +impl MultiSetSort { + pub fn element(&self) -> ArcSort { + self.element.clone() + } + + pub fn element_name(&self) -> Symbol { + self.element.name() + } +} + +impl Presort for MultiSetSort { + fn presort_name() -> Symbol { + "MultiSet".into() + } + + fn reserved_primitives() -> Vec { + vec![ + "multiset-of".into(), + "multiset-insert".into(), + "multiset-contains".into(), + "multiset-not-contains".into(), + "multiset-remove".into(), + "multiset-length".into(), + "unstable-multiset-map".into(), + ] + } + + fn make_sort( + typeinfo: &mut TypeInfo, + name: Symbol, + args: &[Expr], + ) -> Result { + if let [Expr::Var(span, e)] = args { + let e = typeinfo + .sorts + .get(e) + .ok_or(TypeError::UndefinedSort(*e, span.clone()))?; + + if e.is_eq_container_sort() { + return Err(TypeError::DisallowedSort( + name, + "Multisets nested with other EqSort containers are not allowed".into(), + span.clone(), + )); + } + + Ok(Arc::new(Self { + name, + element: e.clone(), + multisets: Default::default(), + })) + } else { + panic!() + } + } +} + +impl Sort for MultiSetSort { + fn name(&self) -> Symbol { + self.name + } + + fn as_arc_any(self: Arc) -> Arc { + self + } + + fn is_container_sort(&self) -> bool { + true + } + + fn is_eq_container_sort(&self) -> bool { + self.element.is_eq_sort() + } + + fn inner_values(&self, value: &Value) -> Vec<(ArcSort, Value)> { + let multisets = self.multisets.lock().unwrap(); + let multiset = multisets.get_index(value.bits as usize).unwrap(); + multiset + .iter() + .map(|k| (self.element.clone(), *k)) + .collect() + } + + fn canonicalize(&self, value: &mut Value, unionfind: &UnionFind) -> bool { + let multisets = self.multisets.lock().unwrap(); + let multiset = multisets.get_index(value.bits as usize).unwrap().clone(); + let mut changed = false; + let new_multiset = multiset.map(|e| { + let mut e = *e; + changed |= self.element.canonicalize(&mut e, unionfind); + e + }); + drop(multisets); + *value = new_multiset.store(self).unwrap(); + changed + } + + fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { + typeinfo.add_primitive(MultiSetOf { + name: "multiset-of".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Insert { + name: "multiset-insert".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Contains { + name: "multiset-contains".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(NotContains { + name: "multiset-not-contains".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Remove { + name: "multiset-remove".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Length { + name: "multiset-length".into(), + multiset: self.clone(), + }); + typeinfo.add_primitive(Pick { + name: "multiset-pick".into(), + multiset: self.clone(), + }); + let inner_name = self.element.name(); + let fn_sort = typeinfo.get_sort_by(|s: &Arc| { + (s.output.name() == inner_name) + && s.inputs.len() == 1 + && (s.inputs[0].name() == inner_name) + }); + // Only include map function if we already declared a function sort with the correct signature + if let Some(fn_sort) = fn_sort { + typeinfo.add_primitive(Map { + name: "unstable-multiset-map".into(), + multiset: self.clone(), + fn_: fn_sort, + }); + } + } + + fn make_expr(&self, egraph: &EGraph, value: Value) -> (Cost, Expr) { + let mut termdag = TermDag::default(); + let extractor = Extractor::new(egraph, &mut termdag); + self.extract_expr(egraph, value, &extractor, &mut termdag) + .expect("Extraction should be successful since extractor has been fully initialized") + } + + fn extract_expr( + &self, + _egraph: &EGraph, + value: Value, + extractor: &Extractor, + termdag: &mut TermDag, + ) -> Option<(Cost, Expr)> { + let multiset = ValueMultiSet::load(self, &value); + let mut children = vec![]; + // let mut expr = Expr::call_no_span("set-empty", []); + let mut cost = 0usize; + for e in multiset.iter() { + let (child_cost, child_term) = extractor.find_best(*e, termdag, &self.element)?; + cost = cost.saturating_add(child_cost); + children.push(termdag.term_to_expr(&child_term)); + } + let expr = Expr::call_no_span("multiset-of", children); + Some((cost, expr)) + } + + fn serialized_name(&self, _value: &Value) -> Symbol { + "multiset-of".into() + } +} + +impl IntoSort for ValueMultiSet { + type Sort = MultiSetSort; + fn store(self, sort: &Self::Sort) -> Option { + let mut multisets = sort.multisets.lock().unwrap(); + let (i, _) = multisets.insert_full(self); + Some(Value { + tag: sort.name, + bits: i as u64, + }) + } +} + +impl FromSort for ValueMultiSet { + type Sort = MultiSetSort; + fn load(sort: &Self::Sort, value: &Value) -> Self { + let sets = sort.multisets.lock().unwrap(); + sets.get_index(value.bits as usize).unwrap().clone() + } +} + +struct MultiSetOf { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for MultiSetOf { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + AllEqualTypeConstraint::new(self.name(), span.clone()) + .with_all_arguments_sort(self.multiset.element()) + .with_output_sort(self.multiset.clone()) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = MultiSet::from_iter(values.iter().copied()); + Some(multiset.store(&self.multiset).unwrap()) + } +} + +struct Insert { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Insert { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + let multiset = multiset.insert(values[1]); + multiset.store(&self.multiset) + } +} + +struct Contains { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Contains { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + Arc::new(UnitSort), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + if multiset.contains(&values[1]) { + Some(Value::unit()) + } else { + None + } + } +} + +struct NotContains { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for NotContains { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + Arc::new(UnitSort), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + if !multiset.contains(&values[1]) { + Some(Value::unit()) + } else { + None + } + } +} + +struct Length { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Length { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![self.multiset.clone(), Arc::new(I64Sort)], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + Some(Value::from(multiset.len() as i64)) + } +} + +struct Remove { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Remove { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.multiset.clone(), + self.multiset.element(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + let multiset = multiset.remove(&values[1]); + multiset.store(&self.multiset) + } +} + +struct Pick { + name: Symbol, + multiset: Arc, +} + +impl PrimitiveLike for Pick { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![self.multiset.clone(), self.multiset.element()], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], _egraph: Option<&mut EGraph>) -> Option { + let multiset = ValueMultiSet::load(&self.multiset, &values[0]); + Some(*multiset.pick().expect("Cannot pick from an empty multiset")) + } +} + +struct Map { + name: Symbol, + multiset: Arc, + fn_: Arc, +} + +impl PrimitiveLike for Map { + fn name(&self) -> Symbol { + self.name + } + + fn get_type_constraints(&self, span: &Span) -> Box { + SimpleTypeConstraint::new( + self.name(), + vec![ + self.fn_.clone(), + self.multiset.clone(), + self.multiset.clone(), + ], + span.clone(), + ) + .into_box() + } + + fn apply(&self, values: &[Value], egraph: Option<&mut EGraph>) -> Option { + let egraph = + egraph.unwrap_or_else(|| panic!("`{}` is not supported yet in facts.", self.name)); + let multiset = ValueMultiSet::load(&self.multiset, &values[1]); + let new_multiset = multiset.map(|e| self.fn_.apply(&values[0], &[*e], egraph)); + new_multiset.store(&self.multiset) + } +} diff --git a/src/typechecking.rs b/src/typechecking.rs index 80f0e464..3696abb0 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -47,6 +47,7 @@ impl Default for TypeInfo { res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); + res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_primitive(ValueEq); diff --git a/tests/eqsat-basic-multiset.egg b/tests/eqsat-basic-multiset.egg new file mode 100644 index 00000000..e4a9fa0c --- /dev/null +++ b/tests/eqsat-basic-multiset.egg @@ -0,0 +1,124 @@ +;; Example showing how to use multisets to hold associative & commutative operations + +(datatype* + (Math + (Num i64) + (Var String) + (Add Math Math) + (Mul Math Math) + (Product MathMultiSet) + (Sum MathMultiSet)) + (sort MathToMath (UnstableFn (Math) Math)) + (sort MathMultiSet (MultiSet Math))) + +;; expr1 = 2 * (x + 3) +(let expr1 (Mul (Num 2) (Add (Var "x") (Num 3)))) +;; expr2 = 6 + 2 * x +(let expr2 (Add (Num 6) (Mul (Num 2) (Var "x")))) + +(rewrite (Add a b) (Sum (multiset-of a b))) +(rewrite (Mul a b) (Product (multiset-of a b))) + +;; 0 or 1 elements sums/products also can be extracted back to numbers +(rule + ( + (= sum (Sum sum-inner)) + (= 0 (multiset-length sum-inner)) + ) + ((union sum (Num 0))) +) +(rule + ( + (= sum (Sum sum-inner)) + (= 1 (multiset-length sum-inner)) + ) + ((union sum (multiset-pick sum-inner))) +) + +(rule + ( + (= product (Product product-inner)) + (= 0 (multiset-length product-inner)) + ) + ((union product (Num 1))) +) +(rule + ( + (= product (Product product-inner)) + (= 1 (multiset-length product-inner)) + ) + ((union product (multiset-pick product-inner))) +) + +; (rewrite (Mul a (Add b c)) +; (Add (Mul a b) (Mul a c))) + +; -> we would like to write it like this, but cannot (yet) bc we can't match on the inner structure of the multisets +; and we don't support anonymous functions + +; (rewrite (Product (multiset-insert a (Sum bc))) +; (Sum (multiset-map (lambda (x) (Product (multiset-insert a x))) bc))) + + +;; so instead we can define a function and partially apply it to get the same function as the lambda +(function tmp-fn (MathMultiSet Math) Math) +(rewrite (tmp-fn xs x) (Product (multiset-insert xs x))) + +(rule + ( + ;; and we can do a cross product search of all possible pairs of products/sums to find one we want + (= sum (Sum bc)) + (= product (Product product-inner)) + (multiset-contains product-inner sum) + (> (multiset-length product-inner) 1) + (= a (multiset-remove product-inner sum)) + ) + ( + (union product (Sum + (unstable-multiset-map + (unstable-fn "tmp-fn" a) + bc) + )) + ) +) + +; (rewrite (Add (Num a) (Num b)) +; (Num (+ a b))) + +(rule + ( + (= sum (Sum sum-inner)) + (= num-a (Num a)) + (multiset-contains sum-inner num-a) + (= without-a (multiset-remove sum-inner num-a)) + (= num-b (Num b)) + (multiset-contains without-a num-b) + ) + ( + (union sum + (Sum (multiset-insert (multiset-remove without-a num-b) (Num (+ a b)))) + ) + ) +) + +; (rewrite (Mul (Num a) (Num b)) +; (Num (* a b))) + +(rule + ( + (= product (Product product-inner)) + (= num-a (Num a)) + (multiset-contains product-inner num-a) + (= without-a (multiset-remove product-inner num-a)) + (= num-b (Num b)) + (multiset-contains without-a num-b) + ) + ( + (union product + (Product (multiset-insert (multiset-remove without-a num-b) (Num (* a b)))) + ) + ) +) + +(run 100) +(check (= expr1 expr2)) diff --git a/tests/multiset.egg b/tests/multiset.egg new file mode 100644 index 00000000..795281cd --- /dev/null +++ b/tests/multiset.egg @@ -0,0 +1,58 @@ +(datatype Math (Num i64)) +(sort MathToMath (UnstableFn (Math) Math)) +(sort Maths (MultiSet Math)) + +(let xs (multiset-of (Num 1) (Num 2) (Num 3))) + +;; verify equal to other ordering +(check (= + (multiset-of (Num 3) (Num 2) (Num 1)) + xs +)) + +;; verify not equal to different counts +(check (!= + (multiset-of (Num 3) (Num 2) (Num 1) (Num 1)) + xs +)) + +;; Unclear why check won't work if this is defined inline +(let inserted (multiset-insert xs (Num 4))) +;; insert +(check (= + (multiset-of (Num 1) (Num 2) (Num 3) (Num 4)) + inserted +)) + + +;; contains and not contains +(check (multiset-contains xs (Num 1))) +(check (multiset-not-contains xs (Num 4))) + +;; remove last +(check (= + (multiset-of (Num 1) (Num 3)) + (multiset-remove xs (Num 2)) +)) +;; remove one of +(check (= (multiset-of (Num 1)) (multiset-remove (multiset-of (Num 1) (Num 1)) (Num 1)))) + + +;; length +(check (= 3 (multiset-length xs))) +;; length repeated +(check (= 3 (multiset-length (multiset-of (Num 1) (Num 1) (Num 1))))) + +;; pick +(check (= (Num 1) (multiset-pick (multiset-of (Num 1))))) + +;; map +(function square (Math) Math) +(rewrite (square (Num x)) (Num (* x x))) + +(let squared-xs (unstable-multiset-map (unstable-fn "square") xs)) +(run 1) +(check (= + (multiset-of (Num 1) (Num 4) (Num 9)) + squared-xs +)) From 76b88cf4b2d1e14c3b15e6c01e7107d6ebe450c5 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 22 Oct 2024 11:39:26 -0400 Subject: [PATCH 2/6] Make function sorts fields pub --- src/sort/fn.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sort/fn.rs b/src/sort/fn.rs index b9c70ef0..7f84588f 100644 --- a/src/sort/fn.rs +++ b/src/sort/fn.rs @@ -48,8 +48,9 @@ impl Eq for ValueFunction {} #[derive(Debug)] pub struct FunctionSort { name: Symbol, - pub(crate) inputs: Vec, - pub(crate) output: ArcSort, + // Public so that other primitive sorts (external or internal) can find a function sort by the sorts of its inputs/output + pub inputs: Vec, + pub output: ArcSort, functions: Mutex>, } From 23564a4e6db1eedbc7d365c54ff232b55ec09e23 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 22 Oct 2024 11:39:46 -0400 Subject: [PATCH 3/6] Remove comment --- src/sort/multiset.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs index 5b7c8982..1da4e4e6 100644 --- a/src/sort/multiset.rs +++ b/src/sort/multiset.rs @@ -171,7 +171,6 @@ impl Sort for MultiSetSort { ) -> Option<(Cost, Expr)> { let multiset = ValueMultiSet::load(self, &value); let mut children = vec![]; - // let mut expr = Expr::call_no_span("set-empty", []); let mut cost = 0usize; for e in multiset.iter() { let (child_cost, child_term) = extractor.find_best(*e, termdag, &self.element)?; From b6fb41cf8ab6e1f082feaeb028b9a64417bf87d4 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 22 Oct 2024 11:50:24 -0400 Subject: [PATCH 4/6] Combine multiset files --- src/lib.rs | 1 - src/multiset.rs | 86 ----------------------------------------- src/sort/multiset.rs | 92 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 88 deletions(-) delete mode 100644 src/multiset.rs diff --git a/src/lib.rs b/src/lib.rs index 34c7f097..c5e477c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,6 @@ mod core; mod extract; mod function; mod gj; -mod multiset; mod serialize; pub mod sort; mod termdag; diff --git a/src/multiset.rs b/src/multiset.rs deleted file mode 100644 index 9f680c5c..00000000 --- a/src/multiset.rs +++ /dev/null @@ -1,86 +0,0 @@ -use rpds::RedBlackTreeMapSync; -use std::hash::Hash; - -/// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order. -/// -/// All methods that return a new multiset take ownership of the old multiset. -#[derive(Debug, Hash, Eq, PartialEq, Clone)] -pub(crate) struct MultiSet( - /// All values should be > 0 - RedBlackTreeMapSync, -); - -impl MultiSet { - /// Create a new empty multiset. - pub(crate) fn new() -> Self { - MultiSet(RedBlackTreeMapSync::new_sync()) - } - - /// Check if the multiset contains a key. - pub(crate) fn contains(&self, value: &T) -> bool { - self.0.contains_key(value) - } - - /// Return the total number of elements in the multiset. - pub(crate) fn len(&self) -> usize { - self.0.iter().map(|(_, v)| *v).sum() - } - - /// Return an iterator over all elements in the multiset. - pub(crate) fn iter(&self) -> impl Iterator { - self.0 - .iter() - .flat_map(|(k, v)| std::iter::repeat(k).take(*v)) - } - - /// Return an arbitrary element from the multiset. - pub(crate) fn pick(&self) -> Option<&T> { - self.0.first().map(|(k, _)| k) - } - - /// Map a function over all elements in the multiset, taking ownership of it and returning a new multiset. - pub(crate) fn map(self, mut f: impl FnMut(&T) -> T) -> MultiSet { - let mut new = MultiSet::new(); - for (k, v) in self.0.into_iter() { - new.insert_multiple_mut(f(k), *v); - } - new - } - - /// Insert a value into the multiset, taking ownership of it and returning a new multiset. - pub(crate) fn insert(mut self, value: T) -> MultiSet { - self.insert_multiple_mut(value, 1); - self - } - - /// Remove a value from the multiset, taking ownership of it and returning a new multiset. - pub(crate) fn remove(mut self, value: &T) -> Option> { - if let Some(v) = self.0.get(value) { - if *v == 1 { - self.0.remove_mut(value); - } else { - self.0.insert_mut(value.clone(), v - 1); - } - Some(self) - } else { - None - } - } - - fn insert_multiple_mut(&mut self, value: T, n: usize) { - if let Some(v) = self.0.get(&value) { - self.0.insert_mut(value, v + n); - } else { - self.0.insert_mut(value, n); - } - } - - /// Create a multiset from an iterator. - pub(crate) fn from_iter(iter: impl IntoIterator) -> Self { - let mut multiset = MultiSet::new(); - for value in iter { - multiset.insert_multiple_mut(value, 1); - } - multiset - } -} diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs index 1da4e4e6..e49be8c4 100644 --- a/src/sort/multiset.rs +++ b/src/sort/multiset.rs @@ -1,8 +1,98 @@ use std::sync::Mutex; +use inner::MultiSet; + use super::*; use crate::constraint::{AllEqualTypeConstraint, SimpleTypeConstraint}; -use crate::multiset::MultiSet; + +// Place multiset in its own module to keep implementation details private from sort +mod inner { + use rpds::RedBlackTreeMapSync; + use std::hash::Hash; + /// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order. + /// + /// All methods that return a new multiset take ownership of the old multiset. + #[derive(Debug, Hash, Eq, PartialEq, Clone)] + pub(crate) struct MultiSet( + /// All values should be > 0 + RedBlackTreeMapSync, + ); + + impl MultiSet { + /// Create a new empty multiset. + pub(crate) fn new() -> Self { + MultiSet(RedBlackTreeMapSync::new_sync()) + } + + /// Check if the multiset contains a key. + pub(crate) fn contains(&self, value: &T) -> bool { + self.0.contains_key(value) + } + + /// Return the total number of elements in the multiset. + pub(crate) fn len(&self) -> usize { + self.0.iter().map(|(_, v)| *v).sum() + } + + /// Return an iterator over all elements in the multiset. + pub(crate) fn iter(&self) -> impl Iterator { + self.0 + .iter() + .flat_map(|(k, v)| std::iter::repeat(k).take(*v)) + } + + /// Return an arbitrary element from the multiset. + pub(crate) fn pick(&self) -> Option<&T> { + self.0.first().map(|(k, _)| k) + } + + /// Map a function over all elements in the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn map(self, mut f: impl FnMut(&T) -> T) -> MultiSet { + let mut new = MultiSet::new(); + for (k, v) in self.0.into_iter() { + new.insert_multiple_mut(f(k), *v); + } + new + } + + /// Insert a value into the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn insert(mut self, value: T) -> MultiSet { + self.insert_multiple_mut(value, 1); + self + } + + /// Remove a value from the multiset, taking ownership of it and returning a new multiset. + pub(crate) fn remove(mut self, value: &T) -> Option> { + if let Some(v) = self.0.get(value) { + if *v == 1 { + self.0.remove_mut(value); + } else { + self.0.insert_mut(value.clone(), v - 1); + } + Some(self) + } else { + None + } + } + + fn insert_multiple_mut(&mut self, value: T, n: usize) { + if let Some(v) = self.0.get(&value) { + self.0.insert_mut(value, v + n); + } else { + self.0.insert_mut(value, n); + } + } + + /// Create a multiset from an iterator. + pub(crate) fn from_iter(iter: impl IntoIterator) -> Self { + let mut multiset = MultiSet::new(); + for value in iter { + multiset.insert_multiple_mut(value, 1); + } + multiset + } + } +} type ValueMultiSet = MultiSet; From b987f899bd26b886d8c29ea096ce545bd6e404f6 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 22 Oct 2024 12:11:42 -0400 Subject: [PATCH 5/6] Switch to im from rpds --- Cargo.lock | 40 +++++++++++++++------------------------- Cargo.toml | 2 +- src/sort/multiset.rs | 18 +++++++++--------- 3 files changed, 25 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3a0d013..aec4f84d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,15 +94,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "archery" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae2ed21cd55021f05707a807a5fc85695dafb98832921f6cfa06db67ca5b869" -dependencies = [ - "triomphe", -] - [[package]] name = "ascii-canvas" version = "3.0.0" @@ -462,6 +453,7 @@ dependencies = [ "getrandom", "glob", "hashbrown 0.14.3", + "im", "im-rc", "indexmap", "instant", @@ -475,7 +467,6 @@ dependencies = [ "num-traits", "ordered-float", "regex", - "rpds", "rustc-hash", "serde_json", "smallvec", @@ -651,6 +642,20 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "im" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0acd33ff0285af998aaf9b57342af478078f53492322fafc47450e09397e0e9" +dependencies = [ + "bitmaps", + "rand_core", + "rand_xoshiro", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "im-rc" version = "15.1.0" @@ -1197,15 +1202,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" -[[package]] -name = "rpds" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0e15515d3ce3313324d842629ea4905c25a13f81953eadb88f85516f59290a4" -dependencies = [ - "archery", -] - [[package]] name = "rustc-hash" version = "1.1.0" @@ -1450,12 +1446,6 @@ dependencies = [ "serde_json", ] -[[package]] -name = "triomphe" -version = "0.1.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef8f7726da4807b58ea5c96fdc122f80702030edc33b35aff9190a51148ccc85" - [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 364fb25a..e2dde686 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ ordered-float = { version = "3.7" } getrandom = { version = "0.2.10", features = ["js"], optional = true } im-rc = "15.1.0" -rpds = "1.1.0" +im = "15.1.0" [build-dependencies] diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs index e49be8c4..a2595864 100644 --- a/src/sort/multiset.rs +++ b/src/sort/multiset.rs @@ -7,7 +7,7 @@ use crate::constraint::{AllEqualTypeConstraint, SimpleTypeConstraint}; // Place multiset in its own module to keep implementation details private from sort mod inner { - use rpds::RedBlackTreeMapSync; + use im::OrdMap; use std::hash::Hash; /// Immutable multiset implementation, which is threadsafe and hash stable, regardless of insertion order. /// @@ -15,13 +15,13 @@ mod inner { #[derive(Debug, Hash, Eq, PartialEq, Clone)] pub(crate) struct MultiSet( /// All values should be > 0 - RedBlackTreeMapSync, + OrdMap, ); impl MultiSet { /// Create a new empty multiset. pub(crate) fn new() -> Self { - MultiSet(RedBlackTreeMapSync::new_sync()) + MultiSet(OrdMap::new()) } /// Check if the multiset contains a key. @@ -43,14 +43,14 @@ mod inner { /// Return an arbitrary element from the multiset. pub(crate) fn pick(&self) -> Option<&T> { - self.0.first().map(|(k, _)| k) + self.0.keys().next() } /// Map a function over all elements in the multiset, taking ownership of it and returning a new multiset. pub(crate) fn map(self, mut f: impl FnMut(&T) -> T) -> MultiSet { let mut new = MultiSet::new(); for (k, v) in self.0.into_iter() { - new.insert_multiple_mut(f(k), *v); + new.insert_multiple_mut(f(&k), v); } new } @@ -65,9 +65,9 @@ mod inner { pub(crate) fn remove(mut self, value: &T) -> Option> { if let Some(v) = self.0.get(value) { if *v == 1 { - self.0.remove_mut(value); + self.0.remove(value); } else { - self.0.insert_mut(value.clone(), v - 1); + self.0.insert(value.clone(), v - 1); } Some(self) } else { @@ -77,9 +77,9 @@ mod inner { fn insert_multiple_mut(&mut self, value: T, n: usize) { if let Some(v) = self.0.get(&value) { - self.0.insert_mut(value, v + n); + self.0.insert(value, v + n); } else { - self.0.insert_mut(value, n); + self.0.insert(value, n); } } From f090848e5e0b5a52b52e2c1a9f19088cfa4a50b9 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Tue, 22 Oct 2024 13:17:51 -0400 Subject: [PATCH 6/6] Cache length of multiset --- src/sort/multiset.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/sort/multiset.rs b/src/sort/multiset.rs index a2595864..f1c276e0 100644 --- a/src/sort/multiset.rs +++ b/src/sort/multiset.rs @@ -16,12 +16,14 @@ mod inner { pub(crate) struct MultiSet( /// All values should be > 0 OrdMap, + /// cached length + usize, ); impl MultiSet { /// Create a new empty multiset. pub(crate) fn new() -> Self { - MultiSet(OrdMap::new()) + MultiSet(OrdMap::new(), 0) } /// Check if the multiset contains a key. @@ -31,7 +33,7 @@ mod inner { /// Return the total number of elements in the multiset. pub(crate) fn len(&self) -> usize { - self.0.iter().map(|(_, v)| *v).sum() + self.1 } /// Return an iterator over all elements in the multiset. @@ -64,6 +66,7 @@ mod inner { /// Remove a value from the multiset, taking ownership of it and returning a new multiset. pub(crate) fn remove(mut self, value: &T) -> Option> { if let Some(v) = self.0.get(value) { + self.1 -= 1; if *v == 1 { self.0.remove(value); } else { @@ -76,6 +79,7 @@ mod inner { } fn insert_multiple_mut(&mut self, value: T, n: usize) { + self.1 += n; if let Some(v) = self.0.get(&value) { self.0.insert(value, v + n); } else {