diff --git a/src/conversions.rs b/src/conversions.rs index a32fbb43c..14db2a061 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -9,7 +9,7 @@ use crate::{ }; use bril_rs::{Argument, Code, EffectOps, Instruction, Literal, Program, Type, ValueOps}; use egglog::ast::{Expr, Symbol}; -use egglog::{match_term_app, Term, TermDag}; +use egglog::{match_term_app, Term, TermDag, TermId}; use ordered_float::OrderedFloat; pub(crate) struct TermConverter<'a> { @@ -18,73 +18,77 @@ pub(crate) struct TermConverter<'a> { } impl TermConverter<'_> { - pub(crate) fn term_to_structured_func(&mut self, t: &Term) -> StructuredFunction { - match_term_app!(t; { + pub(crate) fn get(&self, id: &TermId) -> Term { + self.termdag.get(*id) + } + + pub(crate) fn term_to_structured_func(&mut self, id: &TermId) -> StructuredFunction { + match_term_app!(self.get(id); { ("Func", [func_name, argslist, body]) => { - let args = self.term_conslist_to_vec(&self.termdag.get(*argslist), "Arg") + let args = self.term_conslist_to_vec(argslist, "Arg") .into_iter() .map(|arg| self.term_to_argument(&arg)) .collect(); - let fname = Optimizer::string_term_to_string(&self.termdag.get(*func_name)); + let fname = self.string_term_to_string(func_name); StructuredFunction { name: fname.to_string(), args, - block: self.term_to_structured_block(&self.termdag.get(*body)), + block: self.term_to_structured_block(body), } } (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) }) } - fn term_to_argument(&self, term: &Term) -> Argument { - match_term_app!(term; { + fn term_to_argument(&self, id: &TermId) -> Argument { + match_term_app!(self.get(id); { ("Arg", [name, ty]) => { - let name = Optimizer::string_term_to_string(&self.termdag.get(*name)); + let name = self.string_term_to_string(name); Argument { name: name.to_string(), - arg_type: self.term_to_type(&self.termdag.get(*ty)), + arg_type: self.term_to_type(ty), } } (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) }) } - pub(crate) fn term_to_structured_block(&mut self, term: &Term) -> StructuredBlock { - match_term_app!(term; { + pub(crate) fn term_to_structured_block(&mut self, id: &TermId) -> StructuredBlock { + match_term_app!(self.get(id); { ("Block", [block]) => { - StructuredBlock::Block(Box::new(self.term_to_structured_block(&self.termdag.get(*block)))) + StructuredBlock::Block(Box::new(self.term_to_structured_block(block))) }, ("Basic", [basic_block]) => { - StructuredBlock::Basic(Box::new(self.term_to_basic_block(&self.termdag.get(*basic_block)))) + StructuredBlock::Basic(Box::new(self.term_to_basic_block(basic_block))) }, ("Ite", [name, then_branch, else_branch]) => { - let string = Optimizer::string_term_to_string(&self.termdag.get(*name)); + let string = self.string_term_to_string(name); StructuredBlock::Ite( string.to_string(), - Box::new(self.term_to_structured_block(&self.termdag.get(*then_branch))), - Box::new(self.term_to_structured_block(&self.termdag.get(*else_branch))), + Box::new(self.term_to_structured_block(then_branch)), + Box::new(self.term_to_structured_block(else_branch)), ) }, ("Loop", [block]) => { - StructuredBlock::Loop(Box::new(self.term_to_structured_block(&self.termdag.get(*block)))) + StructuredBlock::Loop(Box::new(self.term_to_structured_block(block))) }, ("Sequence", [block, rest]) => StructuredBlock::Sequence(vec![ - self.term_to_structured_block(&self.termdag.get(*block)), - self.term_to_structured_block(&self.termdag.get(*rest)), + self.term_to_structured_block(block), + self.term_to_structured_block(rest), ]), ("Break", [n]) => { - if let Term::Lit(egglog::ast::Literal::Int(n)) = self.termdag.get(*n) { + if let Term::Lit(egglog::ast::Literal::Int(n)) = self.get(n) { StructuredBlock::Break(n.try_into().unwrap()) } else { panic!("expected int literal for break"); } }, ("Return", [val]) => { - match_term_app!(self.termdag.get(*val); { + match_term_app!(self.get(val); { ("Void", _) => StructuredBlock::Return(None), ("ReturnValue", [arg]) => { - match self.termdag.get(*arg) { + match self.get(arg) { Term::Lit(egglog::ast::Literal::String(s)) => { StructuredBlock::Return(Some(s.to_string())) } @@ -98,11 +102,11 @@ impl TermConverter<'_> { }) } - pub(crate) fn term_to_basic_block(&mut self, term: &Term) -> BasicBlock { - match_term_app!(term; { + pub(crate) fn term_to_basic_block(&mut self, id: &TermId) -> BasicBlock { + match_term_app!(self.get(id); { ("BlockNamed", [name, code]) => { - let name = Optimizer::string_term_to_string(&self.termdag.get(*name)); - let code_vec = self.term_conslist_to_vec(&self.termdag.get(*code), "Code"); + let name = self.string_term_to_string(name); + let code_vec = self.term_conslist_to_vec(code, "Code"); let mut instrs = vec![]; // let mut memo = HashMap::::new(); @@ -121,27 +125,27 @@ impl TermConverter<'_> { }) } - fn term_conslist_to_vec_helper(&self, term: &Term, res: &mut Vec, prefix: &str) { - match_term_app!(term; { + fn term_conslist_to_vec_helper(&self, id: &TermId, res: &mut Vec, prefix: &str) { + match_term_app!(self.get(id); { (op, [head, tail]) if op == prefix.to_string() + "Cons" => { - res.push(self.termdag.get(*head)); - self.term_conslist_to_vec_helper(&self.termdag.get(*tail), res, prefix); + res.push(*head); + self.term_conslist_to_vec_helper(tail, res, prefix); }, (op, []) if op == prefix.to_string() + "Nil" => {} (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) }) } - fn term_conslist_to_vec(&self, term: &Term, prefix: &str) -> Vec { + fn term_conslist_to_vec(&self, id: &TermId, prefix: &str) -> Vec { let mut res = vec![]; - self.term_conslist_to_vec_helper(term, &mut res, prefix); + self.term_conslist_to_vec_helper(id, &mut res, prefix); res } - fn term_to_instructions(&mut self, term: &Term, res: &mut Vec) { - match_term_app!(term; { + fn term_to_instructions(&mut self, id: &TermId, res: &mut Vec) { + match_term_app!(self.get(id); { ("Print", [arg]) => { - let arg = self.term_to_code(&self.termdag.get(*arg), res, None); + let arg = self.term_to_code(arg, res, None); res.push(Instruction::Effect { op: EffectOps::Print, @@ -153,13 +157,13 @@ impl TermConverter<'_> { }, ("End", []) => {}, ("Assign", [dest, src]) => { - let dest = Optimizer::string_term_to_string(&self.termdag.get(*dest)); - self.term_to_code(&self.termdag.get(*src), res, Some(dest.to_string())); + let dest = self.string_term_to_string(dest); + self.term_to_code(src, res, Some(dest.to_string())); }, (op @ ("store" | "free"), args) => { let args = args .iter() - .map(|arg| self.term_to_code(&self.termdag.get(*arg), res, None)) + .map(|arg| self.term_to_code(arg, res, None)) .collect::>(); @@ -172,9 +176,9 @@ impl TermConverter<'_> { }); }, ("alloc", [atype, dest, arg]) => { - let atype = self.term_to_type(&self.termdag.get(*atype)); - let dest = Optimizer::string_term_to_string(&self.termdag.get(*dest)); - let arg = self.term_to_code(&self.termdag.get(*arg), res, None); + let atype = self.term_to_type(atype); + let dest = self.string_term_to_string(dest); + let arg = self.term_to_code(arg, res, None); res.push(Instruction::Value { dest, args: vec![arg], @@ -191,7 +195,7 @@ impl TermConverter<'_> { pub(crate) fn term_to_code( &mut self, - term: &Term, + id: &TermId, res: &mut Vec, assign_to: Option, ) -> String { @@ -200,14 +204,14 @@ impl TermConverter<'_> { None => self.optimizer.fresh_var(), }; - match term { + match self.get(id) { Term::Lit(literal) => { res.push(Instruction::Constant { dest: dest.clone(), op: bril_rs::ConstOps::Const, - value: self.optimizer.literal_to_bril(literal), + value: self.optimizer.literal_to_bril(&literal), pos: None, - const_type: self.optimizer.literal_to_type(literal), + const_type: self.optimizer.literal_to_type(&literal), }); dest } @@ -218,31 +222,31 @@ impl TermConverter<'_> { var.to_string() } } - _ => { - match_term_app!(term; { + t => { + match_term_app!(t; { ("Var", [arg]) => { - match self.termdag.get(*arg) { + match self.get(arg) { Term::Lit(egglog::ast::Literal::String(var)) => var.to_string(), _ => panic!("expected string literal for var"), } }, - ("ReturnValue", [arg]) => self.term_to_code(&self.termdag.get(*arg), res, assign_to), + ("ReturnValue", [arg]) => self.term_to_code(arg, res, assign_to), (op @ ("True" | "False" | "Int" | "Float" | "Char"), [ty, args @ ..]) => { let lit = match (op, args) { ("True", []) => Literal::Bool(true), ("False", []) => Literal::Bool(false), ("Int", [arg]) => { - let arg = self.termdag.get(*arg); + let arg = self.get(arg); let arg_s = self.termdag.to_string(&arg); Literal::Int(arg_s.parse::().unwrap()) } ("Float", [arg]) => { - let arg = self.termdag.get(*arg); + let arg = self.get(arg); let arg_s = self.termdag.to_string(&arg); Literal::Float(arg_s.parse::().unwrap()) } ("Char", [arg]) => { - let arg = self.termdag.get(*arg); + let arg = self.get(arg); let arg_s = self.termdag.to_string(&arg); assert_eq!(arg_s.len(), 1); Literal::Char(arg_s.chars().next().unwrap()) @@ -254,16 +258,16 @@ impl TermConverter<'_> { op: bril_rs::ConstOps::Const, value: lit, pos: None, - const_type: self.term_to_type(&self.termdag.get(*ty)), + const_type: self.term_to_type(ty), }); dest }, ("phi", [etype, arg1, arg2, label1, label2]) => { - let etype = self.term_to_type(&self.termdag.get(*etype)); - let arg1 = self.term_to_code(&self.termdag.get(*arg1), res, None); - let arg2 = self.term_to_code(&self.termdag.get(*arg2), res, None); - let label1 = Optimizer::string_term_to_string(&self.termdag.get(*label1)); - let label2 = Optimizer::string_term_to_string(&self.termdag.get(*label2)); + let etype = self.term_to_type(etype); + let arg1 = self.term_to_code(arg1, res, None); + let arg2 = self.term_to_code(arg2, res, None); + let label1 = self.string_term_to_string(label1); + let label2 = self.string_term_to_string(label2); res.push(Instruction::Value { dest: dest.clone(), args: vec![arg1, arg2], @@ -277,11 +281,11 @@ impl TermConverter<'_> { }, (op, args) => { assert!(op != "Void"); - let etype = self.term_to_type(&self.termdag.get(args[0])); + let etype = self.term_to_type(&args[0]); let args_vars = args .iter() .skip(1) - .map(|arg| self.term_to_code(&self.termdag.get(*arg), res, None)) + .map(|arg| self.term_to_code(arg, res, None)) .collect::>(); res.push(Instruction::Value { dest: dest.clone(), @@ -299,16 +303,24 @@ impl TermConverter<'_> { } } - pub(crate) fn term_to_type(&self, term: &Term) -> Type { - match_term_app!(term; { + pub(crate) fn term_to_type(&self, id: &TermId) -> Type { + match_term_app!(self.get(id); { ("IntT", []) => Type::Int, ("BoolT", []) => Type::Bool, ("FloatT", []) => Type::Float, ("CharT", []) => Type::Char, - ("PointerT", [child]) => Type::Pointer(Box::new(self.term_to_type(&self.termdag.get(*child)))), + ("PointerT", [child]) => Type::Pointer(Box::new(self.term_to_type(child))), (head, _) => panic!("unexpected head {}, in {}:{}:{}", head, file!(), line!(), column!()) }) } + + fn string_term_to_string(&self, id: &TermId) -> String { + if let Term::Lit(egglog::ast::Literal::String(string)) = self.get(id) { + string.to_string() + } else { + panic!("expected string literal"); + } + } } impl Optimizer { @@ -321,7 +333,7 @@ impl Optimizer { optimizer: self, termdag, }; - converter.term_to_structured_func(term) + converter.term_to_structured_func(&termdag.lookup(term)) } pub(crate) fn func_to_expr(&mut self, func: &StructuredFunction) -> Expr { @@ -417,14 +429,6 @@ impl Optimizer { Expr::Call("Var".into(), vec![self.string_to_expr(string)]) } - fn string_term_to_string(term: &Term) -> String { - if let Term::Lit(egglog::ast::Literal::String(string)) = term { - string.to_string() - } else { - panic!("expected string literal"); - } - } - pub(crate) fn convert_basic_block(&mut self, block: &BasicBlock) -> Expr { // leave prints in order // leave any effects in order