Skip to content

Commit

Permalink
Merge pull request #644 from egraphs-good/yihongzhang-cse
Browse files Browse the repository at this point in the history
Eliminate common subexpressions in our queries
  • Loading branch information
yihozhang authored Oct 30, 2024
2 parents 6527a14 + ed7a15d commit fb4db93
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
7 changes: 3 additions & 4 deletions dag_in_context/src/optimizations/loop_unroll.egg
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
;; TODO: we could make it work for decrementing loops
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
Expand All @@ -50,18 +49,18 @@
;; and i is updated after checking pred
(rule
((= lhs (DoWhile inputs outputs))
(= num-inputs (tuple-length inputs))
(= pred (Get outputs 0))
;; iteration counter starts at start_const
(= (Const (Int start_const) _ty1 _ctx1) (Get inputs counter_i))
(= body-arg (Get (Arg _ty _ctx) counter_i))
;; updated counter at counter_i
(= next_counter (Get outputs (+ counter_i 1)))
;; increments by a constant each loop
(= next_counter (Bop (Add) (Get (Arg _ty _ctx) counter_i)
(= next_counter (Bop (Add) body-arg
(Const (Int increment) _ty2 _ctx2)))
(> increment 0)
;; while this counter less than end_constant
(= pred (Bop (LessThan) (Get (Arg _ty _ctx) counter_i)
(= pred (Bop (LessThan) body-arg
(Const (Int end_constant) _ty3 _ctx3)))
;; end constant is at least start constant
(>= end_constant start_const)
Expand Down
10 changes: 6 additions & 4 deletions dag_in_context/src/optimizations/memory.egg
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,18 @@
:ruleset memory-helpers)

; Compute and propagate PointsToCells
(rewrite (PointsToCells (Concat x y) aps)
(rewrite (PointsToCells concat-x-y aps)
(TuplePointsTo (Concat-List<PtrPointees>
(UnwrapTuplePointsTo (PointsToCells x aps))
(UnwrapTuplePointsTo (PointsToCells y aps))))
:when ((HasType (Concat x y) ty) (PointerishType ty))
:when ((= concat-x-y (Concat x y))
(HasType concat-x-y ty) (PointerishType ty))
:ruleset memory-helpers)

(rewrite (PointsToCells (Get x i) aps)
(rewrite (PointsToCells get-x-i aps)
(GetPointees (PointsToCells x aps) i)
:when ((HasType (Get x i) ty) (PointerishType ty))
:when ((= get-x-i (Get x i))
(HasType get-x-i ty) (PointerishType ty))
:ruleset memory-helpers)

(rewrite (PointsToCells (Single x) aps)
Expand Down
10 changes: 6 additions & 4 deletions dag_in_context/src/optimizations/passthrough.egg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
(= (Get pred-outputs (+ i 1)) (Get (Arg _ty _ctx) i))
;; only pass through pure types, since some loops don't terminate
;; so the state edge must pass through them
(HasType (Get loop i) lhs_ty)
(HasType lhs lhs_ty)
(PureType lhs_ty)
)
((union lhs (Get inputs i)))
Expand Down Expand Up @@ -40,9 +40,11 @@

;; Pass through if arguments
(rule ((= if (If pred inputs then_ else_))
(= (Get then_ i) (Get (Arg arg_ty _then_ctx) j))
(= (Get else_ i) (Get (Arg arg_ty _else_ctx) j))
(HasType (Get then_ i) lhs_ty)
(= then-branch (Get then_ i))
(= else-branch (Get else_ i))
(= then-branch (Get (Arg arg_ty _then_ctx) j))
(= else-branch (Get (Arg arg_ty _else_ctx) j))
(HasType then-branch lhs_ty)
(!= lhs_ty (Base (StateT))))
((union (Get if i) (Get inputs j)))
:ruleset passthrough)
Expand Down

0 comments on commit fb4db93

Please sign in to comment.