Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of terms ruleset greatly #723

Merged
merged 9 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions dag_in_context/src/optimizations/select.egg
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
(ruleset select_opt)


;; inlined (Get thn i) makes the query faster ):
(rule
(
(= if_e (If pred inputs thn els))
(ContextOf if_e ctx)

(= thn_out (Get thn i))
(= els_out (Get els i))
(ExprIsPure thn_out)
(ExprIsPure els_out)
(ExprIsPure (Get thn i))
(ExprIsPure (Get els i))

(> 10 (Expr-size thn_out)) ; TODO: Tune these size limits
(> 10 (Expr-size els_out))
(= (TCPair t1 c1) (ExtractedExpr thn_out))
(= (TCPair t2 c2) (ExtractedExpr els_out))
(> 10 (Expr-size (Get thn i))) ; TODO: Tune these size limits
(> 10 (Expr-size (Get els i)))
(= (TCPair t1 c1) (ExtractedExpr (Get thn i)))
(= (TCPair t2 c2) (ExtractedExpr (Get els i)))

(ContextOf if_e ctx)
)
(
(union (Get if_e i)
Expand Down
6 changes: 5 additions & 1 deletion dag_in_context/src/schedule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ pub(crate) fn helpers() -> String {

(saturate canon)
(saturate interval-analysis)
(saturate terms)
(saturate
terms
(saturate
terms-helpers
(saturate terms-helpers-helpers)))
;; memory-helpers TODO run memory helpers for memory optimizations

;; finally, subsume now that helpers are done
Expand Down
24 changes: 24 additions & 0 deletions dag_in_context/src/type_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,30 @@
((panic "(Load) expected pointer, received tuple"))
:ruleset error-checking)

(rule (
(= lhs (Top (Select) pred v1 v2))
)
((ExpectType pred (Base (BoolT)) "(Select)"))
:ruleset type-analysis)

(rule (
(= lhs (Top (Select) pred v1 v2))
(HasType v1 ty)
(HasType v2 ty)
)
((HasType lhs ty))
:ruleset type-analysis)

(rule (
(= lhs (Top (Select) pred v1 v2))
(HasType v1 ty1)
(HasType v2 ty2)
(!= ty1 ty2)
)
((panic "(Select) branches had different types"))
:ruleset error-checking)


; Binary ops

;; Operators that have type Type -> Type -> Type
Expand Down
61 changes: 49 additions & 12 deletions dag_in_context/src/utility/terms.egg
Original file line number Diff line number Diff line change
@@ -1,43 +1,80 @@
(ruleset terms)
;; helpers keeps track of the new best extracted terms
(ruleset terms-helpers)
;; helpers-helpers runs `Smaller` rules, resolving the merge function for helpers
(ruleset terms-helpers-helpers)
oflatt marked this conversation as resolved.
Show resolved Hide resolved

(sort TermAndCost)
(function Smaller (TermAndCost TermAndCost) TermAndCost)

(function ExtractedExpr (Expr) TermAndCost
:merge (Smaller old new))
;; potential extractions- use so that when the costs are equal, we don't change the term
;; this preserves egglog's timestamp of when the last time ExtractedExpr was changed, fixing a big performance problem
(relation PotentialExtractedExpr (Expr TermAndCost))

(function TCPair (Term i64) TermAndCost)

(function NoTerm () Term)

;; set extracted expr to default value
(rule ((PotentialExtractedExpr expr termandcost))
((set (ExtractedExpr expr) (TCPair (NoTerm) 10000000000000000)))
:ruleset terms-helpers)

;; set extracted expr to new value as long as not equal
(rule ((PotentialExtractedExpr expr (TCPair term cost))
(= (ExtractedExpr expr) (TCPair oldterm oldcost))
(< cost oldcost))
((set (ExtractedExpr expr) (TCPair term cost)))
:ruleset terms-helpers)

;; if the cost is negative panic, terms got too big
(rule ((PotentialExtractedExpr expr (TCPair term cost))
(< cost 0))
((panic "Negative cost"))
:ruleset terms-helpers)

;; Resolve Smaller
(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(<= cost1 cost2)
(< cost1 cost2)
)
((union lhs (TCPair t1 cost1)))
:ruleset terms)
:ruleset terms-helpers-helpers)

(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(> cost1 cost2)
)
((union lhs (TCPair t2 cost2)))
:ruleset terms)
:ruleset terms-helpers-helpers)


