Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Partial window of window support (second-order window) #762

Merged
merged 3 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 105 additions & 6 deletions src/exo/core/LoopIR.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections import ChainMap, defaultdict
from typing import Type
from typing import List, Type

from asdl_adt import ADT, validators

Expand Down Expand Up @@ -118,13 +118,25 @@
| Stride()
| Error()
| Tensor( expr* hi, bool is_window, type type )
-- src - type of the tensor from which the window was created
-- as_tensor - tensor type as if this window were simply a tensor
-- src_type - type of the Tensor from which the window was created
-- as_tensor - tensor type as if this window were simply a tensor
-- itself
-- window - the expression that created this window
-- src_buf - sym for the Tensor from which the window was created
-- idx - the expression that created this window
-- NB: when creating a derived window from another derived window,
-- we must "chain" the two window exprs so that src_type, src_buf
-- still refer to the original Tensor
| WindowType( type src_type, type as_tensor,
sym src_buf, w_access *idx )

-- Dense tensor: Tensor(is_window = False)
-- Window parameter (of proc): Tensor(is_window = True)
-- Derived window (from WindowExpr): WindowType / T.Window

-- First two are both "tensors" although imprecisely sometimes "tensor"
-- refers only to "dense tensor" -- we should be more clear about that.
-- Latter two are both "windows" (allows strides), but have separate
-- types since derived windows (WindowType) requires aliasing reasoning
}""",
ext_types={
"name": validators.instance_of(Identifier, convert=True),
Expand Down Expand Up @@ -477,12 +489,22 @@

@extclass(LoopIR.type)
def is_win(t):
# T.Tensor and t.is_window: window parameter
# T.Window: derived window
return (isinstance(t, T.Tensor) and t.is_window) or isinstance(t, T.Window)


del is_win


@extclass(LoopIR.type)
def is_dense_tensor(t):
return isinstance(t, T.Tensor) and not t.is_window

Check warning on line 502 in src/exo/core/LoopIR.py

View check run for this annotation

Codecov / codecov/patch

src/exo/core/LoopIR.py#L502

Added line #L502 was not covered by tests


del is_dense_tensor


@extclass(LoopIR.type)
def is_numeric(t):
return t.is_real_scalar() or isinstance(t, (T.Tensor, T.Window))
Expand Down Expand Up @@ -525,6 +547,74 @@

del basetype


def chain_window_idx(idx0, idx1):
"""Given

window_0 = tensor[idx0]
window_1 = window_0[idx1]

