Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix various small bugs in new_eff that were failing to handle window #761

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 13 additions & 35 deletions src/exo/rewrite/new_eff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,27 +1100,6 @@ def change(x_old, x_new):
# Extraction of Effects from programs


def window_effs(e):
eff_access = []
syms = {}
for i, w in enumerate(e.idx):
if isinstance(w, LoopIR.Interval):
syms[i] = Sym(f"EXO_EFFECTS_WINDOW_TEMP_INDEX_{i}")
eff_access.append(lift_e(syms[i]))
else:
eff_access.append(lift_e(w.pt))

eff = [E.Read(e.name, [idx for idx in eff_access])]

for i, w in enumerate(e.idx):
if isinstance(w, LoopIR.Interval):
sym = syms[i]
bds = AAnd(lift_e(w.lo) <= AInt(sym), AInt(sym) < lift_e(w.hi))
eff = E.Loop(syms[i][E.Guard(bds, eff)])

return eff


def expr_effs(e):
if isinstance(e, LoopIR.Read):
if e.type.is_numeric():
Expand Down Expand Up @@ -1620,10 +1599,10 @@ def Check_ReorderStmts(proc, s1, s2):
slv.push()
slv.assume(AMay(p))

a1 = stmts_effs([s1])
a2 = stmts_effs([s2])
a1 = G(stmts_effs([s1]))
a2 = G(stmts_effs([s2]))

pred = G(AAnd(Commutes(a1, a2), AllocCommutes(a1, a2)))
pred = AAnd(Commutes(a1, a2), AllocCommutes(a1, a2))
is_ok = slv.verify(pred)
slv.pop()
if not is_ok:
Expand Down Expand Up @@ -1662,8 +1641,8 @@ def Check_ReorderLoops(proc, s):
+ expr_effs(y_loop.lo)
+ expr_effs(y_loop.hi)
)
a = stmts_effs(body)
a2 = stmts_effs(body2)
a = G(stmts_effs(body))
a2 = G(stmts_effs(body2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand Down Expand Up @@ -1694,8 +1673,7 @@ def bds(x, lo, hi):
),
)

pred = G(reorder_is_safe)
is_ok = slv.verify(pred)
is_ok = slv.verify(reorder_is_safe)
slv.pop()
if not is_ok:
raise SchedulingError(f"Loops {x} and {y} at {s.srcinfo} cannot be reordered.")
Expand Down Expand Up @@ -1729,8 +1707,8 @@ def Check_ParallelizeLoop(proc, s):
body2 = SubstArgs(body, subenv).result()

a_bd = expr_effs(s.lo) + expr_effs(s.hi)
a = stmts_effs(body)
a2 = stmts_effs(body2)
a = G(stmts_effs(body))
a2 = G(stmts_effs(body2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand All @@ -1747,7 +1725,7 @@ def bds(x, lo, hi):
),
)

pred = G(AAnd(no_bound_change, bodies_commute))
pred = AAnd(no_bound_change, bodies_commute)
is_ok = slv.verify(pred)
slv.pop()
if not is_ok:
Expand Down Expand Up @@ -1792,9 +1770,9 @@ def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False):
# print(Gloop)

a_bd = expr_effs(lo) + expr_effs(hi)
a1 = stmts_effs(stmts1)
a1_j = stmts_effs(stmts1_j)
a2 = stmts_effs(stmts2)
a1 = G(stmts_effs(stmts1))
a1_j = G(stmts_effs(stmts1_j))
a2 = G(stmts_effs(stmts2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand All @@ -1817,7 +1795,7 @@ def bds(x, lo, hi):
),
)

pred = filter_reals(G(AAnd(no_bound_change, stmts_commute)), chgG)
pred = filter_reals(AAnd(no_bound_change, stmts_commute), chgG)
# pred = G(AAnd(no_bound_change, stmts_commute))
is_ok = slv.verify(pred)
slv.pop()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4716,3 +4716,27 @@ def bar(n: size, A: i8[n]):

bar = autolift_alloc(bar, "tmp_a : _", keep_dims=True)
assert str(bar) == golden


def test_fission_window1():
@proc
def foo(t: f32[3]):
tw = t[:]
x: f32[3]
for i in seq(0, 2):
t[i] = 1.0
x[i] = tw[i + 1]

with pytest.raises(SchedulingError, match="Cannot fission loop"):
fission(foo, foo.find("t[_] = 1.0").after())


def test_reorder_stmts_window1():
@proc
def foo(t: f32[3]):
tw = t[:]
t[0] = 1.0
tw[0] = 3.0

with pytest.raises(SchedulingError, match="do not commute"):
reorder_stmts(foo, foo.find("t[_] = 1.0").expand(0, 1))
Loading