Skip to content

Commit

Permalink
Merge pull request #429 from Alex-Fischman/type_info_getters
Browse files Browse the repository at this point in the history
Remove useless type_info getters
  • Loading branch information
Alex-Fischman authored Oct 5, 2024
2 parents 606459d + 357a625 commit 0d4e688
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ impl EGraph {
let (cost, term) = self.extract(
values[0],
&mut termdag,
self.type_info().sorts.get(&values[0].tag).unwrap(),
self.type_info.sorts.get(&values[0].tag).unwrap(),
);
let extracted = termdag.to_string(&term);
log::info!("extracted with cost {cost}: {extracted}");
Expand Down
10 changes: 5 additions & 5 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl Function {
pub(crate) fn new(egraph: &EGraph, decl: &ResolvedFunctionDecl) -> Result<Self, Error> {
let mut input = Vec::with_capacity(decl.schema.input.len());
for s in &decl.schema.input {
input.push(match egraph.type_info().sorts.get(s) {
input.push(match egraph.type_info.sorts.get(s) {
Some(sort) => sort.clone(),
None => {
return Err(Error::TypeError(TypeError::UndefinedSort(
Expand All @@ -97,7 +97,7 @@ impl Function {
})
}

let output = match egraph.type_info().sorts.get(&decl.schema.output) {
let output = match egraph.type_info.sorts.get(&decl.schema.output) {
Some(sort) => sort.clone(),
None => {
return Err(Error::TypeError(TypeError::UndefinedSort(
Expand All @@ -123,11 +123,11 @@ impl Function {
// Invariant: the last element in the stack is the return value.
let merge_vals = if let Some(merge_expr) = &decl.merge {
let (actions, mapped_expr) = merge_expr.to_core_actions(
egraph.type_info(),
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
)?;
let target = mapped_expr.get_corresponding_var_or_lit(egraph.type_info());
let target = mapped_expr.get_corresponding_var_or_lit(&egraph.type_info);
let program = egraph
.compile_expr(&binding, &actions, &target)
.map_err(Error::TypeErrors)?;
Expand All @@ -142,7 +142,7 @@ impl Function {
None
} else {
let (merge_action, _) = decl.merge_action.to_core_actions(
egraph.type_info(),
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
)?;
Expand Down
49 changes: 20 additions & 29 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,11 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.type_info().get_sort_nofail()).unwrap(),
Literal::F64(f) => f.store(&self.type_info().get_sort_nofail()).unwrap(),
Literal::String(s) => s.store(&self.type_info().get_sort_nofail()).unwrap(),
Literal::Unit => ().store(&self.type_info().get_sort_nofail()).unwrap(),
Literal::Bool(b) => b.store(&self.type_info().get_sort_nofail()).unwrap(),
Literal::Int(i) => i.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::F64(f) => f.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::String(s) => s.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Unit => ().store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Bool(b) => b.store(&self.type_info.get_sort_nofail()).unwrap(),
}
}

Expand Down Expand Up @@ -855,7 +855,7 @@ impl EGraph {
/// See also extract_value_to_string for convenience.
pub fn extract_value(&self, value: Value) -> (TermDag, Term) {
let mut termdag = TermDag::default();
let sort = self.type_info().sorts.get(&value.tag).unwrap();
let sort = self.type_info.sorts.get(&value.tag).unwrap();
let term = self.extract(value, &mut termdag, sort).1;
(termdag, term)
}
Expand Down Expand Up @@ -1054,7 +1054,7 @@ impl EGraph {
ruleset: Symbol,
) -> Result<Symbol, Error> {
let name = Symbol::from(name);
let core_rule = rule.to_canonicalized_core_rule(self.type_info())?;
let core_rule = rule.to_canonicalized_core_rule(&self.type_info)?;
let (query, actions) = (core_rule.body, core_rule.head);

let vars = query.get_vars();
Expand Down Expand Up @@ -1093,7 +1093,7 @@ impl EGraph {

fn eval_actions(&mut self, actions: &ResolvedActions) -> Result<(), Error> {
let (actions, _) = actions.to_core_actions(
self.type_info(),
&self.type_info,
&mut Default::default(),
&mut ResolvedGen::new("$".to_string()),
)?;
Expand All @@ -1120,11 +1120,11 @@ impl EGraph {
// then returns the value at the end.
fn eval_resolved_expr(&mut self, expr: &ResolvedExpr) -> Result<Value, Error> {
let (actions, mapped_expr) = expr.to_core_actions(
self.type_info(),
&self.type_info,
&mut Default::default(),
&mut ResolvedGen::new("$".to_string()),
)?;
let target = mapped_expr.get_corresponding_var_or_lit(self.type_info());
let target = mapped_expr.get_corresponding_var_or_lit(&self.type_info);
let program = self
.compile_expr(&Default::default(), &actions, &target)
.map_err(Error::TypeErrors)?;
Expand Down Expand Up @@ -1169,7 +1169,7 @@ impl EGraph {
head: ResolvedActions::default(),
body: facts.to_vec(),
};
let core_rule = rule.to_canonicalized_core_rule(self.type_info())?;
let core_rule = rule.to_canonicalized_core_rule(&self.type_info)?;
let query = core_rule.body;
let ordering = &query.get_vars();
let query = self.compile_gj_query(query, ordering);
Expand Down Expand Up @@ -1318,7 +1318,7 @@ impl EGraph {
let mut termdag = TermDag::default();
for expr in exprs {
let value = self.eval_resolved_expr(&expr)?;
let expr_type = expr.output_type(self.type_info());
let expr_type = expr.output_type(&self.type_info);
let term = self.extract(value, &mut termdag, &expr_type).1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&term))
Expand All @@ -1333,7 +1333,7 @@ impl EGraph {

fn input_file(&mut self, func_name: Symbol, file: String) -> Result<(), Error> {
let function_type = self
.type_info()
.type_info
.lookup_user_func(func_name)
.unwrap_or_else(|| panic!("Unrecognized function name {}", func_name));
let func = self.functions.get_mut(&func_name).unwrap();
Expand Down Expand Up @@ -1398,7 +1398,7 @@ impl EGraph {
.into_iter()
.map(NCommand::CoreAction)
.collect::<Vec<_>>();
let commands: Vec<_> = self.type_info_mut().typecheck_program(&commands)?;
let commands: Vec<_> = self.type_info.typecheck_program(&commands)?;
for command in commands {
self.run_command(command)?;
}
Expand All @@ -1425,7 +1425,7 @@ impl EGraph {
self.desugar
.desugar_program(vec![command], self.test_proofs, self.seminaive)?;

let program = self.type_info_mut().typecheck_program(&program)?;
let program = self.type_info.typecheck_program(&program)?;

let program = remove_globals(&self.type_info, program, &mut self.desugar.fresh_gen);

Expand Down Expand Up @@ -1487,7 +1487,7 @@ impl EGraph {
}

pub(crate) fn get_sort_from_value(&self, value: &Value) -> Option<&ArcSort> {
self.type_info().sorts.get(&value.tag)
self.type_info.sorts.get(&value.tag)
}

/// Returns the first sort that satisfies the type and predicate if there's one.
Expand All @@ -1496,23 +1496,22 @@ impl EGraph {
&self,
pred: impl Fn(&Arc<S>) -> bool,
) -> Option<Arc<S>> {
self.type_info().get_sort_by(pred)
self.type_info.get_sort_by(pred)
}

/// Returns a sort based on the type
pub fn get_sort<S: Sort + Send + Sync>(&self) -> Option<Arc<S>> {
self.type_info().get_sort_by(|_| true)
self.type_info.get_sort_by(|_| true)
}

/// Add a user-defined sort
pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.type_info_mut()
.add_arcsort(arcsort, DUMMY_SPAN.clone())
self.type_info.add_arcsort(arcsort, DUMMY_SPAN.clone())
}

/// Add a user-defined primitive
pub fn add_primitive(&mut self, prim: impl Into<Primitive>) {
self.type_info_mut().add_primitive(prim)
self.type_info.add_primitive(prim)
}

/// Gets the last extract report and returns it, if the last command saved it.
Expand All @@ -1538,14 +1537,6 @@ impl EGraph {
self.msgs.dedup_by(|a, b| a.is_empty() && b.is_empty());
std::mem::take(&mut self.msgs)
}

pub(crate) fn type_info(&self) -> &TypeInfo {
&self.type_info
}

pub(crate) fn type_info_mut(&mut self) -> &mut TypeInfo {
&mut self.type_info
}
}

// Currently, only the following errors can thrown without location information:
Expand Down
6 changes: 3 additions & 3 deletions src/sort/fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl PrimitiveLike for Apply {
/// so that we can re-use the logic for primitive and regular functions.
fn call_fn(egraph: &mut EGraph, name: &Symbol, types: Vec<ArcSort>, args: Vec<Value>) -> Value {
// Make a call with temp vars as each of the args
let resolved_call = ResolvedCall::from_resolution(name, types.as_slice(), egraph.type_info());
let resolved_call = ResolvedCall::from_resolution(name, types.as_slice(), &egraph.type_info);
let arg_vars: Vec<_> = types
.into_iter()
// Skip last sort which is the output sort
Expand All @@ -410,12 +410,12 @@ fn call_fn(egraph: &mut EGraph, name: &Symbol, types: Vec<ArcSort>, args: Vec<Va
// Similar to how the merge function is created in `Function::new`
let (actions, mapped_expr) = expr
.to_core_actions(
egraph.type_info(),
&egraph.type_info,
&mut binding.clone(),
&mut ResolvedGen::new("$".to_string()),
)
.unwrap();
let target = mapped_expr.get_corresponding_var_or_lit(egraph.type_info());
let target = mapped_expr.get_corresponding_var_or_lit(&egraph.type_info);
let program = egraph.compile_expr(&binding, &actions, &target).unwrap();
// Similar to how the `MergeFn::Expr` case is handled in `Egraph::perform_set`
// egraph.rebuild().unwrap();
Expand Down

0 comments on commit 0d4e688

Please sign in to comment.