diff --git a/src/cfg/mod.rs b/src/cfg/mod.rs index 7dcafb96d..dcf048f8d 100644 --- a/src/cfg/mod.rs +++ b/src/cfg/mod.rs @@ -105,7 +105,7 @@ impl FromStr for BlockName { /// kinds of name. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum Identifier { - Name(Box), + Name(String), Num(usize), } diff --git a/src/conversions.rs b/src/conversions.rs index 45ca545c7..9e4336fba 100644 --- a/src/conversions.rs +++ b/src/conversions.rs @@ -326,7 +326,7 @@ impl Optimizer { dest: dest.clone(), args: args_vars, funcs: vec![], - op: self.egglog_op_to_bril(*op), + op: egglog_op_to_bril(*op), labels: vec![], pos: None, op_type: etype, @@ -570,11 +570,6 @@ impl Optimizer { Expr::Call("Assign".into(), vec![self.string_to_expr(dest), expr]) } - pub(crate) fn egglog_op_to_bril(&mut self, op: Symbol) -> ValueOps { - let with_quotes = "\"".to_owned() + &op.to_string() + "\""; - serde_json::from_str(&with_quotes).unwrap() - } - pub(crate) fn effect_op_to_egglog(&mut self, op: EffectOps) -> Symbol { let opstr = op.to_string(); if opstr == "print" { @@ -720,3 +715,8 @@ impl Optimizer { Ok(()) } } + +pub(crate) fn egglog_op_to_bril(op: Symbol) -> ValueOps { + let with_quotes = "\"".to_owned() + &op.to_string() + "\""; + serde_json::from_str(&with_quotes).unwrap() +} diff --git a/src/rvsdg/mod.rs b/src/rvsdg/mod.rs index 942cefe3b..1a6a21d20 100644 --- a/src/rvsdg/mod.rs +++ b/src/rvsdg/mod.rs @@ -37,10 +37,13 @@ pub(crate) mod rvsdg2svg; use std::fmt; use bril_rs::{ConstOps, Literal, Type, ValueOps}; +use egglog::EGraph; +use ordered_float::OrderedFloat; use thiserror::Error; use crate::{ cfg::{CfgProgram, Identifier}, + conversions::egglog_op_to_bril, EggCCError, }; @@ -166,3 +169,284 @@ pub(crate) fn cfg_to_rvsdg(cfg: &CfgProgram) -> std::result::Result) -> egglog::ast::Expr { + use egglog::ast::{Expr::*, Literal::*}; + let f = |operands: &Vec| { + operands + .iter() + .map(|op| self.operand_to_egglog_expr(op)) + .collect() + }; + fn from_ty(ty: &Type) -> egglog::ast::Expr { + match ty { + Type::Int => Call("IntT".into(), vec![]), + Type::Bool => Call("BoolT".into(), vec![]), + Type::Float => Call("FloatT".into(), vec![]), + Type::Char => Call("CharT".into(), vec![]), + Type::Pointer(ty) => Call("PointerT".into(), vec![from_ty(ty.as_ref())]), + } + } + match expr { + Expr::Op(op, operands) => Call(op.to_string().into(), f(operands)), + Expr::Call(ident, operands) => Call(ident.to_string().into(), f(operands)), + Expr::Const(ConstOps::Const, ty, lit) => { + let lit = match (ty, lit) { + (Type::Int, Literal::Int(n)) => Call("Num".into(), vec![Lit(Int(*n))]), + (Type::Bool, Literal::Bool(b)) => { + Call("Bool".into(), vec![Lit(Int(*b as i64))]) + } + (Type::Float, Literal::Float(f)) => Call( + "Float".into(), + vec![Lit(F64(OrderedFloat::::from(*f)))], + ), + (Type::Char, Literal::Char(c)) => { + Call("Char".into(), vec![Lit(String(c.to_string().into()))]) + } + (Type::Pointer(ty), Literal::Int(p)) => { + Call("Ptr".into(), vec![from_ty(ty.as_ref()), Lit(Int(*p))]) + } + _ => panic!("type mismatch"), + }; + Call( + "Const".into(), + vec![Call("const".into(), vec![]), from_ty(ty), lit], + ) + } + } + } + + fn body_to_egglog_expr(&self, body: &RvsdgBody) -> egglog::ast::Expr { + use egglog::ast::Expr::*; + match body { + RvsdgBody::PureOp(expr) => Call("PureOp".into(), vec![self.expr_to_egglog_expr(expr)]), + RvsdgBody::Gamma { + pred, + inputs, + outputs, + } => { + let pred = self.operand_to_egglog_expr(pred); + let inputs = inputs + .iter() + .map(|input| self.operand_to_egglog_expr(input)); + let inputs = Call("vec-of".into(), inputs.collect()); + let outputs = outputs.iter().map(|region| { + let region = region + .iter() + .map(|output| self.operand_to_egglog_expr(output)); + Call("VO".into(), vec![Call("vec-of".into(), region.collect())]) + }); + let outputs = Call("vec-of".into(), outputs.collect()); + Call("Gamma".into(), vec![pred, inputs, outputs]) + } + RvsdgBody::Theta { + pred, + inputs, + outputs, + } => { + let pred = self.operand_to_egglog_expr(pred); + let inputs = inputs + .iter() + .map(|input| self.operand_to_egglog_expr(input)); + let inputs = Call("vec-of".into(), inputs.collect()); + let outputs = outputs + .iter() + .map(|output| self.operand_to_egglog_expr(output)); + let outputs = Call("vec-of".into(), outputs.collect()); + Call("Theta".into(), vec![pred, inputs, outputs]) + } + } + } + + fn operand_to_egglog_expr(&self, op: &Operand) -> egglog::ast::Expr { + use egglog::ast::{Expr::*, Literal::*}; + match op { + Operand::Arg(p) => Call("Arg".into(), vec![Lit(Int(i64::try_from(*p).unwrap()))]), + Operand::Id(id) => Call( + "Node".into(), + vec![self.body_to_egglog_expr(&self.nodes[*id])], + ), + Operand::Project(i, id) => { + let body = self.body_to_egglog_expr(&self.nodes[*id]); + Call( + "Project".into(), + vec![Lit(Int(i64::try_from(*i).unwrap())), body], + ) + } + } + } + + pub fn to_egglog_expr(&self) -> egglog::ast::Expr { + // There might be multiple results in the future, + // e.g., one for return value and one for effect + if let Some(result) = &self.result { + self.operand_to_egglog_expr(result) + } else { + panic!("A function with no output is a noop") + } + } + + fn egglog_expr_to_operand(op: &egglog::ast::Expr, bodies: &mut Vec) -> Operand { + use egglog::ast::{Expr::*, Literal::*}; + if let Call(func, args) = op { + match (func.as_str(), &args.as_slice()) { + ("Arg", [Lit(Int(n))]) => Operand::Arg(*n as usize), + ("Node", [body]) => Operand::Id(Self::egglog_expr_to_body(body, bodies)), + ("Project", [Lit(Int(n)), body]) => { + Operand::Project(*n as usize, Self::egglog_expr_to_body(body, bodies)) + } + _ => panic!("expect an operand, got {op}"), + } + } else { + panic!("expect an operand, got {op}") + } + } + + fn egglog_expr_to_body(body: &egglog::ast::Expr, bodies: &mut Vec) -> Id { + use egglog::ast::Expr::*; + if let Call(func, args) = body { + let body = match (func.as_str(), &args.as_slice()) { + ("PureOp", [expr]) => RvsdgBody::PureOp(Self::egglog_expr_to_expr(expr, bodies)), + ("Gamma", [pred, inputs, outputs]) => { + let pred = Self::egglog_expr_to_operand(pred, bodies); + let inputs = vec_map(inputs, |e| Self::egglog_expr_to_operand(e, bodies)); + let outputs = vec_map(outputs, |es| { + if let Call(func, args) = es { + assert_eq!(func.as_str(), "VO"); + assert_eq!(args.len(), 1); + let es = &args[0]; + vec_map(es, |e| Self::egglog_expr_to_operand(e, bodies)) + } else { + panic!("expect VecOperandWrapper") + } + }); + RvsdgBody::Gamma { + pred, + inputs, + outputs, + } + } + ("Theta", [pred, inputs, outputs]) => { + let pred = Self::egglog_expr_to_operand(pred, bodies); + let inputs = vec_map(inputs, |e| Self::egglog_expr_to_operand(e, bodies)); + let outputs = vec_map(outputs, |e| Self::egglog_expr_to_operand(e, bodies)); + RvsdgBody::Theta { + pred, + inputs, + outputs, + } + } + _ => panic!("expect an operand, got {body}"), + }; + bodies.push(body); + bodies.len() - 1 + } else { + panic!("expect an operand, got {body}") + } + } + + fn egglog_expr_to_expr(expr: &egglog::ast::Expr, bodies: &mut Vec) -> Expr { + use egglog::ast::Literal; + if let egglog::ast::Expr::Call(func, args) = expr { + match (func.as_str(), &args.as_slice()) { + ("Call", [egglog::ast::Expr::Lit(Literal::String(ident)), args]) => { + let args = vec_map(args, |e| Self::egglog_expr_to_operand(e, bodies)); + Expr::Call(Identifier::Name(ident.to_string()), args) + } + ("Const", [_const_op, ty, lit]) => Expr::Const( + ConstOps::Const, + Self::egglog_expr_to_ty(ty), + Self::egglog_expr_to_literal(lit), + ), + (binop, [opr1, opr2]) => { + let opr1 = Self::egglog_expr_to_operand(opr1, bodies); + let opr2 = Self::egglog_expr_to_operand(opr2, bodies); + Expr::Op(egglog_op_to_bril(binop.into()), vec![opr1, opr2]) + } + _ => panic!("expect an operand, got {expr}"), + } + } else { + panic!("expect an operand, got {expr}") + } + } + fn egglog_expr_to_ty(ty: &egglog::ast::Expr) -> Type { + use egglog::ast::Expr::*; + if let Call(func, args) = ty { + match (func.as_str(), &args.as_slice()) { + ("IntT", []) => Type::Int, + ("BoolT", []) => Type::Bool, + ("FloatT", []) => Type::Float, + ("CharT", []) => Type::Char, + ("PointerT", [inner]) => Type::Pointer(Box::new(Self::egglog_expr_to_ty(inner))), + _ => panic!("expect a list, got {ty}"), + } + } else { + panic!("expect a list, got {ty}") + } + } + + fn egglog_expr_to_literal(lit: &egglog::ast::Expr) -> Literal { + use egglog::ast::{Expr::*, Literal::*}; + if let Call(func, args) = lit { + match (func.as_str(), &args.as_slice()) { + ("Num", [Lit(Int(n))]) => Literal::Int(*n), + ("Float", [Lit(F64(n))]) => Literal::Float(f64::from(*n)), + ("Char", [Lit(String(s))]) => { + assert_eq!(s.as_str().len(), 1); + Literal::Char(s.as_str().chars().next().unwrap()) + } + _ => panic!("expect a list, got {lit}"), + } + } else { + panic!("expect a list, got {lit}") + } + } + + pub fn egglog_expr_to_function(func: &egglog::ast::Expr, n_args: usize) -> RvsdgFunction { + let mut nodes = vec![]; + let result = Self::egglog_expr_to_operand(func, &mut nodes); + let result = Some(result); + RvsdgFunction { + n_args, + nodes, + result, + } + } +} + +fn vec_map(inputs: &egglog::ast::Expr, mut f: impl FnMut(&egglog::ast::Expr) -> T) -> Vec { + use egglog::ast::Expr::*; + let mut inputs: &egglog::ast::Expr = inputs; + let mut results = vec![]; + if let Call(func, args) = inputs { + if func.as_str() == "vec-of" { + return args.iter().map(&mut f).collect(); + } + } + loop { + if let Call(func, args) = inputs { + match (func.as_str(), &args.as_slice()) { + ("vec-push", [head, tail]) => { + results.push(f(head)); + inputs = tail; + } + ("vec-empty", []) => { + break; + } + _ => panic!("expect a list, got {inputs}"), + } + } else { + panic!("expect a list, got {inputs}") + } + } + results.reverse(); + results +} + +pub fn new_rvsdg_egraph() -> EGraph { + let mut egraph = EGraph::default(); + let schema = std::fs::read_to_string("src/rvsdg/schema.egg").unwrap(); + egraph.parse_and_run_program(schema.as_str()).unwrap(); + egraph +} diff --git a/src/rvsdg/rvsdg2svg.rs b/src/rvsdg/rvsdg2svg.rs index 969cf0215..f81a3dd29 100644 --- a/src/rvsdg/rvsdg2svg.rs +++ b/src/rvsdg/rvsdg2svg.rs @@ -3,8 +3,6 @@ use std::iter::once; use bril_rs::ConstOps; -use crate::cfg::Identifier; - use super::{Expr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram}; const SIMPLE_NODE_SIZE: f32 = 100.0; @@ -514,17 +512,9 @@ fn mk_node_and_input_edges(index: Id, nodes: &[RvsdgBody]) -> (Node, Vec) RvsdgBody::PureOp(Expr::Op(f, xs)) => { (Node::Unit(format!("{f}"), xs.len(), 1), xs.to_vec()) } - RvsdgBody::PureOp(Expr::Call(f, xs)) => ( - Node::Unit( - match f { - Identifier::Name(s) => (**s).to_owned(), - Identifier::Num(x) => format!("{x}"), - }, - xs.len(), - 1, - ), - xs.to_vec(), - ), + RvsdgBody::PureOp(Expr::Call(f, xs)) => { + (Node::Unit(f.to_string(), xs.len(), 1), xs.to_vec()) + } RvsdgBody::PureOp(Expr::Const(ConstOps::Const, _, v)) => { (Node::Unit(format!("{v}"), 0, 1), vec![]) } diff --git a/src/rvsdg/schema.egg b/src/rvsdg/schema.egg new file mode 100644 index 000000000..9d1504fec --- /dev/null +++ b/src/rvsdg/schema.egg @@ -0,0 +1,76 @@ +(datatype Literal) +(datatype Expr) +(datatype Operand) +(datatype Body) + +(sort VecOperand (Vec Operand)) +(datatype VecOperandWrapper + (VO VecOperand)) +(sort VecVecOperand (Vec VecOperandWrapper)) + +;; Type +(datatype Type + (IntT) + (BoolT) + (FloatT) + (CharT) + (PointerT Type)) + +;; Literal +(function Num (i64) Literal) +(function Float (f64) Literal) +(function Char (String) Literal) + +;; Expr +(datatype ConstOps (const)) +(function Const (ConstOps Type Literal) Expr) +(function Call (String VecOperand) Expr) +(function add (Operand Operand) Expr) +(function sub (Operand Operand) Expr) +(function mul (Operand Operand) Expr) +(function div (Operand Operand) Expr) +(function eq (Operand Operand) Expr) +(function lt (Operand Operand) Expr) +(function gt (Operand Operand) Expr) +(function le (Operand Operand) Expr) +(function ge (Operand Operand) Expr) +(function not (Operand Operand) Expr) +(function and (Operand Operand) Expr) +(function or (Operand Operand) Expr) + +;; Operand +(function Arg (i64) Operand) +(function Node (Body) Operand) +(function Project (i64 Body) Operand) + +;; Body +(function PureOp (Expr) Body) +(function Gamma (Operand VecOperand VecVecOperand) Body) ;; branching +(function Theta (Operand VecOperand VecOperand) Body) ;; loop + + +;; procedure f(n): +;; i = 0 +;; while i < n +;; ans += i * 5 +;; i += 1 +;; return ans + +;; ;; inputs: [n] +;; (Project 1 +;; (Theta +;; (Lt (Arg 1) (Arg 2)) ;; pred +;; (vec-of ;; inputs +;; (Node (PureOp (Const 0))) ;; accumulator +;; (Node (PureOp (Const 0))) ;; loop var +;; (Arg 0) ;; n +;; ) +;; (vec-of ;; outputs +;; (Node (PureOp (Add (Arg 0) +;; (Node (PureOp (Mul +;; (Arg 1) +;; (Node (PureOp (Const 5))))))))) +;; (Node (PureOp (Add (Arg 1) (Node (PureOp (Const 1)))))) +;; (Arg 2) +;; )) +;; ) \ No newline at end of file diff --git a/src/rvsdg/tests.rs b/src/rvsdg/tests.rs index ebf0ec39f..643bced7f 100644 --- a/src/rvsdg/tests.rs +++ b/src/rvsdg/tests.rs @@ -2,7 +2,7 @@ use bril_rs::{ConstOps, Literal, Type, ValueOps}; use crate::{ cfg::to_cfg, - rvsdg::{from_cfg::cfg_func_to_rvsdg, Expr, Id, Operand, RvsdgBody}, + rvsdg::{from_cfg::cfg_func_to_rvsdg, new_rvsdg_egraph, Expr, Id, Operand, RvsdgBody}, util::parse_from_string, }; @@ -174,7 +174,7 @@ fn rvsdg_unstructured() { #[test] fn rvsdg_basic_odd_branch() { // Bril program summing the numbers from 1 to n, multiplying by 2 if that - // value is larger is larger than 5. This gives us a theta node and a gamma + // value is larger than 5. This gives us a theta node and a gamma // node, with the gamma requiring branch restructuring. const PROGRAM: &str = r#" @main(n: int): int { @@ -197,6 +197,7 @@ fn rvsdg_basic_odd_branch() { ret res; }"#; + // construct expected program let mut expected = RvsdgTest::default(); let zero = expected.lit_int(0); let one = expected.lit_int(1); @@ -225,13 +226,53 @@ fn rvsdg_basic_odd_branch() { let pred = expected.lt(res, five); let mul2 = expected.mul(Operand::Arg(0), two); let gamma = expected.gamma(pred, &[res], &[&[Operand::Arg(0)], &[mul2]]); + let expected = expected.into_function(1, Operand::Project(0, gamma)); + + // test correctness of RVSDGs converted from CFG let prog = parse_from_string(PROGRAM); let mut cfg = to_cfg(&prog.functions[0]); - let got = cfg_func_to_rvsdg(&mut cfg).unwrap(); - assert!(deep_equal( - &expected.into_function(1, Operand::Project(0, gamma)), - &got - )); + let actual = cfg_func_to_rvsdg(&mut cfg).unwrap(); + assert!(deep_equal(&expected, &actual)); + + // test equalties of egglog programs generated by RVSDG + let actual: egglog::ast::Expr = actual.to_egglog_expr(); + let actual_command = + egglog::ast::Command::Action(egglog::ast::Action::Let("actual".into(), actual.clone())); + const EGGLOG_PROGRAM: &str = r#" +(let expected (Project 0 + (Gamma + (Node (PureOp (lt + (Project 0 + (Theta + (Node (PureOp (lt (Node (PureOp (add (Arg 1) (Node (PureOp (Const (const) (IntT) (Num 1))))))) + (Arg 2)))) + (vec-of (Node (PureOp (Const (const) (IntT) (Num 0)))) + (Node (PureOp (Const (const) (IntT) (Num 0)))) + (Arg 0)) + (vec-of (Node (PureOp (add (Arg 0) (Arg 1)))) + (Node (PureOp (add (Arg 1) (Node (PureOp (Const (const) (IntT) (Num 1))))))) + (Arg 2)))) + (Node (PureOp (Const (const) (IntT) (Num 5))))))) + (vec-of + (Project 0 + (Theta (Node (PureOp (lt (Node (PureOp (add (Arg 1) (Node (PureOp (Const (const) (IntT) (Num 1))))))) (Arg 2)))) + (vec-of (Node (PureOp (Const (const) (IntT) (Num 0)))) (Node (PureOp (Const (const) (IntT) (Num 0)))) (Arg 0)) + (vec-of (Node (PureOp (add (Arg 0) (Arg 1)))) (Node (PureOp (add (Arg 1) (Node (PureOp (Const (const) (IntT) (Num 1))))))) (Arg 2))))) + (vec-of (VO (vec-of (Arg 0))) + (VO (vec-of (Node (PureOp (mul (Arg 0) (Node (PureOp (Const (const) (IntT) (Num 2)))))))))))))"#; + let mut egraph = new_rvsdg_egraph(); + egraph.parse_and_run_program(EGGLOG_PROGRAM).unwrap(); + // this is weird; shouldn't stop be an optional argument + egraph + .process_commands(vec![actual_command], egglog::CompilerPassStop::All) + .unwrap(); + egraph + .parse_and_run_program("(check (= expected actual))") + .unwrap(); + + // test correctness of RVSDG from egglog + let actual = RvsdgFunction::egglog_expr_to_function(&actual, 1); + assert!(deep_equal(&expected, &actual)); } /// We don't want to commit to the order in which nodes are laid out, so we do a