From ec862cd1787b7c8a68678296b7f749ac95c7f016 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 17 Apr 2024 01:33:08 -0400 Subject: [PATCH 01/15] Finish working prototype of multistage programming --- src/exo/API.py | 16 +- src/exo/frontend/pyparser.py | 502 +++++++++++++++++++++++++++++------ tests/test_uast.py | 4 +- 3 files changed, 428 insertions(+), 94 deletions(-) diff --git a/src/exo/API.py b/src/exo/API.py index 3a690ca3..d00ff5de 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -21,7 +21,7 @@ # Moved to new file from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc -from .frontend.pyparser import get_ast_from_python, Parser, get_src_locals +from .frontend.pyparser import get_ast_from_python, Parser, get_parent_scope from .frontend.typecheck import TypeChecker from . import API_cursors as C @@ -36,14 +36,13 @@ def proc(f, _instr=None) -> "Procedure": if not isinstance(f, types.FunctionType): raise TypeError("@proc decorator must be applied to a function") - body, getsrcinfo = get_ast_from_python(f) + body, src_info = get_ast_from_python(f) assert isinstance(body, pyast.FunctionDef) parser = Parser( body, - getsrcinfo, - func_globals=f.__globals__, - srclocals=get_src_locals(depth=3 if _instr else 2), + src_info, + parent_scope=get_parent_scope(depth=3 if _instr else 2), instr=_instr, as_func=True, ) @@ -68,14 +67,13 @@ def parse_config(cls): if not inspect.isclass(cls): raise TypeError("@config decorator must be applied to a class") - body, getsrcinfo = get_ast_from_python(cls) + body, src_info = get_ast_from_python(cls) assert isinstance(body, pyast.ClassDef) parser = Parser( body, - getsrcinfo, - func_globals={}, - srclocals=get_src_locals(depth=2), + src_info, + parent_scope=get_parent_scope(depth=2), as_config=True, ) return Config(*parser.result(), not readwrite) diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index b341b42e..e2391975 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -15,6 +15,9 @@ from ..core.prelude import * from ..core.extern import Extern +from typing import Any, Callable, Union, NoReturn, Optional +import copy +from dataclasses import dataclass # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -35,54 +38,106 @@ def str_to_mem(name): return getattr(sys.modules[__name__], name) +@dataclass +class SourceInfo: + src_file: str + src_line_offset: int + src_col_offset: int + + def get_src_info(self, node: pyast.AST): + return SrcInfo( + filename=self.src_file, + lineno=node.lineno + self.src_line_offset, + col_offset=node.col_offset + self.src_col_offset, + end_lineno=( + None + if node.end_lineno is None + else node.end_lineno + self.src_line_offset + ), + end_col_offset=( + None + if node.end_col_offset is None + else node.end_col_offset + self.src_col_offset + ), + ) + + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Top-level decorator -def get_ast_from_python(f): +def get_ast_from_python(f: Callable[..., Any]) -> tuple[pyast.stmt, SourceInfo]: # note that we must dedent in case the function is defined # inside of a local scope - rawsrc = inspect.getsource(f) src = textwrap.dedent(rawsrc) n_dedent = len(re.match("^(.*)", rawsrc).group()) - len( re.match("^(.*)", src).group() ) - srcfilename = inspect.getsourcefile(f) - _, srclineno = inspect.getsourcelines(f) - srclineno -= 1 # adjust for decorator line - - # create way to query for src-code information - def getsrcinfo(node): - return SrcInfo( - filename=srcfilename, - lineno=node.lineno + srclineno, - col_offset=node.col_offset + n_dedent, - end_lineno=( - None if node.end_lineno is None else node.end_lineno + srclineno - ), - end_col_offset=( - None if node.end_col_offset is None else node.end_col_offset + n_dedent - ), - ) # convert into AST nodes; which should be a module with a single node module = pyast.parse(src) assert len(module.body) == 1 - return module.body[0], getsrcinfo + return module.body[0], SourceInfo( + src_file=inspect.getsourcefile(f), + src_line_offset=inspect.getsourcelines(f)[1] - 1, + src_col_offset=n_dedent, + ) + + +@dataclass +class BoundLocal: + val: Any + + +Local = Optional[BoundLocal] + + +@dataclass +class FrameScope: + frame: inspect.frame + def get_globals(self) -> dict[str, Any]: + return self.frame.f_globals -def get_src_locals(*, depth): + def read_locals(self) -> dict[str, Local]: + return { + var: ( + BoundLocal(self.frame.f_locals[var]) + if var in self.frame.f_locals + else None + ) + for var in self.frame.f_code.co_varnames + + self.frame.f_code.co_cellvars + + self.frame.f_code.co_freevars + } + + +@dataclass +class DummyScope: + global_dict: dict[str, Any] + local_dict: dict[str, Any] + + def get_globals(self) -> dict[str, Any]: + return self.global_dict + + def read_locals(self) -> dict[str, Any]: + return self.local_dict.copy() + + +Scope = Union[DummyScope, FrameScope] + + +def get_parent_scope(*, depth) -> Scope: """ Get global and local environments for context capture purposes """ - stack_frames: [inspect.FrameInfo] = inspect.stack() + stack_frames = inspect.stack() assert len(stack_frames) >= depth - func_locals = stack_frames[depth].frame.f_locals - assert isinstance(func_locals, dict) - return ChainMap(func_locals) + frame = stack_frames[depth].frame + return FrameScope(frame) # --------------------------------------------------------------------------- # @@ -105,30 +160,230 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): module = pyast.parse(src) assert isinstance(module, pyast.Module) - # create way to query for src-code information - def getsrcinfo(node): - return SrcInfo( - filename=srcfilename, - lineno=node.lineno + srclineno, - col_offset=node.col_offset + n_dedent, - end_lineno=( - None if node.end_lineno is None else node.end_lineno + srclineno - ), - end_col_offset=( - None if node.end_col_offset is None else node.end_col_offset + n_dedent - ), - ) - parser = Parser( module.body, - getsrcinfo, + SourceInfo( + src_file=srcfilename, src_line_offset=srclineno, src_col_offset=n_dedent + ), + parent_scope=DummyScope({}, {}), # add globals from enclosing scope is_fragment=True, - func_globals=srcglobals, - srclocals=srclocals, ) return parser.result() +class QuoteReplacer(pyast.NodeTransformer): + def __init__( + self, + parser_parent: "Parser", + unquote_env: "UnquoteEnv", + stmt_collector: Optional[list[pyast.stmt]] = None, + ): + self.stmt_collector = stmt_collector + self.unquote_env = unquote_env + self.parser_parent = parser_parent + + def visit_With(self, node: pyast.With) -> pyast.Any: + if ( + len(node.items) == 1 + and isinstance(node.items[0].context_expr, pyast.Name) + and node.items[0].context_expr.id == "quote" + and isinstance(node.items[0].context_expr.ctx, pyast.Load) + ): + assert ( + self.stmt_collector != None + ), "Reached quote block with no buffer to place quoted statements" + + def quote_callback(_f): + self.stmt_collector.extend( + Parser( + node.body, + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_stmt=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + ) + + callback_name = self.unquote_env.register_quote_callback(quote_callback) + return pyast.FunctionDef( + name=self.unquote_env.mangle_name(QUOTE_BLOCK_PLACEHOLDER_PREFIX), + args=pyast.arguments( + posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=node.body, + decorator_list=[pyast.Name(id=callback_name, ctx=pyast.Load())], + ) + else: + return super().generic_visit(node) + + def visit_Call(self, node: pyast.Call) -> Any: + if ( + isinstance(node.func, pyast.Name) + and node.func.id == "quote" + and len(node.keywords) == 0 + and len(node.args) == 1 + ): + + def quote_callback(_e): + return Parser( + node.args[0], + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_expr=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + + callback_name = self.unquote_env.register_quote_callback(quote_callback) + return pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[ + pyast.Lambda( + args=pyast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=node, + ) + ], + keywords=[], + ) + else: + return super().generic_visit(node) + + +QUOTE_CALLBACK_PREFIX = "__quote_callback" +QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" +OUTER_SCOPE_HELPER = "__outer_scope" +NESTED_SCOPE_HELPER = "__nested_scope" +UNQUOTE_RETURN_HELPER = "__unquote_val" + + +@dataclass +class UnquoteEnv: + parent_globals: dict[str, Any] + parent_locals: dict[str, Local] + + def mangle_name(self, prefix: str) -> str: + index = 0 + while True: + mangled_name = f"{prefix}{index}" + if ( + mangled_name not in self.parent_locals + and mangled_name not in self.parent_globals + ): + return mangled_name + index += 1 + + def register_quote_callback(self, quote_callback: Callable[[], None]) -> str: + mangled_name = self.mangle_name(QUOTE_CALLBACK_PREFIX) + self.parent_locals[mangled_name] = BoundLocal(quote_callback) + return mangled_name + + def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: + bound_locals = { + name: val.val for name, val in self.parent_locals.items() if val is not None + } + unbound_names = { + name for name, val in self.parent_locals.items() if val is None + } + exec( + compile( + pyast.fix_missing_locations( + pyast.Module( + body=[ + pyast.FunctionDef( + name=OUTER_SCOPE_HELPER, + args=pyast.arguments( + posonlyargs=[], + args=[ + pyast.arg(arg=arg) for arg in self.parent_locals + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ + *( + [ + pyast.Delete( + targets=[ + pyast.Name( + id=name, + ctx=pyast.Del(), + ) + for name in unbound_names + ] + ) + ] + if len(unbound_names) != 0 + else [] + ), + pyast.FunctionDef( + name=NESTED_SCOPE_HELPER, + args=pyast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=stmts, + decorator_list=[], + ), + pyast.Return( + value=pyast.Call( + func=pyast.Name( + id=NESTED_SCOPE_HELPER, + ctx=pyast.Load(), + ), + args=[], + keywords=[], + ) + ), + ], + decorator_list=[], + ), + pyast.Assign( + targets=[ + pyast.Name( + id=UNQUOTE_RETURN_HELPER, ctx=pyast.Store() + ) + ], + value=pyast.Call( + func=pyast.Name( + id=OUTER_SCOPE_HELPER, + ctx=pyast.Load(), + ), + args=[ + ( + pyast.Constant(value=None) + if val is None + else pyast.Name(id=name, ctx=pyast.Load()) + ) + for name, val in self.parent_locals.items() + ], + keywords=[], + ), + ), + ], + type_ignores=[], + ) + ), + "", + "exec", + ), + self.parent_globals, + bound_locals, + ) + return bound_locals[UNQUOTE_RETURN_HELPER] + + def interpret_quote_expr(self, expr: pyast.expr): + return self.interpret_quote_block([pyast.Return(value=expr)]) + + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Parser Pass object @@ -156,19 +411,20 @@ class Parser: def __init__( self, module_ast, - getsrcinfo, + src_info, + parent_scope=None, is_fragment=False, - func_globals=None, - srclocals=None, as_func=False, as_config=False, instr=None, + is_quote_stmt=False, + is_quote_expr=False, + parent_exo_locals=None, ): - self.module_ast = module_ast - self.globals = func_globals - self.locals = srclocals or ChainMap() - self.getsrcinfo = getsrcinfo + self.parent_scope = parent_scope + self.exo_locals = ChainMap() if parent_exo_locals is None else parent_exo_locals + self.src_info = src_info self.is_fragment = is_fragment self.push() @@ -203,18 +459,25 @@ def __init__( self._cached_result = self.parse_expr(s.value) else: self._cached_result = self.parse_stmt_block(module_ast) + elif is_quote_expr: + self._cached_result = self.parse_expr(module_ast) + elif is_quote_stmt: + self._cached_result = self.parse_stmt_block(module_ast) else: assert False, "parser mode configuration unsupported" self.pop() + def getsrcinfo(self, ast): + return self.src_info.get_src_info(ast) + def result(self): return self._cached_result def push(self): - self.locals = self.locals.new_child() + self.exo_locals = self.exo_locals.new_child() def pop(self): - self.locals = self.locals.parents + self.exo_locals = self.exo_locals.parents # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # # parser helper routines @@ -224,9 +487,13 @@ def err(self, node, errstr, origin=None): def eval_expr(self, expr): assert isinstance(expr, pyast.expr) - code = compile(pyast.Expression(expr), "", "eval") - e_obj = eval(code, self.globals, self.locals) - return e_obj + return UnquoteEnv( + self.parent_scope.get_globals(), + { + **self.parent_scope.read_locals(), + **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, + }, + ).interpret_quote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # # structural parsing rules... @@ -268,10 +535,10 @@ def parse_fdef(self, fdef, instr=None): names.add(a.arg) nm = Sym(a.arg) if isinstance(typ, UAST.Size): - self.locals[a.arg] = SizeStub(nm) + self.exo_locals[a.arg] = SizeStub(nm) else: # note we don't need to stub the index variables - self.locals[a.arg] = nm + self.exo_locals[a.arg] = nm args.append(UAST.fnarg(nm, typ, mem, self.getsrcinfo(a))) # return types are non-sensical for Exo, b/c it models procedures @@ -453,11 +720,8 @@ def parse_num_type(self, node, is_arg=False): typ = _prim_types[node.value.id] is_window = False else: - self.err( - node, - "expected tensor type to be " - "of the form 'R[...]', 'f32[...]', etc.", - ) + typ = self.parse_num_type(node.value) + is_window = False if sys.version_info[:3] >= (3, 9): # unpack single or multi-arg indexing to list of slices/indices @@ -484,8 +748,29 @@ def parse_num_type(self, node, is_arg=False): return typ - elif isinstance(node, pyast.Name) and node.id in _prim_types: - return _prim_types[node.id] + elif isinstance(node, pyast.Name) and node.id in Parser._prim_types: + return Parser._prim_types[node.id] + elif ( + isinstance(node, pyast.Call) + and isinstance(node.func, pyast.Name) + and node.func.id == "unquote" + ): + if len(node.keywords) != 0: + self.err(node, "Unquote must take non-keyword argument") + elif len(node.args) != 1: + self.err(node, "Unquote must take 1 argument") + else: + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), self.parent_scope.read_locals() + ) + quote_replacer = QuoteReplacer(self, unquote_env) + unquoted = unquote_env.interpret_quote_expr( + quote_replacer.visit(copy.deepcopy(node.args[0])) + ) + if isinstance(unquoted, str) and unquoted in Parser._prim_types: + return Parser._prim_types[unquoted] + else: + self.err(node, "Unquote computation did not yield valid type") elif isinstance(node, pyast.Name) and ( _is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node) ): @@ -501,8 +786,29 @@ def parse_stmt_block(self, stmts): rstmts = [] for s in stmts: + if isinstance(s, pyast.With): + if ( + len(s.items) == 1 + and isinstance(s.items[0].context_expr, pyast.Name) + and s.items[0].context_expr.id == "unquote" + and isinstance(s.items[0].context_expr.ctx, pyast.Load) + ): + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), self.parent_scope.read_locals() + ) + quoted_stmts = [] + quote_stmt_replacer = QuoteReplacer(self, unquote_env, quoted_stmts) + unquote_env.interpret_quote_block( + [ + quote_stmt_replacer.visit(copy.deepcopy(python_s)) + for python_s in s.body + ], + ) + rstmts.extend(quoted_stmts) + else: + self.err(s.id, "Expected unquote") # ----- Assginment, Reduction, Var Declaration/Allocation parsing - if isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): + elif isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): # parse the rhs first, if it's present rhs = None if isinstance(s, pyast.AnnAssign): @@ -601,7 +907,7 @@ def parse_stmt_block(self, stmts): # insert any needed Allocs if isinstance(s, pyast.AnnAssign): nm = Sym(name_node.id) - self.locals[name_node.id] = nm + self.exo_locals[name_node.id] = nm typ, mem = self.parse_alloc_typmem(s.annotation) rstmts.append(UAST.Alloc(nm, typ, mem, self.getsrcinfo(s))) @@ -610,10 +916,10 @@ def parse_stmt_block(self, stmts): if ( isinstance(s, pyast.Assign) and len(idxs) == 0 - and name_node.id not in self.locals + and name_node.id not in self.exo_locals ): nm = Sym(name_node.id) - self.locals[name_node.id] = nm + self.exo_locals[name_node.id] = nm do_fresh_assignment = True else: do_fresh_assignment = False @@ -621,9 +927,9 @@ def parse_stmt_block(self, stmts): # get the symbol corresponding to the name on the # left-hand-side if isinstance(s, (pyast.Assign, pyast.AugAssign)): - if name_node.id not in self.locals: + if name_node.id not in self.exo_locals: self.err(name_node, f"variable '{name_node.id}' undefined") - nm = self.locals[name_node.id] + nm = self.exo_locals[name_node.id] if isinstance(nm, SizeStub): self.err( name_node, @@ -660,7 +966,7 @@ def parse_stmt_block(self, stmts): itr = s.target.id else: itr = Sym(s.target.id) - self.locals[s.target.id] = itr + self.exo_locals[s.target.id] = itr cond = self.parse_loop_cond(s.iter) body = self.parse_stmt_block(s.body) @@ -879,11 +1185,18 @@ def parse_expr(self, e): else: return PAST.Read(nm, idxs, self.getsrcinfo(e)) else: - if nm_node.id in self.locals: - nm = self.locals[nm_node.id] - elif nm_node.id in self.globals: - nm = self.globals[nm_node.id] - else: # could not resolve name to anything + parent_globals = self.parent_scope.get_globals() + parent_locals = self.parent_scope.read_locals() + if nm_node.id in self.exo_locals: + nm = self.exo_locals[nm_node.id] + elif ( + nm_node.id in parent_locals + and parent_locals[nm_node.id] is not None + ): + nm = parent_locals[nm_node.id].val + elif nm_node.id in parent_globals: + nm = parent_globals[nm_node.id] + else: self.err(nm_node, f"variable '{nm_node.id}' undefined") if isinstance(nm, SizeStub): @@ -937,11 +1250,15 @@ def parse_expr(self, e): opnm = ( "+" if isinstance(e.op, pyast.UAdd) - else "not" - if isinstance(e.op, pyast.Not) - else "~" - if isinstance(e.op, pyast.Invert) - else "ERROR-BAD-OP-CASE" + else ( + "not" + if isinstance(e.op, pyast.Not) + else ( + "~" + if isinstance(e.op, pyast.Invert) + else "ERROR-BAD-OP-CASE" + ) + ) ) self.err(e, f"unsupported unary operator: {opnm}") @@ -1039,8 +1356,27 @@ def parse_expr(self, e): return res elif isinstance(e, pyast.Call): + if isinstance(e.func, pyast.Name) and e.func.id == "unquote": + if len(e.keywords) != 0: + self.err(e, "Unquote must take non-keyword argument") + elif len(e.args) != 1: + self.err(e, "Unquote must take 1 argument") + else: + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), self.parent_scope.read_locals() + ) + quote_replacer = QuoteReplacer(self, unquote_env) + unquoted = unquote_env.interpret_quote_expr( + quote_replacer.visit(copy.deepcopy(e.args[0])) + ) + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, self.AST.expr): + return unquoted + else: + self.err(e, "Unquote received input that couldn't be unquoted") # handle stride expression - if isinstance(e.func, pyast.Name) and e.func.id == "stride": + elif isinstance(e.func, pyast.Name) and e.func.id == "stride": if ( len(e.keywords) > 0 or len(e.args) != 2 @@ -1067,9 +1403,9 @@ def parse_expr(self, e): dim = int(e.args[1].value) if not self.is_fragment: - if name not in self.locals: + if name not in self.exo_locals: self.err(e.args[0], f"variable '{name}' undefined") - name = self.locals[name] + name = self.exo_locals[name] return self.AST.StrideExpr(name, dim, self.getsrcinfo(e)) diff --git a/tests/test_uast.py b/tests/test_uast.py index 08f771e8..aebc2f95 100644 --- a/tests/test_uast.py +++ b/tests/test_uast.py @@ -5,7 +5,7 @@ from exo import DRAM from exo.frontend.pyparser import ( Parser, - get_src_locals, + get_parent_scope, get_ast_from_python, ParseError, ) @@ -17,7 +17,7 @@ def to_uast(f): body, getsrcinfo, func_globals=f.__globals__, - srclocals=get_src_locals(depth=2), + srclocals=get_parent_scope(depth=2), instr=("TEST", ""), as_func=True, ) From 2119ef529e864eaf2199cb100f83150cfa7c4161 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 17 Apr 2024 14:39:32 -0400 Subject: [PATCH 02/15] Add some tests --- .../test_captured_closure.txt | 25 ++++ .../test_metaprogramming/test_conditional.txt | 25 ++++ .../test_constant_lifting.txt | 16 +++ .../test_scope_nesting.txt | 17 +++ .../test_metaprogramming/test_scoping.txt | 16 +++ .../test_metaprogramming/test_type_params.txt | 39 ++++++ .../test_metaprogramming/test_unrolling.txt | 27 ++++ tests/test_metaprogramming.py | 128 ++++++++++++++++++ 8 files changed, 293 insertions(+) create mode 100644 tests/golden/test_metaprogramming/test_captured_closure.txt create mode 100644 tests/golden/test_metaprogramming/test_conditional.txt create mode 100644 tests/golden/test_metaprogramming/test_constant_lifting.txt create mode 100644 tests/golden/test_metaprogramming/test_scope_nesting.txt create mode 100644 tests/golden/test_metaprogramming/test_scoping.txt create mode 100644 tests/golden/test_metaprogramming/test_type_params.txt create mode 100644 tests/golden/test_metaprogramming/test_unrolling.txt create mode 100644 tests/test_metaprogramming.py diff --git a/tests/golden/test_metaprogramming/test_captured_closure.txt b/tests/golden/test_metaprogramming/test_captured_closure.txt new file mode 100644 index 00000000..af22a740 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_captured_closure.txt @@ -0,0 +1,25 @@ +#include "test.h" + + + +#include +#include + + + +// bar( +// a : i32 @DRAM +// ) +void bar( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 2); +*a += ((int32_t) 3); +*a += ((int32_t) 4); +*a += ((int32_t) 5); +*a += ((int32_t) 6); +*a += ((int32_t) 7); +*a += ((int32_t) 8); +*a += ((int32_t) 9); +*a += ((int32_t) 10); +} + diff --git a/tests/golden/test_metaprogramming/test_conditional.txt b/tests/golden/test_metaprogramming/test_conditional.txt new file mode 100644 index 00000000..d1cac99e --- /dev/null +++ b/tests/golden/test_metaprogramming/test_conditional.txt @@ -0,0 +1,25 @@ +#include "test.h" + + + +#include +#include + + + +// bar1( +// a : i8 @DRAM +// ) +void bar1( void *ctxt, const int8_t* a ) { +int8_t b; +b += ((int8_t) 1); +} + +// bar2( +// a : i8 @DRAM +// ) +void bar2( void *ctxt, const int8_t* a ) { +int8_t b; +b = ((int8_t) 0); +} + diff --git a/tests/golden/test_metaprogramming/test_constant_lifting.txt b/tests/golden/test_metaprogramming/test_constant_lifting.txt new file mode 100644 index 00000000..e98d98e3 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_constant_lifting.txt @@ -0,0 +1,16 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : f64 @DRAM +// ) +void foo( void *ctxt, double* a ) { +*a = 2.0818897486445276; +} + diff --git a/tests/golden/test_metaprogramming/test_scope_nesting.txt b/tests/golden/test_metaprogramming/test_scope_nesting.txt new file mode 100644 index 00000000..54d8c8ba --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_nesting.txt @@ -0,0 +1,17 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i8 @DRAM, +// b : i8 @DRAM +// ) +void foo( void *ctxt, int8_t* a, const int8_t* b ) { +*a = *b; +} + diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt new file mode 100644 index 00000000..91ff8af8 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -0,0 +1,16 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i8 @DRAM +// ) +void foo( void *ctxt, int8_t* a ) { +*a = ((int8_t) 3); +} + diff --git a/tests/golden/test_metaprogramming/test_type_params.txt b/tests/golden/test_metaprogramming/test_type_params.txt new file mode 100644 index 00000000..488510d7 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_type_params.txt @@ -0,0 +1,39 @@ +#include "test.h" + + + +#include +#include + + + +// bar1( +// a : i32 @DRAM, +// b : i8 @DRAM +// ) +void bar1( void *ctxt, int32_t* a, const int8_t* b ) { +int32_t *c = (int32_t*) malloc(4 * sizeof(*c)); +for (int_fast32_t i = 0; i < 3; i++) { + int32_t d; + d = (int32_t)(*b); + c[i + 1] = *a + c[i] * d; +} +*a = c[3]; +free(c); +} + +// bar2( +// a : f64 @DRAM, +// b : f64 @DRAM +// ) +void bar2( void *ctxt, double* a, const double* b ) { +double *c = (double*) malloc(4 * sizeof(*c)); +for (int_fast32_t i = 0; i < 3; i++) { + double d; + d = *b; + c[i + 1] = *a + c[i] * d; +} +*a = c[3]; +free(c); +} + diff --git a/tests/golden/test_metaprogramming/test_unrolling.txt b/tests/golden/test_metaprogramming/test_unrolling.txt new file mode 100644 index 00000000..57f613bb --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unrolling.txt @@ -0,0 +1,27 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i8 @DRAM +// ) +void foo( void *ctxt, const int8_t* a ) { +int8_t b; +b = ((int8_t) 0); +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py new file mode 100644 index 00000000..2c35eef0 --- /dev/null +++ b/tests/test_metaprogramming.py @@ -0,0 +1,128 @@ +from __future__ import annotations +from exo import proc, compile_procs_to_strings +from exo.API_scheduling import rename + + +def test_unrolling(golden): + @proc + def foo(a: i8): + b: i8 + b = 0 + with unquote: + for _ in range(10): + with quote: + b += a + + c_file, _ = compile_procs_to_strings([foo], "test.h") + + assert c_file == golden + + +def test_conditional(golden): + def foo(cond: bool): + @proc + def bar(a: i8): + b: i8 + with unquote: + if cond: + with quote: + b = 0 + else: + with quote: + b += 1 + + return bar + + c_file, _ = compile_procs_to_strings( + [rename(foo(False), "bar1"), rename(foo(True), "bar2")], "test.h" + ) + assert c_file == golden + + +def test_scoping(golden): + a = 3 + + @proc + def foo(a: i8): + a = unquote(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_scope_nesting(golden): + x = 3 + + @proc + def foo(a: i8, b: i8): + with unquote: + y = 2 + with quote: + a = unquote(quote(b) if x == 3 and y == 2 else quote(a)) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_global_scope(): + cell = [0] + + @proc + def foo(a: i8): + a = 0 + with unquote: + with quote: + with unquote: + global dict + cell[0] = dict + dict = None + + assert cell[0] == dict + + +def test_constant_lifting(golden): + x = 1.3 + + @proc + def foo(a: f64): + a = unquote((x**x + x) / x) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_type_params(golden): + def foo(T: str, U: str): + @proc + def bar(a: unquote(T), b: unquote(U)): + c: unquote(T)[4] + for i in seq(0, 3): + d: unquote(T) + d = b + c[i + 1] = a + c[i] * d + a = c[3] + + return bar + + c_file, _ = compile_procs_to_strings( + [rename(foo("i32", "i8"), "bar1"), rename(foo("f64", "f64"), "bar2")], "test.h" + ) + assert c_file == golden + + +def test_captured_closure(golden): + cell = [0] + + def foo(): + cell[0] += 1 + + @proc + def bar(a: i32): + with unquote: + for _ in range(10): + foo() + with quote: + a += unquote(cell[0]) + + c_file, _ = compile_procs_to_strings([bar], "test.h") + assert c_file == golden From 3d07c52f69488f97ba9e5ecf59300e6f2cb92ca7 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 17 Apr 2024 15:22:45 -0400 Subject: [PATCH 03/15] Fix namespace overlap issue --- src/exo/frontend/pyparser.py | 54 +++++++++++-------- .../test_capture_nested_quote.txt | 18 +++++++ tests/test_metaprogramming.py | 14 +++++ 3 files changed, 64 insertions(+), 22 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_capture_nested_quote.txt diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index e2391975..2dcba4da 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -193,7 +193,7 @@ def visit_With(self, node: pyast.With) -> pyast.Any: self.stmt_collector != None ), "Reached quote block with no buffer to place quoted statements" - def quote_callback(_f): + def quote_callback(): self.stmt_collector.extend( Parser( node.body, @@ -205,13 +205,12 @@ def quote_callback(_f): ) callback_name = self.unquote_env.register_quote_callback(quote_callback) - return pyast.FunctionDef( - name=self.unquote_env.mangle_name(QUOTE_BLOCK_PLACEHOLDER_PREFIX), - args=pyast.arguments( - posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[] - ), - body=node.body, - decorator_list=[pyast.Name(id=callback_name, ctx=pyast.Load())], + return pyast.Expr( + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ) ) else: return super().generic_visit(node) @@ -224,7 +223,7 @@ def visit_Call(self, node: pyast.Call) -> Any: and len(node.args) == 1 ): - def quote_callback(_e): + def quote_callback(): return Parser( node.args[0], self.parser_parent.src_info, @@ -236,18 +235,7 @@ def quote_callback(_e): callback_name = self.unquote_env.register_quote_callback(quote_callback) return pyast.Call( func=pyast.Name(id=callback_name, ctx=pyast.Load()), - args=[ - pyast.Lambda( - args=pyast.arguments( - posonlyargs=[], - args=[], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=node, - ) - ], + args=[], keywords=[], ) else: @@ -330,7 +318,29 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: kw_defaults=[], defaults=[], ), - body=stmts, + body=[ + pyast.Expr( + value=pyast.Lambda( + args=pyast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=pyast.Tuple( + elts=[ + pyast.Name( + id=arg, ctx=pyast.Load() + ) + for arg in self.parent_locals + ], + ctx=pyast.Load(), + ), + ) + ), + *stmts, + ], decorator_list=[], ), pyast.Return( diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt new file mode 100644 index 00000000..bcb2ebb0 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -0,0 +1,18 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = ((int32_t) 2); +*a = ((int32_t) 2); +*a = ((int32_t) 2); +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 2c35eef0..ec73e198 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -126,3 +126,17 @@ def bar(a: i32): c_file, _ = compile_procs_to_strings([bar], "test.h") assert c_file == golden + + +def test_capture_nested_quote(golden): + a = 2 + + @proc + def foo(a: i32): + with unquote: + for _ in range(3): + with quote: + a = unquote(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden From 9bef8fbc599e23766b1d730d4c6278cfed239a99 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 15 May 2024 10:39:14 -0400 Subject: [PATCH 04/15] Add better syntax for metaprogramming --- src/exo/frontend/pyparser.py | 366 ++++++++++++++++++++++++---------- tests/test_metaprogramming.py | 44 ++-- 2 files changed, 288 insertions(+), 122 deletions(-) diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 2dcba4da..04f8c47a 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -171,6 +171,24 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): return parser.result() +QUOTE_CALLBACK_PREFIX = "__quote_callback" +QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" +OUTER_SCOPE_HELPER = "__outer_scope" +NESTED_SCOPE_HELPER = "__nested_scope" +UNQUOTE_RETURN_HELPER = "__unquote_val" +UNQUOTE_BLOCK_KEYWORD = "meta" + + +@dataclass +class ExoExpression: + _inner: Any # note: strict typing is not possible as long as PAST/UAST grammar definition is not static + + +@dataclass +class ExoStatementList: + _inner: tuple[Any, ...] + + class QuoteReplacer(pyast.NodeTransformer): def __init__( self, @@ -185,52 +203,72 @@ def __init__( def visit_With(self, node: pyast.With) -> pyast.Any: if ( len(node.items) == 1 - and isinstance(node.items[0].context_expr, pyast.Name) - and node.items[0].context_expr.id == "quote" - and isinstance(node.items[0].context_expr.ctx, pyast.Load) + and isinstance(node.items[0].context_expr, pyast.UnaryOp) + and isinstance(node.items[0].context_expr.op, pyast.Invert) + and isinstance(node.items[0].context_expr.operand, pyast.Name) + and node.items[0].context_expr.operand.id == UNQUOTE_BLOCK_KEYWORD + and isinstance(node.items[0].context_expr.operand.ctx, pyast.Load) + and ( + isinstance(node.items[0].optional_vars, pyast.Name) + or node.items[0].optional_vars is None + ) ): assert ( self.stmt_collector != None ), "Reached quote block with no buffer to place quoted statements" + should_append = node.items[0].optional_vars is None def quote_callback(): - self.stmt_collector.extend( - Parser( - node.body, - self.parser_parent.src_info, - parent_scope=get_parent_scope(depth=2), - is_quote_stmt=True, - parent_exo_locals=self.parser_parent.exo_locals, - ).result() - ) + stmts = Parser( + node.body, + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_stmt=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + if should_append: + self.stmt_collector.extend(stmts) + else: + return ExoStatementList(tuple(stmts)) callback_name = self.unquote_env.register_quote_callback(quote_callback) - return pyast.Expr( - value=pyast.Call( - func=pyast.Name(id=callback_name, ctx=pyast.Load()), - args=[], - keywords=[], + if should_append: + return pyast.Expr( + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ) + ) + else: + return pyast.Assign( + targets=[node.items[0].optional_vars], + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ), ) - ) else: return super().generic_visit(node) - def visit_Call(self, node: pyast.Call) -> Any: + def visit_UnaryOp(self, node: pyast.UnaryOp) -> Any: if ( - isinstance(node.func, pyast.Name) - and node.func.id == "quote" - and len(node.keywords) == 0 - and len(node.args) == 1 + isinstance(node.op, pyast.Invert) + and isinstance(node.operand, pyast.Set) + and len(node.operand.elts) == 1 ): def quote_callback(): - return Parser( - node.args[0], - self.parser_parent.src_info, - parent_scope=get_parent_scope(depth=2), - is_quote_expr=True, - parent_exo_locals=self.parser_parent.exo_locals, - ).result() + return ExoExpression( + Parser( + node.operand.elts[0], + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_expr=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + ) callback_name = self.unquote_env.register_quote_callback(quote_callback) return pyast.Call( @@ -242,17 +280,11 @@ def quote_callback(): return super().generic_visit(node) -QUOTE_CALLBACK_PREFIX = "__quote_callback" -QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" -OUTER_SCOPE_HELPER = "__outer_scope" -NESTED_SCOPE_HELPER = "__nested_scope" -UNQUOTE_RETURN_HELPER = "__unquote_val" - - @dataclass class UnquoteEnv: parent_globals: dict[str, Any] parent_locals: dict[str, Local] + exo_local_vars: dict[str, Any] def mangle_name(self, prefix: str) -> str: index = 0 @@ -277,6 +309,12 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: unbound_names = { name for name, val in self.parent_locals.items() if val is None } + quote_locals = { + name: ExoExpression(val) + for name, val in self.exo_local_vars.items() + if name not in self.parent_locals + } + env_locals = {**quote_locals, **bound_locals} exec( compile( pyast.fix_missing_locations( @@ -287,7 +325,9 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: args=pyast.arguments( posonlyargs=[], args=[ - pyast.arg(arg=arg) for arg in self.parent_locals + *[pyast.arg(arg=arg) for arg in bound_locals], + *[pyast.arg(arg=arg) for arg in unbound_names], + *[pyast.arg(arg=arg) for arg in quote_locals], ], kwonlyargs=[], kw_defaults=[], @@ -330,10 +370,20 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: ), body=pyast.Tuple( elts=[ - pyast.Name( - id=arg, ctx=pyast.Load() - ) - for arg in self.parent_locals + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in bound_locals + ], + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in unbound_names + ], ], ctx=pyast.Load(), ), @@ -368,12 +418,18 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: ctx=pyast.Load(), ), args=[ - ( + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in bound_locals + ], + *[ pyast.Constant(value=None) - if val is None - else pyast.Name(id=name, ctx=pyast.Load()) - ) - for name, val in self.parent_locals.items() + for _ in unbound_names + ], + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in quote_locals + ], ], keywords=[], ), @@ -386,9 +442,9 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: "exec", ), self.parent_globals, - bound_locals, + env_locals, ) - return bound_locals[UNQUOTE_RETURN_HELPER] + return env_locals[UNQUOTE_RETURN_HELPER] def interpret_quote_expr(self, expr: pyast.expr): return self.interpret_quote_block([pyast.Return(value=expr)]) @@ -495,6 +551,51 @@ def pop(self): def err(self, node, errstr, origin=None): raise ParseError(f"{self.getsrcinfo(node)}: {errstr}") from origin + def make_exo_var_asts(self, srcinfo): + return { + name: self.AST.Read(val, [], srcinfo) + for name, val in self.exo_locals.items() + if isinstance(val, Sym) + } + + def try_eval_unquote( + self, unquote_node: pyast.expr + ) -> Union[tuple[()], tuple[Any]]: + if isinstance(unquote_node, pyast.Set): + if len(unquote_node.elts) != 1: + self.err(unquote_node, "Unquote must take 1 argument") + else: + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ) + quote_replacer = QuoteReplacer(self, unquote_env) + unquoted = unquote_env.interpret_quote_expr( + quote_replacer.visit(copy.deepcopy(unquote_node.elts[0])) + ) + return (unquoted,) + elif ( + isinstance(unquote_node, pyast.Name) + and isinstance(unquote_node.ctx, pyast.Load) + and unquote_node.id not in self.exo_locals + ): + cur_globals = self.parent_scope.get_globals() + cur_locals = self.parent_scope.read_locals() + return ( + ( + UnquoteEnv( + cur_globals, + cur_locals, + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ).interpret_quote_expr(unquote_node), + ) + if unquote_node.id in cur_locals or unquote_node.id in cur_globals + else tuple() + ) + else: + return tuple() + def eval_expr(self, expr): assert isinstance(expr, pyast.expr) return UnquoteEnv( @@ -503,6 +604,7 @@ def eval_expr(self, expr): **self.parent_scope.read_locals(), **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, }, + self.make_exo_var_asts(self.getsrcinfo(expr)), ).interpret_quote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # @@ -760,27 +862,6 @@ def parse_num_type(self, node, is_arg=False): elif isinstance(node, pyast.Name) and node.id in Parser._prim_types: return Parser._prim_types[node.id] - elif ( - isinstance(node, pyast.Call) - and isinstance(node.func, pyast.Name) - and node.func.id == "unquote" - ): - if len(node.keywords) != 0: - self.err(node, "Unquote must take non-keyword argument") - elif len(node.args) != 1: - self.err(node, "Unquote must take 1 argument") - else: - unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() - ) - quote_replacer = QuoteReplacer(self, unquote_env) - unquoted = unquote_env.interpret_quote_expr( - quote_replacer.visit(copy.deepcopy(node.args[0])) - ) - if isinstance(unquoted, str) and unquoted in Parser._prim_types: - return Parser._prim_types[unquoted] - else: - self.err(node, "Unquote computation did not yield valid type") elif isinstance(node, pyast.Name) and ( _is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node) ): @@ -788,7 +869,13 @@ def parse_num_type(self, node, is_arg=False): node, f"Cannot allocate an intermediate value of type {node.id}" ) else: - self.err(node, "unrecognized type: " + pyast.dump(node)) + unquote_eval_result = self.try_eval_unquote(node) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, str) and unquoted in Parser._prim_types: + return Parser._prim_types[unquoted] + else: + self.err(node, "Unquote computation did not yield valid type") def parse_stmt_block(self, stmts): assert isinstance(stmts, list) @@ -800,11 +887,14 @@ def parse_stmt_block(self, stmts): if ( len(s.items) == 1 and isinstance(s.items[0].context_expr, pyast.Name) - and s.items[0].context_expr.id == "unquote" + and s.items[0].context_expr.id == UNQUOTE_BLOCK_KEYWORD and isinstance(s.items[0].context_expr.ctx, pyast.Load) + and s.items[0].optional_vars is None ): unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(s)), ) quoted_stmts = [] quote_stmt_replacer = QuoteReplacer(self, unquote_env, quoted_stmts) @@ -816,7 +906,28 @@ def parse_stmt_block(self, stmts): ) rstmts.extend(quoted_stmts) else: - self.err(s.id, "Expected unquote") + self.err(s, "Expected unquote") + elif isinstance(s, pyast.Expr) and isinstance(s.value, pyast.Set): + if len(s.value.elts) != 1: + self.err(s, "Unquote must take 1 argument") + else: + unquoted = self.try_eval_unquote(s.value)[0] + if ( + isinstance(unquoted, ExoStatementList) + and isinstance(unquoted._inner, tuple) + and all( + map( + lambda inner_s: isinstance(inner_s, self.AST.stmt), + unquoted._inner, + ) + ) + ): + rstmts.extend(unquoted._inner) + else: + self.err( + s, + "Statement-level unquote expression must return Exo statements", + ) # ----- Assginment, Reduction, Var Declaration/Allocation parsing elif isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): # parse the rhs first, if it's present @@ -1147,12 +1258,75 @@ def parse_array_indexing(self, node): if not isinstance(node.value, pyast.Name): self.err(node, "expected access to have form 'x' or 'x[...]'") - is_window = any(isinstance(e, pyast.Slice) for e in dims) - idxs = [ - (self.parse_slice(e, node) if is_window else self.parse_expr(e)) - for e in dims - ] + def unquote_to_index(unquoted, ref_node, srcinfo, top_level): + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + elif isinstance(unquoted, slice) and top_level: + if unquoted.step is None: + return UAST.Interval( + ( + None + if unquoted.start is None + else unquote_to_index(unquoted.start, False) + ), + ( + None + if unquoted.stop is None + else unquote_to_index(unquoted.stop, False) + ), + srcinfo, + ) + else: + self.err(ref_node, "Unquote returned slice index with step") + else: + self.err( + ref_node, "Unquote received input that couldn't be unquoted" + ) + idxs = [] + srcinfo_for_idxs = [] + for e in dims: + if sys.version_info[:3] >= (3, 9): + srcinfo = self.getsrcinfo(e) + else: + if isinstance(e, pyast.Index): + e = e.value + srcinfo = self.getsrcinfo(e) + else: + srcinfo = self.getsrcinfo(node) + if isinstance(e, pyast.Slice): + idxs.append(self.parse_slice(e, node)) + srcinfo_for_idxs.append(srcinfo) + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + + else: + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, tuple): + for unquoted_val in unquoted: + idxs.append( + unquote_to_index(unquoted_val, e, srcinfo, True) + ) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(unquote_to_index(unquoted, e, srcinfo, True)) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(self.parse_expr(e)) + srcinfo_for_idxs.append(srcinfo) + + is_window = any(map(lambda idx: isinstance(idx, UAST.Interval), idxs)) + if is_window: + for i in range(len(idxs)): + if not isinstance(idxs[i], UAST.Interval): + idxs[i] = UAST.Point(idxs[i], srcinfo_for_idxs[i]) return node.value, idxs, is_window else: assert False, "bad case" @@ -1185,7 +1359,18 @@ def parse_slice(self, e, node): # parse expressions, including values, indices, and booleans def parse_expr(self, e): - if isinstance(e, (pyast.Name, pyast.Subscript)): + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + else: + self.err(e, "Unquote received input that couldn't be unquoted") + elif isinstance(e, (pyast.Name, pyast.Subscript)): nm_node, idxs, is_window = self.parse_array_indexing(e) if self.is_fragment: @@ -1366,27 +1551,8 @@ def parse_expr(self, e): return res elif isinstance(e, pyast.Call): - if isinstance(e.func, pyast.Name) and e.func.id == "unquote": - if len(e.keywords) != 0: - self.err(e, "Unquote must take non-keyword argument") - elif len(e.args) != 1: - self.err(e, "Unquote must take 1 argument") - else: - unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() - ) - quote_replacer = QuoteReplacer(self, unquote_env) - unquoted = unquote_env.interpret_quote_expr( - quote_replacer.visit(copy.deepcopy(e.args[0])) - ) - if isinstance(unquoted, (int, float)): - return self.AST.Const(unquoted, self.getsrcinfo(e)) - elif isinstance(unquoted, self.AST.expr): - return unquoted - else: - self.err(e, "Unquote received input that couldn't be unquoted") # handle stride expression - elif isinstance(e.func, pyast.Name) and e.func.id == "stride": + if isinstance(e.func, pyast.Name) and e.func.id == "stride": if ( len(e.keywords) > 0 or len(e.args) != 2 diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index ec73e198..93bfe745 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -8,9 +8,9 @@ def test_unrolling(golden): def foo(a: i8): b: i8 b = 0 - with unquote: + with meta: for _ in range(10): - with quote: + with ~meta: b += a c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -23,12 +23,12 @@ def foo(cond: bool): @proc def bar(a: i8): b: i8 - with unquote: + with meta: if cond: - with quote: + with ~meta: b = 0 else: - with quote: + with ~meta: b += 1 return bar @@ -44,7 +44,7 @@ def test_scoping(golden): @proc def foo(a: i8): - a = unquote(a) + a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -55,10 +55,10 @@ def test_scope_nesting(golden): @proc def foo(a: i8, b: i8): - with unquote: + with meta: y = 2 - with quote: - a = unquote(quote(b) if x == 3 and y == 2 else quote(a)) + with ~meta: + a = {~{b} if x == 3 and y == 2 else ~{a}} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -70,9 +70,9 @@ def test_global_scope(): @proc def foo(a: i8): a = 0 - with unquote: - with quote: - with unquote: + with meta: + with ~meta: + with meta: global dict cell[0] = dict dict = None @@ -85,7 +85,7 @@ def test_constant_lifting(golden): @proc def foo(a: f64): - a = unquote((x**x + x) / x) + a = {(x**x + x) / x} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -94,10 +94,10 @@ def foo(a: f64): def test_type_params(golden): def foo(T: str, U: str): @proc - def bar(a: unquote(T), b: unquote(U)): - c: unquote(T)[4] + def bar(a: {T}, b: {U}): + c: {T}[4] for i in seq(0, 3): - d: unquote(T) + d: {T} d = b c[i + 1] = a + c[i] * d a = c[3] @@ -118,11 +118,11 @@ def foo(): @proc def bar(a: i32): - with unquote: + with meta: for _ in range(10): foo() - with quote: - a += unquote(cell[0]) + with ~meta: + a += {cell[0]} c_file, _ = compile_procs_to_strings([bar], "test.h") assert c_file == golden @@ -133,10 +133,10 @@ def test_capture_nested_quote(golden): @proc def foo(a: i32): - with unquote: + with meta: for _ in range(3): - with quote: - a = unquote(a) + with ~meta: + a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden From 8319306c588a1945110d0e33bc0daa86bb819d90 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Thu, 16 May 2024 15:47:00 -0400 Subject: [PATCH 05/15] Add new metaprogramming tests --- src/exo/frontend/pyparser.py | 10 +- .../test_quote_complex_expr.txt | 16 ++ .../test_quote_elision.txt | 17 ++ .../test_scope_collision1.txt | 18 ++ .../test_scope_collision2.txt | 17 ++ .../test_statement_assignment.txt | 19 ++ .../test_type_quote_elision.txt | 18 ++ .../test_unquote_elision.txt | 16 ++ .../test_unquote_in_slice.txt | 25 +++ .../test_unquote_index_tuple.txt | 26 +++ .../test_unquote_slice_object1.txt | 31 +++ tests/test_metaprogramming.py | 198 ++++++++++++++++++ 12 files changed, 408 insertions(+), 3 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_quote_complex_expr.txt create mode 100644 tests/golden/test_metaprogramming/test_quote_elision.txt create mode 100644 tests/golden/test_metaprogramming/test_scope_collision1.txt create mode 100644 tests/golden/test_metaprogramming/test_scope_collision2.txt create mode 100644 tests/golden/test_metaprogramming/test_statement_assignment.txt create mode 100644 tests/golden/test_metaprogramming/test_type_quote_elision.txt create mode 100644 tests/golden/test_metaprogramming/test_unquote_elision.txt create mode 100644 tests/golden/test_metaprogramming/test_unquote_in_slice.txt create mode 100644 tests/golden/test_metaprogramming/test_unquote_index_tuple.txt create mode 100644 tests/golden/test_metaprogramming/test_unquote_slice_object1.txt diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 04f8c47a..7593192c 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -1260,7 +1260,7 @@ def parse_array_indexing(self, node): def unquote_to_index(unquoted, ref_node, srcinfo, top_level): if isinstance(unquoted, (int, float)): - return self.AST.Const(unquoted, self.getsrcinfo(e)) + return self.AST.Const(unquoted, srcinfo) elif isinstance(unquoted, ExoExpression) and isinstance( unquoted._inner, self.AST.expr ): @@ -1271,12 +1271,16 @@ def unquote_to_index(unquoted, ref_node, srcinfo, top_level): ( None if unquoted.start is None - else unquote_to_index(unquoted.start, False) + else unquote_to_index( + unquoted.start, ref_node, srcinfo, False + ) ), ( None if unquoted.stop is None - else unquote_to_index(unquoted.stop, False) + else unquote_to_index( + unquoted.stop, ref_node, srcinfo, False + ) ), srcinfo, ) diff --git a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt new file mode 100644 index 00000000..f9245640 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt @@ -0,0 +1,16 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = *a + ((int32_t) 1) + ((int32_t) 1); +} + diff --git a/tests/golden/test_metaprogramming/test_quote_elision.txt b/tests/golden/test_metaprogramming/test_quote_elision.txt new file mode 100644 index 00000000..90a35203 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_quote_elision.txt @@ -0,0 +1,17 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM, +// b : i32 @DRAM +// ) +void foo( void *ctxt, const int32_t* a, int32_t* b ) { +*b = *a; +} + diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt new file mode 100644 index 00000000..78ac1e83 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -0,0 +1,18 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +int32_t b; +b = ((int32_t) 2); +*a = ((int32_t) 1); +} + diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt new file mode 100644 index 00000000..90a35203 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -0,0 +1,17 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM, +// b : i32 @DRAM +// ) +void foo( void *ctxt, const int32_t* a, int32_t* b ) { +*b = *a; +} + diff --git a/tests/golden/test_metaprogramming/test_statement_assignment.txt b/tests/golden/test_metaprogramming/test_statement_assignment.txt new file mode 100644 index 00000000..a289e8c8 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_statement_assignment.txt @@ -0,0 +1,19 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 2); +*a += ((int32_t) 1); +*a += ((int32_t) 2); +} + diff --git a/tests/golden/test_metaprogramming/test_type_quote_elision.txt b/tests/golden/test_metaprogramming/test_type_quote_elision.txt new file mode 100644 index 00000000..70514d6c --- /dev/null +++ b/tests/golden/test_metaprogramming/test_type_quote_elision.txt @@ -0,0 +1,18 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i8 @DRAM, +// x : i8[2] @DRAM +// ) +void foo( void *ctxt, int8_t* a, const int8_t* x ) { +*a += x[0]; +*a += x[1]; +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_elision.txt b/tests/golden/test_metaprogramming/test_unquote_elision.txt new file mode 100644 index 00000000..96f1bfed --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_elision.txt @@ -0,0 +1,16 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = *a * ((int32_t) 2); +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt new file mode 100644 index 00000000..ef1b0745 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt @@ -0,0 +1,25 @@ +#include "test.h" + + + +#include +#include + + + +// bar( +// a : i8[10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 5; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 2], { 1 } }); +} +} + +// foo( +// a : [i8][2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_1i8 a ) { +a.data[0] += a.data[a.strides[0]]; +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt new file mode 100644 index 00000000..0e7a2353 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt @@ -0,0 +1,26 @@ +#include "test.h" + + + +#include +#include + + + +// bar( +// a : i8[10, 10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 7; i++) { + foo(ctxt,(struct exo_win_2i8){ &a[(i) * (100) + (i) * (10) + i + 1], { 10, 1 } }); +} +} + +// foo( +// a : [i8][2, 2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_2i8 a ) { +a.data[0] += a.data[a.strides[1]]; +a.data[a.strides[0]] += a.data[a.strides[0] + a.strides[1]]; +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt new file mode 100644 index 00000000..e58a2981 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt @@ -0,0 +1,31 @@ +#include "test.h" + + + +#include +#include + + + +// bar( +// a : i8[10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 1], { 1 } }); +} +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 5], { 1 } }); +} +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 2], { 1 } }); +} +} + +// foo( +// a : [i8][2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_1i8 a ) { +a.data[0] += a.data[a.strides[0]]; +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 93bfe745..8157f6d6 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -1,6 +1,8 @@ from __future__ import annotations from exo import proc, compile_procs_to_strings from exo.API_scheduling import rename +from exo.pyparser import ParseError +import pytest def test_unrolling(golden): @@ -140,3 +142,199 @@ def foo(a: i32): c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden + + +def test_quote_elision(golden): + @proc + def foo(a: i32, b: i32): + with meta: + + def bar(): + return a + + with ~meta: + b = {bar()} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_unquote_elision(golden): + @proc + def foo(a: i32): + with meta: + x = 2 + with ~meta: + a = a * x + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_scope_collision1(golden): + @proc + def foo(a: i32): + with meta: + b = 1 + with ~meta: + b: i32 + b = 2 + with meta: + c = b + with ~meta: + a = c + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_scope_collision2(golden): + @proc + def foo(a: i32, b: i32): + with meta: + a = 1 + with ~meta: + b = a + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_scope_collision3(): + with pytest.raises( + NameError, + match="free variable 'x' referenced before assignment in enclosing scope", + ): + + @proc + def foo(a: i32, b: i32): + with meta: + with ~meta: + a = b * x + x = 1 + + +def test_type_quote_elision(golden): + T = "i8" + + @proc + def foo(a: T, x: T[2]): + a += x[0] + a += x[1] + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_unquote_in_slice(golden): + @proc + def foo(a: [i8][2]): + a[0] += a[1] + + @proc + def bar(a: i8[10, 10]): + with meta: + x = 2 + with ~meta: + for i in seq(0, 5): + foo(a[i, {x} : {2 * x}]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert c_file == golden + + +def test_unquote_slice_object1(golden): + @proc + def foo(a: [i8][2]): + a[0] += a[1] + + @proc + def bar(a: i8[10, 10]): + with meta: + for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: + with ~meta: + for i in seq(0, 10): + foo(a[i, s]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert c_file == golden + + +def test_unquote_slice_object2(): + with pytest.raises( + ParseError, match="cannot perform windowing on left-hand-side of an assignment" + ): + + @proc + def foo(a: i8[10, 10]): + with meta: + for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: + with ~meta: + for i in seq(0, 10): + a[i, s] = 2 + + +def test_unquote_index_tuple(golden): + @proc + def foo(a: [i8][2, 2]): + a[0, 0] += a[0, 1] + a[1, 0] += a[1, 1] + + @proc + def bar(a: i8[10, 10, 10]): + with meta: + + def get_index(i): + return slice(i, ~{i + 2}), slice(~{i + 1}, ~{i + 3}) + + with ~meta: + for i in seq(0, 7): + foo(a[i, {get_index(i)}]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert c_file == golden + + +def test_unquote_err(): + with pytest.raises( + ParseError, match="Unquote computation did not yield valid type" + ): + T = 1 + + @proc + def foo(a: T): + a += 1 + + +def test_quote_complex_expr(golden): + @proc + def foo(a: i32): + with meta: + + def bar(x): + return ~{x + 1} + + with ~meta: + a = {bar(~{a + 1})} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_statement_assignment(golden): + @proc + def foo(a: i32): + with meta: + with ~meta as s1: + a += 1 + a += 2 + with ~meta as s2: + a += 3 + a += 4 + s = s1 if True else s2 + with ~meta: + {s} + {s} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden From 0c8b758b5f31fc5c96714767ea59f829a1c75dce Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 29 Oct 2024 11:33:10 -0400 Subject: [PATCH 06/15] Fix tests --- src/exo/frontend/pattern_match.py | 2 +- src/exo/frontend/pyparser.py | 22 +++++++++++++------ .../test_capture_nested_quote.txt | 4 ---- .../test_captured_closure.txt | 4 ---- .../test_metaprogramming/test_conditional.txt | 4 ---- .../test_constant_lifting.txt | 4 ---- .../test_quote_complex_expr.txt | 4 ---- .../test_quote_elision.txt | 4 ---- .../test_scope_collision1.txt | 4 ---- .../test_scope_collision2.txt | 4 ---- .../test_scope_nesting.txt | 4 ---- .../test_metaprogramming/test_scoping.txt | 4 ---- .../test_statement_assignment.txt | 4 ---- .../test_metaprogramming/test_statements.txt | 21 ++++++++++++++++++ .../test_metaprogramming/test_type_params.txt | 4 ---- .../test_type_quote_elision.txt | 4 ---- .../test_unquote_elision.txt | 4 ---- .../test_unquote_in_slice.txt | 4 ---- .../test_unquote_index_tuple.txt | 4 ---- .../test_unquote_slice_object1.txt | 4 ---- .../test_metaprogramming/test_unrolling.txt | 4 ---- tests/test_metaprogramming.py | 19 +++++++++++++++- tests/test_typecheck.py | 4 ++-- tests/test_uast.py | 11 ++++++---- 24 files changed, 64 insertions(+), 87 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_statements.txt diff --git a/src/exo/frontend/pattern_match.py b/src/exo/frontend/pattern_match.py index 55eca676..32b71704 100644 --- a/src/exo/frontend/pattern_match.py +++ b/src/exo/frontend/pattern_match.py @@ -83,7 +83,7 @@ def match_pattern( # get source location where this is getting called from caller = inspect.getframeinfo(stack_frames[call_depth][0]) func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) - func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) + func_globals = stack_frames[call_depth].frame.f_globals # parse the pattern we're going to use to match p_ast = pyparser.pattern( diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 7593192c..36f962df 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -165,7 +165,14 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): SourceInfo( src_file=srcfilename, src_line_offset=srclineno, src_col_offset=n_dedent ), - parent_scope=DummyScope({}, {}), # add globals from enclosing scope + parent_scope=DummyScope( + srcglobals if srcglobals is not None else {}, + ( + {k: BoundLocal(v) for k, v in srclocals.items()} + if srclocals is not None + else {} + ), + ), # add globals from enclosing scope is_fragment=True, ) return parser.result() @@ -495,10 +502,10 @@ def __init__( self.push() special_cases = ["stride"] - for key, val in self.globals.items(): + for key, val in parent_scope.get_globals().items(): if isinstance(val, Extern): special_cases.append(key) - for key, val in self.locals.items(): + for key, val in parent_scope.read_locals().items(): if isinstance(val, Extern): special_cases.append(key) @@ -579,6 +586,7 @@ def try_eval_unquote( isinstance(unquote_node, pyast.Name) and isinstance(unquote_node.ctx, pyast.Load) and unquote_node.id not in self.exo_locals + and not self.is_fragment ): cur_globals = self.parent_scope.get_globals() cur_locals = self.parent_scope.read_locals() @@ -860,8 +868,8 @@ def parse_num_type(self, node, is_arg=False): return typ - elif isinstance(node, pyast.Name) and node.id in Parser._prim_types: - return Parser._prim_types[node.id] + elif isinstance(node, pyast.Name) and node.id in _prim_types: + return _prim_types[node.id] elif isinstance(node, pyast.Name) and ( _is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node) ): @@ -872,8 +880,8 @@ def parse_num_type(self, node, is_arg=False): unquote_eval_result = self.try_eval_unquote(node) if len(unquote_eval_result) == 1: unquoted = unquote_eval_result[0] - if isinstance(unquoted, str) and unquoted in Parser._prim_types: - return Parser._prim_types[unquoted] + if isinstance(unquoted, str) and unquoted in _prim_types: + return _prim_types[unquoted] else: self.err(node, "Unquote computation did not yield valid type") diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt index bcb2ebb0..ca9b81a5 100644 --- a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_captured_closure.txt b/tests/golden/test_metaprogramming/test_captured_closure.txt index af22a740..20390796 100644 --- a/tests/golden/test_metaprogramming/test_captured_closure.txt +++ b/tests/golden/test_metaprogramming/test_captured_closure.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_conditional.txt b/tests/golden/test_metaprogramming/test_conditional.txt index d1cac99e..8f2b476b 100644 --- a/tests/golden/test_metaprogramming/test_conditional.txt +++ b/tests/golden/test_metaprogramming/test_conditional.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar1( // a : i8 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_constant_lifting.txt b/tests/golden/test_metaprogramming/test_constant_lifting.txt index e98d98e3..0f25fad1 100644 --- a/tests/golden/test_metaprogramming/test_constant_lifting.txt +++ b/tests/golden/test_metaprogramming/test_constant_lifting.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : f64 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt index f9245640..3f3c8626 100644 --- a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt +++ b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_quote_elision.txt b/tests/golden/test_metaprogramming/test_quote_elision.txt index 90a35203..da671d39 100644 --- a/tests/golden/test_metaprogramming/test_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_quote_elision.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM, // b : i32 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt index 78ac1e83..89ba4b00 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision1.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt index 90a35203..da671d39 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision2.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM, // b : i32 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scope_nesting.txt b/tests/golden/test_metaprogramming/test_scope_nesting.txt index 54d8c8ba..db2f5260 100644 --- a/tests/golden/test_metaprogramming/test_scope_nesting.txt +++ b/tests/golden/test_metaprogramming/test_scope_nesting.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i8 @DRAM, // b : i8 @DRAM diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt index 91ff8af8..8679fce5 100644 --- a/tests/golden/test_metaprogramming/test_scoping.txt +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i8 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_statement_assignment.txt b/tests/golden/test_metaprogramming/test_statement_assignment.txt index a289e8c8..71f64950 100644 --- a/tests/golden/test_metaprogramming/test_statement_assignment.txt +++ b/tests/golden/test_metaprogramming/test_statement_assignment.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_statements.txt b/tests/golden/test_metaprogramming/test_statements.txt new file mode 100644 index 00000000..cf5820b4 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_statements.txt @@ -0,0 +1,21 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 1); +for (int_fast32_t i = 0; i < 2; i++) { + *a += ((int32_t) 1); + *a += ((int32_t) 1); +} +} + diff --git a/tests/golden/test_metaprogramming/test_type_params.txt b/tests/golden/test_metaprogramming/test_type_params.txt index 488510d7..23b4b196 100644 --- a/tests/golden/test_metaprogramming/test_type_params.txt +++ b/tests/golden/test_metaprogramming/test_type_params.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar1( // a : i32 @DRAM, // b : i8 @DRAM diff --git a/tests/golden/test_metaprogramming/test_type_quote_elision.txt b/tests/golden/test_metaprogramming/test_type_quote_elision.txt index 70514d6c..5db02aca 100644 --- a/tests/golden/test_metaprogramming/test_type_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_type_quote_elision.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i8 @DRAM, // x : i8[2] @DRAM diff --git a/tests/golden/test_metaprogramming/test_unquote_elision.txt b/tests/golden/test_metaprogramming/test_unquote_elision.txt index 96f1bfed..da220cec 100644 --- a/tests/golden/test_metaprogramming/test_unquote_elision.txt +++ b/tests/golden/test_metaprogramming/test_unquote_elision.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i32 @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt index ef1b0745..bc7554eb 100644 --- a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt +++ b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar( // a : i8[10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt index 0e7a2353..ead0c0db 100644 --- a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt +++ b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar( // a : i8[10, 10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt index e58a2981..37da11d2 100644 --- a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt +++ b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // bar( // a : i8[10, 10] @DRAM // ) diff --git a/tests/golden/test_metaprogramming/test_unrolling.txt b/tests/golden/test_metaprogramming/test_unrolling.txt index 57f613bb..f556b8d5 100644 --- a/tests/golden/test_metaprogramming/test_unrolling.txt +++ b/tests/golden/test_metaprogramming/test_unrolling.txt @@ -1,12 +1,8 @@ #include "test.h" - - #include #include - - // foo( // a : i8 @DRAM // ) diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 8157f6d6..bd31a50c 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -1,7 +1,7 @@ from __future__ import annotations from exo import proc, compile_procs_to_strings from exo.API_scheduling import rename -from exo.pyparser import ParseError +from exo.frontend.pyparser import ParseError import pytest @@ -338,3 +338,20 @@ def foo(a: i32): c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden + + +def test_statement_in_expr(): + with pytest.raises(ParseError): + + @proc + def foo(a: i32): + with meta: + + def bar(): + with ~meta: + a += 1 + return 2 + + with ~meta: + a += {bar()} + a += {bar()} diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index fe9f86d0..84a8b21c 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -80,14 +80,14 @@ def foo(n: size, A: R[n] @ GEMM_SCRATCH): def test_sin1(): @proc - def sin(x: f32): + def sin_proc(x: f32): y: f32 y = sin(x) def test_sin2(): @proc - def sin(x: f32): + def sin_proc(x: f32): y: f32 if False: y = sin(x) diff --git a/tests/test_uast.py b/tests/test_uast.py index aebc2f95..4bf8b5ab 100644 --- a/tests/test_uast.py +++ b/tests/test_uast.py @@ -16,8 +16,7 @@ def to_uast(f): parser = Parser( body, getsrcinfo, - func_globals=f.__globals__, - srclocals=get_parent_scope(depth=2), + parent_scope=get_parent_scope(depth=2), instr=("TEST", ""), as_func=True, ) @@ -99,7 +98,9 @@ def func(f: f32): for i in seq(0, global_str): f += 1 - with pytest.raises(ParseError, match="type "): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): to_uast(func) local_str = "xyzzy" @@ -108,7 +109,9 @@ def func(f: f32): for i in seq(0, local_str): f += 1 - with pytest.raises(ParseError, match="type "): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): to_uast(func) From 52ebf8746d98c43ec8f144a3e7e1af53834fbf86 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Tue, 29 Oct 2024 23:06:16 -0400 Subject: [PATCH 07/15] Fix scoping issues --- src/exo/frontend/pyparser.py | 122 ++++++++++++++++++++++------------ tests/test_metaprogramming.py | 23 ++++++- 2 files changed, 103 insertions(+), 42 deletions(-) diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 36f962df..381aac9b 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -179,10 +179,10 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): QUOTE_CALLBACK_PREFIX = "__quote_callback" -QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" OUTER_SCOPE_HELPER = "__outer_scope" NESTED_SCOPE_HELPER = "__nested_scope" UNQUOTE_RETURN_HELPER = "__unquote_val" +QUOTE_STMT_PROCESSOR = "__process_quote_stmt" UNQUOTE_BLOCK_KEYWORD = "meta" @@ -196,16 +196,12 @@ class ExoStatementList: _inner: tuple[Any, ...] +@dataclass class QuoteReplacer(pyast.NodeTransformer): - def __init__( - self, - parser_parent: "Parser", - unquote_env: "UnquoteEnv", - stmt_collector: Optional[list[pyast.stmt]] = None, - ): - self.stmt_collector = stmt_collector - self.unquote_env = unquote_env - self.parser_parent = parser_parent + src_info: SourceInfo + exo_locals: dict[str, Any] + unquote_env: "UnquoteEnv" + inside_function: bool = False def visit_With(self, node: pyast.With) -> pyast.Any: if ( @@ -220,36 +216,43 @@ def visit_With(self, node: pyast.With) -> pyast.Any: or node.items[0].optional_vars is None ) ): - assert ( - self.stmt_collector != None - ), "Reached quote block with no buffer to place quoted statements" - should_append = node.items[0].optional_vars is None + stmt_destination = node.items[0].optional_vars - def quote_callback(): - stmts = Parser( + def parse_quote_block(): + return Parser( node.body, - self.parser_parent.src_info, - parent_scope=get_parent_scope(depth=2), + self.src_info, + parent_scope=get_parent_scope(depth=3), is_quote_stmt=True, - parent_exo_locals=self.parser_parent.exo_locals, + parent_exo_locals=self.exo_locals, ).result() - if should_append: - self.stmt_collector.extend(stmts) - else: - return ExoStatementList(tuple(stmts)) - callback_name = self.unquote_env.register_quote_callback(quote_callback) - if should_append: + if stmt_destination is None: + + def quote_callback( + quote_stmt_processor: Optional[Callable[[Any], None]] + ): + if quote_stmt_processor is None: + raise TypeError( + "Cannot unquote Exo statements in this context. You are likely trying to unquote Exo statements while inside an Exo expression." + ) + quote_stmt_processor(parse_quote_block()) + + callback_name = self.unquote_env.register_quote_callback(quote_callback) + return pyast.Expr( value=pyast.Call( func=pyast.Name(id=callback_name, ctx=pyast.Load()), - args=[], + args=[pyast.Name(id=QUOTE_STMT_PROCESSOR, ctx=pyast.Load())], keywords=[], ) ) else: + callback_name = self.unquote_env.register_quote_callback( + lambda: ExoStatementList(tuple(parse_quote_block())) + ) return pyast.Assign( - targets=[node.items[0].optional_vars], + targets=[stmt_destination], value=pyast.Call( func=pyast.Name(id=callback_name, ctx=pyast.Load()), args=[], @@ -270,10 +273,10 @@ def quote_callback(): return ExoExpression( Parser( node.operand.elts[0], - self.parser_parent.src_info, + self.src_info, parent_scope=get_parent_scope(depth=2), is_quote_expr=True, - parent_exo_locals=self.parser_parent.exo_locals, + parent_exo_locals=self.exo_locals, ).result() ) @@ -286,6 +289,33 @@ def quote_callback(): else: return super().generic_visit(node) + def visit_Nonlocal(self, node: pyast.Nonlocal) -> Any: + raise ParseError( + f"{self.src_info.get_src_info(node)}: nonlocal is not supported in metalanguage" + ) + + def visit_FunctionDef(self, node: pyast.FunctionDef): + was_inside_function = self.inside_function + self.inside_function = True + result = super().generic_visit(node) + self.inside_function = was_inside_function + return result + + def visit_AsyncFunctionDef(self, node): + was_inside_function = self.inside_function + self.inside_function = True + result = super().generic_visit(node) + self.inside_function = was_inside_function + return result + + def visit_Return(self, node): + if not self.inside_function: + raise ParseError( + f"{self.src_info.get_src_info(node)}: cannot return from metalanguage fragment" + ) + + return super().generic_visit(node) + @dataclass class UnquoteEnv: @@ -304,12 +334,16 @@ def mangle_name(self, prefix: str) -> str: return mangled_name index += 1 - def register_quote_callback(self, quote_callback: Callable[[], None]) -> str: + def register_quote_callback(self, quote_callback: Callable[..., Any]) -> str: mangled_name = self.mangle_name(QUOTE_CALLBACK_PREFIX) self.parent_locals[mangled_name] = BoundLocal(quote_callback) return mangled_name - def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: + def interpret_unquote_block( + self, + stmts: list[pyast.stmt], + quote_stmt_processor: Optional[Callable[[Any], None]], + ) -> Any: bound_locals = { name: val.val for name, val in self.parent_locals.items() if val is not None } @@ -322,6 +356,7 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: if name not in self.parent_locals } env_locals = {**quote_locals, **bound_locals} + self.parent_globals[QUOTE_STMT_PROCESSOR] = quote_stmt_processor exec( compile( pyast.fix_missing_locations( @@ -453,8 +488,8 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: ) return env_locals[UNQUOTE_RETURN_HELPER] - def interpret_quote_expr(self, expr: pyast.expr): - return self.interpret_quote_block([pyast.Return(value=expr)]) + def interpret_unquote_expr(self, expr: pyast.expr): + return self.interpret_unquote_block([pyast.Return(value=expr)], None) # --------------------------------------------------------------------------- # @@ -577,8 +612,10 @@ def try_eval_unquote( self.parent_scope.read_locals(), self.make_exo_var_asts(self.getsrcinfo(unquote_node)), ) - quote_replacer = QuoteReplacer(self, unquote_env) - unquoted = unquote_env.interpret_quote_expr( + quote_replacer = QuoteReplacer( + self.src_info, self.exo_locals, unquote_env + ) + unquoted = unquote_env.interpret_unquote_expr( quote_replacer.visit(copy.deepcopy(unquote_node.elts[0])) ) return (unquoted,) @@ -596,7 +633,7 @@ def try_eval_unquote( cur_globals, cur_locals, self.make_exo_var_asts(self.getsrcinfo(unquote_node)), - ).interpret_quote_expr(unquote_node), + ).interpret_unquote_expr(unquote_node), ) if unquote_node.id in cur_locals or unquote_node.id in cur_globals else tuple() @@ -613,7 +650,7 @@ def eval_expr(self, expr): **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, }, self.make_exo_var_asts(self.getsrcinfo(expr)), - ).interpret_quote_expr(expr) + ).interpret_unquote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # # structural parsing rules... @@ -904,15 +941,18 @@ def parse_stmt_block(self, stmts): self.parent_scope.read_locals(), self.make_exo_var_asts(self.getsrcinfo(s)), ) - quoted_stmts = [] - quote_stmt_replacer = QuoteReplacer(self, unquote_env, quoted_stmts) - unquote_env.interpret_quote_block( + quote_stmt_replacer = QuoteReplacer( + self.src_info, + self.exo_locals, + unquote_env, + ) + unquote_env.interpret_unquote_block( [ quote_stmt_replacer.visit(copy.deepcopy(python_s)) for python_s in s.body ], + lambda stmts: rstmts.extend(stmts), ) - rstmts.extend(quoted_stmts) else: self.err(s, "Expected unquote") elif isinstance(s, pyast.Expr) and isinstance(s.value, pyast.Set): diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index bd31a50c..762e0919 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -341,7 +341,9 @@ def foo(a: i32): def test_statement_in_expr(): - with pytest.raises(ParseError): + with pytest.raises( + TypeError, match="Cannot unquote Exo statements in this context." + ): @proc def foo(a: i32): @@ -355,3 +357,22 @@ def bar(): with ~meta: a += {bar()} a += {bar()} + + +def test_nonlocal_disallowed(): + with pytest.raises(ParseError, match="nonlocal is not supported"): + x = 0 + + @proc + def foo(a: i32): + with meta: + nonlocal x + + +def test_outer_return_disallowed(): + with pytest.raises(ParseError, match="cannot return from metalanguage fragment"): + + @proc + def foo(a: i32): + with meta: + return From 26868b808a6f1e03683de15df7b5937bd29a842f Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 30 Oct 2024 00:39:09 -0400 Subject: [PATCH 08/15] Fix quote statement placement --- src/exo/frontend/pyparser.py | 21 ++++---- tests/test_metaprogramming.py | 90 +++++++++++++++++------------------ 2 files changed, 56 insertions(+), 55 deletions(-) diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 381aac9b..2dff51d2 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -183,7 +183,8 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): NESTED_SCOPE_HELPER = "__nested_scope" UNQUOTE_RETURN_HELPER = "__unquote_val" QUOTE_STMT_PROCESSOR = "__process_quote_stmt" -UNQUOTE_BLOCK_KEYWORD = "meta" +QUOTE_BLOCK_KEYWORD = "exo" +UNQUOTE_BLOCK_KEYWORD = "python" @dataclass @@ -206,15 +207,9 @@ class QuoteReplacer(pyast.NodeTransformer): def visit_With(self, node: pyast.With) -> pyast.Any: if ( len(node.items) == 1 - and isinstance(node.items[0].context_expr, pyast.UnaryOp) - and isinstance(node.items[0].context_expr.op, pyast.Invert) - and isinstance(node.items[0].context_expr.operand, pyast.Name) - and node.items[0].context_expr.operand.id == UNQUOTE_BLOCK_KEYWORD - and isinstance(node.items[0].context_expr.operand.ctx, pyast.Load) - and ( - isinstance(node.items[0].optional_vars, pyast.Name) - or node.items[0].optional_vars is None - ) + and isinstance(node.items[0].context_expr, pyast.Name) + and node.items[0].context_expr.id == QUOTE_BLOCK_KEYWORD + and isinstance(node.items[0].context_expr.ctx, pyast.Load) ): stmt_destination = node.items[0].optional_vars @@ -356,6 +351,11 @@ def interpret_unquote_block( if name not in self.parent_locals } env_locals = {**quote_locals, **bound_locals} + old_stmt_processor = ( + self.parent_globals[QUOTE_STMT_PROCESSOR] + if QUOTE_STMT_PROCESSOR in self.parent_globals + else None + ) self.parent_globals[QUOTE_STMT_PROCESSOR] = quote_stmt_processor exec( compile( @@ -486,6 +486,7 @@ def interpret_unquote_block( self.parent_globals, env_locals, ) + self.parent_globals[QUOTE_STMT_PROCESSOR] = old_stmt_processor return env_locals[UNQUOTE_RETURN_HELPER] def interpret_unquote_expr(self, expr: pyast.expr): diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 762e0919..f9a80f99 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -10,9 +10,9 @@ def test_unrolling(golden): def foo(a: i8): b: i8 b = 0 - with meta: + with python: for _ in range(10): - with ~meta: + with exo: b += a c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -25,12 +25,12 @@ def foo(cond: bool): @proc def bar(a: i8): b: i8 - with meta: + with python: if cond: - with ~meta: + with exo: b = 0 else: - with ~meta: + with exo: b += 1 return bar @@ -57,9 +57,9 @@ def test_scope_nesting(golden): @proc def foo(a: i8, b: i8): - with meta: + with python: y = 2 - with ~meta: + with exo: a = {~{b} if x == 3 and y == 2 else ~{a}} c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -72,9 +72,9 @@ def test_global_scope(): @proc def foo(a: i8): a = 0 - with meta: - with ~meta: - with meta: + with python: + with exo: + with python: global dict cell[0] = dict dict = None @@ -120,10 +120,10 @@ def foo(): @proc def bar(a: i32): - with meta: + with python: for _ in range(10): foo() - with ~meta: + with exo: a += {cell[0]} c_file, _ = compile_procs_to_strings([bar], "test.h") @@ -135,9 +135,9 @@ def test_capture_nested_quote(golden): @proc def foo(a: i32): - with meta: + with python: for _ in range(3): - with ~meta: + with exo: a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -147,12 +147,12 @@ def foo(a: i32): def test_quote_elision(golden): @proc def foo(a: i32, b: i32): - with meta: + with python: def bar(): return a - with ~meta: + with exo: b = {bar()} c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -162,9 +162,9 @@ def bar(): def test_unquote_elision(golden): @proc def foo(a: i32): - with meta: + with python: x = 2 - with ~meta: + with exo: a = a * x c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -174,14 +174,14 @@ def foo(a: i32): def test_scope_collision1(golden): @proc def foo(a: i32): - with meta: + with python: b = 1 - with ~meta: + with exo: b: i32 b = 2 - with meta: + with python: c = b - with ~meta: + with exo: a = c c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -191,9 +191,9 @@ def foo(a: i32): def test_scope_collision2(golden): @proc def foo(a: i32, b: i32): - with meta: + with python: a = 1 - with ~meta: + with exo: b = a c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -208,8 +208,8 @@ def test_scope_collision3(): @proc def foo(a: i32, b: i32): - with meta: - with ~meta: + with python: + with exo: a = b * x x = 1 @@ -233,9 +233,9 @@ def foo(a: [i8][2]): @proc def bar(a: i8[10, 10]): - with meta: + with python: x = 2 - with ~meta: + with exo: for i in seq(0, 5): foo(a[i, {x} : {2 * x}]) @@ -250,9 +250,9 @@ def foo(a: [i8][2]): @proc def bar(a: i8[10, 10]): - with meta: + with python: for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: - with ~meta: + with exo: for i in seq(0, 10): foo(a[i, s]) @@ -267,9 +267,9 @@ def test_unquote_slice_object2(): @proc def foo(a: i8[10, 10]): - with meta: + with python: for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: - with ~meta: + with exo: for i in seq(0, 10): a[i, s] = 2 @@ -282,12 +282,12 @@ def foo(a: [i8][2, 2]): @proc def bar(a: i8[10, 10, 10]): - with meta: + with python: def get_index(i): return slice(i, ~{i + 2}), slice(~{i + 1}, ~{i + 3}) - with ~meta: + with exo: for i in seq(0, 7): foo(a[i, {get_index(i)}]) @@ -309,12 +309,12 @@ def foo(a: T): def test_quote_complex_expr(golden): @proc def foo(a: i32): - with meta: + with python: def bar(x): return ~{x + 1} - with ~meta: + with exo: a = {bar(~{a + 1})} c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -324,15 +324,15 @@ def bar(x): def test_statement_assignment(golden): @proc def foo(a: i32): - with meta: - with ~meta as s1: + with python: + with exo as s1: a += 1 a += 2 - with ~meta as s2: + with exo as s2: a += 3 a += 4 s = s1 if True else s2 - with ~meta: + with exo: {s} {s} @@ -347,14 +347,14 @@ def test_statement_in_expr(): @proc def foo(a: i32): - with meta: + with python: def bar(): - with ~meta: + with exo: a += 1 return 2 - with ~meta: + with exo: a += {bar()} a += {bar()} @@ -365,7 +365,7 @@ def test_nonlocal_disallowed(): @proc def foo(a: i32): - with meta: + with python: nonlocal x @@ -374,5 +374,5 @@ def test_outer_return_disallowed(): @proc def foo(a: i32): - with meta: + with python: return From 151d26e6ba3f6bd0dd56afae2f4ad927a88f836d Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Thu, 31 Oct 2024 00:04:52 -0400 Subject: [PATCH 09/15] Increase code coverage --- src/exo/frontend/pyparser.py | 38 ++--- .../test_local_externs.txt | 13 ++ .../test_metaprogramming/test_unary_ops.txt | 12 ++ tests/test_metaprogramming.py | 144 ++++++++++++++++++ 4 files changed, 178 insertions(+), 29 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_local_externs.txt create mode 100644 tests/golden/test_metaprogramming/test_unary_ops.txt diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 2dff51d2..6b523a87 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -34,10 +34,6 @@ def __init__(self, nm): self.nm = nm -def str_to_mem(name): - return getattr(sys.modules[__name__], name) - - @dataclass class SourceInfo: src_file: str @@ -1354,10 +1350,6 @@ def unquote_to_index(unquoted, ref_node, srcinfo, top_level): if isinstance(e, pyast.Slice): idxs.append(self.parse_slice(e, node)) srcinfo_for_idxs.append(srcinfo) - unquote_eval_result = self.try_eval_unquote(e) - if len(unquote_eval_result) == 1: - unquoted = unquote_eval_result[0] - else: unquote_eval_result = self.try_eval_unquote(e) if len(unquote_eval_result) == 1: @@ -1396,19 +1388,16 @@ def parse_slice(self, e, node): else: srcinfo = self.getsrcinfo(node) - if isinstance(e, pyast.Slice): - lo = None if e.lower is None else self.parse_expr(e.lower) - hi = None if e.upper is None else self.parse_expr(e.upper) - if e.step is not None: - self.err( - e, - "expected windowing to have the form x[:], " - "x[i:], x[:j], or x[i:j], but not x[i:j:k]", - ) + lo = None if e.lower is None else self.parse_expr(e.lower) + hi = None if e.upper is None else self.parse_expr(e.upper) + if e.step is not None: + self.err( + e, + "expected windowing to have the form x[:], " + "x[i:], x[:j], or x[i:j], but not x[i:j:k]", + ) - return UAST.Interval(lo, hi, srcinfo) - else: - return UAST.Point(self.parse_expr(e), srcinfo) + return UAST.Interval(lo, hi, srcinfo) # parse expressions, including values, indices, and booleans def parse_expr(self, e): @@ -1433,17 +1422,8 @@ def parse_expr(self, e): else: return PAST.Read(nm, idxs, self.getsrcinfo(e)) else: - parent_globals = self.parent_scope.get_globals() - parent_locals = self.parent_scope.read_locals() if nm_node.id in self.exo_locals: nm = self.exo_locals[nm_node.id] - elif ( - nm_node.id in parent_locals - and parent_locals[nm_node.id] is not None - ): - nm = parent_locals[nm_node.id].val - elif nm_node.id in parent_globals: - nm = parent_globals[nm_node.id] else: self.err(nm_node, f"variable '{nm_node.id}' undefined") diff --git a/tests/golden/test_metaprogramming/test_local_externs.txt b/tests/golden/test_metaprogramming/test_local_externs.txt new file mode 100644 index 00000000..14a3b2c6 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_local_externs.txt @@ -0,0 +1,13 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// a : f64 @DRAM +// ) +void foo( void *ctxt, double* a ) { +*a = log((double)*a); +} + diff --git a/tests/golden/test_metaprogramming/test_unary_ops.txt b/tests/golden/test_metaprogramming/test_unary_ops.txt new file mode 100644 index 00000000..456b67fc --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unary_ops.txt @@ -0,0 +1,12 @@ +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = ((int32_t) -2); +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index f9a80f99..250efc78 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -3,6 +3,8 @@ from exo.API_scheduling import rename from exo.frontend.pyparser import ParseError import pytest +import warnings +from exo.core.extern import Extern, _EErr def test_unrolling(golden): @@ -376,3 +378,145 @@ def test_outer_return_disallowed(): def foo(a: i32): with python: return + + +def test_with_block(): + @proc + def foo(a: i32): + with python: + + def issue_warning(): + warnings.warn("deprecated", DeprecationWarning) + + with warnings.catch_warnings(record=True) as recorded_warnings: + issue_warning() + assert len(recorded_warnings) == 1 + pass + + +def test_unary_ops(golden): + @proc + def foo(a: i32): + with python: + x = ~1 + with exo: + a = x + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_return_in_async(): + @proc + def foo(a: i32): + with python: + + async def bar(): + return 1 + + pass + + +def test_local_externs(golden): + class _Log(Extern): + def __init__(self): + super().__init__("log") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + def compile(self, args, prim_type): + return f"log(({prim_type}){args[0]})" + + log = _Log() + + @proc + def foo(a: f64): + a = log(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert c_file == golden + + +def test_unquote_multiple_exprs(): + with pytest.raises(ParseError, match="Unquote must take 1 argument"): + x = 0 + + @proc + def foo(a: i32): + a = {x, x} + + +def test_disallow_with_in_exo(): + with pytest.raises(ParseError, match="Expected unquote"): + + @proc + def foo(a: i32): + with a: + pass + + +def test_unquote_multiple_stmts(): + with pytest.raises(ParseError, match="Unquote must take 1 argument"): + + @proc + def foo(a: i32): + with python: + with exo as s: + a += 1 + with exo: + {s, s} + + +def test_unquote_non_statement(): + with pytest.raises( + ParseError, + match="Statement-level unquote expression must return Exo statements", + ): + + @proc + def foo(a: i32): + with python: + x = ~{a} + with exo: + {x} + + +def test_unquote_slice_with_step(): + with pytest.raises(ParseError, match="Unquote returned slice index with step"): + + @proc + def bar(a: [i32][10]): + a[0] = 0 + + @proc + def foo(a: i32[20]): + with python: + x = slice(0, 20, 2) + with exo: + bar(a[x]) + + +def test_typecheck_unquote_index(): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): + + @proc + def foo(a: i32[20]): + with python: + x = "0" + with exo: + a[x] = 0 From 6c10524299c9a773a0734c048abfa69d15a72872 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Thu, 31 Oct 2024 00:47:25 -0400 Subject: [PATCH 10/15] Better local extern test --- .../test_local_externs.txt | 2 +- tests/test_metaprogramming.py | 28 ++----------------- 2 files changed, 4 insertions(+), 26 deletions(-) diff --git a/tests/golden/test_metaprogramming/test_local_externs.txt b/tests/golden/test_metaprogramming/test_local_externs.txt index 14a3b2c6..1c9d31d3 100644 --- a/tests/golden/test_metaprogramming/test_local_externs.txt +++ b/tests/golden/test_metaprogramming/test_local_externs.txt @@ -8,6 +8,6 @@ // a : f64 @DRAM // ) void foo( void *ctxt, double* a ) { -*a = log((double)*a); +*a = sin((double)*a); } diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 250efc78..2f98796e 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -4,7 +4,7 @@ from exo.frontend.pyparser import ParseError import pytest import warnings -from exo.core.extern import Extern, _EErr +from exo.libs.externs import * def test_unrolling(golden): @@ -418,33 +418,11 @@ async def bar(): def test_local_externs(golden): - class _Log(Extern): - def __init__(self): - super().__init__("log") - - def typecheck(self, args): - if len(args) != 1: - raise _EErr(f"expected 1 argument, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _EErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self, prim_type): - return "#include " - - def compile(self, args, prim_type): - return f"log(({prim_type}){args[0]})" - - log = _Log() + my_sin = sin @proc def foo(a: f64): - a = log(a) + a = my_sin(a) c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden From 72cc9cec89f17562b92c6d096548a48521e97ecf Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 6 Nov 2024 11:44:22 -0500 Subject: [PATCH 11/15] Update tests to include ir --- .../test_capture_nested_quote.txt | 6 ++ .../test_captured_closure.txt | 13 ++++ .../test_metaprogramming/test_conditional.txt | 8 +++ .../test_constant_lifting.txt | 4 ++ .../test_local_externs.txt | 4 ++ .../test_proc_shadowing.txt | 28 ++++++++ .../test_quote_complex_expr.txt | 4 ++ .../test_quote_elision.txt | 4 ++ .../test_scope_collision1.txt | 6 ++ .../test_scope_collision2.txt | 4 ++ .../test_scope_nesting.txt | 4 ++ .../test_metaprogramming/test_scoping.txt | 4 ++ .../test_statement_assignment.txt | 7 ++ .../test_metaprogramming/test_type_params.txt | 16 +++++ .../test_type_quote_elision.txt | 5 ++ .../test_metaprogramming/test_unary_ops.txt | 4 ++ .../test_unquote_elision.txt | 4 ++ .../test_unquote_in_slice.txt | 7 ++ .../test_unquote_index_tuple.txt | 8 +++ .../test_unquote_slice_object1.txt | 11 +++ .../test_metaprogramming/test_unrolling.txt | 15 +++++ tests/test_metaprogramming.py | 67 ++++++++++++------- 22 files changed, 207 insertions(+), 26 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_proc_shadowing.txt diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt index ca9b81a5..56a54d49 100644 --- a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -1,3 +1,9 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = 2 + a = 2 + a = 2 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_captured_closure.txt b/tests/golden/test_metaprogramming/test_captured_closure.txt index 20390796..569653d3 100644 --- a/tests/golden/test_metaprogramming/test_captured_closure.txt +++ b/tests/golden/test_metaprogramming/test_captured_closure.txt @@ -1,3 +1,16 @@ +EXO IR: +def bar(a: i32 @ DRAM): + a += 1 + a += 2 + a += 3 + a += 4 + a += 5 + a += 6 + a += 7 + a += 8 + a += 9 + a += 10 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_conditional.txt b/tests/golden/test_metaprogramming/test_conditional.txt index 8f2b476b..7e3473e5 100644 --- a/tests/golden/test_metaprogramming/test_conditional.txt +++ b/tests/golden/test_metaprogramming/test_conditional.txt @@ -1,3 +1,11 @@ +EXO IR: +def bar1(a: i8 @ DRAM): + b: i8 @ DRAM + b += 1 +def bar2(a: i8 @ DRAM): + b: i8 @ DRAM + b = 0 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_constant_lifting.txt b/tests/golden/test_metaprogramming/test_constant_lifting.txt index 0f25fad1..5ac001ad 100644 --- a/tests/golden/test_metaprogramming/test_constant_lifting.txt +++ b/tests/golden/test_metaprogramming/test_constant_lifting.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: f64 @ DRAM): + a = 2.0818897486445276 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_local_externs.txt b/tests/golden/test_metaprogramming/test_local_externs.txt index 1c9d31d3..504175e7 100644 --- a/tests/golden/test_metaprogramming/test_local_externs.txt +++ b/tests/golden/test_metaprogramming/test_local_externs.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: f64 @ DRAM): + a = sin(a) +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_proc_shadowing.txt b/tests/golden/test_metaprogramming/test_proc_shadowing.txt new file mode 100644 index 00000000..5a3d3670 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_proc_shadowing.txt @@ -0,0 +1,28 @@ +EXO IR: +def foo(a: f32 @ DRAM): + sin(a) +C: +#include "test.h" + +#include +#include + +// sin( +// a : f32 @DRAM +// ) +static void sin( void *ctxt, float* a ); + +// foo( +// a : f32 @DRAM +// ) +void foo( void *ctxt, float* a ) { +sin(ctxt,a); +} + +// sin( +// a : f32 @DRAM +// ) +static void sin( void *ctxt, float* a ) { +*a = 0.0f; +} + diff --git a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt index 3f3c8626..b111df4f 100644 --- a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt +++ b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = a + 1 + 1 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_quote_elision.txt b/tests/golden/test_metaprogramming/test_quote_elision.txt index da671d39..a22821c7 100644 --- a/tests/golden/test_metaprogramming/test_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_quote_elision.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i32 @ DRAM, b: i32 @ DRAM): + b = a +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt index 89ba4b00..c2d6b20c 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision1.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -1,3 +1,9 @@ +EXO IR: +def foo(a: i32 @ DRAM): + b: i32 @ DRAM + b = 2 + a = 1 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt index da671d39..a22821c7 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision2.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i32 @ DRAM, b: i32 @ DRAM): + b = a +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_scope_nesting.txt b/tests/golden/test_metaprogramming/test_scope_nesting.txt index db2f5260..0ae39ca1 100644 --- a/tests/golden/test_metaprogramming/test_scope_nesting.txt +++ b/tests/golden/test_metaprogramming/test_scope_nesting.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i8 @ DRAM, b: i8 @ DRAM): + a = b +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt index 8679fce5..331db00a 100644 --- a/tests/golden/test_metaprogramming/test_scoping.txt +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i8 @ DRAM): + a = 3 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_statement_assignment.txt b/tests/golden/test_metaprogramming/test_statement_assignment.txt index 71f64950..a8ea5b1a 100644 --- a/tests/golden/test_metaprogramming/test_statement_assignment.txt +++ b/tests/golden/test_metaprogramming/test_statement_assignment.txt @@ -1,3 +1,10 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a += 1 + a += 2 + a += 1 + a += 2 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_type_params.txt b/tests/golden/test_metaprogramming/test_type_params.txt index 23b4b196..98c6282a 100644 --- a/tests/golden/test_metaprogramming/test_type_params.txt +++ b/tests/golden/test_metaprogramming/test_type_params.txt @@ -1,3 +1,19 @@ +EXO IR: +def bar1(a: i32 @ DRAM, b: i8 @ DRAM): + c: i32[4] @ DRAM + for i in seq(0, 3): + d: i32 @ DRAM + d = b + c[i + 1] = a + c[i] * d + a = c[3] +def bar2(a: f64 @ DRAM, b: f64 @ DRAM): + c: f64[4] @ DRAM + for i in seq(0, 3): + d: f64 @ DRAM + d = b + c[i + 1] = a + c[i] * d + a = c[3] +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_type_quote_elision.txt b/tests/golden/test_metaprogramming/test_type_quote_elision.txt index 5db02aca..d9173f3d 100644 --- a/tests/golden/test_metaprogramming/test_type_quote_elision.txt +++ b/tests/golden/test_metaprogramming/test_type_quote_elision.txt @@ -1,3 +1,8 @@ +EXO IR: +def foo(a: i8 @ DRAM, x: i8[2] @ DRAM): + a += x[0] + a += x[1] +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unary_ops.txt b/tests/golden/test_metaprogramming/test_unary_ops.txt index 456b67fc..028ac6f3 100644 --- a/tests/golden/test_metaprogramming/test_unary_ops.txt +++ b/tests/golden/test_metaprogramming/test_unary_ops.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = -2 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unquote_elision.txt b/tests/golden/test_metaprogramming/test_unquote_elision.txt index da220cec..71079913 100644 --- a/tests/golden/test_metaprogramming/test_unquote_elision.txt +++ b/tests/golden/test_metaprogramming/test_unquote_elision.txt @@ -1,3 +1,7 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = a * 2 +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt index bc7554eb..de0fc0e9 100644 --- a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt +++ b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt @@ -1,3 +1,10 @@ +EXO IR: +def foo(a: [i8][2] @ DRAM): + a[0] += a[1] +def bar(a: i8[10, 10] @ DRAM): + for i in seq(0, 5): + foo(a[i, 2:4]) +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt index ead0c0db..49abf306 100644 --- a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt +++ b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt @@ -1,3 +1,11 @@ +EXO IR: +def foo(a: [i8][2, 2] @ DRAM): + a[0, 0] += a[0, 1] + a[1, 0] += a[1, 1] +def bar(a: i8[10, 10, 10] @ DRAM): + for i in seq(0, 7): + foo(a[i, i:i + 2, i + 1:i + 3]) +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt index 37da11d2..ea4f9798 100644 --- a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt +++ b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt @@ -1,3 +1,14 @@ +EXO IR: +def foo(a: [i8][2] @ DRAM): + a[0] += a[1] +def bar(a: i8[10, 10] @ DRAM): + for i in seq(0, 10): + foo(a[i, 1:3]) + for i in seq(0, 10): + foo(a[i, 5:7]) + for i in seq(0, 10): + foo(a[i, 2:4]) +C: #include "test.h" #include diff --git a/tests/golden/test_metaprogramming/test_unrolling.txt b/tests/golden/test_metaprogramming/test_unrolling.txt index f556b8d5..136c770c 100644 --- a/tests/golden/test_metaprogramming/test_unrolling.txt +++ b/tests/golden/test_metaprogramming/test_unrolling.txt @@ -1,3 +1,18 @@ +EXO IR: +def foo(a: i8 @ DRAM): + b: i8 @ DRAM + b = 0 + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a +C: #include "test.h" #include diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 2f98796e..d869480e 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -19,7 +19,7 @@ def foo(a: i8): c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_conditional(golden): @@ -37,10 +37,11 @@ def bar(a: i8): return bar - c_file, _ = compile_procs_to_strings( - [rename(foo(False), "bar1"), rename(foo(True), "bar2")], "test.h" - ) - assert c_file == golden + bar1 = rename(foo(False), "bar1") + bar2 = rename(foo(True), "bar2") + + c_file, _ = compile_procs_to_strings([bar1, bar2], "test.h") + assert f"EXO IR:\n{str(bar1)}\n{str(bar2)}\nC:\n{c_file}" == golden def test_scoping(golden): @@ -51,7 +52,7 @@ def foo(a: i8): a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_scope_nesting(golden): @@ -65,7 +66,7 @@ def foo(a: i8, b: i8): a = {~{b} if x == 3 and y == 2 else ~{a}} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_global_scope(): @@ -92,7 +93,7 @@ def foo(a: f64): a = {(x**x + x) / x} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_type_params(golden): @@ -108,10 +109,11 @@ def bar(a: {T}, b: {U}): return bar - c_file, _ = compile_procs_to_strings( - [rename(foo("i32", "i8"), "bar1"), rename(foo("f64", "f64"), "bar2")], "test.h" - ) - assert c_file == golden + bar1 = rename(foo("i32", "i8"), "bar1") + bar2 = rename(foo("f64", "f64"), "bar2") + + c_file, _ = compile_procs_to_strings([bar1, bar2], "test.h") + assert f"EXO IR:\n{str(bar1)}\n{str(bar2)}\nC:\n{c_file}" == golden def test_captured_closure(golden): @@ -129,7 +131,7 @@ def bar(a: i32): a += {cell[0]} c_file, _ = compile_procs_to_strings([bar], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(bar)}\nC:\n{c_file}" == golden def test_capture_nested_quote(golden): @@ -143,7 +145,7 @@ def foo(a: i32): a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_quote_elision(golden): @@ -158,7 +160,7 @@ def bar(): b = {bar()} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_unquote_elision(golden): @@ -170,7 +172,7 @@ def foo(a: i32): a = a * x c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_scope_collision1(golden): @@ -187,7 +189,7 @@ def foo(a: i32): a = c c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_scope_collision2(golden): @@ -199,7 +201,7 @@ def foo(a: i32, b: i32): b = a c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_scope_collision3(): @@ -225,7 +227,7 @@ def foo(a: T, x: T[2]): a += x[1] c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_unquote_in_slice(golden): @@ -242,7 +244,7 @@ def bar(a: i8[10, 10]): foo(a[i, {x} : {2 * x}]) c_file, _ = compile_procs_to_strings([foo, bar], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden def test_unquote_slice_object1(golden): @@ -259,7 +261,7 @@ def bar(a: i8[10, 10]): foo(a[i, s]) c_file, _ = compile_procs_to_strings([foo, bar], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden def test_unquote_slice_object2(): @@ -294,7 +296,7 @@ def get_index(i): foo(a[i, {get_index(i)}]) c_file, _ = compile_procs_to_strings([foo, bar], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden def test_unquote_err(): @@ -320,7 +322,7 @@ def bar(x): a = {bar(~{a + 1})} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_statement_assignment(golden): @@ -339,7 +341,7 @@ def foo(a: i32): {s} c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_statement_in_expr(): @@ -403,7 +405,7 @@ def foo(a: i32): a = x c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_return_in_async(): @@ -425,7 +427,7 @@ def foo(a: f64): a = my_sin(a) c_file, _ = compile_procs_to_strings([foo], "test.h") - assert c_file == golden + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden def test_unquote_multiple_exprs(): @@ -498,3 +500,16 @@ def foo(a: i32[20]): x = "0" with exo: a[x] = 0 + + +def test_proc_shadowing(golden): + @proc + def sin(a: f32): + a = 0 + + @proc + def foo(a: f32): + sin(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden From 89a628114baae1bba1348c3efbbc29b27c8771a3 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 6 Nov 2024 12:38:10 -0500 Subject: [PATCH 12/15] Document helpers used by metalanguage --- src/exo/frontend/pyparser.py | 87 +++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index 6b523a87..c2023737 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -36,11 +36,18 @@ def __init__(self, nm): @dataclass class SourceInfo: + """ + Source code locations that are needed to compute the location of AST nodes. + """ + src_file: str src_line_offset: int src_col_offset: int def get_src_info(self, node: pyast.AST): + """ + Computes the location of the given AST node based on line and column offsets. + """ return SrcInfo( filename=self.src_file, lineno=node.lineno + self.src_line_offset, @@ -85,20 +92,37 @@ def get_ast_from_python(f: Callable[..., Any]) -> tuple[pyast.stmt, SourceInfo]: @dataclass class BoundLocal: + """ + Wrapper class that represents locals that have been assigned a value. + """ + val: Any -Local = Optional[BoundLocal] +Local = Optional[BoundLocal] # Locals that are unassigned will be represesnted as None @dataclass class FrameScope: + """ + Wrapper around frame object to read local and global variables. + """ + frame: inspect.frame def get_globals(self) -> dict[str, Any]: + """ + Get globals dictionary for the frame. The globals dictionary is not a copy. If the + returned dictionary is modified, the globals of the scope will be changed. + """ return self.frame.f_globals def read_locals(self) -> dict[str, Local]: + """ + Return a copy of the local variables held by the scope. In contrast to globals, it is + not possible to add new local variables or modify the local variables by modifying + the returned dictionary. + """ return { var: ( BoundLocal(self.frame.f_locals[var]) @@ -113,6 +137,11 @@ def read_locals(self) -> dict[str, Local]: @dataclass class DummyScope: + """ + Wrapper for emulating a scope with a set of global and local variables. + Used for parsing patterns, which should not be able to capture local variables from the enclosing scope. + """ + global_dict: dict[str, Any] local_dict: dict[str, Any] @@ -123,7 +152,9 @@ def read_locals(self) -> dict[str, Any]: return self.local_dict.copy() -Scope = Union[DummyScope, FrameScope] +Scope = Union[ + DummyScope, FrameScope +] # Type to represent scopes, which have an API for getting global and local variables. def get_parent_scope(*, depth) -> Scope: @@ -174,33 +205,53 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): return parser.result() +# These constants are used to name helper variables that allow the metalanguage to be parsed and evaluated. +# All of them start with two underscores, so there is not collision in names if the user avoids using names +# with two underscores. QUOTE_CALLBACK_PREFIX = "__quote_callback" OUTER_SCOPE_HELPER = "__outer_scope" NESTED_SCOPE_HELPER = "__nested_scope" UNQUOTE_RETURN_HELPER = "__unquote_val" QUOTE_STMT_PROCESSOR = "__process_quote_stmt" + QUOTE_BLOCK_KEYWORD = "exo" UNQUOTE_BLOCK_KEYWORD = "python" @dataclass class ExoExpression: + """ + Opaque wrapper class for representing expressions in object code. Can be unquoted. + """ + _inner: Any # note: strict typing is not possible as long as PAST/UAST grammar definition is not static @dataclass class ExoStatementList: + """ + Opaque wrapper class for representing a list of statements in object code. Can be unquoted. + """ + _inner: tuple[Any, ...] @dataclass class QuoteReplacer(pyast.NodeTransformer): + """ + Replace quotes (Exo object code statements/expressions) in the metalanguage with calls to + helper functions that will parse and return the quoted code. + """ + src_info: SourceInfo exo_locals: dict[str, Any] unquote_env: "UnquoteEnv" inside_function: bool = False def visit_With(self, node: pyast.With) -> pyast.Any: + """ + Replace quoted statements. These will begin with "with exo:". + """ if ( len(node.items) == 1 and isinstance(node.items[0].context_expr, pyast.Name) @@ -254,6 +305,9 @@ def quote_callback( return super().generic_visit(node) def visit_UnaryOp(self, node: pyast.UnaryOp) -> Any: + """ + Replace quoted expressions. These will look like "~{...}". + """ if ( isinstance(node.op, pyast.Invert) and isinstance(node.operand, pyast.Set) @@ -286,6 +340,10 @@ def visit_Nonlocal(self, node: pyast.Nonlocal) -> Any: ) def visit_FunctionDef(self, node: pyast.FunctionDef): + """ + Record whether we are inside a function definition in the metalanguage, so that we can + prevent return statements that occur outside a function. + """ was_inside_function = self.inside_function self.inside_function = True result = super().generic_visit(node) @@ -310,11 +368,21 @@ def visit_Return(self, node): @dataclass class UnquoteEnv: + """ + Record of all the context needed to interpret a block of metalanguage code. + This includes the local and global variables of the scope that the metalanguage code will be evaluated in + and the Exo variables of the surrounding object code. + """ + parent_globals: dict[str, Any] parent_locals: dict[str, Local] exo_local_vars: dict[str, Any] def mangle_name(self, prefix: str) -> str: + """ + Create unique names for helper functions that are used to parse object code + (see QuoteReplacer). + """ index = 0 while True: mangled_name = f"{prefix}{index}" @@ -326,6 +394,10 @@ def mangle_name(self, prefix: str) -> str: index += 1 def register_quote_callback(self, quote_callback: Callable[..., Any]) -> str: + """ + Store helper functions that are used to parse object code so that they may be referenced + when we interpret the metalanguage code. + """ mangled_name = self.mangle_name(QUOTE_CALLBACK_PREFIX) self.parent_locals[mangled_name] = BoundLocal(quote_callback) return mangled_name @@ -335,6 +407,14 @@ def interpret_unquote_block( stmts: list[pyast.stmt], quote_stmt_processor: Optional[Callable[[Any], None]], ) -> Any: + """ + Interpret a metalanguage block of code. This is done by pasting the AST of the metalanguage code + into a helper function that sets up the local variables that need to be referenced in the metalanguage code, + and then calling that helper function. + + This function is also used to parse metalanguage expressions by representing the expressions as return statements + and saving the output returned by the helper function. + """ bound_locals = { name: val.val for name, val in self.parent_locals.items() if val is not None } @@ -486,6 +566,9 @@ def interpret_unquote_block( return env_locals[UNQUOTE_RETURN_HELPER] def interpret_unquote_expr(self, expr: pyast.expr): + """ + Parse a metalanguage expression using the machinery provided by interpret_unquote_block. + """ return self.interpret_unquote_block([pyast.Return(value=expr)], None) From 44e59a57d31041645666bbe967cdf3f1e38d03fb Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 15 Jan 2025 11:39:21 -0800 Subject: [PATCH 13/15] write doc --- docs/Metaprogramming.md | 128 ++++++++++++++++++ .../test_eval_expr_in_mem.txt | 16 +++ tests/test_metaprogramming.py | 12 ++ 3 files changed, 156 insertions(+) create mode 100644 docs/Metaprogramming.md create mode 100644 tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt diff --git a/docs/Metaprogramming.md b/docs/Metaprogramming.md new file mode 100644 index 00000000..6be41d43 --- /dev/null +++ b/docs/Metaprogramming.md @@ -0,0 +1,128 @@ +# Metaprogramming + +In the context of Exo, metaprogramming refers to the composition of [object code](object_code.md) fragments, similar to macros in languages like C. Unlike scheduling operations, metaprogramming does not seek to preserve equivalence as it transforms the object code - instead, it stitches together Exo code fragments, allowing the user to make code more concise or parametrizable. + +The user can get a reference to one of these Exo code fragments through *quoting*, which produces a Python reference to the code fragment. After manipulating this code fragment as a Python object, the user can then paste in a code fragment from Python through *unquoting*. + +## Quoting and Unquoting Statements + +An unquote statement composes any quoted fragments that are executed within it. Syntactically, it is a block of *Python* code which is wrapped in a `with python:` block. Within this block, there may be multiple quoted *Exo* fragments which get executed, which are represented as `with exo:` blocks. + +Note that we are carefully distinguishing *Python* code from *Exo* code here. The Python code inside the `with python:` block does not describe any operations in Exo. Instead, it describes how the Exo fragments within it are composed. Thus, this code can use familiar Python constructs, such as `range(...)` loops (as opposed to Exo's `seq(...)` loops). + +An unquote statement will only read a quoted fragment when its corresponding `with exo:` block gets executed in the Python code. So, the following example results in an empty Exo procedure: +```python +@proc +def foo(a: i32): + with python: + if False: + with exo: + a += 1 +``` + +A `with exo:` may also be executed multiple times. The following example compiles to 10 `a += 1` statements in a row: +```python +@proc +def foo(a: i32): + with python: + for i in range(10): + with exo: + a += 1 +``` + +## Quoting and Unquoting Expressions + +An unquote expression reads the Exo expression that is referred to by a Python object. This is syntactically represented as `{...}`, where the insides of the braces are interpreted as a Python object. To obtain a Python object that refers to an Exo expression, one can use an unquote expression, represented as `~{...}`. + +As a simple example, we can try iterating through a list of Exo expressions. The following example should be equivalent to `a += a; a += b * 2`: +```python +@proc +def foo(a: i32, b: i32): + with python: + exprs = [~{a}, ~{b * 2}] + for expr in exprs: + with exo: + a += {expr} +``` + +### Implicit Quotes and Unquotes + +As we can see from the example, it is often the case that quote and unquote expressions will consist of a single variable. For convenience, if a variable name would otherwise be an invalid reference, the parser will try unquoting or quoting it before throwing an error. So, the following code is equivalent to the previous example: +```python +@proc +def foo(a: i32, b: i32): + with python: + exprs = [a, ~{b * 2}] + for expr in exprs: + with exo: + a += expr +``` + +### Unquoting Numbers + +Besides quoted expressions, a Python number can also be unquoted and converted into the corresponding numeric literal in Exo. The following example will alternate between `a += 1` and `a += 2` 10 times: +```python +@proc +def foo(a: i32): + with python: + for i in range(10): + with exo: + a += {i % 2} +``` + +### Unquoting Types + +When an unquote expression occurs in the place that a type would normally be used in Exo, for instance in the declaration of function arguments, the unquote expression will read the Python object as a string and parse it as the corresponding type. The following example will take an argument whose type depends on the first statement: +```python +T = "i32" + +@proc +def foo(a: {T}, b: {T}): + a += b +``` + +### Unquoting Indices + +Unquote expressions can also be used to index into a buffer. The Python object that gets unquoted may be a single Exo expression, a number, or a slice object. + +### Unquoting Memories + +Memory objects can also be unquoted. Note that memories in Exo correspond to Python objects in the base language anyway, so the process of unquoting an object representing a type of memory in Exo is relatively straightforward. For instance, the memory used to pass in the arguments to this function are determined by the first line: +```python +mem = DRAM + +@proc +def foo(a: i32 @ {mem}, b: i32 @ {mem}): + a += b +``` + +## Binding Quoted Statements to Variables + +A quoted Exo statement does not have to be executed immediately in the place that it is declared. Instead, the quote may be stored in a Python variable using the syntax `with exo as ...:`. It can then be unquoted with the `{...}` operator if it appears as a statement. + +The following example is equivalent to `a += b; a += b`: +```python +@proc +def foo(a: i32, b: i32): + with python: + with exo as stmt: + a += b + {stmt} + {stmt} +``` + +## Limitations + +- There is currently no support for defining quotes outside of an Exo procedure. Thus, it is difficult to share metaprogramming logic between two different Exo procedures. +- Attempting to execute a quoted statement while unquoting an expression will result in an error being thrown. Since Exo expressions do not have side effects, the semantics of such a program would be unclear if allowed. For instance: +```python +@proc +def foo(a: i32): + with python: + def bar(): + with exo: + a += 1 + return 2 + a *= {bar()} # illegal! +``` +- Identifiers that appear on the left hand side of assignment and reductions in Exo cannot be unquoted. This is partly due to limitations in the Python grammar, which Exo must conform to. \ No newline at end of file diff --git a/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt new file mode 100644 index 00000000..29bc1782 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: f32 @ DRAM): + pass +C: +#include "test.h" + +#include +#include + +// foo( +// a : f32 @DRAM +// ) +void foo( void *ctxt, const float* a ) { +; // NO-OP +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index d869480e..7067db37 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -5,6 +5,7 @@ import pytest import warnings from exo.libs.externs import * +from exo.platforms.x86 import DRAM def test_unrolling(golden): @@ -513,3 +514,14 @@ def foo(a: f32): c_file, _ = compile_procs_to_strings([foo], "test.h") assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_eval_expr_in_mem(golden): + mems = [DRAM] + + @proc + def foo(a: f32 @ mems[0]): + pass + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden From 91c62b25111f81e9b2774f3a69c24b3bbd4791be Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 15 Jan 2025 11:41:44 -0800 Subject: [PATCH 14/15] add to readme --- docs/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/README.md b/docs/README.md index 7fb03de9..9e26170f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,7 @@ This directory provides detailed documentation about Exo's interface and interna - To learn how to define **hardware targets externally to the compiler**, refer to [externs.md](externs.md), [instructions.md](instructions.md), and [memories.md](memories.md). - To learn how to define **new scheduling operations externally to the compiler**, refer to [Cursors.md](./Cursors.md) and [inspection.md](./inspection.md). - To understand the available scheduling primitives and how to use them, look into the [primitives/](./primitives) directory. +- To learn about metaprogramming as a method for writing cleaner code, see [Metaprogramming.md](Metaprogramming.md). The scheduling primitives are classified into six categories: From e09d7eb7ca1bf714a5946ca658a6400ea116fd00 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Mon, 3 Feb 2025 03:26:57 -0800 Subject: [PATCH 15/15] Clarify unquoting use cases + change implicit quote/unquote scoping rules --- docs/Metaprogramming.md | 45 +++++-- src/exo/frontend/pyparser.py | 113 +++++++++++------- .../test_capture_nested_quote.txt | 12 +- .../test_implicit_lhs_unquote.txt | 19 +++ .../test_scope_collision1.txt | 4 +- .../test_scope_collision2.txt | 4 +- .../test_metaprogramming/test_scoping.txt | 4 +- tests/test_metaprogramming.py | 13 ++ 8 files changed, 147 insertions(+), 67 deletions(-) create mode 100644 tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt diff --git a/docs/Metaprogramming.md b/docs/Metaprogramming.md index 6be41d43..751701d7 100644 --- a/docs/Metaprogramming.md +++ b/docs/Metaprogramming.md @@ -14,6 +14,7 @@ An unquote statement will only read a quoted fragment when its corresponding `wi ```python @proc def foo(a: i32): + pass with python: if False: with exo: @@ -47,7 +48,7 @@ def foo(a: i32, b: i32): ### Implicit Quotes and Unquotes -As we can see from the example, it is often the case that quote and unquote expressions will consist of a single variable. For convenience, if a variable name would otherwise be an invalid reference, the parser will try unquoting or quoting it before throwing an error. So, the following code is equivalent to the previous example: +As we can see from the example, it is often the case that quote and unquote expressions will consist of a single variable. For convenience, if a variable name would otherwise be an invalid reference in the current scope (within the current `with ...:` block), the parser will search up progressively larger scopes before throwing an error, while implicitly unquoting or quoting if necessary. So, the following code is equivalent to the previous example: ```python @proc def foo(a: i32, b: i32): @@ -58,6 +59,17 @@ def foo(a: i32, b: i32): a += expr ``` +An implicit unquote may occur on the left-hand side of an assignment if and only if it references an implicit quote. Thus, the following code adds 1 to `a` and `b` by implicitly unquoting `sym` on the left-hand side: +```python +@proc +def foo(a: i32, b: i32): + with python: + syms = [a, b] + for sym in syms: + with exo: + sym += 1 +``` + ### Unquoting Numbers Besides quoted expressions, a Python number can also be unquoted and converted into the corresponding numeric literal in Exo. The following example will alternate between `a += 1` and `a += 2` 10 times: @@ -72,33 +84,43 @@ def foo(a: i32): ### Unquoting Types -When an unquote expression occurs in the place that a type would normally be used in Exo, for instance in the declaration of function arguments, the unquote expression will read the Python object as a string and parse it as the corresponding type. The following example will take an argument whose type depends on the first statement: +When an unquote expression occurs in the place that a primitive type would normally be used in Exo, for instance in the declaration of function arguments, the unquote expression will read the Python object as a string and parse it as the corresponding type. More complicated types such as arrays cannot be directly unquoted. The following example will take an argument whose type depends on the first statement: ```python T = "i32" @proc -def foo(a: {T}, b: {T}): - a += b +def foo(a: {T}[5], b: {T}): + a[0] += b ``` ### Unquoting Indices -Unquote expressions can also be used to index into a buffer. The Python object that gets unquoted may be a single Exo expression, a number, or a slice object. +Unquote expressions can also be used to index into a buffer. The Python object that gets unquoted may be a single Exo expression, a number, or a slice object where the bounds are numbers or Exo expressions. The following example will execute `foo` on a variety of slices in `a`: +```python +@proc +def bar(n: size, a: R[n]): + assert n > 2 + with python: + slices = [slice(1, ~{n - 1}), slice(0, 1), slice(0, ~{n})] + for s in slices: + with exo: + foo({s.stop} - {s.start}, a[{s}]) +``` ### Unquoting Memories -Memory objects can also be unquoted. Note that memories in Exo correspond to Python objects in the base language anyway, so the process of unquoting an object representing a type of memory in Exo is relatively straightforward. For instance, the memory used to pass in the arguments to this function are determined by the first line: +Memory objects are always unquoted without the `{...}` notation, since memories in Exo correspond to Python objects in the base language anyway. For instance, the memory used to pass in the arguments to this function are determined by the first line: ```python -mem = DRAM +mems = [DRAM] @proc -def foo(a: i32 @ {mem}, b: i32 @ {mem}): +def foo(a: i32 @ mems[0], b: i32 @ mems[0]): a += b ``` ## Binding Quoted Statements to Variables -A quoted Exo statement does not have to be executed immediately in the place that it is declared. Instead, the quote may be stored in a Python variable using the syntax `with exo as ...:`. It can then be unquoted with the `{...}` operator if it appears as a statement. +A quoted Exo statement does not have to be executed immediately in the place that it is declared. Instead, the quote may be stored in a Python variable using the syntax `with exo as ...:`. It can then be unquoted with the `{...}` operator if it appears as a statement in an Exo scope. The following example is equivalent to `a += b; a += b`: ```python @@ -107,8 +129,9 @@ def foo(a: i32, b: i32): with python: with exo as stmt: a += b - {stmt} - {stmt} + with exo: + {stmt} + {stmt} ``` ## Limitations diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index c2023737..f85c928f 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -218,6 +218,16 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): UNQUOTE_BLOCK_KEYWORD = "python" +@dataclass +class ExoSymbol: + """ + Opaque wrapper class for representing symbols in object code. Can be unquoted in both expressions and + implicitly on left-hand side of statements. + """ + + _inner: Sym + + @dataclass class ExoExpression: """ @@ -244,7 +254,6 @@ class QuoteReplacer(pyast.NodeTransformer): """ src_info: SourceInfo - exo_locals: dict[str, Any] unquote_env: "UnquoteEnv" inside_function: bool = False @@ -266,7 +275,6 @@ def parse_quote_block(): self.src_info, parent_scope=get_parent_scope(depth=3), is_quote_stmt=True, - parent_exo_locals=self.exo_locals, ).result() if stmt_destination is None: @@ -321,7 +329,6 @@ def quote_callback(): self.src_info, parent_scope=get_parent_scope(depth=2), is_quote_expr=True, - parent_exo_locals=self.exo_locals, ).result() ) @@ -376,7 +383,7 @@ class UnquoteEnv: parent_globals: dict[str, Any] parent_locals: dict[str, Local] - exo_local_vars: dict[str, Any] + exo_locals: dict[str, Any] def mangle_name(self, prefix: str) -> str: """ @@ -415,18 +422,22 @@ def interpret_unquote_block( This function is also used to parse metalanguage expressions by representing the expressions as return statements and saving the output returned by the helper function. """ - bound_locals = { - name: val.val for name, val in self.parent_locals.items() if val is not None + quote_locals = { + name: ExoSymbol(val) + for name, val in self.exo_locals.items() + if isinstance(val, Sym) } unbound_names = { - name for name, val in self.parent_locals.items() if val is None + name + for name, val in self.parent_locals.items() + if val is None and name not in quote_locals } - quote_locals = { - name: ExoExpression(val) - for name, val in self.exo_local_vars.items() - if name not in self.parent_locals + bound_locals = { + name: val.val + for name, val in self.parent_locals.items() + if val is not None and name not in quote_locals } - env_locals = {**quote_locals, **bound_locals} + env_locals = {**bound_locals, **quote_locals} old_stmt_processor = ( self.parent_globals[QUOTE_STMT_PROCESSOR] if QUOTE_STMT_PROCESSOR in self.parent_globals @@ -502,6 +513,13 @@ def interpret_unquote_block( ) for arg in unbound_names ], + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in quote_locals + ], ], ctx=pyast.Load(), ), @@ -607,11 +625,10 @@ def __init__( instr=None, is_quote_stmt=False, is_quote_expr=False, - parent_exo_locals=None, ): self.module_ast = module_ast self.parent_scope = parent_scope - self.exo_locals = ChainMap() if parent_exo_locals is None else parent_exo_locals + self.exo_locals = ChainMap() self.src_info = src_info self.is_fragment = is_fragment @@ -673,13 +690,6 @@ def pop(self): def err(self, node, errstr, origin=None): raise ParseError(f"{self.getsrcinfo(node)}: {errstr}") from origin - def make_exo_var_asts(self, srcinfo): - return { - name: self.AST.Read(val, [], srcinfo) - for name, val in self.exo_locals.items() - if isinstance(val, Sym) - } - def try_eval_unquote( self, unquote_node: pyast.expr ) -> Union[tuple[()], tuple[Any]]: @@ -690,11 +700,9 @@ def try_eval_unquote( unquote_env = UnquoteEnv( self.parent_scope.get_globals(), self.parent_scope.read_locals(), - self.make_exo_var_asts(self.getsrcinfo(unquote_node)), - ) - quote_replacer = QuoteReplacer( - self.src_info, self.exo_locals, unquote_env + self.exo_locals, ) + quote_replacer = QuoteReplacer(self.src_info, unquote_env) unquoted = unquote_env.interpret_unquote_expr( quote_replacer.visit(copy.deepcopy(unquote_node.elts[0])) ) @@ -712,7 +720,7 @@ def try_eval_unquote( UnquoteEnv( cur_globals, cur_locals, - self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + self.exo_locals, ).interpret_unquote_expr(unquote_node), ) if unquote_node.id in cur_locals or unquote_node.id in cur_globals @@ -729,7 +737,7 @@ def eval_expr(self, expr): **self.parent_scope.read_locals(), **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, }, - self.make_exo_var_asts(self.getsrcinfo(expr)), + self.exo_locals, ).interpret_unquote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # @@ -1019,11 +1027,10 @@ def parse_stmt_block(self, stmts): unquote_env = UnquoteEnv( self.parent_scope.get_globals(), self.parent_scope.read_locals(), - self.make_exo_var_asts(self.getsrcinfo(s)), + self.exo_locals, ) quote_stmt_replacer = QuoteReplacer( self.src_info, - self.exo_locals, unquote_env, ) unquote_env.interpret_unquote_block( @@ -1160,25 +1167,31 @@ def parse_stmt_block(self, stmts): typ, mem = self.parse_alloc_typmem(s.annotation) rstmts.append(UAST.Alloc(nm, typ, mem, self.getsrcinfo(s))) - # handle cases of ambiguous assignment to undefined - # variables - if ( - isinstance(s, pyast.Assign) - and len(idxs) == 0 - and name_node.id not in self.exo_locals - ): - nm = Sym(name_node.id) - self.exo_locals[name_node.id] = nm - do_fresh_assignment = True - else: - do_fresh_assignment = False + do_fresh_assignment = False # get the symbol corresponding to the name on the # left-hand-side if isinstance(s, (pyast.Assign, pyast.AugAssign)): if name_node.id not in self.exo_locals: - self.err(name_node, f"variable '{name_node.id}' undefined") - nm = self.exo_locals[name_node.id] + unquote_eval_result = self.try_eval_unquote( + pyast.Name(id=name_node.id, ctx=pyast.Load()) + ) + if len(unquote_eval_result) == 1 and isinstance( + unquote_eval_result[0], ExoSymbol + ): + nm = unquote_eval_result[0]._inner + elif len(idxs) == 0 and name_node.id not in self.exo_locals: + # handle cases of ambiguous assignment to undefined + # variables + nm = Sym(name_node.id) + self.exo_locals[name_node.id] = nm + do_fresh_assignment = True + else: + self.err( + name_node, f"variable '{name_node.id}' undefined" + ) + else: + nm = self.exo_locals[name_node.id] if isinstance(nm, SizeStub): self.err( name_node, @@ -1393,6 +1406,8 @@ def unquote_to_index(unquoted, ref_node, srcinfo, top_level): unquoted._inner, self.AST.expr ): return unquoted._inner + elif isinstance(unquoted, ExoSymbol): + return self.AST.Read(unquoted._inner, [], srcinfo) elif isinstance(unquoted, slice) and top_level: if unquoted.step is None: return UAST.Interval( @@ -1493,6 +1508,8 @@ def parse_expr(self, e): unquoted._inner, self.AST.expr ): return unquoted._inner + elif isinstance(unquoted, ExoSymbol): + return self.AST.Read(unquoted._inner, [], self.getsrcinfo(e)) else: self.err(e, "Unquote received input that couldn't be unquoted") elif isinstance(e, (pyast.Name, pyast.Subscript)): @@ -1508,7 +1525,15 @@ def parse_expr(self, e): if nm_node.id in self.exo_locals: nm = self.exo_locals[nm_node.id] else: - self.err(nm_node, f"variable '{nm_node.id}' undefined") + unquote_eval_result = self.try_eval_unquote( + pyast.Name(id=nm_node.id, ctx=pyast.Load()) + ) + if len(unquote_eval_result) == 1 and isinstance( + unquote_eval_result[0], ExoSymbol + ): + nm = unquote_eval_result[0]._inner + else: + self.err(nm_node, f"variable '{nm_node.id}' undefined") if isinstance(nm, SizeStub): nm = nm.nm diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt index 56a54d49..8e508282 100644 --- a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -1,8 +1,8 @@ EXO IR: def foo(a: i32 @ DRAM): - a = 2 - a = 2 - a = 2 + a = a + a = a + a = a C: #include "test.h" @@ -13,8 +13,8 @@ C: // a : i32 @DRAM // ) void foo( void *ctxt, int32_t* a ) { -*a = ((int32_t) 2); -*a = ((int32_t) 2); -*a = ((int32_t) 2); +*a = *a; +*a = *a; +*a = *a; } diff --git a/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt b/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt new file mode 100644 index 00000000..c8720ccf --- /dev/null +++ b/tests/golden/test_metaprogramming/test_implicit_lhs_unquote.txt @@ -0,0 +1,19 @@ +EXO IR: +def foo(a: i32 @ DRAM, b: i32 @ DRAM): + a += 1 + b += 1 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM, +// b : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a, int32_t* b ) { +*a += ((int32_t) 1); +*b += ((int32_t) 1); +} + diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt index c2d6b20c..bc9b6758 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision1.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -2,7 +2,7 @@ EXO IR: def foo(a: i32 @ DRAM): b: i32 @ DRAM b = 2 - a = 1 + a = b C: #include "test.h" @@ -15,6 +15,6 @@ C: void foo( void *ctxt, int32_t* a ) { int32_t b; b = ((int32_t) 2); -*a = ((int32_t) 1); +*a = b; } diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt index a22821c7..fe7faf52 100644 --- a/tests/golden/test_metaprogramming/test_scope_collision2.txt +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -1,6 +1,6 @@ EXO IR: def foo(a: i32 @ DRAM, b: i32 @ DRAM): - b = a + b = 1 C: #include "test.h" @@ -12,6 +12,6 @@ C: // b : i32 @DRAM // ) void foo( void *ctxt, const int32_t* a, int32_t* b ) { -*b = *a; +*b = ((int32_t) 1); } diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt index 331db00a..ddd9e9f3 100644 --- a/tests/golden/test_metaprogramming/test_scoping.txt +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -1,6 +1,6 @@ EXO IR: def foo(a: i8 @ DRAM): - a = 3 + a = a C: #include "test.h" @@ -11,6 +11,6 @@ C: // a : i8 @DRAM // ) void foo( void *ctxt, int8_t* a ) { -*a = ((int8_t) 3); +*a = *a; } diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index 7067db37..c95e6394 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -525,3 +525,16 @@ def foo(a: f32 @ mems[0]): c_file, _ = compile_procs_to_strings([foo], "test.h") assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_implicit_lhs_unquote(golden): + @proc + def foo(a: i32, b: i32): + with python: + syms = [a, b] + for sym in syms: + with exo: + sym += 1 + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden