Skip to content

Commit

Permalink
Merge pull request #631 from egraphs-good/ajpal-select
Browse files Browse the repository at this point in the history
Add select rules
  • Loading branch information
ajpal authored Oct 31, 2024
2 parents 774b7f7 + 9037b8b commit 271406f
Show file tree
Hide file tree
Showing 14 changed files with 343 additions and 161 deletions.
74 changes: 74 additions & 0 deletions dag_in_context/src/optimizations/switch_rewrites.egg
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
(ruleset switch_rewrite)
(ruleset always-switch-rewrite)

; if a < b then a else b ~~> (min a b)
(rule (
(= pred (Bop (LessThan) a b))
(= if_e (If pred inputs thn els))
Expand All @@ -15,6 +16,7 @@
((union (Get if_e k) (Bop (Smin) a b)))
:ruleset switch_rewrite)

; if a < b then b else a ~~> (max a b)
(rule (
(= pred (Bop (LessThan) a b))
(= if_e (If pred inputs thn els))
Expand All @@ -29,6 +31,78 @@
((union (Get if_e k) (Bop (Smax) a b)))
:ruleset switch_rewrite)

; if pred then a else b ~~> (select pred a b)
; where a and b are inputs to the region
(rule (
(= if_e (If pred inputs thn els))
(= a (Get inputs i))
(= b (Get inputs j))

; if pred then a else b
(= (Get thn k) (Get (Arg ty (InIf true pred inputs)) i))
(= (Get els k) (Get (Arg ty (InIf false pred inputs)) j))

; If i = j, then the arg is just passed through the if, and we
; don't need a select. This will get handled by the passthrough rules.
(!= i j)
)
(
(union (Get if_e k) (Top (Select) pred a b))
)
:ruleset switch_rewrite)

(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)
(= (Get thn i) (Const x _ty (InIf true pred inputs)))
(= (Get els i) (Const y _ty (InIf false pred inputs)))
)
((union (Get if_e i) (Top (Select) pred (Const x ty ctx) (Const y ty ctx))))
:ruleset switch_rewrite)

; if pred then A else Const -> select pred A Const
; where A is an input to the region
(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)

; input to the if
(= a (Get inputs i))
(= (Get thn k) (Get (Arg _ty (InIf true pred inputs)) i))

(= els_out (Get els k))
(= (IntB y) (lo-bound els_out))
(= (IntB y) (hi-bound els_out))
)
(
(union (Get if_e k) (Top (Select) pred a (Const (Int y) ty ctx)))
)
:ruleset switch_rewrite
)

; if pred then Const else B -> select pred Const B
; where B is an input to the region
(rule (
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)
(HasArgType if_e ty)

(= thn_out (Get thn k))
(= (IntB y) (lo-bound thn_out))
(= (IntB y) (hi-bound thn_out))

; input to the if
(= b (Get inputs i))
(= (Get els k) (Get (Arg _ty (InIf false pred inputs)) i))
)
(
(union (Get if_e k) (Top (Select) pred (Const (Int y) ty ctx) b))
)
:ruleset switch_rewrite
)

