From 0fafc19bd8ab26f0c31c5e65fd68a207c6ffdeef Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 24 Aug 2023 10:31:37 -0400 Subject: [PATCH 1/2] Don't create new TypeInfo when desugaring --- src/ast/desugar.rs | 5 ++++- src/function/mod.rs | 4 ++-- src/lib.rs | 44 ++++++++++++++++++++------------------------ src/proofs.rs | 11 ----------- src/serialize.rs | 2 +- src/typecheck.rs | 27 ++++++++------------------- src/typechecking.rs | 6 ++++-- 7 files changed, 39 insertions(+), 60 deletions(-) delete mode 100644 src/proofs.rs diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 5e0b6d9d..7f7f11ba 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -125,7 +125,7 @@ fn normalize_expr( panic!("handled above"); } Expr::Call(f, children) => { - let is_compute = TypeInfo::default().is_primitive(*f); + let is_compute = desugar.type_info.is_primitive(*f); let mut new_children = vec![]; for child in children { match child { @@ -418,6 +418,7 @@ pub struct Desugar { // TODO fix getting fresh names using modules pub(crate) number_underscores: usize, pub(crate) global_variables: HashSet, + pub(crate) type_info: TypeInfo, } impl Default for Desugar { @@ -429,6 +430,7 @@ impl Default for Desugar { parser: ast::parse::ProgramParser::new(), number_underscores: 3, global_variables: Default::default(), + type_info: TypeInfo::default(), } } } @@ -689,6 +691,7 @@ impl Clone for Desugar { parser: ast::parse::ProgramParser::new(), number_underscores: self.number_underscores, global_variables: self.global_variables.clone(), + type_info: self.type_info.clone(), } } } diff --git a/src/function/mod.rs b/src/function/mod.rs index 1e5c9f76..051971a1 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -69,13 +69,13 @@ impl Function { pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result { let mut input = Vec::with_capacity(decl.schema.input.len()); for s in &decl.schema.input { - input.push(match egraph.proof_state.type_info.sorts.get(s) { + input.push(match egraph.desugar.type_info.sorts.get(s) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(*s))), }) } - let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) { + let output = match egraph.desugar.type_info.sorts.get(&decl.schema.output) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))), }; diff --git a/src/lib.rs b/src/lib.rs index 85c2ab7c..86538d15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ pub mod ast; mod extract; mod function; mod gj; -mod proofs; mod serialize; pub mod sort; mod termdag; @@ -12,6 +11,7 @@ mod unionfind; pub mod util; mod value; +use ast::desugar::Desugar; use extract::Extractor; use hashbrown::hash_map::Entry; use index::ColumnIndex; @@ -21,8 +21,6 @@ use sort::*; pub use termdag::{Term, TermDag}; use thiserror::Error; -use proofs::ProofState; - use symbolic_expressions::Sexp; use ast::*; @@ -201,7 +199,7 @@ impl FromStr for CompilerPassStop { pub struct EGraph { egraphs: Vec, unionfind: UnionFind, - pub(crate) proof_state: ProofState, + pub(crate) desugar: Desugar, functions: HashMap, rulesets: HashMap>, ruleset_iteration: HashMap, @@ -240,7 +238,7 @@ impl Default for EGraph { functions: Default::default(), rulesets: Default::default(), ruleset_iteration: Default::default(), - proof_state: ProofState::default(), + desugar: Desugar::default(), global_bindings: Default::default(), match_limit: usize::MAX, node_limit: usize::MAX, @@ -469,10 +467,10 @@ impl EGraph { pub fn eval_lit(&self, lit: &Literal) -> Value { match lit { - Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(), - Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(), + Literal::Int(i) => i.store(&self.desugar.type_info.get_sort()).unwrap(), + Literal::F64(f) => f.store(&self.desugar.type_info.get_sort()).unwrap(), + Literal::String(s) => s.store(&self.desugar.type_info.get_sort()).unwrap(), + Literal::Unit => ().store(&self.desugar.type_info.get_sort()).unwrap(), } } @@ -985,7 +983,7 @@ impl EGraph { } NormAction::LetLit(var, lit) => { let value = self.eval_lit(lit); - let etype = self.proof_state.type_info.infer_literal(lit); + let etype = self.desugar.type_info.infer_literal(lit); let present = self .global_bindings .insert(*var, (etype, value, self.timestamp)); @@ -1162,7 +1160,7 @@ impl EGraph { } pub fn set_underscores_for_desugaring(&mut self, underscores: usize) { - self.proof_state.desugar.number_underscores = underscores; + self.desugar.number_underscores = underscores; } fn process_command( @@ -1170,25 +1168,23 @@ impl EGraph { command: Command, stop: CompilerPassStop, ) -> Result, Error> { - let program = self.proof_state.desugar.desugar_program( - vec![command], - self.test_proofs, - self.seminaive, - )?; + let program = + self.desugar + .desugar_program(vec![command], self.test_proofs, self.seminaive)?; if stop == CompilerPassStop::Desugar { return Ok(program); } - let type_info_before = self.proof_state.type_info.clone(); + let type_info_before = self.desugar.type_info.clone(); - self.proof_state.type_info.typecheck_program(&program)?; + self.desugar.type_info.typecheck_program(&program)?; if stop == CompilerPassStop::TypecheckDesugared { return Ok(program); } // reset type info - self.proof_state.type_info = type_info_before; - self.proof_state.type_info.typecheck_program(&program)?; + self.desugar.type_info = type_info_before; + self.desugar.type_info.typecheck_program(&program)?; if stop == CompilerPassStop::TypecheckTermEncoding { return Ok(program); } @@ -1222,11 +1218,11 @@ impl EGraph { } pub fn parse_program(&self, input: &str) -> Result, Error> { - self.proof_state.desugar.parse_program(input) + self.desugar.parse_program(input) } pub fn parse_and_run_program(&mut self, input: &str) -> Result, Error> { - let parsed = self.proof_state.desugar.parse_program(input)?; + let parsed = self.desugar.parse_program(input)?; self.run_program(parsed) } @@ -1235,11 +1231,11 @@ impl EGraph { } pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> { - self.proof_state.type_info.sorts.get(&value.tag) + self.desugar.type_info.sorts.get(&value.tag) } pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> { - self.proof_state.type_info.add_arcsort(arcsort) + self.desugar.type_info.add_arcsort(arcsort) } /// Gets the last extract report and returns it, if the last command saved it. diff --git a/src/proofs.rs b/src/proofs.rs deleted file mode 100644 index beb2b13a..00000000 --- a/src/proofs.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::*; - -use crate::ast::desugar::Desugar; - -pub const RULE_PROOF_KEYWORD: &str = "rule-proof"; - -#[derive(Default, Clone)] -pub(crate) struct ProofState { - pub(crate) desugar: Desugar, - pub(crate) type_info: TypeInfo, -} diff --git a/src/serialize.rs b/src/serialize.rs index b4e9f886..fa75f583 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -184,7 +184,7 @@ impl EGraph { /// /// Checks for pattern created by Desugar.get_fresh fn is_temp_name(&self, name: String) -> bool { - let number_underscores = self.proof_state.desugar.number_underscores; + let number_underscores = self.desugar.number_underscores; let res = name.starts_with('v') && name.ends_with("_".repeat(number_underscores).as_str()) && name[1..name.len() - number_underscores] diff --git a/src/typecheck.rs b/src/typecheck.rs index 6a46fd3e..389f814c 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -112,7 +112,7 @@ impl<'a> Context<'a> { pub fn new(egraph: &'a EGraph) -> Self { Self { egraph, - unit: egraph.proof_state.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(), + unit: egraph.desugar.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(), types: Default::default(), errors: Vec::default(), unionfind: UnionFind::default(), @@ -396,7 +396,7 @@ impl<'a> Context<'a> { (self.add_node(ENode::Var(*sym)), ty) } Expr::Lit(lit) => { - let t = self.egraph.proof_state.type_info.infer_literal(lit); + let t = self.egraph.desugar.type_info.infer_literal(lit); (self.add_node(ENode::Literal(lit.clone())), Some(t)) } Expr::Call(sym, args) => { @@ -415,7 +415,7 @@ impl<'a> Context<'a> { .collect(); let t = f.schema.output.clone(); (self.add_node(ENode::Func(*sym, ids)), Some(t)) - } else if let Some(prims) = self.egraph.proof_state.type_info.primitives.get(sym) { + } else if let Some(prims) = self.egraph.desugar.type_info.primitives.get(sym) { let (ids, arg_tys): (Vec, Vec>) = args.iter().map(|arg| self.infer_query_expr(arg)).unzip(); @@ -533,13 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> { } fn do_function(&mut self, f: Symbol, _args: Vec) -> Self::T { - let func_type = self - .egraph - .proof_state - .type_info - .func_types - .get(&f) - .unwrap(); + let func_type = self.egraph.desugar.type_info.func_types.get(&f).unwrap(); self.instructions.push(Instruction::CallFunction( f, func_type.has_default || !func_type.has_merge, @@ -601,11 +595,11 @@ trait ExprChecker<'a> { match expr { Expr::Lit(lit) => { let t = self.do_lit(lit); - Ok((t, self.egraph().proof_state.type_info.infer_literal(lit))) + Ok((t, self.egraph().desugar.type_info.infer_literal(lit))) } Expr::Var(sym) => self.infer_var(*sym), Expr::Call(sym, args) => { - if let Some(functype) = self.egraph().proof_state.type_info.func_types.get(sym) { + if let Some(functype) = self.egraph().desugar.type_info.func_types.get(sym) { assert!(functype.input.len() == args.len()); let mut ts = vec![]; @@ -615,8 +609,7 @@ trait ExprChecker<'a> { let t = self.do_function(*sym, ts); Ok((t, functype.output.clone())) - } else if let Some(prims) = self.egraph().proof_state.type_info.primitives.get(sym) - { + } else if let Some(prims) = self.egraph().desugar.type_info.primitives.get(sym) { let mut ts = Vec::with_capacity(args.len()); let mut tys = Vec::with_capacity(args.len()); for arg in args { @@ -880,11 +873,7 @@ impl EGraph { let (cost, expr) = self.extract( values[0], &mut termdag, - self.proof_state - .type_info - .sorts - .get(&values[0].tag) - .unwrap(), + self.desugar.type_info.sorts.get(&values[0].tag).unwrap(), ); let extracted = termdag.to_string(&expr); log::info!("extracted with cost {cost}: {}", extracted); diff --git a/src/typechecking.rs b/src/typechecking.rs index c8f048b6..0e1fa00c 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -1,4 +1,6 @@ -use crate::{proofs::RULE_PROOF_KEYWORD, *}; +use crate::*; + +pub const RULE_PROOF_KEYWORD: &str = "rule-proof"; #[derive(Clone, Debug)] pub struct FuncType { @@ -630,7 +632,7 @@ pub enum TypeError { #[error("Arity mismatch, expected {expected} args: {expr}")] Arity { expr: Expr, expected: usize }, #[error( - "Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}", + "Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}", .expected.name(), .actual.name(), )] Mismatch { From 7ec920ecb149d699b1a128aa98fd361e97b1dfc6 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sat, 26 Aug 2023 07:38:43 -0400 Subject: [PATCH 2/2] Add type_info alias method --- src/function/mod.rs | 4 ++-- src/lib.rs | 18 +++++++++++------- src/typecheck.rs | 16 ++++++++-------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/function/mod.rs b/src/function/mod.rs index 051971a1..93c32c8c 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -69,13 +69,13 @@ impl Function { pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result { let mut input = Vec::with_capacity(decl.schema.input.len()); for s in &decl.schema.input { - input.push(match egraph.desugar.type_info.sorts.get(s) { + input.push(match egraph.type_info().sorts.get(s) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(*s))), }) } - let output = match egraph.desugar.type_info.sorts.get(&decl.schema.output) { + let output = match egraph.type_info().sorts.get(&decl.schema.output) { Some(sort) => sort.clone(), None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))), }; diff --git a/src/lib.rs b/src/lib.rs index 86538d15..e961c909 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -467,10 +467,10 @@ impl EGraph { pub fn eval_lit(&self, lit: &Literal) -> Value { match lit { - Literal::Int(i) => i.store(&self.desugar.type_info.get_sort()).unwrap(), - Literal::F64(f) => f.store(&self.desugar.type_info.get_sort()).unwrap(), - Literal::String(s) => s.store(&self.desugar.type_info.get_sort()).unwrap(), - Literal::Unit => ().store(&self.desugar.type_info.get_sort()).unwrap(), + Literal::Int(i) => i.store(&self.type_info().get_sort()).unwrap(), + Literal::F64(f) => f.store(&self.type_info().get_sort()).unwrap(), + Literal::String(s) => s.store(&self.type_info().get_sort()).unwrap(), + Literal::Unit => ().store(&self.type_info().get_sort()).unwrap(), } } @@ -983,7 +983,7 @@ impl EGraph { } NormAction::LetLit(var, lit) => { let value = self.eval_lit(lit); - let etype = self.desugar.type_info.infer_literal(lit); + let etype = self.type_info().infer_literal(lit); let present = self .global_bindings .insert(*var, (etype, value, self.timestamp)); @@ -1175,7 +1175,7 @@ impl EGraph { return Ok(program); } - let type_info_before = self.desugar.type_info.clone(); + let type_info_before = self.type_info().clone(); self.desugar.type_info.typecheck_program(&program)?; if stop == CompilerPassStop::TypecheckDesugared { @@ -1231,7 +1231,7 @@ impl EGraph { } pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> { - self.desugar.type_info.sorts.get(&value.tag) + self.type_info().sorts.get(&value.tag) } pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> { @@ -1263,6 +1263,10 @@ impl EGraph { self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty()); std::mem::take(&mut self.msgs) } + + pub(crate) fn type_info(&self) -> &TypeInfo { + &self.desugar.type_info + } } #[derive(Debug, Error)] diff --git a/src/typecheck.rs b/src/typecheck.rs index 389f814c..a0f89e88 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -112,7 +112,7 @@ impl<'a> Context<'a> { pub fn new(egraph: &'a EGraph) -> Self { Self { egraph, - unit: egraph.desugar.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(), + unit: egraph.type_info().sorts[&Symbol::from(UNIT_SYM)].clone(), types: Default::default(), errors: Vec::default(), unionfind: UnionFind::default(), @@ -396,7 +396,7 @@ impl<'a> Context<'a> { (self.add_node(ENode::Var(*sym)), ty) } Expr::Lit(lit) => { - let t = self.egraph.desugar.type_info.infer_literal(lit); + let t = self.egraph.type_info().infer_literal(lit); (self.add_node(ENode::Literal(lit.clone())), Some(t)) } Expr::Call(sym, args) => { @@ -415,7 +415,7 @@ impl<'a> Context<'a> { .collect(); let t = f.schema.output.clone(); (self.add_node(ENode::Func(*sym, ids)), Some(t)) - } else if let Some(prims) = self.egraph.desugar.type_info.primitives.get(sym) { + } else if let Some(prims) = self.egraph.type_info().primitives.get(sym) { let (ids, arg_tys): (Vec, Vec>) = args.iter().map(|arg| self.infer_query_expr(arg)).unzip(); @@ -533,7 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> { } fn do_function(&mut self, f: Symbol, _args: Vec) -> Self::T { - let func_type = self.egraph.desugar.type_info.func_types.get(&f).unwrap(); + let func_type = self.egraph.type_info().func_types.get(&f).unwrap(); self.instructions.push(Instruction::CallFunction( f, func_type.has_default || !func_type.has_merge, @@ -595,11 +595,11 @@ trait ExprChecker<'a> { match expr { Expr::Lit(lit) => { let t = self.do_lit(lit); - Ok((t, self.egraph().desugar.type_info.infer_literal(lit))) + Ok((t, self.egraph().type_info().infer_literal(lit))) } Expr::Var(sym) => self.infer_var(*sym), Expr::Call(sym, args) => { - if let Some(functype) = self.egraph().desugar.type_info.func_types.get(sym) { + if let Some(functype) = self.egraph().type_info().func_types.get(sym) { assert!(functype.input.len() == args.len()); let mut ts = vec![]; @@ -609,7 +609,7 @@ trait ExprChecker<'a> { let t = self.do_function(*sym, ts); Ok((t, functype.output.clone())) - } else if let Some(prims) = self.egraph().desugar.type_info.primitives.get(sym) { + } else if let Some(prims) = self.egraph().type_info().primitives.get(sym) { let mut ts = Vec::with_capacity(args.len()); let mut tys = Vec::with_capacity(args.len()); for arg in args { @@ -873,7 +873,7 @@ impl EGraph { let (cost, expr) = self.extract( values[0], &mut termdag, - self.desugar.type_info.sorts.get(&values[0].tag).unwrap(), + self.type_info().sorts.get(&values[0].tag).unwrap(), ); let extracted = termdag.to_string(&expr); log::info!("extracted with cost {cost}: {}", extracted);