Skip to content

Commit

Permalink
Merge egraphs-good/main into add-
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Aug 29, 2023
2 parents d860ca5 + af0a493 commit 938c6bc
Show file tree
Hide file tree
Showing 16 changed files with 242 additions and 107 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* @egraphs-good/egglog-reviewers
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# egglog

<a href="https://egraphs-good.github.io/egglog/docs/egglog">
<a href="https://egraphs-good.github.io/egglog/">
<img alt="Web Demo" src="https://img.shields.io/badge/-web demo-blue"></a>
<a href="https://egraphs-good.github.io/egglog/docs/egglog">
<img alt="Main Branch Documentation" src="https://img.shields.io/badge/docs-main-blue"></a>
Expand Down
5 changes: 4 additions & 1 deletion src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn normalize_expr(
panic!("handled above");
}
Expr::Call(f, children) => {
let is_compute = TypeInfo::default().is_primitive(*f);
let is_compute = desugar.type_info.is_primitive(*f);
let mut new_children = vec![];
for child in children {
match child {
Expand Down Expand Up @@ -418,6 +418,7 @@ pub struct Desugar {
// TODO fix getting fresh names using modules
pub(crate) number_underscores: usize,
pub(crate) global_variables: HashSet<Symbol>,
pub(crate) type_info: TypeInfo,
}

impl Default for Desugar {
Expand All @@ -429,6 +430,7 @@ impl Default for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: 3,
global_variables: Default::default(),
type_info: TypeInfo::default(),
}
}
}
Expand Down Expand Up @@ -689,6 +691,7 @@ impl Clone for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: self.number_underscores,
global_variables: self.global_variables.clone(),
type_info: self.type_info.clone(),
}
}
}
Expand Down
21 changes: 21 additions & 0 deletions src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,24 @@ impl Display for Expr {
write!(f, "{}", self.to_sexp())
}
}

// currently only used for testing, but no reason it couldn't be used elsewhere later
#[cfg(test)]
pub(crate) fn parse_expr(s: &str) -> Result<Expr, lalrpop_util::ParseError<usize, String, String>> {
let parser = ast::parse::ExprParser::new();
parser
.parse(s)
.map_err(|e| e.map_token(|tok| tok.to_string()))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parser_display_roundtrip() {
let s = r#"(f (g a 3) 4.0 (H "hello"))"#;
let e = parse_expr(s).unwrap();
assert_eq!(format!("{}", e), s);
}
}
4 changes: 3 additions & 1 deletion src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ impl ToSexp for Command {
} => rule.to_sexp(*ruleset, *name),
Command::RunSchedule(sched) => list!("run-schedule", sched),
Command::Calc(args, exprs) => list!("calc", list!(++ args), ++ exprs),
Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact),
Command::Extract { variants, fact } => {
list!("query-extract", ":variants", variants, fact)
}
Command::Check(facts) => list!("check", ++ facts),
Command::CheckProof => list!("check-proof"),
Command::Push(n) => list!("push", n),
Expand Down
2 changes: 1 addition & 1 deletion src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Schema: Schema = {
<types:List<Type>> <output:Type> => Schema { input: types, output }
}