; if (a and b) X Y ~~> if a (if b X Y) Y
(rule ((= lhs (If (Bop (And) a b) ins X Y))
(HasType ins (TupleT ins_ty))
Expand Down
15 changes: 15 additions & 0 deletions tests/passing/small/select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// ARGS: 20
fn main(x: i64) {
let ten: i64 = 10;
let done: i64 = ten;
let i: i64 = 0;
let res: i64 = 0;
while !(done == 5) {
i += 1;
res += i;
if i == x {
done = 5;
}
}
println!("{}", res);
}
10 changes: 10 additions & 0 deletions tests/passing/small/select_simple.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// ARGS: 20 30
fn main(x: i64, y: i64) {
let res: i64 = 0;
if (x * y < 20) {
res = x;
} else {
res = y;
}
println!("{}", res);
}
37 changes: 14 additions & 23 deletions tests/snapshots/files__block-diamond-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,22 @@ expression: visualization.result
c2_: int = const 1;
c3_: int = const 2;
v4_: bool = lt v0 c3_;
c5_: int = const 0;
c6_: int = const 5;
v7_: int = id c2_;
c5_: int = const 4;
v6_: int = select v4_ c5_ c2_;
v7_: int = id v6_;
v8_: int = id c2_;
v9_: int = id c3_;
br v4_ .b10_ .b11_;
br v4_ .b9_ .b10_;
.b9_:
v11_: int = add c2_ v7_;
print v11_;
ret;
jmp .b12_;
.b10_:
c12_: int = const 4;
v7_: int = id c12_;
v13_: int = add c3_ v6_;
v7_: int = id v13_;
v8_: int = id c2_;
v9_: int = id c3_;
v13_: int = id v7_;
v14_: int = id v8_;
v15_: int = add c2_ v13_;
print v15_;
ret;
jmp .b16_;
.b11_:
v13_: int = id v7_;
v14_: int = id v8_;
v17_: int = add v7_ v9_;
v13_: int = id v17_;
v14_: int = id v8_;
v15_: int = add c2_ v13_;
print v15_;
v11_: int = add c2_ v7_;
print v11_;
ret;
.b16_:
.b12_:
}
35 changes: 5 additions & 30 deletions tests/snapshots/files__collatz_redundant_computation-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -32,48 +32,23 @@ expression: visualization.result
v25_: bool = eq v11_ v24_;
v26_: int = mul v7_ v9_;
v27_: int = add v26_ v8_;
c28_: bool = const true;
v29_: int = id v6_;
v30_: bool = id c28_;
v31_: int = id v27_;
v32_: int = id v8_;
v33_: int = id v9_;
v34_: int = id v10_;
v35_: int = id v11_;
br v25_ .b36_ .b37_;
.b36_:
c38_: bool = const true;
v29_: int = id v6_;
v30_: bool = id c38_;
v31_: int = id v22_;
v32_: int = id v8_;
v33_: int = id v9_;
v34_: int = id v10_;
v35_: int = id v11_;
v28_: int = select v25_ v22_ v27_;
v14_: int = id v6_;
v15_: int = id v31_;
v15_: int = id v28_;
v16_: int = id v8_;
v17_: int = id v9_;
v18_: int = id v10_;
v19_: int = id v11_;
.b20_:
v39_: bool = not v13_;
v29_: bool = not v13_;
v6_: int = id v14_;
v7_: int = id v15_;
v8_: int = id v16_;
v9_: int = id v17_;
v10_: int = id v18_;
v11_: int = id v19_;
br v39_ .b12_ .b40_;
.b37_:
v14_: int = id v6_;
v15_: int = id v31_;
v16_: int = id v8_;
v17_: int = id v9_;
v18_: int = id v10_;
v19_: int = id v11_;
jmp .b20_;
.b40_:
br v29_ .b12_ .b30_;
.b30_:
print v0;
ret;
}
19 changes: 5 additions & 14 deletions tests/snapshots/files__gamma_pull_in-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,10 @@ expression: visualization.result
.b1_:
c2_: int = const 10;
v3_: bool = lt v0 c2_;
c4_: int = const 3;
v5_: int = id c4_;
br v3_ .b6_ .b7_;
.b6_:
c8_: int = const 2;
v5_: int = id c8_;
v9_: int = add v5_ v5_;
print v9_;
c4_: int = const 2;
c5_: int = const 3;
v6_: int = select v3_ c4_ c5_;
v7_: int = add v6_ v6_;
print v7_;
ret;
jmp .b10_;
.b7_:
v9_: int = add v5_ v5_;
print v9_;
ret;
.b10_:
}
14 changes: 2 additions & 12 deletions tests/snapshots/files__if_dead_code-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,9 @@ expression: visualization.result
.b1_:
c2_: int = const 0;
v3_: bool = lt v0 c2_;
c4_: int = const 0;
v5_: int = id c4_;
br v3_ .b6_ .b7_;
.b6_:
c8_: int = const 1;
v5_: int = id c8_;
c4_: int = const 1;
v5_: int = select v3_ c4_ c2_;
print v5_;
print v3_;
ret;
jmp .b9_;
.b7_:
print v5_;
print v3_;
ret;
.b9_:
}
86 changes: 48 additions & 38 deletions tests/snapshots/files__if_dead_code_nested-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,76 @@ expression: visualization.result
v14_: int = id c10_;
v17_: int = id v12_;
v18_: int = id c10_;
print v18_;
c19_: int = const 1;
v20_: int = select v3_ c19_ c2_;
print v20_;
print v3_;
print v17_;
ret;
jmp .b19_;
jmp .b21_;
.b16_:
v17_: int = id v12_;
v18_: int = id c10_;
print v18_;
c19_: int = const 1;
v20_: int = select v3_ c19_ c2_;
print v20_;
print v3_;
print v17_;
ret;
jmp .b19_;
jmp .b21_;
.b7_:
v20_: bool = gt v0 c5_;
c21_: bool = const false;
c22_: int = const 2;
v23_: int = id c22_;
v24_: bool = id c21_;
v25_: int = id c2_;
br v20_ .b26_ .b27_;
.b26_:
v28_: bool = gt v0 c4_;
c29_: int = const 4;
v30_: int = id c29_;
v31_: bool = id c21_;
v32_: int = id c2_;
br v28_ .b33_ .b34_;
.b33_:
c35_: int = const 3;
v30_: int = id c35_;
v31_: bool = id c21_;
v32_: int = id c2_;
v23_: int = id v30_;
v24_: bool = id v31_;
v22_: bool = gt v0 c5_;
c23_: bool = const false;
c24_: int = const 2;
v25_: int = id c24_;
v26_: bool = id c23_;
v27_: int = id c2_;
br v22_ .b28_ .b29_;
.b28_:
v30_: bool = gt v0 c4_;
c31_: int = const 4;
v32_: int = id c31_;
v33_: bool = id c23_;
v34_: int = id c2_;
br v30_ .b35_ .b36_;
.b35_:
c37_: int = const 3;
v32_: int = id c37_;
v33_: bool = id c23_;
v34_: int = id c2_;
v25_: int = id v32_;
v17_: int = id v23_;
v26_: bool = id v33_;
v27_: int = id v34_;
v17_: int = id v25_;
v18_: int = id c2_;
print v18_;
c19_: int = const 1;
v20_: int = select v3_ c19_ c2_;
print v20_;
print v3_;
print v17_;
ret;
jmp .b19_;
.b34_:
v23_: int = id v30_;
v24_: bool = id v31_;
jmp .b21_;
.b36_:
v25_: int = id v32_;
v17_: int = id v23_;
v26_: bool = id v33_;
v27_: int = id v34_;
v17_: int = id v25_;
v18_: int = id c2_;
print v18_;
c19_: int = const 1;
v20_: int = select v3_ c19_ c2_;
print v20_;
print v3_;
print v17_;
ret;
jmp .b19_;
.b27_:
v17_: int = id v23_;
jmp .b21_;
.b29_:
v17_: int = id v25_;
v18_: int = id c2_;
print v18_;
c19_: int = const 1;
v20_: int = select v3_ c19_ c2_;
print v20_;
print v3_;
print v17_;
ret;
.b19_:
.b21_:
}
Loading

0 comments on commit 271406f

Please sign in to comment.