From 6aa96bdd2d2e24a638ed13b06ca8bec6fdcb027e Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Fri, 11 Oct 2024 15:39:13 -0400 Subject: [PATCH] fix --- src/exo/LoopIR_pprint.py | 36 ++++++++++--------- ...st_block_replace_forwarding_for_blocks.txt | 2 +- .../test_cursor_pretty_print_blocks.txt | 6 ++-- .../test_cursor_pretty_print_gaps.txt | 10 +++--- .../test_insert_forwarding_for_blocks.txt | 4 +-- .../test_move_forwarding_for_blocks.txt | 6 ++-- .../test_wrap_forwarding_for_blocks.txt | 4 +-- 7 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index 50aa18043..b40762511 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -628,20 +628,24 @@ def _print_cursor_proc( def _print_cursor_block( cur: Block, target: Cursor, env: PrintEnv, indent: str ) -> list[str]: - def while_cursor(c, move, k): + def while_next(c): s = [] while True: try: - c = move(c) - s.extend(k(c)) + c = c.next() + s.extend(local_stmt(c)) except: return s - def if_cursor(c, move, k): - try: - return k(move(c)) - except InvalidCursorError: - return [] + def while_prev(c): + s = [] + while True: + try: + c = c.prev() + s.append(local_stmt(c)) + except: + s.reverse() + return [x for xs in s for x in xs] def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) @@ -649,18 +653,18 @@ def local_stmt(c): if isinstance(target, Gap) and target in cur: if target._type == GapType.Before: return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), f"{indent}[GAP - Before]", - *if_cursor(target, lambda g: g.anchor(), local_stmt), - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *local_stmt(target.anchor()), + *while_next(target.anchor()), ] else: assert target._type == GapType.After return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), - *if_cursor(target, lambda g: g.anchor(), local_stmt), + *while_prev(target.anchor()), + *local_stmt(target.anchor()), f"{indent}[GAP - After]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *while_next(target.anchor()), ] elif isinstance(target, Block) and target in cur: @@ -669,9 +673,9 @@ def local_stmt(c): block.extend(local_stmt(stmt)) block.append(f"{indent}# BLOCK END") return [ - *while_cursor(target[0], lambda g: g.prev(), local_stmt), + *while_prev(target[0]), *block, - *while_cursor(target[-1], lambda g: g.next(), local_stmt), + *while_next(target[-1]), ] else: diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index 034193e1d..7667657dd 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -35,9 +35,9 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM pass pass - x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index 08f8cac13..cacb83a11 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -41,10 +41,10 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 3.0 - x = 2.0 - x = 1.0 x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index 31c69fc62..bca928ff4 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -50,8 +50,8 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 1.0 x = 0.0 + x = 1.0 [GAP - Before] x = 2.0 x = 3.0 @@ -62,10 +62,10 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 4.0 - x = 3.0 - x = 2.0 - x = 1.0 x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index c76411e8d..df3f44091 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -41,9 +41,9 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): - pass - x = 0.0 x: f32 @ DRAM + x = 0.0 + pass # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index c7311b674..9db487954 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -38,10 +38,10 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): - x = 0.0 - y = 1.1 - y: f32 @ DRAM x: f32 @ DRAM + y: f32 @ DRAM + y = 1.1 + x = 0.0 # BLOCK START for k in seq(0, n): pass diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index fea2ab41c..161b23e34 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -41,11 +41,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 for k in seq(0, 8): y: f32 @ DRAM y = 1.1 - x = 0.0 - x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass