From d92e550c3d40335ff2356145501e70c273bc3fc7 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 5 Nov 2024 02:14:09 -0500 Subject: [PATCH] fixed loop codegen --- README.md | 56 +++++++++++++++++++++++++++++++++++++-- luisa_lang/codegen/cpp.py | 53 ++++++++++++++++++++++++------------ luisa_lang/hir.py | 16 ++++++----- luisa_lang/parse.py | 52 ++++++++++++++++++++++++++++-------- 4 files changed, 140 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 3c4e5ea..5d1e99a 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,12 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute ## Content - [Introduction](#introduction) - [Basics](#basic-syntax) + - [Difference from Python](#difference-from-python) - [Types](#types) + - [Value & Reference Semantics](#value--reference-semantics) - [Functions](#functions) - [User-defined Structs](#user-defined-structs) + - [Control Flow](#control-flow) - [Advanced Usage](#advanced-syntax) - [Generics](#generics) - [Metaprogramming](#metaprogramming) @@ -20,11 +23,16 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute import luisa_lang as lc ``` ## Basic Syntax +### Difference from Python +There are some notable differences between luisa_lang and Python: +- Variables have value semantics by default. Use `inout` to indicate that an argument that is passed by reference. +- Generic functions and structs are implemented via monomorphization (a.k.a instantiation) at compile time rather than via type erasure. +- Overloading subscript operator and attribute access is different from Python. Only `__getitem__` and `__getattr__` are needed, which returns a local reference. + ### Types ```python ``` - ### Functions Functions are defined using the `@lc.func` decorator. The function body can contain any valid LuisaCompute code. You can also include normal Python code that will be executed at DSL comile time using `lc.comptime()`. (See [Metaprogramming](#metaprogramming) for more details) @@ -37,13 +45,57 @@ def add(a: lc.float, b: lc.float) -> lc.float: ``` -LuisaCompute uses value semantics, which means that all types are passed by value. You can use `inout` to indicate that a variable can be modified in place. + +### Value & Reference Semantics +Variables have value semantics by default. This means that when you assign a variable to another, a copy is made. +```python +a = lc.float3(1.0, 2.0, 3.0) +b = a +a.x = 2.0 +lc.print(f'{a.x} {b.x}') # prints 2.0 1.0 +``` + +You can use `inout` to indicate that a variable is passed as a *local reference*. Assigning to an `inout` variable will update the original variable. ```python @luisa.func(a=inout, b=inout) def swap(a: int, b: int): a, b = b, a + +a = lc.float3(1.0, 2.0, 3.0) +b = lc.float3(4.0, 5.0, 6.0) +swap(a.x, b.x) +lc.print(f'{a.x} {b.x}') # prints 4.0 1.0 ``` +When overloading subscript operator or attribute access, you actually return a local reference to the object. + +#### Local References +Local references are like pointers in C++. However, they cannot escape the expression boundary. This means that you cannot store a local reference in a variable and use it later. While you can return a local reference from a function, it must be returned from a uniform path. That is you cannot return different local references based on a condition. + + +```python +@lc.struct +class InfiniteArray: + def __getitem__(self, index: int) -> int: + return self.data[index] # returns a local reference + + # this method will be ignored by the compiler. but you can still put it here for linting + def __setitem__(self, index: int, value: int): + pass + + # Not allowed, non-uniform return + def __getitem__(self, index: int) -> int: + if index == 0: + return self.data[0] + else: + return self.data[1] + +``` + + + + + ### User-defined Structs ```python @lc.struct diff --git a/luisa_lang/codegen/cpp.py b/luisa_lang/codegen/cpp.py index 1956fa7..0c3b0b7 100644 --- a/luisa_lang/codegen/cpp.py +++ b/luisa_lang/codegen/cpp.py @@ -133,6 +133,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str: case hir.Function(name=name, params=params, return_type=ret): assert ret name = mangle_name(name) + params = list(filter(lambda p: not isinstance( + p.type, (hir.FunctionType)), params)) return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}") case hir.BuiltinFunction(name=name): name = map_builtin_to_cpp_func(name) @@ -203,7 +205,8 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None: self.base = base self.name = base.mangling.mangle(func) self.func = func - params = ",".join(self.gen_var(p) for p in func.params) + params = ",".join(self.gen_var( + p) for p in func.params if not isinstance(p.type, hir.FunctionType)) assert func.return_type self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}' self.body = ScratchBuffer() @@ -250,6 +253,8 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str: f"unsupported value or reference: {value}") def gen_expr(self, expr: hir.Value) -> str: + if expr.type and isinstance(expr.type, hir.FunctionType): + return '' if expr in self.node_map: return self.node_map[expr] vid = self.new_vid() @@ -269,8 +274,10 @@ def impl() -> None: f"const auto v{vid} = {base}.{member.field};") case hir.Call() as call: op = self.gen_func(call.op) + args_s = ','.join(self.gen_value_or_ref( + arg) for arg in call.args if not isinstance(arg.type, hir.FunctionType)) self.body.writeln( - f"auto v{vid} ={op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)});") + f"auto v{vid} ={op}({args_s});") case hir.Constant() as constant: value = constant.value if isinstance(value, int): @@ -302,6 +309,7 @@ def impl() -> None: return f'v{vid}' def gen_node(self, node: hir.Node): + match node: case hir.Return() as ret: if ret.value: @@ -324,31 +332,42 @@ def gen_node(self, node: hir.Node): self.gen_bb(if_stmt.else_body) self.body.indent -= 1 self.gen_bb(if_stmt.merge) + case hir.Break(): + self.body.writeln("__loop_break = true; break;") + case hir.Continue(): + self.body.writeln("break;") case hir.Loop() as loop: - vid = self.new_vid() - self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{") + """ + while(true) { + bool loop_break = false; + prepare(); + if (!cond()) break; + do { + // break => { loop_break = true; break; } + // continue => { break; } + } while(false); + if (loop_break) break; + update(); + } + + """ + self.body.writeln("while(true) {") self.body.indent += 1 + self.body.writeln("bool __loop_break = false;") self.gen_bb(loop.prepare) if loop.cond: - self.body.writeln(f"return {self.gen_expr(loop.cond)};") - else: - self.body.writeln("return true;") - self.body.indent -= 1 - self.body.writeln("};") - self.body.writeln(f"auto loop{vid}_body = [&]() {{") + cond = self.gen_expr(loop.cond) + self.body.writeln(f"if (!{cond}) break;") + self.body.writeln("do {") self.body.indent += 1 self.gen_bb(loop.body) self.body.indent -= 1 - self.body.writeln("};") - self.body.writeln(f"auto loop{vid}_update = [&]() {{") - self.body.indent += 1 + self.body.writeln("} while(false);") + self.body.writeln("if (__loop_break) break;") if loop.update: self.gen_bb(loop.update) self.body.indent -= 1 - self.body.writeln("};") - self.body.writeln( - f"for(;loop{vid}_prepare();loop{vid}_update());") - self.gen_bb(loop.merge) + self.body.writeln("}") case hir.Alloca() as alloca: vid = self.new_vid() assert alloca.type diff --git a/luisa_lang/hir.py b/luisa_lang/hir.py index b2ffbd3..ee4e834 100644 --- a/luisa_lang/hir.py +++ b/luisa_lang/hir.py @@ -127,7 +127,7 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]: def is_concrete(self) -> bool: return True - + def __len__(self) -> int: return 1 @@ -341,7 +341,8 @@ def member(self, field: Any) -> Optional['Type']: def __len__(self) -> int: return self.count - + + class ArrayType(Type): element: Type count: Union[int, "SymbolicConstant"] @@ -868,6 +869,7 @@ def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) - super().__init__(type, span) self.args = args + class Call(Value): op: FunctionLike """After type inference, op should be a Value.""" @@ -988,17 +990,17 @@ def __init__( class Break(Terminator): - target: Loop + target: Loop | None - def __init__(self, target: Loop, span: Optional[Span] = None) -> None: + def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None: super().__init__(span) self.target = target class Continue(Terminator): - target: Loop + target: Loop | None - def __init__(self, target: Loop, span: Optional[Span] = None) -> None: + def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None: super().__init__(span) self.target = target @@ -1057,7 +1059,7 @@ def update(self, value: Any) -> None: self.update_func(value) else: raise RuntimeError("unable to update comptime value") - + def __str__(self) -> str: return f"ComptimeValue({self.value})" diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index c2b5674..d9e5c66 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -151,6 +151,7 @@ class FuncParser: type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue] bb_stack: List[hir.BasicBlock] type_parser: TypeParser + break_and_continues: List[hir.Break | hir.Continue] | None def __init__(self, name: str, func: object, @@ -173,7 +174,7 @@ def __init__(self, name: str, self.parsed_func = hir.Function(name, [], None) self.type_var_ns = type_var_ns self.bb_stack = [] - + self.break_and_continues = None self.parsed_func.params = signature.params for p in self.parsed_func.params: self.vars[p.name] = p @@ -262,11 +263,12 @@ def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir. if name.id in self.globalns: resolved = self.globalns[name.id] return self.convert_any_to_value(resolved, span) - elif name.id in __builtins__: # type: ignore + elif name.id in __builtins__: # type: ignore resolved = __builtins__[name.id] # type: ignore return self.convert_any_to_value(resolved, span) elif new_var_hint == 'comptime': self.globalns[name.id] = None + def update_fn(value: Any) -> None: self.globalns[name.id] = value return ComptimeValue(None, update_fn) @@ -379,7 +381,9 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct span, f"Expected {len(template_params)} arguments, got {len(args)}") for i, (param, arg) in enumerate(zip(template_params, args)): - assert arg.type is not None + if arg.type is None: + raise hir.TypeInferenceError( + span, f"failed to infer type of argument {i}") template_resolve_args.append((param, arg.type)) resolved_f = f.resolve(template_resolve_args) if isinstance(resolved_f, hir.TemplateMatchingError): @@ -467,6 +471,7 @@ def handle_range() -> hir.Value | ComptimeValue: args[i] = self.try_convert_comptime_value( arg, hir.Span.from_ast(expr.args[i])) converted_args = cast(List[hir.Value], args) + def make_int(i: int) -> hir.Value: return hir.Constant(i, type=hir.GenericIntType()) if len(args) == 1: @@ -516,10 +521,12 @@ def collect_args() -> List[hir.Value | hir.Ref]: raise hir.ParsingError(expr, call.message) assert isinstance(call, hir.Call) return self.cur_bb().append(hir.Load(tmp)) - - if not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)): + if func.type is not None and isinstance(func.type, hir.FunctionType): + func_like = func.type.func_like + elif not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)): raise hir.ParsingError(expr, f"function expected") - func_like = func.value + else: + func_like = func.value ret = self.parse_call_impl( hir.Span.from_ast(expr), func_like, collect_args()) if isinstance(ret, hir.TemplateMatchingError): @@ -791,13 +798,19 @@ def parse_stmt(self, stmt: ast.stmt) -> None: stmt, "while loop condition must not be a comptime value") body = hir.BasicBlock(span) self.bb_stack.append(body) + old_break_and_continues = self.break_and_continues + self.break_and_continues = [] for s in stmt.body: self.parse_stmt(s) + break_and_continues = self.break_and_continues + self.break_and_continues = old_break_and_continues body = self.bb_stack.pop() update = hir.BasicBlock(span) merge = hir.BasicBlock(span) - pred_bb.append( - hir.Loop(prepare, cond, body, update, merge, span)) + loop_node = hir.Loop(prepare, cond, body, update, merge, span) + pred_bb.append(loop_node) + for bc in break_and_continues: + bc.target = loop_node self.bb_stack.append(merge) case ast.For(): iter_val = self.parse_expr(stmt.iter) @@ -828,12 +841,16 @@ def parse_stmt(self, stmt: ast.stmt) -> None: self.bb_stack.pop() body = hir.BasicBlock(span) self.bb_stack.append(body) + old_break_and_continues = self.break_and_continues + self.break_and_continues = [] for s in stmt.body: self.parse_stmt(s) body = self.bb_stack.pop() + break_and_continues = self.break_and_continues + self.break_and_continues = old_break_and_continues update = hir.BasicBlock(span) self.bb_stack.append(update) - inc =loop_range.step + inc = loop_range.step int_add = loop_var.type.method("__add__") assert int_add is not None add = self.parse_call_impl( @@ -842,9 +859,22 @@ def parse_stmt(self, stmt: ast.stmt) -> None: self.cur_bb().append(hir.Assign(loop_var, add)) self.bb_stack.pop() merge = hir.BasicBlock(span) - pred_bb.append( - hir.Loop(prepare, cmp_result, body, update, merge, span)) + loop_node = hir.Loop(prepare, cmp_result, + body, update, merge, span) + pred_bb.append(loop_node) + for bc in break_and_continues: + bc.target = loop_node self.bb_stack.append(merge) + case ast.Break(): + if self.break_and_continues is None: + raise hir.ParsingError( + stmt, "break statement must be inside a loop") + self.cur_bb().append(hir.Break(None, span)) + case ast.Continue(): + if self.break_and_continues is None: + raise hir.ParsingError( + stmt, "continue statement must be inside a loop") + self.cur_bb().append(hir.Continue(None, span)) case ast.Return(): def check_return_type(ty: hir.Type) -> None: assert self.parsed_func