diff --git a/src/conversions.rs b/src/conversions.rs index f918bf100..408bb46f5 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -114,10 +114,10 @@ impl TermConverter<'_> { let name = self.string_term_to_string(name); let code_vec = self.term_conslist_to_vec(code, "Code"); let mut instrs = vec![]; - // let mut memo = HashMap::::new(); + let mut memo = HashMap::::new(); for t in code_vec { - self.term_to_instructions(&t, &mut instrs); + self.term_to_instructions(&t, &mut instrs, &mut memo); } BasicBlock { @@ -147,10 +147,10 @@ impl TermConverter<'_> { res } - fn term_to_instructions(&mut self, id: &TermId, res: &mut Vec) { + fn term_to_instructions(&mut self, id: &TermId, res: &mut Vec, memo: &mut HashMap) { match_term_app!(self.get(id); { ("Print", [arg]) => { - let arg = self.term_to_code(arg, res, None); + let arg = self.term_to_code(arg, res, None, memo); res.push(Instruction::Effect { op: EffectOps::Print, @@ -163,12 +163,12 @@ impl TermConverter<'_> { ("End", []) => {}, ("Assign", [dest, src]) => { let dest = self.string_term_to_string(dest); - self.term_to_code(src, res, Some(dest.to_string())); + self.term_to_code(src, res, Some(dest.to_string()), memo); }, (op @ ("store" | "free"), args) => { let args = args .iter() - .map(|arg| self.term_to_code(arg, res, None)) + .map(|arg| self.term_to_code(arg, res, None, memo)) .collect::>(); @@ -183,7 +183,7 @@ impl TermConverter<'_> { ("alloc", [atype, dest, arg]) => { let atype = self.term_to_type(atype); let dest = self.string_term_to_string(dest); - let arg = self.term_to_code(arg, res, None); + let arg = self.term_to_code(arg, res, None, memo); res.push(Instruction::Value { dest, args: vec![arg], @@ -203,109 +203,111 @@ impl TermConverter<'_> { id: &TermId, res: &mut Vec, assign_to: Option, + memo: &mut HashMap ) -> String { + if memo.contains_key(id) && assign_to.is_none() { + return memo[id].clone(); + } + let dest = match &assign_to { Some(dest) => dest.clone(), None => self.optimizer.fresh_var(), }; - match self.get(id) { - Term::Lit(literal) => { - res.push(Instruction::Constant { - dest: dest.clone(), - op: bril_rs::ConstOps::Const, - value: self.optimizer.literal_to_bril(&literal), - pos: None, - const_type: self.optimizer.literal_to_type(&literal), - }); - dest - } - Term::Var(var) => { - if let Some(_output) = assign_to { - panic!("Cannot assign var to var") - } else { - var.to_string() + let ret = + match self.get(id) { + Term::Lit(literal) => { + res.push(Instruction::Constant { + dest: dest.clone(), + op: bril_rs::ConstOps::Const, + value: self.optimizer.literal_to_bril(&literal), + pos: None, + const_type: self.optimizer.literal_to_type(&literal), + }); + dest } - } - t => { - match_term_app!(t; { - ("Var", [arg]) => { - match self.get(arg) { - Term::Lit(egglog::ast::Literal::String(var)) => var.to_string(), - _ => panic!("expected string literal for var"), - } - }, - ("ReturnValue", [arg]) => self.term_to_code(arg, res, assign_to), - (op @ ("True" | "False" | "Int" | "Float" | "Char"), [ty, args @ ..]) => { - let lit = match (op, args) { - ("True", []) => Literal::Bool(true), - ("False", []) => Literal::Bool(false), - ("Int", [arg]) => { - let arg = self.get(arg); - let arg_s = self.termdag.to_string(&arg); - Literal::Int(arg_s.parse::().unwrap()) - } - ("Float", [arg]) => { - let arg = self.get(arg); - let arg_s = self.termdag.to_string(&arg); - Literal::Float(arg_s.parse::().unwrap()) + t => { + match_term_app!(t; { + ("Var", [arg]) => { + match self.get(arg) { + Term::Lit(egglog::ast::Literal::String(var)) => var.to_string(), + _ => panic!("expected string literal for var"), } - ("Char", [arg]) => { - let arg = self.get(arg); - let arg_s = self.termdag.to_string(&arg); - assert_eq!(arg_s.len(), 1); - Literal::Char(arg_s.chars().next().unwrap()) - } - _ => panic!("unexpected args to literal in term_to_code") - }; - res.push(Instruction::Constant { - dest: dest.clone(), - op: bril_rs::ConstOps::Const, - value: lit, - pos: None, - const_type: self.term_to_type(ty), - }); - dest - }, - ("phi", [etype, arg1, arg2, label1, label2]) => { - let etype = self.term_to_type(etype); - let arg1 = self.term_to_code(arg1, res, None); - let arg2 = self.term_to_code(arg2, res, None); - let label1 = self.string_term_to_string(label1); - let label2 = self.string_term_to_string(label2); - res.push(Instruction::Value { - dest: dest.clone(), - args: vec![arg1, arg2], - funcs: vec![], - op: ValueOps::Phi, - labels: vec![label1, label2], - pos: None, - op_type: etype, - }); - dest - }, - (op, args) => { - assert!(op != "Void"); - let etype = self.term_to_type(&args[0]); - let args_vars = args - .iter() - .skip(1) - .map(|arg| self.term_to_code(arg, res, None)) - .collect::>(); - res.push(Instruction::Value { - dest: dest.clone(), - args: args_vars, - funcs: vec![], - op: self.optimizer.egglog_op_to_bril(op), - labels: vec![], - pos: None, - op_type: etype, - }); - dest - } - }) - } - } + }, + ("ReturnValue", [arg]) => self.term_to_code(arg, res, assign_to, memo), + (op @ ("True" | "False" | "Int" | "Float" | "Char"), [ty, args @ ..]) => { + let lit = match (op, args) { + ("True", []) => Literal::Bool(true), + ("False", []) => Literal::Bool(false), + ("Int", [arg]) => { + let arg = self.get(arg); + let arg_s = self.termdag.to_string(&arg); + Literal::Int(arg_s.parse::().unwrap()) + } + ("Float", [arg]) => { + let arg = self.get(arg); + let arg_s = self.termdag.to_string(&arg); + Literal::Float(arg_s.parse::().unwrap()) + } + ("Char", [arg]) => { + let arg = self.get(arg); + let arg_s = self.termdag.to_string(&arg); + assert_eq!(arg_s.len(), 1); + Literal::Char(arg_s.chars().next().unwrap()) + } + _ => panic!("unexpected args to literal in term_to_code") + }; + res.push(Instruction::Constant { + dest: dest.clone(), + op: bril_rs::ConstOps::Const, + value: lit, + pos: None, + const_type: self.term_to_type(ty), + }); + dest + }, + ("phi", [etype, arg1, arg2, label1, label2]) => { + let etype = self.term_to_type(etype); + let arg1 = self.term_to_code(arg1, res, None, memo); + let arg2 = self.term_to_code(arg2, res, None, memo); + let label1 = self.string_term_to_string(label1); + let label2 = self.string_term_to_string(label2); + res.push(Instruction::Value { + dest: dest.clone(), + args: vec![arg1, arg2], + funcs: vec![], + op: ValueOps::Phi, + labels: vec![label1, label2], + pos: None, + op_type: etype, + }); + dest + }, + (op, args) => { + assert!(op != "Void"); + let etype = self.term_to_type(&args[0]); + let args_vars = args + .iter() + .skip(1) + .map(|arg| self.term_to_code(arg, res, None, memo)) + .collect::>(); + res.push(Instruction::Value { + dest: dest.clone(), + args: args_vars, + funcs: vec![], + op: self.optimizer.egglog_op_to_bril(op), + labels: vec![], + pos: None, + op_type: etype, + }); + dest + } + }) + } + }; + + memo.insert(*id, ret.clone()); + ret } pub(crate) fn term_to_type(&self, id: &TermId) -> Type { diff --git a/tests/snapshots/files__add.snap b/tests/snapshots/files__add.snap index a3e74871d..6543f4067 100644 --- a/tests/snapshots/files__add.snap +++ b/tests/snapshots/files__add.snap @@ -7,8 +7,7 @@ expression: "format!(\"{}\", res)" v0: int = const 1; v1: int = const 2; v2: int = const 3; - v0_: int = const 3; - print v0_; + print v2; ret; .sblock___0: .exit___: diff --git a/tests/snapshots/files__add_no_opt.snap b/tests/snapshots/files__add_no_opt.snap index 4932bfc19..f000bd986 100644 --- a/tests/snapshots/files__add_no_opt.snap +++ b/tests/snapshots/files__add_no_opt.snap @@ -6,13 +6,8 @@ expression: "format!(\"{}\", res)" .entry___: v0: int = const 1; v1: int = const 2; - v0_: int = const 1; - v1_: int = const 2; - v2: int = add v0_ v1_; - v3_: int = const 1; - v4_: int = const 2; - v2_: int = add v3_ v4_; - print v2_; + v2: int = add v0 v1; + print v2; ret; .sblock___0: .exit___: diff --git a/tests/snapshots/files__block-diamond.snap b/tests/snapshots/files__block-diamond.snap index 98de39b23..ae6f858d5 100644 --- a/tests/snapshots/files__block-diamond.snap +++ b/tests/snapshots/files__block-diamond.snap @@ -7,9 +7,7 @@ expression: "format!(\"{}\", res)" one: int = const 1; two: int = const 2; x: int = const 0; - v0_: int = const 1; - v1_: int = const 2; - a_cond: bool = lt v0_ v1_; + a_cond: bool = lt one two; br a_cond .sblock___4 .sblock___5; .sblock___4: jmp .sblock___2; diff --git a/tests/snapshots/files__block-diamond_no_opt.snap b/tests/snapshots/files__block-diamond_no_opt.snap index 98de39b23..ae6f858d5 100644 --- a/tests/snapshots/files__block-diamond_no_opt.snap +++ b/tests/snapshots/files__block-diamond_no_opt.snap @@ -7,9 +7,7 @@ expression: "format!(\"{}\", res)" one: int = const 1; two: int = const 2; x: int = const 0; - v0_: int = const 1; - v1_: int = const 2; - a_cond: bool = lt v0_ v1_; + a_cond: bool = lt one two; br a_cond .sblock___4 .sblock___5; .sblock___4: jmp .sblock___2; diff --git a/tests/snapshots/files__diamond.snap b/tests/snapshots/files__diamond.snap index 0514fdb8b..b5d444ac6 100644 --- a/tests/snapshots/files__diamond.snap +++ b/tests/snapshots/files__diamond.snap @@ -5,9 +5,7 @@ expression: "format!(\"{}\", res)" @main { .entry___: x: int = const 4; - v0_: int = const 4; - v1_: int = const 4; - cond: bool = lt v0_ v1_; + cond: bool = lt x x; br cond .sblock___3 .sblock___4; .sblock___3: jmp .sblock___1; diff --git a/tests/snapshots/files__diamond_no_opt.snap b/tests/snapshots/files__diamond_no_opt.snap index 0514fdb8b..b5d444ac6 100644 --- a/tests/snapshots/files__diamond_no_opt.snap +++ b/tests/snapshots/files__diamond_no_opt.snap @@ -5,9 +5,7 @@ expression: "format!(\"{}\", res)" @main { .entry___: x: int = const 4; - v0_: int = const 4; - v1_: int = const 4; - cond: bool = lt v0_ v1_; + cond: bool = lt x x; br cond .sblock___3 .sblock___4; .sblock___3: jmp .sblock___1; diff --git a/tests/snapshots/files__two_fns_no_opt.snap b/tests/snapshots/files__two_fns_no_opt.snap index 013fb98db..13fc24f68 100644 --- a/tests/snapshots/files__two_fns_no_opt.snap +++ b/tests/snapshots/files__two_fns_no_opt.snap @@ -6,9 +6,7 @@ expression: "format!(\"{}\", res)" .entry___: v0: int = const 1; v1: int = const 2; - v0_: int = const 1; - v1_: int = const 2; - v2: int = add v0_ v1_; + v2: int = add v0 v1; ret; .sblock___0: .exit___: @@ -17,9 +15,7 @@ expression: "format!(\"{}\", res)" .entry___: v0: int = const 1; v1: int = const 2; - v2_: int = const 1; - v3_: int = const 2; - v2: int = sub v2_ v3_; + v2: int = sub v0 v1; ret; .sblock___0: .exit___: