diff --git a/src/exo/rewrite/new_eff.py b/src/exo/rewrite/new_eff.py index cf8c5d73..df3a82c5 100644 --- a/src/exo/rewrite/new_eff.py +++ b/src/exo/rewrite/new_eff.py @@ -1100,27 +1100,6 @@ def change(x_old, x_new): # Extraction of Effects from programs -def window_effs(e): - eff_access = [] - syms = {} - for i, w in enumerate(e.idx): - if isinstance(w, LoopIR.Interval): - syms[i] = Sym(f"EXO_EFFECTS_WINDOW_TEMP_INDEX_{i}") - eff_access.append(lift_e(syms[i])) - else: - eff_access.append(lift_e(w.pt)) - - eff = [E.Read(e.name, [idx for idx in eff_access])] - - for i, w in enumerate(e.idx): - if isinstance(w, LoopIR.Interval): - sym = syms[i] - bds = AAnd(lift_e(w.lo) <= AInt(sym), AInt(sym) < lift_e(w.hi)) - eff = E.Loop(syms[i][E.Guard(bds, eff)]) - - return eff - - def expr_effs(e): if isinstance(e, LoopIR.Read): if e.type.is_numeric(): @@ -1620,10 +1599,10 @@ def Check_ReorderStmts(proc, s1, s2): slv.push() slv.assume(AMay(p)) - a1 = stmts_effs([s1]) - a2 = stmts_effs([s2]) + a1 = G(stmts_effs([s1])) + a2 = G(stmts_effs([s2])) - pred = G(AAnd(Commutes(a1, a2), AllocCommutes(a1, a2))) + pred = AAnd(Commutes(a1, a2), AllocCommutes(a1, a2)) is_ok = slv.verify(pred) slv.pop() if not is_ok: @@ -1662,8 +1641,8 @@ def Check_ReorderLoops(proc, s): + expr_effs(y_loop.lo) + expr_effs(y_loop.hi) ) - a = stmts_effs(body) - a2 = stmts_effs(body2) + a = G(stmts_effs(body)) + a2 = G(stmts_effs(body2)) def bds(x, lo, hi): return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi)) @@ -1694,8 +1673,7 @@ def bds(x, lo, hi): ), ) - pred = G(reorder_is_safe) - is_ok = slv.verify(pred) + is_ok = slv.verify(reorder_is_safe) slv.pop() if not is_ok: raise SchedulingError(f"Loops {x} and {y} at {s.srcinfo} cannot be reordered.") @@ -1729,8 +1707,8 @@ def Check_ParallelizeLoop(proc, s): body2 = SubstArgs(body, subenv).result() a_bd = expr_effs(s.lo) + expr_effs(s.hi) - a = stmts_effs(body) - a2 = stmts_effs(body2) + a = G(stmts_effs(body)) + a2 = G(stmts_effs(body2)) def bds(x, lo, hi): return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi)) @@ -1747,7 +1725,7 @@ def bds(x, lo, hi): ), ) - pred = G(AAnd(no_bound_change, bodies_commute)) + pred = AAnd(no_bound_change, bodies_commute) is_ok = slv.verify(pred) slv.pop() if not is_ok: @@ -1792,9 +1770,9 @@ def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False): # print(Gloop) a_bd = expr_effs(lo) + expr_effs(hi) - a1 = stmts_effs(stmts1) - a1_j = stmts_effs(stmts1_j) - a2 = stmts_effs(stmts2) + a1 = G(stmts_effs(stmts1)) + a1_j = G(stmts_effs(stmts1_j)) + a2 = G(stmts_effs(stmts2)) def bds(x, lo, hi): return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi)) @@ -1817,7 +1795,7 @@ def bds(x, lo, hi): ), ) - pred = filter_reals(G(AAnd(no_bound_change, stmts_commute)), chgG) + pred = filter_reals(AAnd(no_bound_change, stmts_commute), chgG) # pred = G(AAnd(no_bound_change, stmts_commute)) is_ok = slv.verify(pred) slv.pop() diff --git a/tests/test_schedules.py b/tests/test_schedules.py index f8be4728..a7a8bfa5 100644 --- a/tests/test_schedules.py +++ b/tests/test_schedules.py @@ -4716,3 +4716,27 @@ def bar(n: size, A: i8[n]): bar = autolift_alloc(bar, "tmp_a : _", keep_dims=True) assert str(bar) == golden + + +def test_fission_window1(): + @proc + def foo(t: f32[3]): + tw = t[:] + x: f32[3] + for i in seq(0, 2): + t[i] = 1.0 + x[i] = tw[i + 1] + + with pytest.raises(SchedulingError, match="Cannot fission loop"): + fission(foo, foo.find("t[_] = 1.0").after()) + + +def test_reorder_stmts_window1(): + @proc + def foo(t: f32[3]): + tw = t[:] + t[0] = 1.0 + tw[0] = 3.0 + + with pytest.raises(SchedulingError, match="do not commute"): + reorder_stmts(foo, foo.find("t[_] = 1.0").expand(0, 1))