Skip to content

Commit

Permalink
simplify unsatisfiable branches
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Feb 8, 2025
1 parent 885b447 commit b389181
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 568 deletions.
33 changes: 28 additions & 5 deletions src/exo/rewrite/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,10 @@ def map_tree(tree: D.node):
return tree

elif isinstance(tree, D.AffineSplit):
# When assumptions are unsatisfiable, this is a cell without integer points. We Bottom such cases.
if not slv.satisfy((A.Const(True, T.bool, null_srcinfo()))):
return D.Leaf(D.SubVal(V.Bot()))

# we can collapse the tree when all values are the same
if (
isinstance(tree.ltz, D.Leaf)
Expand All @@ -1082,15 +1086,32 @@ def map_tree(tree: D.node):
return tree.ltz

pred = lift_to_smt_a(tree.ae)

# If ltz branch is unsatisfiable and the values for eqz and gtz branches are equivalent, we can collapse this node.
if (
isinstance(tree.eqz, D.Leaf)
and isinstance(tree.gtz, D.Leaf)
and (type(tree.eqz) == type(tree.gtz))
and (tree.eqz.v == tree.gtz.v)
):
if not slv.satisfy(mk_aexpr("<", pred)):
return tree.eqz

# If gtz branch is unsatisfiable and the values for eqz and ltz branches are equivalent, we can collapse this node.
if (
isinstance(tree.eqz, D.Leaf)
and isinstance(tree.ltz, D.Leaf)
and (type(tree.eqz) == type(tree.ltz))
and (tree.eqz.v == tree.ltz.v)
):
if not slv.satisfy(mk_aexpr(">", pred)):
return tree.eqz

# check if anything is simplifiable
ltz_eq = mk_aexpr("<", pred)
eqz_eq = mk_aexpr("==", pred)
gtz_eq = mk_aexpr(">", pred)

# When assumptions are unsatisfiable, this is a cell without integer points. We can Bottom such cases.
if not slv.satisfy((A.Const(True, T.bool, null_srcinfo()))):
return D.Leaf(D.SubVal(V.Bot()))

if slv.verify(eqz_eq):
return map_tree(tree.eqz)
elif slv.verify(gtz_eq):
Expand Down Expand Up @@ -1781,7 +1802,9 @@ def widening(a1: D.abs, a2: D.abs) -> D.abs:
reconstructed_tree = dict_tree_to_node(dict_tree)

print("\nReconstructed Abstract Domain Tree:")
a = abs_simplify(abs_simplify(D.abs(a2.iterators, reconstructed_tree)))
a = abs_simplify(
abs_simplify(abs_simplify(D.abs(a2.iterators, reconstructed_tree)))
)
print(a)

return a
Expand Down
88 changes: 17 additions & 71 deletions tests/golden/test_dataflow/test_reverse_x_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def foo(x : R[11]):
- 3.0
- x[d0]
- 3.0
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- ((10-(i-1))-d0)
- (2*d0-10)
Expand All @@ -46,31 +43,22 @@ def foo(x : R[11]):
- 3.0
- 3.0
- x[d0]
- (d0-10)
- 3.0
- 3.0
- ⊥
- 3.0
- ⊥
- 1.0
- ((10-(i-1))-d0)
- (2*d0-9)
- 3.0
- 3.0
- (2*d0-10)
- 3.0
- 1.0
- 1.0
- 1.0
- 3.0
- ⊥
- 3.0
- (i-0)
- ⊥
- x[d0]
- ((i-1)-d0)
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- 1.0
x_2 : \i. \d0
Expand All @@ -90,19 +78,13 @@ def foo(x : R[11]):
- 3.0
- 3.0
- x[d0]
- (d0-10)
- 3.0
- 3.0
- ⊥
- 3.0
- ⊥
- 3.0
- (i-0)
- ⊥
- x[d0]
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- ((10-i)-d0)
- ((i-1)-d0)
Expand All @@ -112,17 +94,11 @@ def foo(x : R[11]):
- (2*d0-9)
- 3.0
- 3.0
- (2*d0-10)
- 3.0
- 1.0
- 1.0
- 1.0
- 3.0
- ⊥
- 3.0
- ((i-1)-d0)
- ⊥
- 1.0
- 1.0
- 1.0
x_4 =\d0 \phi(10 > 0 ? x_2[10 - 1, d0] : x[d0]) # LoopExit
------------------------ x_1 : \i. \d0
- (i-0)
Expand All @@ -138,10 +114,7 @@ def foo(x : R[11]):
- 3.0
- x[d0]
- 3.0
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- ((10-(i-1))-d0)
- (2*d0-10)
Expand All @@ -167,31 +140,22 @@ def foo(x : R[11]):
- 3.0
- 3.0
- x[d0]
- (d0-10)
- 3.0
- 3.0
- ⊥
- 3.0
- ⊥
- 1.0
- ((10-(i-1))-d0)
- (2*d0-9)
- 3.0
- 3.0
- (2*d0-10)
- 3.0
- 1.0
- 1.0
- 1.0
- 3.0
- ⊥
- 3.0
- (i-0)
- ⊥
- x[d0]
- ((i-1)-d0)
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- 1.0
x_2 : \i. \d0
Expand All @@ -211,19 +175,10 @@ def foo(x : R[11]):
- 3.0
- 3.0
- x[d0]
- (d0-10)
- 3.0
- 3.0
- ⊥
- 3.0
- ⊥
- 3.0
- (i-0)
- ⊥
- x[d0]
- ((i-1)-0)
- ⊥
- x[d0]
- x[d0]
- x[d0]
- 1.0
- ((10-i)-d0)
- ((i-1)-d0)
Expand All @@ -233,17 +188,11 @@ def foo(x : R[11]):
- (2*d0-9)
- 3.0
- 3.0
- (2*d0-10)
- 3.0
- 1.0
- 1.0
- 1.0
- 3.0
- ⊥
- 3.0
- ((i-1)-d0)
- ⊥
- 1.0
- 1.0
- 1.0
x_4 : \d0
- ((10-1)-d0)
- (d0-10)
Expand All @@ -259,10 +208,7 @@ def foo(x : R[11]):
- (2*d0-9)
- 3.0
- 3.0
- (2*d0-10)
- 3.0
- 1.0
- 1.0
- 1.0
- 3.0
- ⊥
- 3.0
Expand Down
Loading

0 comments on commit b389181

Please sign in to comment.