(rule (
(= lhs (Smaller (TCPair t1 cost1) (TCPair t2 cost2)))
(= cost1 cost2)
)
;; arbitrarily pick first one
((union lhs (TCPair t1 cost1)))
:ruleset terms-helpers-helpers)


; Compute smallest Expr bottom-up
(rule ((= lhs (Const c ty ass)))
((set (ExtractedExpr lhs) (TCPair (TermConst c) 1)))
((PotentialExtractedExpr lhs (TCPair (TermConst c) 1)))
:ruleset terms)

(rule ((= lhs (Arg ty ass)))
((set (ExtractedExpr lhs) (TCPair (TermArg) 1)))
((PotentialExtractedExpr lhs (TCPair (TermArg) 1)))
:ruleset terms)

(rule (
(= lhs (Bop o e1 e2))
(= (TCPair t1 c1) (ExtractedExpr e1))
(= (TCPair t2 c2) (ExtractedExpr e2))
)
((set (ExtractedExpr lhs) (TCPair (TermBop o t1 t2) (+ 1 (+ c1 c2)))))
((PotentialExtractedExpr lhs (TCPair (TermBop o t1 t2) (+ 1 (+ c1 c2)))))
:ruleset terms)

(rule (
Expand All @@ -46,22 +83,22 @@
(= (TCPair t2 c2) (ExtractedExpr e2))
(= (TCPair t3 c3) (ExtractedExpr e3))
)
((set (ExtractedExpr lhs) (TCPair (TermTop o t1 t2 t3) (+ (+ 1 c1) (+ c2 c3)))))
((PotentialExtractedExpr lhs (TCPair (TermTop o t1 t2 t3) (+ (+ 1 c1) (+ c2 c3)))))
:ruleset terms)

(rule (
(= lhs (Uop o e1))
(= (TCPair t1 c1) (ExtractedExpr e1))
)
((set (ExtractedExpr lhs) (TCPair (TermUop o t1) (+ 1 c1))))
((PotentialExtractedExpr lhs (TCPair (TermUop o t1) (+ 1 c1))))
:ruleset terms)

(rule (
(= lhs (Get tup i))
(= (TCPair t1 c1) (ExtractedExpr tup))
)
; cost of the get is the same as the cost of the whole tuple
((set (ExtractedExpr lhs) (TCPair (TermGet t1 i) c1)))
((PotentialExtractedExpr lhs (TCPair (TermGet t1 i) c1)))
:ruleset terms)

; todo Alloc
Expand All @@ -73,7 +110,7 @@
(= (TCPair t1 c1) (ExtractedExpr e1))
)
; cost of single is same as cost of the element
((set (ExtractedExpr lhs) (TCPair (TermSingle t1) c1)))
((PotentialExtractedExpr lhs (TCPair (TermSingle t1) c1)))
:ruleset terms)

(rule (
Expand All @@ -82,7 +119,7 @@
(= (TCPair t2 c2) (ExtractedExpr e2))
)
; cost of concat is sum of the costs
((set (ExtractedExpr lhs) (TCPair (TermConcat t1 t2) (+ c1 c2))))
((PotentialExtractedExpr lhs (TCPair (TermConcat t1 t2) (+ c1 c2))))
:ruleset terms)


Expand All @@ -95,7 +132,7 @@
; (= (TCPair t4 c4) (ExtractedExpr els))
; )
; ; cost of if is 10 + cost of pred + cost of input + max of branch costs
; ((set (ExtractedExpr lhs) (TCPair (TermIf t1 t2 t3 t4) (+ 10 (+ (+ c1 c2) (max c3 c4))))))
; ((PotentialExtractedExpr lhs (TCPair (TermIf t1 t2 t3 t4) (+ 10 (+ (+ c1 c2) (max c3 c4))))))
; :ruleset terms)

(sort Node)
Expand Down
20 changes: 20 additions & 0 deletions tests/passing/small/simple_select_after_block_diamond.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ARGS: 1
@main(v0: int) {
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
c4_: int = const 4;
v5_: int = select v3_ c4_ c1_;
v6_: int = id v5_;
v7_: int = id c1_;
br v3_ .b8_ .b9_;
.b9_:
v10_: int = add c2_ v5_;
v6_: int = id v10_;
v7_: int = id c1_;
.b8_:
v11_: int = add c1_ v6_;
print v11_;
ret;
}

