From cb416d4ab8213e95c0858aa3c71fd20718ef5486 Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Sun, 13 Oct 2024 14:27:55 -0400 Subject: [PATCH] Externalize builtins and fix bugs (#637) - Rename builtins to _externs_ - Externalize _externs_ - Fix bug 1: extern did not work on read expressions with index accesses - Fix bug 2: extern fnarg codegen should be `comp_e` not `comp_fnarg` - Fix bug 3: extern did not work on constants - Change pattern_match and parse_fragment to take local and global variables in scope. Fixes #505 --- apps/x86/conv/conv.py | 2 +- src/exo/API.py | 10 +- src/exo/API_cursors.py | 23 +- src/exo/API_scheduling.py | 11 +- src/exo/LoopIR.py | 19 +- src/exo/LoopIR_compiler.py | 46 +-- src/exo/LoopIR_pprint.py | 4 +- src/exo/LoopIR_unification.py | 6 +- src/exo/__init__.py | 2 + src/exo/boundscheck.py | 2 +- src/exo/builtins.py | 142 --------- src/exo/extern.py | 36 +++ src/exo/libs/externs.py | 234 ++++++++++++++ src/exo/mem_analysis.py | 2 +- src/exo/new_eff.py | 2 +- src/exo/parse_fragment.py | 28 +- src/exo/pattern_match.py | 22 +- src/exo/platforms/gemmini.py | 7 +- src/exo/platforms/x86.py | 1 + src/exo/prec_analysis.py | 34 ++ src/exo/pyparser.py | 55 ++-- src/exo/stdlib/inspection.py | 6 +- src/exo/typecheck.py | 8 +- .../test_gemmini_matmul_new/test_matmul.txt | 5 - .../test_gemmini_matmul_old/test_matmul.txt | 5 - tests/golden/test_apps/test_gemmini_conv.txt | 16 +- .../golden/test_apps/test_gemmini_matmul.txt | 22 +- tests/golden/test_apps/test_x86_conv.txt | 5 - tests/golden/test_externs/test_expf.txt | 61 ++++ .../golden/test_externs/test_extern_find.txt | 2 + tests/golden/test_externs/test_fmaxf.txt | 61 ++++ tests/golden/test_externs/test_relu.txt | 63 ++++ tests/golden/test_externs/test_relu2.txt | 63 ++++ tests/golden/test_externs/test_relu3.txt | 67 ++++ tests/golden/test_externs/test_relu4.txt | 63 ++++ tests/golden/test_externs/test_select.txt | 67 ++++ tests/golden/test_externs/test_sigmoid.txt | 66 ++++ tests/golden/test_externs/test_sin.txt | 59 ++++ tests/golden/test_externs/test_sqrt.txt | 61 ++++ tests/test_codegen.py | 1 + tests/test_config.py | 1 + tests/test_cursors.py | 1 + tests/test_externs.py | 292 ++++++++++++++++++ tests/test_typecheck.py | 1 + 44 files changed, 1402 insertions(+), 282 deletions(-) delete mode 100644 src/exo/builtins.py create mode 100644 src/exo/extern.py create mode 100644 src/exo/libs/externs.py create mode 100644 tests/golden/test_externs/test_expf.txt create mode 100644 tests/golden/test_externs/test_extern_find.txt create mode 100644 tests/golden/test_externs/test_fmaxf.txt create mode 100644 tests/golden/test_externs/test_relu.txt create mode 100644 tests/golden/test_externs/test_relu2.txt create mode 100644 tests/golden/test_externs/test_relu3.txt create mode 100644 tests/golden/test_externs/test_relu4.txt create mode 100644 tests/golden/test_externs/test_select.txt create mode 100644 tests/golden/test_externs/test_sigmoid.txt create mode 100644 tests/golden/test_externs/test_sin.txt create mode 100644 tests/golden/test_externs/test_sqrt.txt create mode 100644 tests/test_externs.py diff --git a/apps/x86/conv/conv.py b/apps/x86/conv/conv.py index c8282c610..b89706e4f 100644 --- a/apps/x86/conv/conv.py +++ b/apps/x86/conv/conv.py @@ -1,7 +1,7 @@ from __future__ import annotations from exo import * -from exo.builtins import * +from exo.libs.externs import * from exo.platforms.x86 import * from exo.syntax import * from exo.stdlib.scheduling import * diff --git a/src/exo/API.py b/src/exo/API.py index 7889f5509..2dea74528 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -245,7 +245,7 @@ def body(self): block = self._root()._child_block("body") return C.lift_cursor(block, self) - def find(self, pattern, many=False): + def find(self, pattern, many=False, call_depth=1): """ Find the most specific possible cursor for the given pattern. For example, a pattern matching a single assignment statement @@ -256,7 +256,7 @@ def find(self, pattern, many=False): In any event, if no matches are found, a SchedulingError is raised """ - return C.find(self._root(), self, pattern, many) + return C.find(self._root(), self, pattern, many, call_depth=call_depth + 1) def find_loop(self, pattern, many=False): """ @@ -273,7 +273,7 @@ def find_loop(self, pattern, many=False): name, count = results[1], (results[2] if results[2] else "") pattern = f"for {name} in _: _{count}" - return self.find(pattern, many) + return self.find(pattern, many, call_depth=1) def find_alloc_or_arg(self, pattern): _name_count_re = r"^([a-zA-Z_]\w*)\s*(\#\s*[0-9]+)?$" @@ -286,10 +286,10 @@ def find_alloc_or_arg(self, pattern): pattern = f"{name}: _{count}" - return self.find(pattern) + return self.find(pattern, call_depth=1) def find_all(self, pattern): - return self.find(pattern, many=True) + return self.find(pattern, many=True, call_depth=1) # ---------------------------------------------- # # execution / compilation operations diff --git a/src/exo/API_cursors.py b/src/exo/API_cursors.py index 2a8b8b755..639157462 100644 --- a/src/exo/API_cursors.py +++ b/src/exo/API_cursors.py @@ -72,9 +72,8 @@ class Cursor(ABC): | Literal( value : bool, int, or float ) | UnaryMinus( arg : Expr ) | BinaryOp( op : str, lhs : Expr, rhs : Expr ) - | BuiltIn( name : str, args : ExprList ) + | Extern( name : str, args : ExprList ) | WindowExpr( name : str, idx : *(see below) ) - | BuiltIn( name : str, args : ExprList ) The `idx` argument of `WindowExpr` is a list containing either `Expr` or `(Expr,Expr)` (a pair of expressions) at each position. @@ -128,8 +127,8 @@ def parent(self): return InvalidCursor() return lift_cursor(impl_parent, self._proc) - def find(self, pattern, many=False): - return find(self._impl, self._proc, pattern, many) + def find(self, pattern, many=False, call_depth=1): + return find(self._impl, self._proc, pattern, many, call_depth=call_depth + 1) def _child_node(self, *args, **kwargs): return lift_cursor(self._impl._child_node(*args, **kwargs), self._proc) @@ -783,7 +782,7 @@ def rhs(self) -> ExprCursor: return self._child_node("rhs") -class BuiltInFunctionCursor(ExprCursor): +class ExternFunctionCursor(ExprCursor): """ Cursor pointing to the call to some built-in function `name ( args )` @@ -791,13 +790,13 @@ class BuiltInFunctionCursor(ExprCursor): def name(self) -> str: assert isinstance(self._impl, C.Node) - assert isinstance(self._impl._node, LoopIR.BuiltIn) + assert isinstance(self._impl._node, LoopIR.Extern) return self._impl._node.f.name() def args(self) -> ExprListCursor: assert isinstance(self._impl, C.Node) - assert isinstance(self._impl._node, LoopIR.BuiltIn) + assert isinstance(self._impl._node, LoopIR.Extern) return ExprListCursor(self._impl._child_block("args"), self._proc) @@ -923,8 +922,8 @@ def lift_cursor(impl, proc): return UnaryMinusCursor(impl, proc) elif isinstance(n, LoopIR.BinOp): return BinaryOpCursor(impl, proc) - elif isinstance(n, LoopIR.BuiltIn): - return BuiltInFunctionCursor(impl, proc) + elif isinstance(n, LoopIR.Extern): + return ExternFunctionCursor(impl, proc) elif isinstance(n, LoopIR.WindowExpr): return WindowExprCursor(impl, proc) elif isinstance(n, LoopIR.StrideExpr): @@ -937,7 +936,7 @@ def lift_cursor(impl, proc): assert False, f"bad case: {type(impl)}" -def find(scope: C, proc: API.Procedure, pattern: str, many: bool): +def find(scope: C, proc: API.Procedure, pattern: str, many: bool, call_depth=1): """ Find the most specific possible cursor for the given pattern in the given scope of the proc. For example, a pattern matching a @@ -953,7 +952,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool): raise TypeError("expected a pattern string") default_match_no = None if many else 0 raw_cursors = match_pattern( - scope, pattern, call_depth=1, default_match_no=default_match_no + scope, pattern, call_depth=call_depth + 1, default_match_no=default_match_no ) assert isinstance(raw_cursors, list) cursors = [] @@ -1000,7 +999,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool): "LiteralCursor", "UnaryMinusCursor", "BinaryOpCursor", - "BuiltInFunctionCursor", + "ExternFunctionCursor", "WindowExprCursor", "StrideExprCursor", # diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index 54513d99a..9a23d6faf 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -381,8 +381,7 @@ def _cursor_call(self, expr_pattern, all_args): self.err("expected an ExprCursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(expr_pattern, many=self.match_many) + matches = proc.find(expr_pattern, many=self.match_many, call_depth=1) if self.match_many: for m in matches: @@ -411,8 +410,7 @@ def _cursor_call(self, stmt_pattern, all_args): self.err("expected a StmtCursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(stmt_pattern, many=self.match_many) + matches = proc.find(stmt_pattern, many=self.match_many, call_depth=1) match = matches[0] if self.match_many else matches if not isinstance(match, PC.StmtCursor): @@ -441,8 +439,7 @@ def _cursor_call(self, block_pattern, all_args): self.err("expected a Cursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(block_pattern, many=self.match_many) + matches = proc.find(block_pattern, many=self.match_many, call_depth=1) match = matches[0] if self.match_many else matches if isinstance(match, PC.StmtCursor): @@ -540,7 +537,7 @@ def _cursor_call(self, alloc_pattern, all_args): if not isinstance(cursor, (PC.AllocCursor, PC.ArgCursor)): proc = all_args["proc"] try: - cursor = proc.find(alloc_pattern) + cursor = proc.find(alloc_pattern, call_depth=1) except: for arg in proc.args(): if arg.name() == name: diff --git a/src/exo/LoopIR.py b/src/exo/LoopIR.py index 733b0d133..9ee862779 100644 --- a/src/exo/LoopIR.py +++ b/src/exo/LoopIR.py @@ -4,7 +4,7 @@ from asdl_adt import ADT, validators -from .builtins import BuiltIn +from .extern import Extern from .configs import Config from .memory import Memory from .prelude import Sym, SrcInfo, extclass @@ -92,7 +92,7 @@ def __new__(cls, op): | Const( object val ) | USub( expr arg ) -- i.e. -(...) | BinOp( binop op, expr lhs, expr rhs ) - | BuiltIn( builtin f, expr* args ) + | Extern( extern f, expr* args ) | WindowExpr( sym name, w_access* idx ) | StrideExpr( sym name, int dim ) | ReadConfig( config config, string field ) @@ -130,7 +130,7 @@ def __new__(cls, op): "name": validators.instance_of(Identifier, convert=True), "sym": Sym, "mem": Type[Memory], - "builtin": BuiltIn, + "extern": Extern, "config": Config, "binop": validators.instance_of(Operator, convert=True), "srcinfo": SrcInfo, @@ -190,7 +190,7 @@ def __new__(cls, op): | Const ( object val ) | USub ( expr arg ) -- i.e. -(...) | BinOp ( op op, expr lhs, expr rhs ) - | BuiltIn( builtin f, expr* args ) + | Extern( extern f, expr* args ) | WindowExpr( sym name, w_access* idx ) | StrideExpr( sym name, int dim ) | ParRange( expr lo, expr hi ) -- only use for loop cond @@ -221,7 +221,7 @@ def __new__(cls, op): "name": validators.instance_of(Identifier, convert=True), "sym": Sym, "mem": Type[Memory], - "builtin": BuiltIn, + "extern": Extern, "config": Config, "loopir_proc": LoopIR.proc, "op": validators.instance_of(Operator, convert=True), @@ -270,14 +270,13 @@ def __new__(cls, op): | Const ( object val ) | USub ( expr arg ) -- i.e. -(...) | BinOp ( op op, expr lhs, expr rhs ) - | BuiltIn ( builtin f, expr* args ) + | Extern ( name f, expr* args ) | ReadConfig( string config, string field ) attributes( srcinfo srcinfo ) } """, ext_types={ "name": validators.instance_of(IdentifierOrHole, convert=True), - "builtin": BuiltIn, "op": validators.instance_of(Operator, convert=True), "srcinfo": SrcInfo, }, @@ -673,7 +672,7 @@ def map_e(self, e): rhs=new_rhs or e.rhs, type=new_type or e.type, ) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): new_type = self.map_t(e.type) new_args = self.map_exprs(e.args) if any((new_type, new_args is not None)): @@ -810,7 +809,7 @@ def do_e(self, e): elif etyp is LoopIR.BinOp: self.do_e(e.lhs) self.do_e(e.rhs) - elif etyp is LoopIR.BuiltIn: + elif etyp is LoopIR.Extern: for a in e.args: self.do_e(a) elif etyp is LoopIR.USub: @@ -914,7 +913,7 @@ def match_e(self, e1, e2): and self.match_e(e1.lhs, e2.lhs) and self.match_e(e1.rhs, e2.rhs) ) - elif isinstance(e1, LoopIR.BuiltIn): + elif isinstance(e1, LoopIR.Extern): # TODO: check f equality return e1.f is e2.f and all( self.match_e(a1, a2) for a1, a2 in zip(e1.args, e2.args) diff --git a/src/exo/LoopIR_compiler.py b/src/exo/LoopIR_compiler.py index b590c19bf..0520b7387 100644 --- a/src/exo/LoopIR_compiler.py +++ b/src/exo/LoopIR_compiler.py @@ -196,18 +196,18 @@ def do_t(self, t): pass -class LoopIR_FindBuiltIns(LoopIR_Do): +class LoopIR_FindExterns(LoopIR_Do): def __init__(self, proc): - self._builtins = set() + self._externs = set() super().__init__(proc) def result(self): - return self._builtins + return self._externs # to improve efficiency def do_e(self, e): - if isinstance(e, LoopIR.BuiltIn): - self._builtins.add(e.f) + if isinstance(e, LoopIR.Extern): + self._externs.add((e.f, e.type.basetype().ctype())) else: super().do_e(e) @@ -247,12 +247,12 @@ def find_all_mems(proc_list): return [m for m in mems] -def find_all_builtins(proc_list): - builtins = set() +def find_all_externs(proc_list): + externs = set() for p in proc_list: - builtins.update(LoopIR_FindBuiltIns(p).result()) + externs.update(LoopIR_FindExterns(p).result()) - return [b for b in builtins] + return externs def find_all_configs(proc_list): @@ -376,10 +376,10 @@ def from_lines(x): # Body contents memory_code = _compile_memories(find_all_mems(proc_list)) - builtin_code = _compile_builtins(find_all_builtins(proc_list)) private_fwd_decls = [] proc_bodies = [] instrs_global = [] + analyzed_proc_list = [] needed_helpers = set() @@ -424,6 +424,8 @@ def from_lines(x): proc_bodies.append(b) + analyzed_proc_list.append(p) + # Structs are just blobs of code... still sort them for output stability struct_defns = [x.definition for x in sorted(struct_defns, key=lambda x: x.name)] @@ -454,12 +456,14 @@ def from_lines(x): {from_lines(public_fwd_decls)} """ + extern_code = _compile_externs(find_all_externs(analyzed_proc_list)) + helper_code = [_static_helpers[v] for v in needed_helpers] body_contents = [ helper_code, instrs_global, memory_code, - builtin_code, + extern_code, private_fwd_decls, proc_bodies, ] @@ -470,12 +474,12 @@ def from_lines(x): return header_contents, body_contents -def _compile_builtins(builtins): - builtin_code = [] - for b in sorted(builtins, key=lambda x: x.name()): - if glb := b.globl(): - builtin_code.append(glb) - return builtin_code +def _compile_externs(externs): + extern_code = [] + for f, t in sorted(externs, key=lambda x: x[0].name() + x[1]): + if glb := f.globl(t): + extern_code.append(glb) + return extern_code def _compile_memories(mems): @@ -971,7 +975,7 @@ def comp_fnarg(self, e, fn, i, *, prec=0): x for x, _ in get_writes_of_stmts(fn.body) ) else: - raise NotImplementedError("Passing windows to built-ins") + raise NotImplementedError("Passing windows to externs") win_struct = self.get_window_type(e.type, is_const) data, strides = self.window_struct_fields(e) return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" @@ -1044,9 +1048,9 @@ def comp_e(self, e, prec=0): elif isinstance(e, LoopIR.USub): return f'-{self.comp_e(e.arg, op_prec["~"])}' - elif isinstance(e, LoopIR.BuiltIn): - args = [self.comp_fnarg(a, e, i) for i, a in enumerate(e.args)] - return e.f.compile(args) + elif isinstance(e, LoopIR.Extern): + args = [self.comp_e(a) for a in e.args] + return e.f.compile(args, e.type.basetype().ctype()) elif isinstance(e, LoopIR.StrideExpr): basetyp = self.envtyp[e.name] diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index 79eb13e5f..b47bbfc2a 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -271,7 +271,7 @@ def pacc(w): return f"{self.get_name(e.name)}[{', '.join([pacc(w) for w in e.idx])}]" elif isinstance(e, UAST.StrideExpr): return f"stride({self.get_name(e.name)}, {e.dim})" - elif isinstance(e, UAST.BuiltIn): + elif isinstance(e, UAST.Extern): pname = e.f.name() or "_anon_" args = [self.pexpr(a) for a in e.args] return f"{pname}({','.join(args)})" @@ -507,7 +507,7 @@ def _print_expr(e, env: PrintEnv, prec: int = 0) -> str: elif isinstance(e, LoopIR.StrideExpr): return f"stride({env.get_name(e.name)}, {e.dim})" - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): pname = e.f.name() or "_anon_" args = [_print_expr(a, env) for a in e.args] return f"{pname}({', '.join(args)})" diff --git a/src/exo/LoopIR_unification.py b/src/exo/LoopIR_unification.py index 9a6e49621..ce7d65622 100644 --- a/src/exo/LoopIR_unification.py +++ b/src/exo/LoopIR_unification.py @@ -797,7 +797,7 @@ def all_bound_e(self, be): return self.all_bound_e(be.arg) elif isinstance(be, LoopIR.BinOp): return self.all_bound_e(be.lhs) and self.all_bound_e(be.rhs) - elif isinstance(be, LoopIR.BuiltIn): + elif isinstance(be, LoopIR.Extern): return all(self.all_bound_e(a) for a in be.args) else: assert False, "unsupported case" @@ -819,7 +819,7 @@ def is_exact_e(self, e0, e1): and self.is_exact_e(e0.lhs, e1.lhs) and self.is_exact_e(e0.rhs, e1.rhs) ) - elif isinstance(e0, LoopIR.BuiltIn): + elif isinstance(e0, LoopIR.Extern): return e0.f == e1.f and all( self.is_exact_e(a0, a1) for a0, a1 in zip(e0.args, e1.args) ) @@ -1165,7 +1165,7 @@ def unify_e(self, pe, be): ) self.unify_e(pe.lhs, be.lhs) self.unify_e(pe.rhs, be.rhs) - elif isinstance(pe, LoopIR.BuiltIn): + elif isinstance(pe, LoopIR.Extern): if pe.f != be.f: raise UnificationError( f"cannot unify builtin '{pe.f.name()}' (@{pe.srcinfo}) " diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 6eba4861b..6aa3a1ac1 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -11,6 +11,7 @@ from .parse_fragment import ParseFragmentError from .configs import Config from .memory import Memory, DRAM +from .extern import Extern from . import stdlib @@ -25,6 +26,7 @@ "config", "Config", "Memory", + "Extern", "DRAM", "SchedulingError", "ParseFragmentError", diff --git a/src/exo/boundscheck.py b/src/exo/boundscheck.py index da6ea372c..20eb8dc59 100644 --- a/src/exo/boundscheck.py +++ b/src/exo/boundscheck.py @@ -1117,7 +1117,7 @@ def eff_e(self, e, type_env): return eff_null(e.srcinfo) elif isinstance(e, LoopIR.WindowExpr): return eff_null(e.srcinfo) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): return eff_null(e.srcinfo) elif isinstance(e, LoopIR.StrideExpr): return eff_null(e.srcinfo) diff --git a/src/exo/builtins.py b/src/exo/builtins.py deleted file mode 100644 index 913620171..000000000 --- a/src/exo/builtins.py +++ /dev/null @@ -1,142 +0,0 @@ -# --------------------------------------------------------------------------- # -# --------------------------------------------------------------------------- # -# BuiltIn superclass - - -class BuiltIn_Typecheck_Error(Exception): - def __init__(self, msg): - self._builtin_err_msg = str(msg) - - def __str__(self): - return self._builtin_err_msg - - -_BErr = BuiltIn_Typecheck_Error - - -class BuiltIn: - def __init__(self, name): - self._name = name - - def name(self): - return self._name - - def globl(self): - raise NotImplementedError() - - def typecheck(self, args): - raise NotImplementedError() - - def compile(self, args): - raise NotImplementedError() - - -class _Sin(BuiltIn): - def __init__(self): - super().__init__("sin") - - def typecheck(self, args): - if len(args) != 1: - raise _BErr(f"expected 1 argument, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - return "#include " - - def compile(self, args): - return f"sin((double)*{args[0]})" - - -sin = _Sin() - - -class _Relu(BuiltIn): - def __init__(self): - super().__init__("relu") - - def typecheck(self, args): - if len(args) != 1: - raise _BErr(f"expected 1 argument, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - s = ( - "double _relu_(double x) {\n" - " if (x > 0.0) return x;\n" - " else return 0.0;\n" - "}\n" - ) - return s - - def compile(self, args): - return f"_relu_((double)*{args[0]})" - - -relu = _Relu() - - -class _Select(BuiltIn): - def __init__(self): - super().__init__("select") - - def typecheck(self, args): - if len(args) != 4: - raise _BErr(f"expected 4 arguments, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[1].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 2 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[2].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 3 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[3].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 4 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - s = ( - "double _select_(double x, double v, double y, double z) {\n" - " if (x < v) return y;\n" - " else return z;\n" - "}\n" - ) - return s - - def compile(self, args): - return f"_select_((double)*{args[0]}, (double)*{args[1]}, (double)*{args[2]}, (double)*{args[3]})" - - -select = _Select() diff --git a/src/exo/extern.py b/src/exo/extern.py new file mode 100644 index 000000000..b1ae39d6d --- /dev/null +++ b/src/exo/extern.py @@ -0,0 +1,36 @@ +import math + +# --------------------------------------------------------------------------- # +# --------------------------------------------------------------------------- # +# Extern superclass + + +class Extern_Typecheck_Error(Exception): + def __init__(self, msg): + self._builtin_err_msg = str(msg) + + def __str__(self): + return self._builtin_err_msg + + +_EErr = Extern_Typecheck_Error + + +class Extern: + def __init__(self, name): + self._name = name + + def name(self): + return self._name + + def globl(self, prim_type): + raise NotImplementedError() + + def typecheck(self, args): + raise NotImplementedError() + + def interpret(self, args): + raise NotImplementedError() + + def compile(self, args, prim_type): + raise NotImplementedError() diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py new file mode 100644 index 000000000..eb95ddf32 --- /dev/null +++ b/src/exo/libs/externs.py @@ -0,0 +1,234 @@ +from ..extern import Extern, _EErr + + +class _Sin(Extern): + def __init__(self): + super().__init__("sin") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.sin(args[0]) + + def compile(self, args, prim_type): + return f"sin(({prim_type}){args[0]})" + + +sin = _Sin() + + +class _Relu(Extern): + def __init__(self): + super().__init__("relu") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _relu_{prim_type}({prim_type} x) " + "{\n" + " if (x > 0.0) return x;\n" + " else return 0.0;\n" + "}\n" + ) + return s + + # def interpret(self, args): + # if args[0] > 0: + # return args[0] + # else: + # return 0 + + def compile(self, args, prim_type): + return f"_relu_{prim_type}(({prim_type}){args[0]})" + + +relu = _Relu() + + +class _Select(Extern): + def __init__(self): + super().__init__("select") + + def typecheck(self, args): + if len(args) != 4: + raise _EErr(f"expected 4 arguments, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _select_{prim_type}({prim_type} x,{prim_type} v,{prim_type} y,{prim_type} z)" + + " {\n" + " if (x < v) return y;\n" + " else return z;\n" + "}\n" + ) + return s + + # def interpret(self, args): + # x = args[0] + # v = args[1] + # y = args[2] + # z = args[3] + # if x < v: + # return y + # else: + # return z + + def compile(self, args, prim_type): + return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" + + +select = _Select() + + +class _Expf(Extern): + def __init__(self): + super().__init__("expf") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.expf(args[0]) + + def compile(self, args, prim_type): + return f"expf(({prim_type})({args[0]}))" + + +expf = _Expf() + + +class _FmaxF(Extern): + def __init__(self): + super().__init__("fmaxf") + + def typecheck(self, args): + if len(args) != 2: + raise _EErr(f"expected 2 argument, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.fmaxf(args[0], args[1]) + + def compile(self, args, prim_type): + return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))" + + +fmaxf = _FmaxF() + + +class _Sigmoid(Extern): + def __init__(self): + super().__init__("sigmoid") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return f""" +#include +{prim_type} sigmoid({prim_type} x) {{ + return 1 / (1 + exp(-x)); +}} +""" + + # def interpret(self, args): + # return math.sigmoid(args[0]) + + def compile(self, args, prim_type): + return f"sigmoid(({prim_type})({args[0]}))" + + +sigmoid = _Sigmoid() + + +class _Sqrt(Extern): + def __init__(self): + super().__init__("sqrt") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.sqrt(args[0]) + + def compile(self, args, prim_type): + return f"sqrt(({prim_type})({args[0]}))" + + +sqrt = _Sqrt() diff --git a/src/exo/mem_analysis.py b/src/exo/mem_analysis.py index 0743503d3..0835a199b 100644 --- a/src/exo/mem_analysis.py +++ b/src/exo/mem_analysis.py @@ -69,7 +69,7 @@ def used_e(e): elif isinstance(e, LoopIR.BinOp): res += used_e(e.lhs) res += used_e(e.rhs) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): for ei in e.args: res += used_e(ei) elif isinstance(e, (LoopIR.WindowExpr, LoopIR.StrideExpr)): diff --git a/src/exo/new_eff.py b/src/exo/new_eff.py index ec03b6b3c..f7fa0e1d5 100644 --- a/src/exo/new_eff.py +++ b/src/exo/new_eff.py @@ -1133,7 +1133,7 @@ def expr_effs(e): return expr_effs(e.arg) elif isinstance(e, LoopIR.BinOp): return expr_effs(e.lhs) + expr_effs(e.rhs) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): return list_expr_effs(e.args) elif isinstance(e, LoopIR.WindowExpr): diff --git a/src/exo/parse_fragment.py b/src/exo/parse_fragment.py index 35ac914da..908c16166 100644 --- a/src/exo/parse_fragment.py +++ b/src/exo/parse_fragment.py @@ -20,13 +20,23 @@ class ParseFragmentError(Exception): def parse_fragment( - proc, fragment, ctx_stmt, call_depth=0, configs=[], scope="before", expr_holes=None + proc, fragment, ctx_stmt, call_depth=1, configs=[], scope="before", expr_holes=None ): + stack_frames: [inspect.FrameInfo] = inspect.stack() # get source location where this is getting called from - caller = inspect.getframeinfo(inspect.stack()[call_depth + 1][0]) + caller = inspect.getframeinfo(stack_frames[call_depth][0]) + func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) + func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) # parse the pattern we're going to use to match - p_ast = pyparser.pattern(fragment, filename=caller.filename, lineno=caller.lineno) + p_ast = pyparser.pattern( + fragment, + filename=caller.filename, + lineno=caller.lineno, + srclocals=func_locals, + srcglobals=func_globals, + ) + if isinstance(p_ast, PAST.expr): return ParseFragment( p_ast, proc, ctx_stmt, configs, scope, expr_holes @@ -47,7 +57,7 @@ def parse_fragment( PAST.USub: LoopIR.USub, PAST.BinOp: LoopIR.BinOp, PAST.StrideExpr: LoopIR.StrideExpr, - PAST.BuiltIn: LoopIR.BuiltIn, + PAST.Extern: LoopIR.Extern, PAST.ReadConfig: LoopIR.ReadConfig, } @@ -234,14 +244,14 @@ def parse_e(self, pat): typ = {float: T.R, bool: T.bool, int: T.int}.get(type(pat.val)) assert typ is not None, "bad type!" return LoopIR.Const(pat.val, typ, self.srcinfo) - elif isinstance(pat, PAST.BuiltIn): + elif isinstance(pat, PAST.Extern): args = [self.parse_e(a) for a in pat.args] try: typ = pat.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: raise ParseFragmentError(err) - return LoopIR.BuiltIn(pat.f, args, typ, self.srcinfo) + return LoopIR.Extern(pat.f, args, typ, self.srcinfo) elif isinstance(pat, PAST.ReadConfig): if pat.config not in self.configs: raise ParseFragmentError( @@ -304,12 +314,12 @@ def check_sym_consistency(sym): rhs=self.rebuild_ast(loopIR_expr.rhs), srcinfo=self.srcinfo, ) - elif isinstance(loopIR_expr, LoopIR.BuiltIn): + elif isinstance(loopIR_expr, LoopIR.Extern): args = [self.rebuild_ast(a) for a in loopIR_expr.args] try: typ = loopIR_expr.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: raise ParseFragmentError(err) if typ != loopIR_expr.typ: diff --git a/src/exo/pattern_match.py b/src/exo/pattern_match.py index 7dca84c83..f3e2131e5 100644 --- a/src/exo/pattern_match.py +++ b/src/exo/pattern_match.py @@ -3,6 +3,7 @@ import inspect import re from typing import Optional, Iterable +from collections import ChainMap import exo.pyparser as pyparser from exo.LoopIR import LoopIR, PAST @@ -59,7 +60,7 @@ def get_match_no(pattern_str: str) -> Optional[int]: def match_pattern( context: Cursor, pattern_str: str, - call_depth=0, + call_depth=1, default_match_no=None, use_sym_id=False, ): @@ -78,12 +79,19 @@ def match_pattern( else: match_no = default_match_no # None means match-all + stack_frames: [inspect.FrameInfo] = inspect.stack() # get source location where this is getting called from - caller = inspect.getframeinfo(inspect.stack()[call_depth + 1][0]) + caller = inspect.getframeinfo(stack_frames[call_depth][0]) + func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) + func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) # parse the pattern we're going to use to match p_ast = pyparser.pattern( - pattern_str, filename=caller.filename, lineno=caller.lineno + pattern_str, + filename=caller.filename, + lineno=caller.lineno, + srclocals=func_locals, + srcglobals=func_globals, ) # do the pattern match, to find the nodes in ast @@ -109,7 +117,7 @@ def match_pattern( PAST.Const: [LoopIR.Const], PAST.USub: [LoopIR.USub], PAST.BinOp: [LoopIR.BinOp], - PAST.BuiltIn: [LoopIR.BuiltIn], + PAST.Extern: [LoopIR.Extern], PAST.ReadConfig: [LoopIR.ReadConfig], PAST.E_Hole: None, } @@ -324,8 +332,8 @@ def match_e(self, pat, e): ) elif isinstance(e, LoopIR.USub): return self.match_e(pat.arg, e.arg) - elif isinstance(e, LoopIR.BuiltIn): - return pat.f is e.f and all( + elif isinstance(e, LoopIR.Extern): + return self.match_name(pat.f, e.f.name()) and all( self.match_e(pa, sa) for pa, sa in zip(pat.args, e.args) ) elif isinstance(e, LoopIR.ReadConfig): @@ -383,7 +391,7 @@ def _children(cur) -> Iterable[Node]: yield from _children_from_attrs(cur, n, "arg") elif isinstance(n, LoopIR.BinOp): yield from _children_from_attrs(cur, n, "lhs", "rhs") - elif isinstance(n, LoopIR.BuiltIn): + elif isinstance(n, LoopIR.Extern): yield from _children_from_attrs(cur, n, "args") else: assert False, f"case {type(n)} unsupported" diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index c91e4a460..5f598817b 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -2,6 +2,7 @@ from exo import proc, instr, DRAM, config from exo.libs.memories import GEMM_SCRATCH, GEMM_ACCUM +from exo.libs.externs import select, relu from exo.stdlib.scheduling import * @@ -800,8 +801,10 @@ def clamp(src: f32, dst: i8): h: f32 l = -128.0 h = 127.0 - dst = select(h, src, h, src) - dst = select(src, l, l, dst) + tmp: f32 + tmp = select(h, src, h, src) + tmp = select(src, l, l, tmp) + dst = tmp def new_config_st(): diff --git a/src/exo/platforms/x86.py b/src/exo/platforms/x86.py index 9a7b6bc70..f049e1c49 100644 --- a/src/exo/platforms/x86.py +++ b/src/exo/platforms/x86.py @@ -2,6 +2,7 @@ from .. import instr, DRAM from ..libs.memories import AVX2, AVX512 +from ..libs.externs import relu, select # --------------------------------------------------------------------------- # # Prefetching diff --git a/src/exo/prec_analysis.py b/src/exo/prec_analysis.py index 224173ebc..19c118435 100644 --- a/src/exo/prec_analysis.py +++ b/src/exo/prec_analysis.py @@ -199,6 +199,28 @@ def map_e(self, e): typ = lhs.type return LoopIR.BinOp(e.op, lhs, rhs, typ, e.srcinfo) + elif isinstance(e, LoopIR.Extern): + typ = T.R + for a in e.args: + if a.type != T.R: + typ = a.type + + new_args = [] + for a in e.args: + a = self.apply_e(a) + if typ != a.type: + # coerce if const and real + if a.type == T.R: + a = self.coerce_e(a, typ) + else: + self.err( + e, + f"all extern arguments must have a same type, got {typ} and {a.type}", + ) + new_args.append(a) + + return LoopIR.Extern(e.f, new_args, typ, e.srcinfo) + return super().map_e(e) # this routine allows for us to retro-actively @@ -224,6 +246,18 @@ def coerce_e(self, e, btyp): assert rhs.type == btyp return LoopIR.BinOp(e.op, lhs, rhs, btyp, e.srcinfo) + elif isinstance(e, LoopIR.Extern): + assert e.type == T.R + # coerce if T.R + args = [] + for a in e.args: + if a.type == T.R: + args.append(self.coerce_e(a, btyp)) + else: + assert a.type == btyp + args.append(a) + return LoopIR.Extern(e.f, args, btyp, e.srcinfo) + else: assert False, f"Should not be coercing a {type(e)} Node" diff --git a/src/exo/pyparser.py b/src/exo/pyparser.py index f997543c6..9edbcbb48 100644 --- a/src/exo/pyparser.py +++ b/src/exo/pyparser.py @@ -10,10 +10,10 @@ from asdl_adt.validators import ValidationError from .API_types import ProcedureBase -from .builtins import * from .configs import Config from .LoopIR import UAST, PAST, front_ops from .prelude import * +from .extern import Extern # --------------------------------------------------------------------------- # @@ -90,7 +90,7 @@ def get_src_locals(*, depth): # Pattern-Parser top-level, invoked on strings rather than as a decorator -def pattern(s, filename=None, lineno=None): +def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): assert isinstance(s, str) src = s @@ -119,7 +119,13 @@ def getsrcinfo(node): ), ) - parser = Parser(module.body, getsrcinfo, is_fragment=True) + parser = Parser( + module.body, + getsrcinfo, + is_fragment=True, + func_globals=srcglobals, + srclocals=srclocals, + ) return parser.result() @@ -166,15 +172,18 @@ def __init__( self.is_fragment = is_fragment self.push() + special_cases = ["stride"] + for key, val in self.globals.items(): + if isinstance(val, Extern): + special_cases.append(key) + for key, val in self.locals.items(): + if isinstance(val, Extern): + special_cases.append(key) - builtins = {"sin": sin, "relu": relu, "select": select} if is_fragment: self.AST = PAST else: self.AST = UAST - # add builtins - for key, val in builtins.items(): - self.locals[key] = val if as_func: self._cached_result = self.parse_fdef(module_ast, instr=instr) @@ -184,7 +193,6 @@ def __init__( is_expr = False if len(module_ast) == 1: s = module_ast[0] - special_cases = list(builtins.keys()) + ["stride"] if isinstance(s, pyast.Expr) and ( not isinstance(s.value, pyast.Call) or s.value.func.id in special_cases @@ -1064,21 +1072,30 @@ def parse_expr(self, e): # handle built-in functions else: - f = self.eval_expr(e.func) fname = e.func.id + if self.is_fragment: + if len(e.keywords) > 0: + self.err( + f, "cannot call a extern function " "with keyword arguments" + ) + args = [self.parse_expr(a) for a in e.args] - if not isinstance(f, BuiltIn): - self.err( - e.func, f"expected called object " "to be a builtin function" - ) + return self.AST.Extern(fname, args, self.getsrcinfo(e)) + else: + f = self.eval_expr(e.func) - if len(e.keywords) > 0: - self.err( - f, "cannot call a builtin function " "with keyword arguments" - ) - args = [self.parse_expr(a) for a in e.args] + if not isinstance(f, Extern): + self.err( + e.func, f"expected called object " "to be a extern function" + ) + + if len(e.keywords) > 0: + self.err( + f, "cannot call a extern function " "with keyword arguments" + ) + args = [self.parse_expr(a) for a in e.args] - return self.AST.BuiltIn(f, args, self.getsrcinfo(e)) + return self.AST.Extern(f, args, self.getsrcinfo(e)) else: self.err(e, "unsupported form of expression") diff --git a/src/exo/stdlib/inspection.py b/src/exo/stdlib/inspection.py index 4bbedf04a..cfb98e80a 100644 --- a/src/exo/stdlib/inspection.py +++ b/src/exo/stdlib/inspection.py @@ -24,7 +24,7 @@ def expr_children(expr): elif isinstance(expr, BinaryOpCursor): yield expr.lhs() yield expr.rhs() - elif isinstance(expr, BuiltInFunctionCursor): + elif isinstance(expr, ExternFunctionCursor): yield from expr.args() elif isinstance(expr, (LiteralCursor, ReadConfigCursor)): pass @@ -381,7 +381,7 @@ def is_mod(proc, expr): def is_builtin(proc, expr, name): expr = proc.forward(expr) - return isinstance(expr, BuiltInFunctionCursor) and expr.name() == name + return isinstance(expr, ExternFunctionCursor) and expr.name() == name def is_select(proc, expr): @@ -563,7 +563,7 @@ def expr_list_to_string(expr_list, subst): lhs_str = expr_to_string(expr_cursor.lhs(), subst) rhs_str = expr_to_string(expr_cursor.rhs(), subst) return f"({lhs_str}{binop_str}{rhs_str})" - elif isinstance(expr_cursor, BuiltInFunctionCursor): + elif isinstance(expr_cursor, ExternFunctionCursor): name = expr_cursor.name() args_str = expr_list_to_string(expr_cursor.args(), subst) return f"({name}({args_str[1:-1]}))" diff --git a/src/exo/typecheck.py b/src/exo/typecheck.py index a9581f9b2..b44c39a16 100644 --- a/src/exo/typecheck.py +++ b/src/exo/typecheck.py @@ -6,7 +6,7 @@ get_writeconfigs, get_loop_iters, ) -from .builtins import BuiltIn_Typecheck_Error +from .extern import Extern_Typecheck_Error from .memory import * @@ -555,17 +555,17 @@ def check_e(self, e, is_index=False): return LoopIR.BinOp(e.op, lhs, rhs, typ, e.srcinfo) - elif isinstance(e, UAST.BuiltIn): + elif isinstance(e, UAST.Extern): args = [self.check_e(a) for a in e.args] try: typ = e.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: typ = T.err self.err(e, str(err)) - return LoopIR.BuiltIn(e.f, args, typ, e.srcinfo) + return LoopIR.Extern(e.f, args, typ, e.srcinfo) elif isinstance(e, UAST.StrideExpr): idx, typ = self.check_access(e, e.name, [], lvalue=False) diff --git a/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt b/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt index cd100f9dc..b55d9f095 100644 --- a/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt +++ b/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt @@ -99,11 +99,6 @@ void matmul_on_gemmini( c_code_str_Context *ctxt, int_fast32_t N, int_fast32_t M #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - /* relying on the following instruction..." config_ld_i8_id1(src_stride) diff --git a/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt b/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt index c892fcf29..f43b34b48 100644 --- a/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt +++ b/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt @@ -99,11 +99,6 @@ void matmul_on_cpu( c_code_str_Context *ctxt, int_fast32_t N, int_fast32_t M, co #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - /* relying on the following instruction..." config_ld_i8_id1(src_stride) diff --git a/tests/golden/test_apps/test_gemmini_conv.txt b/tests/golden/test_apps/test_gemmini_conv.txt index 0f4ceb63f..c50409bde 100644 --- a/tests/golden/test_apps/test_gemmini_conv.txt +++ b/tests/golden/test_apps/test_gemmini_conv.txt @@ -161,12 +161,12 @@ void conv_3_cpu( test_case_Context *ctxt, int8_t* output, const int32_t* bias, c #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { +int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; } -double _select_(double x, double v, double y, double z) { +float _select_float(float x,float v,float y,float z) { if (x < v) return y; else return z; } @@ -191,8 +191,10 @@ float l; float h; l = -128.0f; h = 127.0f; -*dst = (int8_t)(_select_((double)*&h, (double)*src, (double)*&h, (double)*src)); -*dst = _select_((double)*src, (double)*&l, (double)*&l, (double)*dst); +float tmp; +tmp = _select_float((float)h, (float)*src, (float)h, (float)*src); +tmp = _select_float((float)*src, (float)l, (float)l, (float)tmp); +*dst = (int8_t)(tmp); } @@ -1336,7 +1338,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 100352 + orow * 3584 + ocol * 128 + och] = tmp_res2; } @@ -1570,7 +1572,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 50176 + orow * 3584 + ocol * 256 + och] = tmp_res2; } @@ -1617,7 +1619,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 200704 + orow * 3584 + ocol * 64 + och] = tmp_res2; } diff --git a/tests/golden/test_apps/test_gemmini_matmul.txt b/tests/golden/test_apps/test_gemmini_matmul.txt index 205df4e61..89a5be56e 100644 --- a/tests/golden/test_apps/test_gemmini_matmul.txt +++ b/tests/golden/test_apps/test_gemmini_matmul.txt @@ -213,12 +213,12 @@ void matmul_6( test_case_Context *ctxt, const float* scale, bool act, const int8 #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { +int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; } -double _select_(double x, double v, double y, double z) { +float _select_float(float x,float v,float y,float z) { if (x < v) return y; else return z; } @@ -243,8 +243,10 @@ float l; float h; l = -128.0f; h = 127.0f; -*dst = (int8_t)(_select_((double)*&h, (double)*src, (double)*&h, (double)*src)); -*dst = _select_((double)*src, (double)*&l, (double)*&l, (double)*dst); +float tmp; +tmp = _select_float((float)h, (float)*src, (float)h, (float)*src); +tmp = _select_float((float)*src, (float)l, (float)l, (float)tmp); +*dst = (int8_t)(tmp); } @@ -307,7 +309,7 @@ for (int_fast32_t i = 0; i < 3136; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 512 + j] = tmp_res2; } @@ -344,7 +346,7 @@ for (int_fast32_t i = 0; i < 3136; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 128 + j] = tmp_res2; } @@ -381,7 +383,7 @@ for (int_fast32_t i = 0; i < 784; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 1024 + j] = tmp_res2; } @@ -418,7 +420,7 @@ for (int_fast32_t i = 0; i < 12544; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 256 + j] = tmp_res2; } @@ -455,7 +457,7 @@ for (int_fast32_t i = 0; i < 512; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 512 + j] = tmp_res2; } @@ -492,7 +494,7 @@ for (int_fast32_t i = 0; i < 12544; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 64 + j] = tmp_res2; } diff --git a/tests/golden/test_apps/test_x86_conv.txt b/tests/golden/test_apps/test_x86_conv.txt index a2b815842..7070eca2d 100644 --- a/tests/golden/test_apps/test_x86_conv.txt +++ b/tests/golden/test_apps/test_x86_conv.txt @@ -65,11 +65,6 @@ void conv_specialized( void *ctxt, const float* inp, float* output, const float* #include #include -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - // conv_specialized( // inp : f32[5, 82, 102, 128] @DRAM, // output : f32[5, 80, 100, 128] @DRAM, diff --git a/tests/golden/test_externs/test_expf.txt b/tests/golden/test_externs/test_expf.txt new file mode 100644 index 000000000..ebbab2553 --- /dev/null +++ b/tests/golden/test_externs/test_expf.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, int8_t* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = expf((int8_t)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, int8_t* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_extern_find.txt b/tests/golden/test_externs/test_extern_find.txt new file mode 100644 index 000000000..bc820ef0e --- /dev/null +++ b/tests/golden/test_externs/test_extern_find.txt @@ -0,0 +1,2 @@ +def foo(a: f32 @ DRAM): + a = sin(a) # <-- NODE \ No newline at end of file diff --git a/tests/golden/test_externs/test_fmaxf.txt b/tests/golden/test_externs/test_fmaxf.txt new file mode 100644 index 000000000..af16a7798 --- /dev/null +++ b/tests/golden/test_externs/test_fmaxf.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = fmaxf((float)(x[i]), (float)(y[i] * 2.0f)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu.txt b/tests/golden/test_externs/test_relu.txt new file mode 100644 index 000000000..f2fd00d91 --- /dev/null +++ b/tests/golden/test_externs/test_relu.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_float((float)3.0f); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu2.txt b/tests/golden/test_externs/test_relu2.txt new file mode 100644 index 000000000..8d5174c56 --- /dev/null +++ b/tests/golden/test_externs/test_relu2.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_float((float)x[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu3.txt b/tests/golden/test_externs/test_relu3.txt new file mode 100644 index 000000000..d1b294fc3 --- /dev/null +++ b/tests/golden/test_externs/test_relu3.txt @@ -0,0 +1,67 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM, +// z : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, const float* y, float* z ) { +for (int_fast32_t i = 0; i < 16; i++) { + z[i] = _relu_float((float)x[i] + y[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM, +// z : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, const float* y, float* z ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu4.txt b/tests/golden/test_externs/test_relu4.txt new file mode 100644 index 000000000..e1d141c51 --- /dev/null +++ b/tests/golden/test_externs/test_relu4.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +int8_t _relu_int8_t(int8_t x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_int8_t((int8_t)((int8_t) 3.0)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_select.txt b/tests/golden/test_externs/test_select.txt new file mode 100644 index 000000000..fa71ccbad --- /dev/null +++ b/tests/golden/test_externs/test_select.txt @@ -0,0 +1,67 @@ +#include "test.h" + +#include +#include + +int8_t _select_int8_t(int8_t x,int8_t v,int8_t y,int8_t z) { + if (x < v) return y; + else return z; +} + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM, +// z : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, const int8_t* y, int8_t* z ) { +for (int_fast32_t i = 0; i < 16; i++) { + z[i] = _select_int8_t((int8_t)x[i] * ((int8_t) 2), (int8_t)y[i], (int8_t)z[i] + y[i], (int8_t)-x[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM, +// z : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, const int8_t* y, int8_t* z ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sigmoid.txt b/tests/golden/test_externs/test_sigmoid.txt new file mode 100644 index 000000000..bc202a82b --- /dev/null +++ b/tests/golden/test_externs/test_sigmoid.txt @@ -0,0 +1,66 @@ +#include "test.h" + +#include +#include + + +#include +float sigmoid(float x) { + return 1 / (1 + exp(-x)); +} + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = sigmoid((float)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sin.txt b/tests/golden/test_externs/test_sin.txt new file mode 100644 index 000000000..3c6784c39 --- /dev/null +++ b/tests/golden/test_externs/test_sin.txt @@ -0,0 +1,59 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = sin((int8_t)x[i] * ((int8_t) 2)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sqrt.txt b/tests/golden/test_externs/test_sqrt.txt new file mode 100644 index 000000000..d37ce59b5 --- /dev/null +++ b/tests/golden/test_externs/test_sqrt.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = sqrt((float)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/test_codegen.py b/tests/test_codegen.py index bdc95a3f9..3fe2ab678 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -8,6 +8,7 @@ from exo import proc, instr, Procedure, DRAM, compile_procs_to_strings from exo.libs.memories import MDRAM, MemGenError, StaticMemory, DRAM_STACK +from exo.libs.externs import * from exo.stdlib.scheduling import * mock_registers = 0 diff --git a/tests/test_config.py b/tests/test_config.py index 88447d88c..c30d7fbb9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,7 @@ from exo import proc, DRAM, config, instr from exo.libs.memories import GEMM_SCRATCH +from exo.libs.externs import * from exo.stdlib.scheduling import * diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 672ca0f52..97dee83b1 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -4,6 +4,7 @@ from exo import proc, ExoType from exo.libs.memories import * +from exo.libs.externs import * from exo.API_cursors import * from exo.stdlib.inspection import * diff --git a/tests/test_externs.py b/tests/test_externs.py new file mode 100644 index 000000000..5b5d1033f --- /dev/null +++ b/tests/test_externs.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import pytest + +from exo import proc, DRAM, Procedure, config, compile_procs_to_strings +from exo.libs.externs import * +from exo.stdlib.scheduling import SchedulingError + + +def test_relu(golden, compiler): + @proc + def foo(x: f32[16]): + for i in seq(0, 16): + x[i] = relu(3.0) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu2(golden, compiler): + @proc + def foo(x: f32[16]): + for i in seq(0, 16): + x[i] = relu(x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu3(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16], z: f32[16]): + for i in seq(0, 16): + z[i] = relu(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu4(golden, compiler): + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(3.0) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu5(): + with pytest.raises(TypeError, match="expected 1 argument, got 2"): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(3.0, 2.0) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(i) + + +def test_sin(golden, compiler): + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(x[i] * 2) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sin2(golden, compiler): + with pytest.raises(TypeError, match="expected 1 argument, got 2"): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(x[i] * 2, 3) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(i) + + +def test_select(golden, compiler): + @proc + def foo(x: i8[16], y: i8[16], z: i8[16]): + for i in seq(0, 16): + z[i] = select(x[i] * 2, y[i], z[i] + y[i], -x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_expf(golden, compiler): + @proc + def foo(x: i8[16], y: i8[16]): + for i in seq(0, 16): + y[i] = expf(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_expf2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = expf(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = expf(True) + + +def test_fmaxf(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i], y[i] * 2) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_fmaxf2(): + with pytest.raises(TypeError, match="expected 2 argument, got 1"): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(i, x[i]) + + with pytest.raises( + TypeError, match="expected argument 2 to be a real scalar value," + ): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i], i) + + +def test_sigmoid(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = sigmoid(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sigmoid2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sigmoid(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sigmoid(True) + + +def test_sqrt(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = sqrt(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sqrt2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sqrt(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sqrt(True) + + +def test_select_error(): + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(x[i] * 2, y[i], z[i], -x[i]) + + with pytest.raises(TypeError, match="all extern arguments must have a same type"): + c_file, h_file = compile_procs_to_strings([foo], "test.h") + + +def test_type_error(): + with pytest.raises(TypeError, match="expected scalar type"): + + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(i * 2, y[i], z[i], -x[i]) + + with pytest.raises(TypeError, match="expected 4 arguments, got 3"): + + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(i * 2, y[i], z[i]) + + +def test_select_fine(): + @proc + def foo(x: i8[16], y: i8[16], z: i8[16]): + for i in seq(0, 16): + z[i] = select(0.0, y[i], z[i], -x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + + +def test_two(): + c = 2 + + @proc + def foo(a: f32): + a = a + c + + with pytest.raises(SchedulingError, match="find: failed to find matches"): + foo.find("a + c").parent() + + +def test_extern_find(golden): + @proc + def foo(a: f32): + a = sin(a) + + assert golden == str(foo.find("sin(a)").parent()) diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index 393bc4d74..b3b3aa442 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -5,6 +5,7 @@ from exo import proc, config from exo.libs.memories import GEMM_SCRATCH from exo.pyparser import ParseError +from exo.libs.externs import * # --- Typechecking tests ---