Skip to content

Commit

Permalink
Merge pull request #28 from yihozhang/rvsdg-to-egglog
Browse files Browse the repository at this point in the history
RVSDG <-> egglog
  • Loading branch information
oflatt authored Sep 2, 2023
2 parents 2c5f2c7 + 7bba911 commit 99ee831
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/cfg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl FromStr for BlockName {
/// kinds of name.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Identifier {
Name(Box<str>),
Name(String),
Num(usize),
}

Expand Down
12 changes: 6 additions & 6 deletions src/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ impl Optimizer {
dest: dest.clone(),
args: args_vars,
funcs: vec![],
op: self.egglog_op_to_bril(*op),
op: egglog_op_to_bril(*op),
labels: vec![],
pos: None,
op_type: etype,
Expand Down Expand Up @@ -570,11 +570,6 @@ impl Optimizer {
Expr::Call("Assign".into(), vec![self.string_to_expr(dest), expr])
}

pub(crate) fn egglog_op_to_bril(&mut self, op: Symbol) -> ValueOps {
let with_quotes = "\"".to_owned() + &op.to_string() + "\"";
serde_json::from_str(&with_quotes).unwrap()
}

pub(crate) fn effect_op_to_egglog(&mut self, op: EffectOps) -> Symbol {
let opstr = op.to_string();
if opstr == "print" {
Expand Down Expand Up @@ -720,3 +715,8 @@ impl Optimizer {
Ok(())
}
}

pub(crate) fn egglog_op_to_bril(op: Symbol) -> ValueOps {
let with_quotes = "\"".to_owned() + &op.to_string() + "\"";
serde_json::from_str(&with_quotes).unwrap()
}
284 changes: 284 additions & 0 deletions src/rvsdg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ pub(crate) mod rvsdg2svg;
use std::fmt;

use bril_rs::{ConstOps, Literal, Type, ValueOps};
use egglog::EGraph;
use ordered_float::OrderedFloat;
use thiserror::Error;

use crate::{
cfg::{CfgProgram, Identifier},
conversions::egglog_op_to_bril,
EggCCError,
};

Expand Down Expand Up @@ -166,3 +169,284 @@ pub(crate) fn cfg_to_rvsdg(cfg: &CfgProgram) -> std::result::Result<RvsdgProgram
}
Ok(RvsdgProgram { functions })
}

