diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 00000000..a0ed54b7
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1 @@
+* @egraphs-good/egglog-reviewers
\ No newline at end of file
diff --git a/README.md b/README.md
index 9637c54a..6c0ff993 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# egglog
-
+
diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs
index 5e0b6d9d..7f7f11ba 100644
--- a/src/ast/desugar.rs
+++ b/src/ast/desugar.rs
@@ -125,7 +125,7 @@ fn normalize_expr(
panic!("handled above");
}
Expr::Call(f, children) => {
- let is_compute = TypeInfo::default().is_primitive(*f);
+ let is_compute = desugar.type_info.is_primitive(*f);
let mut new_children = vec![];
for child in children {
match child {
@@ -418,6 +418,7 @@ pub struct Desugar {
// TODO fix getting fresh names using modules
pub(crate) number_underscores: usize,
pub(crate) global_variables: HashSet,
+ pub(crate) type_info: TypeInfo,
}
impl Default for Desugar {
@@ -429,6 +430,7 @@ impl Default for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: 3,
global_variables: Default::default(),
+ type_info: TypeInfo::default(),
}
}
}
@@ -689,6 +691,7 @@ impl Clone for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: self.number_underscores,
global_variables: self.global_variables.clone(),
+ type_info: self.type_info.clone(),
}
}
}
diff --git a/src/ast/expr.rs b/src/ast/expr.rs
index dc7ea582..36129091 100644
--- a/src/ast/expr.rs
+++ b/src/ast/expr.rs
@@ -193,3 +193,24 @@ impl Display for Expr {
write!(f, "{}", self.to_sexp())
}
}
+
+// currently only used for testing, but no reason it couldn't be used elsewhere later
+#[cfg(test)]
+pub(crate) fn parse_expr(s: &str) -> Result> {
+ let parser = ast::parse::ExprParser::new();
+ parser
+ .parse(s)
+ .map_err(|e| e.map_token(|tok| tok.to_string()))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_parser_display_roundtrip() {
+ let s = r#"(f (g a 3) 4.0 (H "hello"))"#;
+ let e = parse_expr(s).unwrap();
+ assert_eq!(format!("{}", e), s);
+ }
+}
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 8612590c..0461dd6d 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -395,7 +395,9 @@ impl ToSexp for Command {
} => rule.to_sexp(*ruleset, *name),
Command::RunSchedule(sched) => list!("run-schedule", sched),
Command::Calc(args, exprs) => list!("calc", list!(++ args), ++ exprs),
- Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact),
+ Command::Extract { variants, fact } => {
+ list!("query-extract", ":variants", variants, fact)
+ }
Command::Check(facts) => list!("check", ++ facts),
Command::CheckProof => list!("check-proof"),
Command::Push(n) => list!("push", n),
diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop
index e6ecb115..f2d1b533 100644
--- a/src/ast/parse.lalrpop
+++ b/src/ast/parse.lalrpop
@@ -133,7 +133,7 @@ Schema: Schema = {
> => Schema { input: types, output }
}
-Expr: Expr = {
+pub Expr: Expr = {
=> Expr::Lit(<>),
=> Expr::Var(<>),
=> <>,
diff --git a/src/extract.rs b/src/extract.rs
index 5d6a9298..d98f2328 100644
--- a/src/extract.rs
+++ b/src/extract.rs
@@ -120,7 +120,7 @@ impl<'a> Extractor<'a> {
children.push(self.find_best(*value, termdag, arcsort)?.1)
}
- Some(termdag.make(node.sym, children))
+ Some(termdag.app(node.sym, children))
}
pub fn find_best(
@@ -172,7 +172,7 @@ impl<'a> Extractor<'a> {
if let Some((term_inputs, new_cost)) =
self.node_total_cost(func, inputs, termdag)
{
- let make_new_pair = || (new_cost, termdag.make(sym, term_inputs));
+ let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));
let id = self.find(&output.value);
match self.costs.entry(id) {
diff --git a/src/function/mod.rs b/src/function/mod.rs
index 1e5c9f76..93c32c8c 100644
--- a/src/function/mod.rs
+++ b/src/function/mod.rs
@@ -69,13 +69,13 @@ impl Function {
pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
- input.push(match egraph.proof_state.type_info.sorts.get(s) {
+ input.push(match egraph.type_info().sorts.get(s) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(*s))),
})
}
- let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) {
+ let output = match egraph.type_info().sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};
diff --git a/src/lib.rs b/src/lib.rs
index 6ea8e453..e4c23dc2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,7 +2,6 @@ pub mod ast;
mod extract;
mod function;
mod gj;
-mod proofs;
mod serialize;
pub mod sort;
mod termdag;
@@ -12,17 +11,16 @@ mod unionfind;
pub mod util;
mod value;
+use ast::desugar::Desugar;
use extract::Extractor;
use hashbrown::hash_map::Entry;
use index::ColumnIndex;
use instant::{Duration, Instant};
pub use serialize::SerializeConfig;
use sort::*;
-pub use termdag::{Term, TermDag};
+pub use termdag::{Term, TermDag, TermId};
use thiserror::Error;
-use proofs::ProofState;
-
use symbolic_expressions::Sexp;
use ast::*;
@@ -73,11 +71,11 @@ pub enum ExtractReport {
Best {
termdag: TermDag,
cost: usize,
- expr: Term,
+ term: Term,
},
Variants {
termdag: TermDag,
- variants: Vec,
+ terms: Vec,
},
}
@@ -207,7 +205,7 @@ impl FromStr for CompilerPassStop {
pub struct EGraph {
egraphs: Vec,
unionfind: UnionFind,
- pub(crate) proof_state: ProofState,
+ pub(crate) desugar: Desugar,
functions: HashMap,
rulesets: HashMap>,
ruleset_iteration: HashMap,
@@ -246,7 +244,7 @@ impl Default for EGraph {
functions: Default::default(),
rulesets: Default::default(),
ruleset_iteration: Default::default(),
- proof_state: ProofState::default(),
+ desugar: Desugar::default(),
global_bindings: Default::default(),
match_limit: usize::MAX,
node_limit: usize::MAX,
@@ -486,10 +484,10 @@ impl EGraph {
pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
- Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(),
- Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(),
- Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(),
- Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(),
+ Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(),
+ Literal::F64(f) => f.store(&self.type_info().get_sort()).unwrap(),
+ Literal::String(s) => s.store(&self.type_info().get_sort()).unwrap(),
+ Literal::Unit => ().store(&self.type_info().get_sort()).unwrap(),
}
}
@@ -528,7 +526,7 @@ impl EGraph {
} else {
termdag.expr_to_term(&schema.output.make_expr(self, out.value).1)
};
- terms.push((termdag.make(sym, children), out));
+ terms.push((termdag.app(sym, children), out));
}
drop(extractor);
@@ -998,7 +996,7 @@ impl EGraph {
}
NormAction::LetLit(var, lit) => {
let value = self.eval_lit(lit);
- let etype = self.proof_state.type_info.infer_literal(lit);
+ let etype = self.type_info().infer_literal(lit);
let present = self
.global_bindings
.insert(*var, (etype, value, self.timestamp));
@@ -1105,9 +1103,9 @@ impl EGraph {
let mut termdag = TermDag::default();
for expr in exprs {
let (t, value) = self.eval_expr(&expr, None, true)?;
- let expr = self.extract(value, &mut termdag, &t).1;
+ let term = self.extract(value, &mut termdag, &t).1;
use std::io::Write;
- writeln!(f, "{}", termdag.to_string(&expr))
+ writeln!(f, "{}", termdag.to_string(&term))
.map_err(|e| Error::IoError(filename.clone(), e))?;
}
@@ -1151,7 +1149,7 @@ impl EGraph {
}
pub fn set_underscores_for_desugaring(&mut self, underscores: usize) {
- self.proof_state.desugar.number_underscores = underscores;
+ self.desugar.number_underscores = underscores;
}
fn process_command(
@@ -1159,25 +1157,23 @@ impl EGraph {
command: Command,
stop: CompilerPassStop,
) -> Result, Error> {
- let program = self.proof_state.desugar.desugar_program(
- vec![command],
- self.test_proofs,
- self.seminaive,
- )?;
+ let program =
+ self.desugar
+ .desugar_program(vec![command], self.test_proofs, self.seminaive)?;
if stop == CompilerPassStop::Desugar {
return Ok(program);
}
- let type_info_before = self.proof_state.type_info.clone();
+ let type_info_before = self.type_info().clone();
- self.proof_state.type_info.typecheck_program(&program)?;
+ self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckDesugared {
return Ok(program);
}
// reset type info
- self.proof_state.type_info = type_info_before;
- self.proof_state.type_info.typecheck_program(&program)?;
+ self.desugar.type_info = type_info_before;
+ self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckTermEncoding {
return Ok(program);
}
@@ -1211,11 +1207,11 @@ impl EGraph {
}
pub fn parse_program(&self, input: &str) -> Result, Error> {
- self.proof_state.desugar.parse_program(input)
+ self.desugar.parse_program(input)
}
pub fn parse_and_run_program(&mut self, input: &str) -> Result, Error> {
- let parsed = self.proof_state.desugar.parse_program(input)?;
+ let parsed = self.desugar.parse_program(input)?;
self.run_program(parsed)
}
@@ -1224,11 +1220,11 @@ impl EGraph {
}
pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> {
- self.proof_state.type_info.sorts.get(&value.tag)
+ self.type_info().sorts.get(&value.tag)
}
pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
- self.proof_state.type_info.add_arcsort(arcsort)
+ self.desugar.type_info.add_arcsort(arcsort)
}
/// Gets the last extract report and returns it, if the last command saved it.
@@ -1256,6 +1252,10 @@ impl EGraph {
self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty());
std::mem::take(&mut self.msgs)
}
+
+ pub(crate) fn type_info(&self) -> &TypeInfo {
+ &self.desugar.type_info
+ }
}
#[derive(Debug, Error)]
diff --git a/src/proofs.rs b/src/proofs.rs
deleted file mode 100644
index beb2b13a..00000000
--- a/src/proofs.rs
+++ /dev/null
@@ -1,11 +0,0 @@
-use crate::*;
-
-use crate::ast::desugar::Desugar;
-
-pub const RULE_PROOF_KEYWORD: &str = "rule-proof";
-
-#[derive(Default, Clone)]
-pub(crate) struct ProofState {
- pub(crate) desugar: Desugar,
- pub(crate) type_info: TypeInfo,
-}
diff --git a/src/serialize.rs b/src/serialize.rs
index b4e9f886..fa75f583 100644
--- a/src/serialize.rs
+++ b/src/serialize.rs
@@ -184,7 +184,7 @@ impl EGraph {
///
/// Checks for pattern created by Desugar.get_fresh
fn is_temp_name(&self, name: String) -> bool {
- let number_underscores = self.proof_state.desugar.number_underscores;
+ let number_underscores = self.desugar.number_underscores;
let res = name.starts_with('v')
&& name.ends_with("_".repeat(number_underscores).as_str())
&& name[1..name.len() - number_underscores]
diff --git a/src/termdag.rs b/src/termdag.rs
index 47594d13..59cf791a 100644
--- a/src/termdag.rs
+++ b/src/termdag.rs
@@ -4,14 +4,17 @@ use crate::{
Symbol,
};
+pub type TermId = usize;
+
/// Like [`Expr`]s but with sharing and deduplication.
///
-/// Terms refer to their children indirectly as indexes into an ambient [`TermDag`].
+/// Terms refer to their children indirectly via opaque [TermId]s (internally
+/// these are just `usize`s) that map into an ambient [`TermDag`].
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Term {
Lit(Literal),
Var(Symbol),
- App(Symbol, Vec),
+ App(Symbol, Vec),
}
/// A hashconsing arena for [`Term`]s.
@@ -24,7 +27,7 @@ pub struct TermDag {
// - every element of node is a key in hashcons
// - every key of hashcons is in nodes
pub nodes: Vec,
- pub hashcons: HashMap,
+ pub hashcons: HashMap,
}
#[macro_export]
@@ -56,25 +59,25 @@ impl TermDag {
self.nodes.len()
}
- /// Convert the given term to its index.
+ /// Convert the given term to its id.
///
/// Panics if the term does not already exist in this [TermDag].
- pub fn lookup(&self, node: &Term) -> usize {
+ pub fn lookup(&self, node: &Term) -> TermId {
*self.hashcons.get(node).unwrap()
}
- /// Convert the given index to the corresponding term.
+ /// Convert the given id to the corresponding term.
///
- /// Panics if the index is not valid.
- pub fn get(&self, idx: usize) -> Term {
- self.nodes[idx].clone()
+ /// Panics if the id is not valid.
+ pub fn get(&self, id: TermId) -> Term {
+ self.nodes[id].clone()
}
/// Make and return a [`Term::App`] with the given head symbol and children,
/// and insert into the DAG if it is not already present.
///
/// Panics if any of the children are not already in the DAG.
- pub fn make(&mut self, sym: Symbol, children: Vec) -> Term {
+ pub fn app(&mut self, sym: Symbol, children: Vec) -> Term {
let node = Term::App(sym, children.iter().map(|c| self.lookup(c)).collect());
self.add_node(&node);
@@ -82,6 +85,26 @@ impl TermDag {
node
}
+ /// Make and return a [`Term::Lit`] with the given literal, and insert into
+ /// the DAG if it is not already present.
+ pub fn lit(&mut self, lit: Literal) -> Term {
+ let node = Term::Lit(lit);
+
+ self.add_node(&node);
+
+ node
+ }
+
+ /// Make and return a [`Term::Var`] with the given symbol, and insert into
+ /// the DAG if it is not already present.
+ pub fn var(&mut self, sym: Symbol) -> Term {
+ let node = Term::Var(sym);
+
+ self.add_node(&node);
+
+ node
+ }
+
fn add_node(&mut self, node: &Term) {
if self.hashcons.get(node).is_none() {
let idx = self.nodes.len();
@@ -138,8 +161,8 @@ impl TermDag {
///
/// Panics if the term or any of its subterms are not in the DAG.
pub fn to_string(&self, term: &Term) -> String {
- let mut stored = HashMap::::default();
- let mut seen = HashSet::::default();
+ let mut stored = HashMap::::default();
+ let mut seen = HashSet::::default();
let id = self.lookup(term);
// use a stack to avoid stack overflow
let mut stack = vec![id];
@@ -174,3 +197,78 @@ impl TermDag {
stored.get(&id).unwrap().clone()
}
}
+
+#[cfg(test)]
+mod tests {
+ use crate::ast;
+
+ use super::*;
+
+ fn parse_term(s: &str) -> (TermDag, Term) {
+ let e = crate::ast::parse_expr(s).unwrap();
+ let mut td = TermDag::default();
+ let t = td.expr_to_term(&e);
+ (td, t)
+ }
+
+ #[test]
+ fn test_to_from_expr() {
+ let s = r#"(f (g x y) x y (g x y))"#;
+ let e = crate::ast::parse_expr(s).unwrap();
+ let mut td = TermDag::default();
+ assert_eq!(td.size(), 0);
+ let t = td.expr_to_term(&e);
+ assert_eq!(td.size(), 4);
+ // the expression above has 4 distinct subterms.
+ // in left-to-right, depth-first order, they are:
+ // x, y, (g x y), and the root call to f
+ // so we can compute expected answer by hand:
+ assert_eq!(
+ td.nodes,
+ vec![
+ Term::Var("x".into()),
+ Term::Var("y".into()),
+ Term::App("g".into(), vec![0, 1]),
+ Term::App("f".into(), vec![2, 0, 1, 2]),
+ ]
+ );
+ let e2 = td.term_to_expr(&t);
+ assert_eq!(e, e2); // roundtrip
+ }
+
+ #[test]
+ fn test_match_term_app() {
+ let s = r#"(f (g x y) x y (g x y))"#;
+ let (td, t) = parse_term(s);
+ match_term_app!(t; {
+ ("f", [_, x, _, _]) =>
+ assert_eq!(td.term_to_expr(&td.get(*x)), ast::Expr::Var(Symbol::new("x")))
+ })
+ }
+
+ #[test]
+ fn test_to_string() {
+ let s = r#"(f (g x y) x y (g x y))"#;
+ let (td, t) = parse_term(s);
+ assert_eq!(td.to_string(&t), s);
+ }
+
+ #[test]
+ fn test_lookup() {
+ let s = r#"(f (g x y) x y (g x y))"#;
+ let (td, t) = parse_term(s);
+ assert_eq!(td.lookup(&t), td.size() - 1);
+ }
+
+ #[test]
+ fn test_app_var_lit() {
+ let s = r#"(f (g x y) x 7 (g x y))"#;
+ let (mut td, t) = parse_term(s);
+ let x = td.var("x".into());
+ let y = td.var("y".into());
+ let seven = td.lit(7.into());
+ let g = td.app("g".into(), vec![x.clone(), y.clone()]);
+ let t2 = td.app("f".into(), vec![g.clone(), x, seven, g]);
+ assert_eq!(t, t2);
+ }
+}
diff --git a/src/typecheck.rs b/src/typecheck.rs
index 5e4d91c9..9b294ea9 100644
--- a/src/typecheck.rs
+++ b/src/typecheck.rs
@@ -112,7 +112,7 @@ impl<'a> Context<'a> {
pub fn new(egraph: &'a EGraph) -> Self {
Self {
egraph,
- unit: egraph.proof_state.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
+ unit: egraph.type_info().sorts[&Symbol::from(UNIT_SYM)].clone(),
types: Default::default(),
errors: Vec::default(),
unionfind: UnionFind::default(),
@@ -396,7 +396,7 @@ impl<'a> Context<'a> {
(self.add_node(ENode::Var(*sym)), ty)
}
Expr::Lit(lit) => {
- let t = self.egraph.proof_state.type_info.infer_literal(lit);
+ let t = self.egraph.type_info().infer_literal(lit);
(self.add_node(ENode::Literal(lit.clone())), Some(t))
}
Expr::Call(sym, args) => {
@@ -415,7 +415,7 @@ impl<'a> Context<'a> {
.collect();
let t = f.schema.output.clone();
(self.add_node(ENode::Func(*sym, ids)), Some(t))
- } else if let Some(prims) = self.egraph.proof_state.type_info.primitives.get(sym) {
+ } else if let Some(prims) = self.egraph.type_info().primitives.get(sym) {
let (ids, arg_tys): (Vec, Vec