Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix desugaring foreign primitives #215

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn normalize_expr(
panic!("handled above");
}
Expr::Call(f, children) => {
let is_compute = TypeInfo::default().is_primitive(*f);
let is_compute = desugar.type_info.is_primitive(*f);
let mut new_children = vec![];
for child in children {
match child {
Expand Down Expand Up @@ -418,6 +418,7 @@ pub struct Desugar {
// TODO fix getting fresh names using modules
pub(crate) number_underscores: usize,
pub(crate) global_variables: HashSet<Symbol>,
pub(crate) type_info: TypeInfo,
}

impl Default for Desugar {
Expand All @@ -429,6 +430,7 @@ impl Default for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: 3,
global_variables: Default::default(),
type_info: TypeInfo::default(),
}
}
}
Expand Down Expand Up @@ -689,6 +691,7 @@ impl Clone for Desugar {
parser: ast::parse::ProgramParser::new(),
number_underscores: self.number_underscores,
global_variables: self.global_variables.clone(),
type_info: self.type_info.clone(),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ impl Function {
pub fn new(egraph: &EGraph, decl: &FunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.proof_state.type_info.sorts.get(s) {
input.push(match egraph.desugar.type_info.sorts.get(s) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(*s))),
})
}

let output = match egraph.proof_state.type_info.sorts.get(&decl.schema.output) {
let output = match egraph.desugar.type_info.sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => return Err(Error::TypeError(TypeError::Unbound(decl.schema.output))),
};
Expand Down
44 changes: 20 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ pub mod ast;
mod extract;
mod function;
mod gj;
mod proofs;
mod serialize;
pub mod sort;
mod termdag;
Expand All @@ -12,6 +11,7 @@ mod unionfind;
pub mod util;
mod value;

use ast::desugar::Desugar;
use extract::Extractor;
use hashbrown::hash_map::Entry;
use index::ColumnIndex;
Expand All @@ -21,8 +21,6 @@ use sort::*;
pub use termdag::{Term, TermDag};
use thiserror::Error;

use proofs::ProofState;

use symbolic_expressions::Sexp;

use ast::*;
Expand Down Expand Up @@ -201,7 +199,7 @@ impl FromStr for CompilerPassStop {
pub struct EGraph {
egraphs: Vec<Self>,
unionfind: UnionFind,
pub(crate) proof_state: ProofState,
pub(crate) desugar: Desugar,
functions: HashMap<Symbol, Function>,
rulesets: HashMap<Symbol, HashMap<Symbol, Rule>>,
ruleset_iteration: HashMap<Symbol, usize>,
Expand Down Expand Up @@ -240,7 +238,7 @@ impl Default for EGraph {
functions: Default::default(),
rulesets: Default::default(),
ruleset_iteration: Default::default(),
proof_state: ProofState::default(),
desugar: Desugar::default(),
global_bindings: Default::default(),
match_limit: usize::MAX,
node_limit: usize::MAX,
Expand Down Expand Up @@ -469,10 +467,10 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.proof_state.type_info.get_sort()).unwrap(),
Literal::Int(i) => i.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::F64(f) => f.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::String(s) => s.store(&self.desugar.type_info.get_sort()).unwrap(),
Literal::Unit => ().store(&self.desugar.type_info.get_sort()).unwrap(),
}
}

Expand Down Expand Up @@ -985,7 +983,7 @@ impl EGraph {
}
NormAction::LetLit(var, lit) => {
let value = self.eval_lit(lit);
let etype = self.proof_state.type_info.infer_literal(lit);
let etype = self.desugar.type_info.infer_literal(lit);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about making a type_info method on the EGraph, to not have to go through desugar every time, but opted to leave it as is for a minimal change.

Happy to refactor to that though, if that is preferred.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the type_info method idea!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave that a go, refactoring all non-mutable usages of egraph.desugar.type_info.

let present = self
.global_bindings
.insert(*var, (etype, value, self.timestamp));
Expand Down Expand Up @@ -1162,33 +1160,31 @@ impl EGraph {
}

pub fn set_underscores_for_desugaring(&mut self, underscores: usize) {
self.proof_state.desugar.number_underscores = underscores;
self.desugar.number_underscores = underscores;
}

fn process_command(
&mut self,
command: Command,
stop: CompilerPassStop,
) -> Result<Vec<NormCommand>, Error> {
let program = self.proof_state.desugar.desugar_program(
vec![command],
self.test_proofs,
self.seminaive,
)?;
let program =
self.desugar
.desugar_program(vec![command], self.test_proofs, self.seminaive)?;
if stop == CompilerPassStop::Desugar {
return Ok(program);
}

let type_info_before = self.proof_state.type_info.clone();
let type_info_before = self.desugar.type_info.clone();

self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckDesugared {
return Ok(program);
}

// reset type info
self.proof_state.type_info = type_info_before;
self.proof_state.type_info.typecheck_program(&program)?;
self.desugar.type_info = type_info_before;
self.desugar.type_info.typecheck_program(&program)?;
if stop == CompilerPassStop::TypecheckTermEncoding {
return Ok(program);
}
Expand Down Expand Up @@ -1222,11 +1218,11 @@ impl EGraph {
}

pub fn parse_program(&self, input: &str) -> Result<Vec<Command>, Error> {
self.proof_state.desugar.parse_program(input)
self.desugar.parse_program(input)
}

pub fn parse_and_run_program(&mut self, input: &str) -> Result<Vec<String>, Error> {
let parsed = self.proof_state.desugar.parse_program(input)?;
let parsed = self.desugar.parse_program(input)?;
self.run_program(parsed)
}

Expand All @@ -1235,11 +1231,11 @@ impl EGraph {
}

pub(crate) fn get_sort(&self, value: &Value) -> Option<&ArcSort> {
self.proof_state.type_info.sorts.get(&value.tag)
self.desugar.type_info.sorts.get(&value.tag)
}

pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.proof_state.type_info.add_arcsort(arcsort)
self.desugar.type_info.add_arcsort(arcsort)
}

/// Gets the last extract report and returns it, if the last command saved it.
Expand Down
11 changes: 0 additions & 11 deletions src/proofs.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl EGraph {
///
/// Checks for pattern created by Desugar.get_fresh
fn is_temp_name(&self, name: String) -> bool {
let number_underscores = self.proof_state.desugar.number_underscores;
let number_underscores = self.desugar.number_underscores;
let res = name.starts_with('v')
&& name.ends_with("_".repeat(number_underscores).as_str())
&& name[1..name.len() - number_underscores]
Expand Down
27 changes: 8 additions & 19 deletions src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<'a> Context<'a> {
pub fn new(egraph: &'a EGraph) -> Self {
Self {
egraph,
unit: egraph.proof_state.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
unit: egraph.desugar.type_info.sorts[&Symbol::from(UNIT_SYM)].clone(),
types: Default::default(),
errors: Vec::default(),
unionfind: UnionFind::default(),
Expand Down Expand Up @@ -396,7 +396,7 @@ impl<'a> Context<'a> {
(self.add_node(ENode::Var(*sym)), ty)
}
Expr::Lit(lit) => {
let t = self.egraph.proof_state.type_info.infer_literal(lit);
let t = self.egraph.desugar.type_info.infer_literal(lit);
(self.add_node(ENode::Literal(lit.clone())), Some(t))
}
Expr::Call(sym, args) => {
Expand All @@ -415,7 +415,7 @@ impl<'a> Context<'a> {
.collect();
let t = f.schema.output.clone();
(self.add_node(ENode::Func(*sym, ids)), Some(t))
} else if let Some(prims) = self.egraph.proof_state.type_info.primitives.get(sym) {
} else if let Some(prims) = self.egraph.desugar.type_info.primitives.get(sym) {
let (ids, arg_tys): (Vec<Id>, Vec<Option<ArcSort>>) =
args.iter().map(|arg| self.infer_query_expr(arg)).unzip();

Expand Down Expand Up @@ -533,13 +533,7 @@ impl<'a> ExprChecker<'a> for ActionChecker<'a> {
}

fn do_function(&mut self, f: Symbol, _args: Vec<Self::T>) -> Self::T {
let func_type = self
.egraph
.proof_state
.type_info
.func_types
.get(&f)
.unwrap();
let func_type = self.egraph.desugar.type_info.func_types.get(&f).unwrap();
self.instructions.push(Instruction::CallFunction(
f,
func_type.has_default || !func_type.has_merge,
Expand Down Expand Up @@ -601,11 +595,11 @@ trait ExprChecker<'a> {
match expr {
Expr::Lit(lit) => {
let t = self.do_lit(lit);
Ok((t, self.egraph().proof_state.type_info.infer_literal(lit)))
Ok((t, self.egraph().desugar.type_info.infer_literal(lit)))
}
Expr::Var(sym) => self.infer_var(*sym),
Expr::Call(sym, args) => {
if let Some(functype) = self.egraph().proof_state.type_info.func_types.get(sym) {
if let Some(functype) = self.egraph().desugar.type_info.func_types.get(sym) {
assert!(functype.input.len() == args.len());

let mut ts = vec![];
Expand All @@ -615,8 +609,7 @@ trait ExprChecker<'a> {

let t = self.do_function(*sym, ts);
Ok((t, functype.output.clone()))
} else if let Some(prims) = self.egraph().proof_state.type_info.primitives.get(sym)
{
} else if let Some(prims) = self.egraph().desugar.type_info.primitives.get(sym) {
let mut ts = Vec::with_capacity(args.len());
let mut tys = Vec::with_capacity(args.len());
for arg in args {
Expand Down Expand Up @@ -880,11 +873,7 @@ impl EGraph {
let (cost, expr) = self.extract(
values[0],
&mut termdag,
self.proof_state
.type_info
.sorts
.get(&values[0].tag)
.unwrap(),
self.desugar.type_info.sorts.get(&values[0].tag).unwrap(),
);
let extracted = termdag.to_string(&expr);
log::info!("extracted with cost {cost}: {}", extracted);
Expand Down
6 changes: 4 additions & 2 deletions src/typechecking.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{proofs::RULE_PROOF_KEYWORD, *};
use crate::*;

pub const RULE_PROOF_KEYWORD: &str = "rule-proof";

#[derive(Clone, Debug)]
pub struct FuncType {
Expand Down Expand Up @@ -630,7 +632,7 @@ pub enum TypeError {
#[error("Arity mismatch, expected {expected} args: {expr}")]
Arity { expr: Expr, expected: usize },
#[error(
"Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}",
"Type mismatch: expr = {expr}, expected = {}, actual = {}, reason: {reason}",
.expected.name(), .actual.name(),
)]
Mismatch {
Expand Down