From cfbdd3cd646b16847e6ac28cdc1561114b9c6cf8 Mon Sep 17 00:00:00 2001 From: David Zhao Akeley Date: Thu, 17 Oct 2024 08:25:41 -0400 Subject: [PATCH] Fix fissioning of if/else (#726) Fix incorrect fission of if/else statements. See test_if_fission for repro. (Old behavior: fissioning the body of an if statement causes the orelse part to be duplicated. Fissioning within the orelse causes an exception) --- src/exo/LoopIR_scheduling.py | 6 ++++-- tests/test_schedules.py | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/exo/LoopIR_scheduling.py b/src/exo/LoopIR_scheduling.py index 0c65580d7..317d0f6ed 100644 --- a/src/exo/LoopIR_scheduling.py +++ b/src/exo/LoopIR_scheduling.py @@ -2410,7 +2410,7 @@ def wrapper(body): if cur_c._node in par_s.body: def wrapper(body): - return par_s.update(body=body) + return par_s.update(body=body, orelse=[]) ir, fwd_wrap = pre_c._wrap(wrapper, "body") fwd = _compose(fwd_wrap, fwd) @@ -2424,7 +2424,9 @@ def wrapper(body): assert cur_c._node in par_s.orelse def wrapper(orelse): - return par_s.update(body=None, orelse=orelse) + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) ir, fwd_wrap = post_c._wrap(wrapper, "orelse") fwd = _compose(fwd_wrap, fwd) diff --git a/tests/test_schedules.py b/tests/test_schedules.py index a0e42ebb3..f8be47285 100644 --- a/tests/test_schedules.py +++ b/tests/test_schedules.py @@ -481,6 +481,48 @@ def foo(): fission(foo, foo.find("x = 0.0").after(), n_lifts=2) +def test_if_fission(): + @proc + def before(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_if(x: size, y: f32): + if x < 10: + y += 1 + if x < 10: + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_else(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + if x < 10: + pass + else: + y += 4 + + test_fission_if = rename(before, "fission_if") + test_fission_if = fission(test_fission_if, test_fission_if.find("y += 1").after()) + assert str(fission_if) == str(test_fission_if) + test_fission_else = rename(before, "fission_else") + test_fission_else = fission( + test_fission_else, test_fission_else.find("y += 3").after() + ) + assert str(fission_else) == str(test_fission_else) + + def test_resize_dim(golden): @proc def foo():