Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable Value.tag in release mode #448

Merged
merged 13 commits into from
Oct 26, 2024
87 changes: 46 additions & 41 deletions src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ impl<'a> ActionCompiler<'a> {
self.locals.insert(v.clone());
}
GenericCoreAction::Extract(_ann, e, b) => {
self.do_atom_term(e);
let sort = self.do_atom_term(e);
self.do_atom_term(b);
self.instructions.push(Instruction::Extract(2));
self.instructions.push(Instruction::Extract(2, sort));
}
GenericCoreAction::Set(_ann, f, args, e) => {
let ResolvedCall::Func(func) = f else {
Expand All @@ -50,9 +50,9 @@ impl<'a> ActionCompiler<'a> {
.push(Instruction::Change(*change, func.name));
}
GenericCoreAction::Union(_ann, arg1, arg2) => {
self.do_atom_term(arg1);
let sort = self.do_atom_term(arg1);
self.do_atom_term(arg2);
self.instructions.push(Instruction::Union(2));
self.instructions.push(Instruction::Union(2, sort));
}
GenericCoreAction::Panic(_ann, msg) => {
self.instructions.push(Instruction::Panic(msg.clone()));
Expand All @@ -70,18 +70,21 @@ impl<'a> ActionCompiler<'a> {
}
}

fn do_atom_term(&mut self, at: &ResolvedAtomTerm) {
fn do_atom_term(&mut self, at: &ResolvedAtomTerm) -> ArcSort {
match at {
ResolvedAtomTerm::Var(_ann, var) => {
if let Some((i, _ty)) = self.locals.get_full(var) {
if let Some((i, ty)) = self.locals.get_full(var) {
self.instructions.push(Instruction::Load(Load::Stack(i)));
ty.sort.clone()
} else {
let (i, _, _ty) = self.types.get_full(&var.name).unwrap();
let (i, _, ty) = self.types.get_full(&var.name).unwrap();
self.instructions.push(Instruction::Load(Load::Subst(i)));
ty.clone()
}
}
ResolvedAtomTerm::Literal(_ann, lit) => {
self.instructions.push(Instruction::Literal(lit.clone()));
crate::sort::literal_sort(lit)
}
ResolvedAtomTerm::Global(_ann, _var) => {
panic!("Global variables should have been desugared");
Expand All @@ -97,10 +100,8 @@ impl<'a> ActionCompiler<'a> {
}

fn do_prim(&mut self, prim: &SpecializedPrimitive) {
self.instructions.push(Instruction::CallPrimitive(
prim.primitive.clone(),
prim.input.len(),
));
self.instructions
.push(Instruction::CallPrimitive(prim.clone(), prim.input.len()));
}
}

Expand All @@ -126,19 +127,19 @@ enum Instruction {
CallFunction(Symbol, bool),
/// Pop primitive arguments off the stack, calls the primitive,
/// and push the result onto the stack.
CallPrimitive(Primitive, usize),
CallPrimitive(SpecializedPrimitive, usize),
/// Pop function arguments off the stack and either deletes or subsumes the corresponding row
/// in the function.
Change(Change, Symbol),
/// Pop the value to be set and the function arguments off the stack.
/// Set the function at the given arguments to the new value.
Set(Symbol),
/// Union the last `n` values on the stack.
Union(usize),
Union(usize, ArcSort),
/// Extract the best expression. `n` is always 2.
/// The first value on the stack is the expression to extract,
/// and the second value is the number of variants to extract.
Extract(usize),
Extract(usize, ArcSort),
/// Panic with the given message.
Panic(String),
}
Expand Down Expand Up @@ -223,10 +224,11 @@ impl EGraph {
MergeFn::AssertEq => {
return Err(Error::MergeError(table, new_value, old_value));
}
MergeFn::Union => {
self.unionfind
.union_values(old_value, new_value, old_value.tag)
}
MergeFn::Union => self.unionfind.union_values(
old_value,
new_value,
function.decl.schema.output,
),
MergeFn::Expr(merge_prog) => {
let values = [old_value, new_value];
let mut stack = vec![];
Expand Down Expand Up @@ -268,14 +270,15 @@ impl EGraph {
},
Instruction::CallFunction(f, make_defaults) => {
let function = self.functions.get_mut(f).unwrap();
let output_tag = function.schema.output.name();
let new_len = stack.len() - function.schema.input.len();
let values = &stack[new_len..];

if cfg!(debug_assertions) {
for (ty, val) in function.schema.input.iter().zip(values) {
assert_eq!(ty.name(), val.tag,);
}
#[cfg(debug_assertions)]
let output_tag = function.schema.output.name();

#[cfg(debug_assertions)]
for (ty, val) in function.schema.input.iter().zip(values) {
assert_eq!(ty.name(), val.tag);
}

let value = if let Some(out) = function.nodes.get(values) {
Expand All @@ -289,8 +292,11 @@ impl EGraph {
Value::unit()
}
None if out.is_eq_sort() => {
let id = self.unionfind.make_set();
let value = Value::from_id(out.name(), id);
let value = Value {
#[cfg(debug_assertions)]
tag: out.name(),
bits: self.unionfind.make_set(),
};
function.insert(values, value, ts);
value
}
Expand All @@ -314,18 +320,24 @@ impl EGraph {
))));
};

// cfg is necessary because debug_assert_eq still evaluates its
// arguments in release mode (is has to because of side effects)
#[cfg(debug_assertions)]
debug_assert_eq!(output_tag, value.tag);

stack.truncate(new_len);
stack.push(value);
}
Instruction::CallPrimitive(p, arity) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
if let Some(value) = p.apply(values, Some(self)) {
if let Some(value) =
p.primitive.apply(values, (&p.input, &p.output), Some(self))
{
stack.truncate(new_len);
stack.push(value);
} else {
return Err(Error::PrimitiveError(p.clone(), values.to_vec()));
return Err(Error::PrimitiveError(p.primitive.clone(), values.to_vec()));
}
}
Instruction::Set(f) => {
Expand All @@ -338,32 +350,25 @@ impl EGraph {
self.perform_set(*f, new_value, stack)?;
stack.truncate(new_len)
}
Instruction::Union(arity) => {
Instruction::Union(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let sort = values[0].tag;
let first = self.unionfind.find(Id::from(values[0].bits as usize));
let first = self.unionfind.find(values[0].bits);
values[1..].iter().fold(first, |a, b| {
let b = self.unionfind.find(Id::from(b.bits as usize));
self.unionfind.union(a, b, sort)
let b = self.unionfind.find(b.bits);
self.unionfind.union(a, b, sort.name())
});
stack.truncate(new_len);
}
Instruction::Extract(arity) => {
Instruction::Extract(arity, sort) => {
let new_len = stack.len() - arity;
let values = &stack[new_len..];
let new_len = stack.len() - arity;
let mut termdag = TermDag::default();
let num_sort = values[1].tag;
assert!(num_sort.to_string() == "i64");

let variants = values[1].bits as i64;
if variants == 0 {
let (cost, term) = self.extract(
values[0],
&mut termdag,
self.type_info.sorts.get(&values[0].tag).unwrap(),
);
let (cost, term) = self.extract(values[0], &mut termdag, sort);
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
self.print_msg(extracted);
Expand All @@ -377,7 +382,7 @@ impl EGraph {
panic!("Cannot extract negative number of variants");
}
let terms =
self.extract_variants(values[0], variants as usize, &mut termdag);
self.extract_variants(sort, values[0], variants as usize, &mut termdag);
log::info!("extracted variants:");
let mut msg = String::default();
msg += "(\n";
Expand Down
21 changes: 0 additions & 21 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,6 @@ pub use expr::*;
pub mod desugar;
pub(crate) mod remove_globals;

#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct Id(usize);

impl From<usize> for Id {
fn from(n: usize) -> Self {
Id(n)
}
}

impl From<Id> for usize {
fn from(id: Id) -> Self {
id.0
}
}

impl Display for Id {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "id{}", self.0)
}
}

#[derive(Clone, Debug)]
/// The egglog internal representation of already compiled rules
pub(crate) enum Ruleset {
Expand Down
6 changes: 3 additions & 3 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ impl std::fmt::Display for Query<ResolvedCall, Symbol> {
writeln!(
f,
"({} {})",
filter.head.name(),
filter.head.primitive.name(),
ListDisplay(&filter.args, " ")
)?;
}
Expand All @@ -347,12 +347,12 @@ impl std::fmt::Display for Query<ResolvedCall, Symbol> {
}

impl<Leaf: Clone> Query<ResolvedCall, Leaf> {
pub fn filters(&self) -> impl Iterator<Item = GenericAtom<Primitive, Leaf>> + '_ {
pub fn filters(&self) -> impl Iterator<Item = GenericAtom<SpecializedPrimitive, Leaf>> + '_ {
self.atoms.iter().filter_map(|atom| match &atom.head {
ResolvedCall::Func(_) => None,
ResolvedCall::Primitive(head) => Some(GenericAtom {
span: atom.span.clone(),
head: head.primitive.clone(),
head: head.clone(),
args: atom.args.clone(),
}),
})
Expand Down
38 changes: 22 additions & 16 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub type Cost = usize;
#[derive(Debug)]
pub(crate) struct Node<'a> {
sym: Symbol,
func: &'a Function,
inputs: &'a [Value],
}

Expand Down Expand Up @@ -51,11 +52,16 @@ impl EGraph {
if output.value == value {
log::error!("Found unextractable function: {:?}", func.decl.name);
log::error!("Inputs: {:?}", inputs);

assert_eq!(inputs.len(), func.schema.input.len());
log::error!(
"{:?}",
inputs
.iter()
.map(|input| extractor.costs.get(&extractor.find_id(*input)))
.zip(&func.schema.input)
.map(|(input, sort)| extractor
.costs
.get(&extractor.egraph.find(sort, *input).bits))
.collect::<Vec<_>>()
);
}
Expand All @@ -68,11 +74,13 @@ impl EGraph {

pub fn extract_variants(
&mut self,
sort: &ArcSort,
value: Value,
limit: usize,
termdag: &mut TermDag,
) -> Vec<Term> {
let output_value = self.find(value);
let output_sort = sort.name();
let output_value = self.find(sort, value);
let ext = &Extractor::new(self, termdag);
ext.ctors
.iter()
Expand All @@ -85,9 +93,11 @@ impl EGraph {

func.nodes
.iter(false)
.filter(|&(_, output)| (output.value == output_value))
.filter(|&(_, output)| {
func.schema.output.name() == output_sort && output.value == output_value
})
.map(|(inputs, _output)| {
let node = Node { sym, inputs };
let node = Node { sym, func, inputs };
ext.expr_from_node(&node, termdag).expect(
"extract_variants should be called after extractor initialization",
)
Expand Down Expand Up @@ -123,8 +133,12 @@ impl<'a> Extractor<'a> {

fn expr_from_node(&self, node: &Node, termdag: &mut TermDag) -> Option<Term> {
let mut children = vec![];
for value in node.inputs {
let arcsort = self.egraph.get_sort_from_value(value).unwrap();

let values = node.inputs;
let arcsorts = &node.func.schema.input;
assert_eq!(values.len(), arcsorts.len());

for (value, arcsort) in values.iter().zip(arcsorts) {
children.push(self.find_best(*value, termdag, arcsort)?.1)
}

Expand All @@ -138,7 +152,7 @@ impl<'a> Extractor<'a> {
sort: &ArcSort,
) -> Option<(Cost, Term)> {
if sort.is_eq_sort() {
let id = self.find_id(value);
let id = self.egraph.find(sort, value).bits;
let (cost, node) = self.costs.get(&id)?.clone();
Some((cost, node))
} else {
Expand All @@ -164,14 +178,6 @@ impl<'a> Extractor<'a> {
Some((terms, cost))
}

fn find(&self, value: Value) -> Value {
self.egraph.find(value)
}

fn find_id(&self, value: Value) -> Id {
Id::from(self.find(value).bits as usize)
}

fn find_costs(&mut self, termdag: &mut TermDag) {
let mut did_something = true;
while did_something {
Expand All @@ -186,7 +192,7 @@ impl<'a> Extractor<'a> {
{
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

let id = self.find_id(output.value);
let id = self.egraph.find(&func.schema.output, output.value).bits;
match self.costs.entry(id) {
Entry::Vacant(e) => {
did_something = true;
Expand Down
1 change: 1 addition & 0 deletions src/function/binary_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ mod tests {

fn make_value(bits: u32) -> Value {
Value {
#[cfg(debug_assertions)]
tag: "testing".into(),
bits: bits as u64,
}
Expand Down
Loading
Loading