Expr: Expr = {
pub Expr: Expr = {
<Literal> => Expr::Lit(<>),
<Ident> => Expr::Var(<>),
<CallExpr> => <>,
Expand Down
4 changes: 2 additions & 2 deletions src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl<'a> Extractor<'a> {
children.push(self.find_best(*value, termdag, arcsort)?.1)
}

Some(termdag.make(node.sym, children))
Some(termdag.app(node.sym, children))
}

pub fn find_best(
Expand Down Expand Up @@ -172,7 +172,7 @@ impl<'a> Extractor<'a> {
if let Some((term_inputs, new_cost)) =
self.node_total_cost(func, inputs, termdag)
{
let make_new_pair = || (new_cost, termdag.make(sym, term_inputs));
let make_new_pair = || (new_cost, termdag.app(sym, term_inputs));

let id = self.find(&output.value);
match self.costs.entry(id) {
Expand Down
4 changes: 2 additions & 2 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ impl Function {
pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.proof_state.type_info.sorts.get(s) {
input.push(match egraph.type_info().sorts.get(s) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(*s))),
})
}

let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) {
let output = match egraph.type_info().sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};
Expand Down
60 changes: 30 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pub mod ast;
mod extract;
mod function;
mod gj;
mod proofs;
mod serialize;
pub mod sort;
mod termdag;
Expand All @@ -12,17 +11,16 @@ mod unionfind;
pub mod util;
mod value;

use ast::desugar::Desugar;
use extract::Extractor;
use hashbrown::hash_map::Entry;
use index::ColumnIndex;
use instant::{Duration, Instant};
pub use serialize::SerializeConfig;
use sort::*;
pub use termdag::{Term, TermDag};
pub use termdag::{Term, TermDag, TermId};
use thiserror::Error;

use proofs::ProofState;

use symbolic_expressions::Sexp;

use ast::*;
Expand Down Expand Up @@ -73,11 +71,11 @@ pub enum ExtractReport {
Best {
termdag: TermDag,
cost: usize,
expr: Term,
term: Term,
},
Variants {
termdag: TermDag,
variants: Vec<Term>,
terms: Vec<Term>,
},
}

Expand Down Expand Up @@ -207,7 +205,7 @@ impl FromStr for CompilerPassStop {
pub struct EGraph {
egraphs: Vec<Self>,
unionfind: UnionFind,
pub(crate) proof_state: ProofState,
pub(crate) desugar: Desugar,
functions: HashMap<Symbol, Function>,
rulesets: HashMap<Symbol, HashMap<Symbol, Rule>>,
ruleset_iteration: HashMap<Symbol, usize>,
Expand Down Expand Up @@ -246,7 +244,7 @@ impl Default for EGraph {
functions: Default::default(),
rulesets: Default::default(),
ruleset_iteration: Default::default(),
proof_state: ProofState::default(),
desugar: Desugar::default(),
global_bindings: Default::default(),
match_limit: usize::MAX,
node_limit: usize::MAX,
Expand Down Expand Up @@ -486,10 +484,10 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.type_info().get_sort()).unwrap(),
Literal::String(s) => s.store(&self.type_info().get_sort()).unwrap(),
Literal::Unit => ().store(&self.type_info().get_sort()).unwrap(),
}
}

Expand Down Expand Up @@ -528,7 +526,7 @@ impl EGraph {
} else {
termdag.expr_to_term(&schema.output.make_expr(self, out.value).1)
};
terms.push((termdag.make(sym, children), out));
terms.push((termdag.app(sym, children), out));
}
drop(extractor);

Expand Down Expand Up @@ -998,7 +996,7 @@ impl EGraph {
}
NormAction::LetLit(var, lit) => {
let value = self.eval_lit(lit);
let etype = self.proof_state.type_info.infer_literal(lit);
let etype = self.type_info().infer_literal(lit);
let present = self
.global_bindings
.insert(*var, (etype, value, self.timestamp));
Expand Down Expand Up @@ -1105,9 +1103,9 @@ impl EGraph {
let mut termdag = TermDag::default();
for expr in exprs {
let (t, value) = self.eval_expr(&expr, None, true)?;
let expr = self.extract(value, &mut termdag, &t).1;
let term = self.extract(value, &mut termdag, &t).1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&expr))
writeln!(f, "{}", termdag.to_string(&term))
.map_err(|e| Error::IoError(filename.clone(), e))?;
}

Expand Down Expand Up @@ -1151,33 +1149,31 @@ impl EGraph {
}

pub fn set_underscores_for_desugaring(&mut self, underscores: usize) {
self.proof_state.desugar.number_underscores = underscores;
self.desugar.number_underscores = underscores;
}

fn process_command(
&mut self,
command: Command,
stop: CompilerPassStop,
) -> Result<Vec<NormCommand>, Error> {
let program = self.proof_state.desugar.desugar_program(
vec![command],
self.test_proofs,
self.seminaive,
)?;
let program =
self.desugar
.desugar_program(vec![command], self.test_proofs, self.seminaive)?;
if stop == CompilerPassStop::Desugar {
return Ok(program);
}

let type_info_before = self.proof_state.type_info.clone();
let type_info_before = self.type_info().clone();

self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckDesugared {
return Ok(program);
}

// reset type info
self.proof_state.type_info = type_info_before;
self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info = type_info_before;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckTermEncoding {
return Ok(program);
}
Expand Down Expand Up @@ -1211,11 +1207,11 @@ impl EGraph {
}

pub fn parse_program(&self, input: &str) -> Result<Vec<Command>, Error> {
self.proof_state.desugar.parse_program(input)
self.desugar.parse_program(input)
}

pub fn parse_and_run_program(&mut self, input: &str) -> Result<Vec<String>, Error> {
let parsed = self.proof_state.desugar.parse_program(input)?;
let parsed = self.desugar.parse_program(input)?;
self.run_program(parsed)
}

Expand All @@ -1224,11 +1220,11 @@ impl EGraph {
}

pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> {
self.proof_state.type_info.sorts.get(&value.tag)
self.type_info().sorts.get(&value.tag)
}

pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.proof_state.type_info.add_arcsort(arcsort)
self.desugar.type_info.add_arcsort(arcsort)
}

/// Gets the last extract report and returns it, if the last command saved it.
Expand Down Expand Up @@ -1256,6 +1252,10 @@ impl EGraph {
self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty());
std::mem::take(&mut self.msgs)
}

pub(crate) fn type_info(&self) -> &TypeInfo {
&self.desugar.type_info
}
}

#[derive(Debug, Error)]
Expand Down
11 changes: 0 additions & 11 deletions src/proofs.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl EGraph {
///
/// Checks for pattern created by Desugar.get_fresh
fn is_temp_name(&self, name: String) -> bool {
let number_underscores = self.proof_state.desugar.number_underscores;
let number_underscores = self.desugar.number_underscores;
let res = name.starts_with('v')
&& name.ends_with("_".repeat(number_underscores).as_str())
&& name[1..name.len() - number_underscores]
Expand Down
Loading

0 comments on commit 938c6bc

Please sign in to comment.