Return chained_idx such that window_1 = tensor[chained_idx]
"""

def add_e(scalar_0, scalar_1):
if isinstance(scalar_0, LoopIR.Const) and scalar_0.val == 0:
return scalar_1

Check warning on line 562 in src/exo/core/LoopIR.py

View check run for this annotation

Codecov / codecov/patch

src/exo/core/LoopIR.py#L562

Added line #L562 was not covered by tests
if isinstance(scalar_1, LoopIR.Const) and scalar_1.val == 0:
return scalar_0
return LoopIR.BinOp("+", scalar_0, scalar_1, T.index, scalar_1.srcinfo)

assert sum(isinstance(e0, LoopIR.Interval) for e0 in idx0) == len(idx1)
chained_idx = [None] * len(idx0)
i1 = 0
for i0, e0 in enumerate(idx0):
if isinstance(e0, LoopIR.Point):
chained_idx[i0] = e0
else:
assert isinstance(e0, LoopIR.Interval)
e1 = idx1[i1]
i1 += 1
srcinfo = e1.srcinfo # newer srcinfo likely more relevant
if isinstance(e1, LoopIR.Point):
chained_idx[i0] = LoopIR.Point(add_e(e0.lo, e1.pt), srcinfo)
else:
# Note e0.hi unused ... not responsibility here to do
# bounds checking.
chained_idx[i0] = LoopIR.Interval(
add_e(e0.lo, e1.lo), add_e(e0.lo, e1.hi), srcinfo
)
return chained_idx


def build_window_shape(ws: List[LoopIR.w_access]):
def subtract(hi, lo):
if isinstance(lo, LoopIR.Const) and lo.val == 0:
return hi
else:
return LoopIR.BinOp("-", hi, lo, T.index, hi.srcinfo)

return [subtract(w.hi, w.lo) for w in ws if isinstance(w, LoopIR.Interval)]


def create_window_type(in_name: Sym, in_typ: LoopIR.type, idx):
"""Construct a derived window type from any tensor or window type"""
assert isinstance(in_name, Sym)
window_shape = build_window_shape(idx)
as_tensor = T.Tensor(window_shape, True, in_typ.basetype())

if isinstance(in_typ, T.Tensor):
# in_typ is dense tensor or window parameter
w_typ = T.Window(in_typ, as_tensor, in_name, idx)
else:
# in_typ is another derived window
# we need to "inline" through to get the underlying Tensor
assert isinstance(in_typ, T.Window)
chained_idx = chain_window_idx(in_typ.idx, idx)
w_typ = T.Window(in_typ.src_type, as_tensor, in_typ.src_buf, chained_idx)

return w_typ


# --------------------------------------------------------------------------- #
# --------------------------------------------------------------------------- #

Expand Down Expand Up @@ -998,18 +1088,27 @@
class GetWrites(LoopIR_Do):
def __init__(self):
self.writes = []
# Translates access through T.Window to underlying T.Tensor
self.window_dict = {}

def do_s(self, s):
if isinstance(s, (LoopIR.Assign, LoopIR.Reduce)):
self.writes.append((s.name, s.type))
sym = s.name
self.writes.append((self.window_dict.get(sym, sym), s.type))
elif isinstance(s, LoopIR.Call):
writes_in_subproc = [a for a, _ in get_writes_of_stmts(s.f.body)]
for arg, call_arg in zip(s.args, s.f.args):
if call_arg.name in writes_in_subproc:
if isinstance(
arg, (LoopIR.Read, LoopIR.WindowExpr, LoopIR.StrideExpr)
):
self.writes.append((arg.name, arg.type))
sym = arg.name
self.writes.append((self.window_dict.get(sym, sym), arg.type))
elif isinstance(s, LoopIR.WindowStmt):
w_sym, base_sym = s.name, s.rhs.name
while base_sym in self.window_dict:
base_sym = self.window_dict[base_sym]
self.window_dict[w_sym] = base_sym

super().do_s(s)

Expand Down
29 changes: 8 additions & 21 deletions src/exo/frontend/typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LoopIR_Dependencies,
get_writeconfigs,
get_loop_iters,
create_window_type,
)
from ..core.extern import Extern_Typecheck_Error
from ..core.memory import *
Expand Down Expand Up @@ -376,15 +377,6 @@ def check_w_access(self, e, orig_hi):

return LoopIR.Interval(lo, hi, e.srcinfo)

def build_window_shape(self, ws):
def subtract(hi, lo):
if isinstance(lo, LoopIR.Const) and lo.val == 0:
return hi
else:
return LoopIR.BinOp("-", hi, lo, T.index, hi.srcinfo)

return [subtract(w.hi, w.lo) for w in ws if isinstance(w, LoopIR.Interval)]

def check_e(self, e, is_index=False):
if isinstance(e, UAST.Read):
typ = self.env[e.name]
Expand All @@ -408,30 +400,25 @@ def check_e(self, e, is_index=False):
return LoopIR.Read(e.name, idx, typ, e.srcinfo)

elif isinstance(e, UAST.WindowExpr):
typ = self.env[e.name]
if not typ.is_tensor_or_window():
in_typ = self.env[e.name]
if not in_typ.is_tensor_or_window():
self.err(
e,
f"cannot perform windowing on non-tensor, "
f"non-window type {e.base}",
)
return LoopIR.WindowExpr(e.name, [], T.err, e.srcinfo)

shape = typ.shape()
if len(shape) != len(e.idx):
in_shape = in_typ.shape()
if len(in_shape) != len(e.idx):
self.err(
e,
f"expected {len(shape)} indices for window "
f"expected {len(in_shape)} indices for window "
f"but got {len(e.idx)}",
)

idx = [self.check_w_access(w, t) for w, t in zip(e.idx, shape)]

# TODO: Construct as_tensor...
window_shape = self.build_window_shape(idx)
as_tensor = T.Tensor(window_shape, True, typ.type)

w_typ = T.Window(typ, as_tensor, e.name, idx)
idx = [self.check_w_access(w, t) for w, t in zip(e.idx, in_shape)]
w_typ = create_window_type(e.name, in_typ, idx)
return LoopIR.WindowExpr(e.name, idx, w_typ, e.srcinfo)

elif isinstance(e, UAST.Const):
Expand Down
48 changes: 13 additions & 35 deletions src/exo/rewrite/new_eff.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,27 +1100,6 @@ def change(x_old, x_new):
# Extraction of Effects from programs


def window_effs(e):
eff_access = []
syms = {}
for i, w in enumerate(e.idx):
if isinstance(w, LoopIR.Interval):
syms[i] = Sym(f"EXO_EFFECTS_WINDOW_TEMP_INDEX_{i}")
eff_access.append(lift_e(syms[i]))
else:
eff_access.append(lift_e(w.pt))

eff = [E.Read(e.name, [idx for idx in eff_access])]

for i, w in enumerate(e.idx):
if isinstance(w, LoopIR.Interval):
sym = syms[i]
bds = AAnd(lift_e(w.lo) <= AInt(sym), AInt(sym) < lift_e(w.hi))
eff = E.Loop(syms[i][E.Guard(bds, eff)])

return eff


def expr_effs(e):
if isinstance(e, LoopIR.Read):
if e.type.is_numeric():
Expand Down Expand Up @@ -1620,10 +1599,10 @@ def Check_ReorderStmts(proc, s1, s2):
slv.push()
slv.assume(AMay(p))

a1 = stmts_effs([s1])
a2 = stmts_effs([s2])
a1 = G(stmts_effs([s1]))
a2 = G(stmts_effs([s2]))

pred = G(AAnd(Commutes(a1, a2), AllocCommutes(a1, a2)))
pred = AAnd(Commutes(a1, a2), AllocCommutes(a1, a2))
is_ok = slv.verify(pred)
slv.pop()
if not is_ok:
Expand Down Expand Up @@ -1662,8 +1641,8 @@ def Check_ReorderLoops(proc, s):
+ expr_effs(y_loop.lo)
+ expr_effs(y_loop.hi)
)
a = stmts_effs(body)
a2 = stmts_effs(body2)
a = G(stmts_effs(body))
a2 = G(stmts_effs(body2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand Down Expand Up @@ -1694,8 +1673,7 @@ def bds(x, lo, hi):
),
)

pred = G(reorder_is_safe)
is_ok = slv.verify(pred)
is_ok = slv.verify(reorder_is_safe)
slv.pop()
if not is_ok:
raise SchedulingError(f"Loops {x} and {y} at {s.srcinfo} cannot be reordered.")
Expand Down Expand Up @@ -1729,8 +1707,8 @@ def Check_ParallelizeLoop(proc, s):
body2 = SubstArgs(body, subenv).result()

a_bd = expr_effs(s.lo) + expr_effs(s.hi)
a = stmts_effs(body)
a2 = stmts_effs(body2)
a = G(stmts_effs(body))
a2 = G(stmts_effs(body2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand All @@ -1747,7 +1725,7 @@ def bds(x, lo, hi):
),
)

pred = G(AAnd(no_bound_change, bodies_commute))
pred = AAnd(no_bound_change, bodies_commute)
is_ok = slv.verify(pred)
slv.pop()
if not is_ok:
Expand Down Expand Up @@ -1792,9 +1770,9 @@ def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False):
# print(Gloop)

a_bd = expr_effs(lo) + expr_effs(hi)
a1 = stmts_effs(stmts1)
a1_j = stmts_effs(stmts1_j)
a2 = stmts_effs(stmts2)
a1 = G(stmts_effs(stmts1))
a1_j = G(stmts_effs(stmts1_j))
a2 = G(stmts_effs(stmts2))

def bds(x, lo, hi):
return AAnd(lift_e(lo) <= AInt(x), AInt(x) < lift_e(hi))
Expand All @@ -1817,7 +1795,7 @@ def bds(x, lo, hi):
),
)

pred = filter_reals(G(AAnd(no_bound_change, stmts_commute)), chgG)
pred = filter_reals(AAnd(no_bound_change, stmts_commute), chgG)
# pred = G(AAnd(no_bound_change, stmts_commute))
is_ok = slv.verify(pred)
slv.pop()
Expand Down
27 changes: 27 additions & 0 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,30 @@ def bar(n: size, dst: f32[n], src: f32[n]):

np.testing.assert_almost_equal(dst, expected)
np.testing.assert_almost_equal(src, expected)


def test_window_of_window_codegen(compiler):
@proc
def bar(n: size, dst: f32[n, n, n]):
assert n >= 8
w1 = dst[2:, 3, 1:]
w2 = w1[0:, 4] # dst[2:, 3, 5]
for n in seq(0, 4):
# Set dst[2:6, 3, 5] to 42
w2[n] = 42.0
w3 = w1[1, 4:] # dst[3, 3, 5:]
for n in seq(0, 2):
# Set dst[3, 3, 5:7] to 137
w3[n] = 137.0

fn = compiler.compile(bar)

for n_size in (8, 10):
dst = np.zeros(shape=(n_size, n_size, n_size), dtype=np.float32)
fn(None, n_size, dst)
expected = np.zeros(shape=(n_size, n_size, n_size), dtype=np.float32)
for n in range(0, 4):
expected[2 + n, 3, 5] = 42.0
for n in range(0, 2):
expected[3, 3, 5 + n] = 137.0
np.testing.assert_almost_equal(dst, expected)
24 changes: 24 additions & 0 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4716,3 +4716,27 @@ def bar(n: size, A: i8[n]):

bar = autolift_alloc(bar, "tmp_a : _", keep_dims=True)
assert str(bar) == golden


def test_fission_window1():
@proc
def foo(t: f32[3]):
tw = t[:]
x: f32[3]
for i in seq(0, 2):
t[i] = 1.0
x[i] = tw[i + 1]

with pytest.raises(SchedulingError, match="Cannot fission loop"):
fission(foo, foo.find("t[_] = 1.0").after())


def test_reorder_stmts_window1():
@proc
def foo(t: f32[3]):
tw = t[:]
t[0] = 1.0
tw[0] = 3.0

with pytest.raises(SchedulingError, match="do not commute"):
reorder_stmts(foo, foo.find("t[_] = 1.0").expand(0, 1))
Loading
Loading