Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Oct 11, 2024
1 parent 359e127 commit 6aa96bd
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 32 deletions.
36 changes: 20 additions & 16 deletions src/exo/LoopIR_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,39 +628,43 @@ def _print_cursor_proc(
def _print_cursor_block(
cur: Block, target: Cursor, env: PrintEnv, indent: str
) -> list[str]:
def while_cursor(c, move, k):
def while_next(c):
s = []
while True:
try:
c = move(c)
s.extend(k(c))
c = c.next()
s.extend(local_stmt(c))
except:
return s

def if_cursor(c, move, k):
try:
return k(move(c))
except InvalidCursorError:
return []
def while_prev(c):
s = []
while True:
try:
c = c.prev()
s.append(local_stmt(c))
except:
s.reverse()
return [x for xs in s for x in xs]

def local_stmt(c):
return _print_cursor_stmt(c, target, env, indent)

if isinstance(target, Gap) and target in cur:
if target._type == GapType.Before:
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*while_prev(target.anchor()),
f"{indent}[GAP - Before]",
*if_cursor(target, lambda g: g.anchor(), local_stmt),
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*local_stmt(target.anchor()),
*while_next(target.anchor()),
]
else:
assert target._type == GapType.After
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*if_cursor(target, lambda g: g.anchor(), local_stmt),
*while_prev(target.anchor()),
*local_stmt(target.anchor()),
f"{indent}[GAP - After]",
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*while_next(target.anchor()),
]

elif isinstance(target, Block) and target in cur:
Expand All @@ -669,9 +673,9 @@ def local_stmt(c):
block.extend(local_stmt(stmt))
block.append(f"{indent}# BLOCK END")
return [
*while_cursor(target[0], lambda g: g.prev(), local_stmt),
*while_prev(target[0]),
*block,
*while_cursor(target[-1], lambda g: g.next(), local_stmt),
*while_next(target[-1]),
]

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
pass
pass
x: f32 @ DRAM
# BLOCK START
for k in seq(0, n):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 3.0
x = 2.0
x = 1.0
x = 0.0
x = 1.0
x = 2.0
x = 3.0
# BLOCK START
x = 4.0
x = 5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 1.0
x = 0.0
x = 1.0
[GAP - Before]
x = 2.0
x = 3.0
Expand All @@ -62,10 +62,10 @@ def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 4.0
x = 3.0
x = 2.0
x = 1.0
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0
[GAP - After]
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
pass
x = 0.0
x: f32 @ DRAM
x = 0.0
pass
# BLOCK START
y: f32 @ DRAM
y = 1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
y = 1.1
y: f32 @ DRAM
x: f32 @ DRAM
y: f32 @ DRAM
y = 1.1
x = 0.0
# BLOCK START
for k in seq(0, n):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
x = 0.0
for k in seq(0, 8):
y: f32 @ DRAM
y = 1.1
x = 0.0
x: f32 @ DRAM
# BLOCK START
for k in seq(0, n):
pass
Expand Down

0 comments on commit 6aa96bd

Please sign in to comment.