impl RvsdgFunction {
fn expr_to_egglog_expr(&self, expr: &Expr<Operand>) -> egglog::ast::Expr {
use egglog::ast::{Expr::*, Literal::*};
let f = |operands: &Vec<Operand>| {
operands
.iter()
.map(|op| self.operand_to_egglog_expr(op))
.collect()
};
fn from_ty(ty: &Type) -> egglog::ast::Expr {
match ty {
Type::Int => Call("IntT".into(), vec![]),
Type::Bool => Call("BoolT".into(), vec![]),
Type::Float => Call("FloatT".into(), vec![]),
Type::Char => Call("CharT".into(), vec![]),
Type::Pointer(ty) => Call("PointerT".into(), vec![from_ty(ty.as_ref())]),
}
}
match expr {
Expr::Op(op, operands) => Call(op.to_string().into(), f(operands)),
Expr::Call(ident, operands) => Call(ident.to_string().into(), f(operands)),
Expr::Const(ConstOps::Const, ty, lit) => {
let lit = match (ty, lit) {
(Type::Int, Literal::Int(n)) => Call("Num".into(), vec![Lit(Int(*n))]),
(Type::Bool, Literal::Bool(b)) => {
Call("Bool".into(), vec![Lit(Int(*b as i64))])
}
(Type::Float, Literal::Float(f)) => Call(
"Float".into(),
vec![Lit(F64(OrderedFloat::<f64>::from(*f)))],
),
(Type::Char, Literal::Char(c)) => {
Call("Char".into(), vec![Lit(String(c.to_string().into()))])
}
(Type::Pointer(ty), Literal::Int(p)) => {
Call("Ptr".into(), vec![from_ty(ty.as_ref()), Lit(Int(*p))])
}
_ => panic!("type mismatch"),
};
Call(
"Const".into(),
vec![Call("const".into(), vec![]), from_ty(ty), lit],
)
}
}
}

fn body_to_egglog_expr(&self, body: &RvsdgBody) -> egglog::ast::Expr {
use egglog::ast::Expr::*;
match body {
RvsdgBody::PureOp(expr) => Call("PureOp".into(), vec![self.expr_to_egglog_expr(expr)]),
RvsdgBody::Gamma {
pred,
inputs,
outputs,
} => {
let pred = self.operand_to_egglog_expr(pred);
let inputs = inputs
.iter()
.map(|input| self.operand_to_egglog_expr(input));
let inputs = Call("vec-of".into(), inputs.collect());
let outputs = outputs.iter().map(|region| {
let region = region
.iter()
.map(|output| self.operand_to_egglog_expr(output));
Call("VO".into(), vec![Call("vec-of".into(), region.collect())])
});
let outputs = Call("vec-of".into(), outputs.collect());
Call("Gamma".into(), vec![pred, inputs, outputs])
}
RvsdgBody::Theta {
pred,
inputs,
outputs,
} => {
let pred = self.operand_to_egglog_expr(pred);
let inputs = inputs
.iter()
.map(|input| self.operand_to_egglog_expr(input));
let inputs = Call("vec-of".into(), inputs.collect());
let outputs = outputs
.iter()
.map(|output| self.operand_to_egglog_expr(output));
let outputs = Call("vec-of".into(), outputs.collect());
Call("Theta".into(), vec![pred, inputs, outputs])
}
}
}

fn operand_to_egglog_expr(&self, op: &Operand) -> egglog::ast::Expr {
use egglog::ast::{Expr::*, Literal::*};
match op {
Operand::Arg(p) => Call("Arg".into(), vec![Lit(Int(i64::try_from(*p).unwrap()))]),
Operand::Id(id) => Call(
"Node".into(),
vec![self.body_to_egglog_expr(&self.nodes[*id])],
),
Operand::Project(i, id) => {
let body = self.body_to_egglog_expr(&self.nodes[*id]);
Call(
"Project".into(),
vec![Lit(Int(i64::try_from(*i).unwrap())), body],
)
}
}
}

pub fn to_egglog_expr(&self) -> egglog::ast::Expr {
// There might be multiple results in the future,
// e.g., one for return value and one for effect
if let Some(result) = &self.result {
self.operand_to_egglog_expr(result)
} else {
panic!("A function with no output is a noop")
}
}

fn egglog_expr_to_operand(op: &egglog::ast::Expr, bodies: &mut Vec<RvsdgBody>) -> Operand {
use egglog::ast::{Expr::*, Literal::*};
if let Call(func, args) = op {
match (func.as_str(), &args.as_slice()) {
("Arg", [Lit(Int(n))]) => Operand::Arg(*n as usize),
("Node", [body]) => Operand::Id(Self::egglog_expr_to_body(body, bodies)),
("Project", [Lit(Int(n)), body]) => {
Operand::Project(*n as usize, Self::egglog_expr_to_body(body, bodies))
}
_ => panic!("expect an operand, got {op}"),
}
} else {
panic!("expect an operand, got {op}")
}
}

fn egglog_expr_to_body(body: &egglog::ast::Expr, bodies: &mut Vec<RvsdgBody>) -> Id {
use egglog::ast::Expr::*;
if let Call(func, args) = body {
let body = match (func.as_str(), &args.as_slice()) {
("PureOp", [expr]) => RvsdgBody::PureOp(Self::egglog_expr_to_expr(expr, bodies)),
("Gamma", [pred, inputs, outputs]) => {
let pred = Self::egglog_expr_to_operand(pred, bodies);
let inputs = vec_map(inputs, |e| Self::egglog_expr_to_operand(e, bodies));
let outputs = vec_map(outputs, |es| {
if let Call(func, args) = es {
assert_eq!(func.as_str(), "VO");
assert_eq!(args.len(), 1);
let es = &args[0];
vec_map(es, |e| Self::egglog_expr_to_operand(e, bodies))
} else {
panic!("expect VecOperandWrapper")
}
});
RvsdgBody::Gamma {
pred,
inputs,
outputs,
}
}
("Theta", [pred, inputs, outputs]) => {
let pred = Self::egglog_expr_to_operand(pred, bodies);
let inputs = vec_map(inputs, |e| Self::egglog_expr_to_operand(e, bodies));
let outputs = vec_map(outputs, |e| Self::egglog_expr_to_operand(e, bodies));
RvsdgBody::Theta {
pred,
inputs,
outputs,
}
}
_ => panic!("expect an operand, got {body}"),
};
bodies.push(body);
bodies.len() - 1
} else {
panic!("expect an operand, got {body}")
}
}

fn egglog_expr_to_expr(expr: &egglog::ast::Expr, bodies: &mut Vec<RvsdgBody>) -> Expr<Operand> {
use egglog::ast::Literal;
if let egglog::ast::Expr::Call(func, args) = expr {
match (func.as_str(), &args.as_slice()) {
("Call", [egglog::ast::Expr::Lit(Literal::String(ident)), args]) => {
let args = vec_map(args, |e| Self::egglog_expr_to_operand(e, bodies));
Expr::Call(Identifier::Name(ident.to_string()), args)
}
("Const", [_const_op, ty, lit]) => Expr::Const(
ConstOps::Const,
Self::egglog_expr_to_ty(ty),
Self::egglog_expr_to_literal(lit),
),
(binop, [opr1, opr2]) => {
let opr1 = Self::egglog_expr_to_operand(opr1, bodies);
let opr2 = Self::egglog_expr_to_operand(opr2, bodies);
Expr::Op(egglog_op_to_bril(binop.into()), vec![opr1, opr2])
}
_ => panic!("expect an operand, got {expr}"),
}
} else {
panic!("expect an operand, got {expr}")
}
}
fn egglog_expr_to_ty(ty: &egglog::ast::Expr) -> Type {
use egglog::ast::Expr::*;
if let Call(func, args) = ty {
match (func.as_str(), &args.as_slice()) {
("IntT", []) => Type::Int,
("BoolT", []) => Type::Bool,
("FloatT", []) => Type::Float,
("CharT", []) => Type::Char,
("PointerT", [inner]) => Type::Pointer(Box::new(Self::egglog_expr_to_ty(inner))),
_ => panic!("expect a list, got {ty}"),
}
} else {
panic!("expect a list, got {ty}")
}
}

fn egglog_expr_to_literal(lit: &egglog::ast::Expr) -> Literal {
use egglog::ast::{Expr::*, Literal::*};
if let Call(func, args) = lit {
match (func.as_str(), &args.as_slice()) {
("Num", [Lit(Int(n))]) => Literal::Int(*n),
("Float", [Lit(F64(n))]) => Literal::Float(f64::from(*n)),
("Char", [Lit(String(s))]) => {
assert_eq!(s.as_str().len(), 1);
Literal::Char(s.as_str().chars().next().unwrap())
}
_ => panic!("expect a list, got {lit}"),
}
} else {
panic!("expect a list, got {lit}")
}
}

pub fn egglog_expr_to_function(func: &egglog::ast::Expr, n_args: usize) -> RvsdgFunction {
let mut nodes = vec![];
let result = Self::egglog_expr_to_operand(func, &mut nodes);
let result = Some(result);
RvsdgFunction {
n_args,
nodes,
result,
}
}
}

fn vec_map<T>(inputs: &egglog::ast::Expr, mut f: impl FnMut(&egglog::ast::Expr) -> T) -> Vec<T> {
use egglog::ast::Expr::*;
let mut inputs: &egglog::ast::Expr = inputs;
let mut results = vec![];
if let Call(func, args) = inputs {
if func.as_str() == "vec-of" {
return args.iter().map(&mut f).collect();
}
}
loop {
if let Call(func, args) = inputs {
match (func.as_str(), &args.as_slice()) {
("vec-push", [head, tail]) => {
results.push(f(head));
inputs = tail;
}
("vec-empty", []) => {
break;
}
_ => panic!("expect a list, got {inputs}"),
}
} else {
panic!("expect a list, got {inputs}")
}
}
results.reverse();
results
}

pub fn new_rvsdg_egraph() -> EGraph {
let mut egraph = EGraph::default();
let schema = std::fs::read_to_string("src/rvsdg/schema.egg").unwrap();
egraph.parse_and_run_program(schema.as_str()).unwrap();
egraph
}
16 changes: 3 additions & 13 deletions src/rvsdg/rvsdg2svg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use std::iter::once;

use bril_rs::ConstOps;

use crate::cfg::Identifier;

use super::{Expr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram};

const SIMPLE_NODE_SIZE: f32 = 100.0;
Expand Down Expand Up @@ -514,17 +512,9 @@ fn mk_node_and_input_edges(index: Id, nodes: &[RvsdgBody]) -> (Node, Vec<Edge>)
RvsdgBody::PureOp(Expr::Op(f, xs)) => {
(Node::Unit(format!("{f}"), xs.len(), 1), xs.to_vec())
}
RvsdgBody::PureOp(Expr::Call(f, xs)) => (
Node::Unit(
match f {
Identifier::Name(s) => (**s).to_owned(),
Identifier::Num(x) => format!("{x}"),
},
xs.len(),
1,
),
xs.to_vec(),
),
RvsdgBody::PureOp(Expr::Call(f, xs)) => {
(Node::Unit(f.to_string(), xs.len(), 1), xs.to_vec())
}
RvsdgBody::PureOp(Expr::Const(ConstOps::Const, _, v)) => {
(Node::Unit(format!("{v}"), 0, 1), vec![])
}
Expand Down
Loading

0 comments on commit 99ee831

Please sign in to comment.