32 changes: 16 additions & 16 deletions tests/snapshots/files__block-diamond-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@ expression: visualization.result
---
# ARGS: 1
@main(v0: int) {
c1_: int = const 2;
v2_: bool = lt v0 c1_;
c3_: int = const 1;
v4_: int = id c3_;
v5_: int = id c3_;
v6_: int = id c1_;
br v2_ .b7_ .b8_;
c1_: int = const 1;
c2_: int = const 2;
v3_: bool = lt v0 c2_;
v4_: int = id c1_;
v5_: int = id c1_;
v6_: int = id c2_;
br v3_ .b7_ .b8_;
.b7_:
c9_: bool = const true;
c10_: int = const 4;
v11_: int = select c9_ c10_ c1_;
v11_: int = select c9_ c10_ c2_;
v4_: int = id v11_;
v5_: int = id c3_;
v6_: int = id c1_;
v12_: int = add c1_ v4_;
v13_: int = select v2_ v4_ v12_;
v14_: int = add c3_ v13_;
v5_: int = id c1_;
v6_: int = id c2_;
v12_: int = add c2_ v4_;
v13_: int = select v3_ v4_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
jmp .b15_;
.b8_:
v12_: int = add c1_ v4_;
v13_: int = select v2_ v4_ v12_;
v14_: int = add c3_ v13_;
v12_: int = add c2_ v4_;
v13_: int = select v3_ v4_ v12_;
v14_: int = add c1_ v13_;
print v14_;
ret;
.b15_:
Expand Down
45 changes: 25 additions & 20 deletions tests/snapshots/files__branch_hoisting-optimize-sequential.snap
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,33 @@ expression: visualization.result
@main(v0: int) {
c1_: int = const 0;
c2_: int = const 500;
v3_: int = id c1_;
v3_: bool = eq c1_ v0;
v4_: int = id c1_;
v5_: int = id v0;
v6_: int = id c1_;
v7_: int = id c2_;
.b8_:
v9_: bool = eq v5_ v6_;
c10_: int = const 2;
v11_: int = mul c10_ v4_;
c12_: int = const 3;
v13_: int = mul c12_ v4_;
v14_: int = select v9_ v11_ v13_;
c15_: int = const 1;
v16_: int = add c15_ v4_;
v17_: bool = lt v16_ v7_;
v3_: int = id v14_;
v4_: int = id v16_;
v5_: int = id v5_;
v5_: int = id c1_;
v6_: int = id v0;
v7_: int = id c1_;
v8_: int = id c2_;
v9_: bool = id v3_;
.b10_:
c11_: int = const 1;
v12_: int = add c11_ v5_;
v13_: int = add c11_ v12_;
v14_: int = add c11_ v13_;
c15_: int = const 2;
v16_: int = mul c15_ v14_;
c17_: int = const 3;
v18_: int = mul c17_ v14_;
v19_: int = select v9_ v16_ v18_;
v20_: int = add c11_ v14_;
v21_: bool = lt v20_ v8_;
v4_: int = id v19_;
v5_: int = id v20_;
v6_: int = id v6_;
v7_: int = id v7_;
br v17_ .b8_ .b18_;
.b18_:
print v3_;
v8_: int = id v8_;
v9_: bool = id v9_;
br v21_ .b10_ .b22_;
.b22_:
print v4_;
ret;
}
16 changes: 8 additions & 8 deletions tests/snapshots/files__branch_hoisting-optimize.snap
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ expression: visualization.result
v13_: int = id c3_;
.b14_:
v15_: bool = eq v11_ v12_;
c16_: int = const 2;
c17_: int = const 1;
v18_: int = add c17_ v10_;
v19_: int = add c17_ v18_;
v20_: int = add c17_ v19_;
v21_: int = mul c16_ v20_;
c16_: int = const 1;
v17_: int = add c16_ v10_;
v18_: int = add c16_ v17_;
v19_: int = add c16_ v18_;
c20_: int = const 2;
v21_: int = mul c20_ v19_;
c22_: int = const 3;
v23_: int = mul c22_ v20_;
v23_: int = mul c22_ v19_;
v24_: int = select v15_ v21_ v23_;
v25_: int = add c17_ v20_;
v25_: int = add c16_ v19_;
v26_: bool = lt v25_ v13_;
v9_: int = id v24_;
v10_: int = id v25_;
Expand Down
Loading