diff --git a/luisa_lang/_utils.py b/luisa_lang/_utils.py index 71c038a..fd88019 100644 --- a/luisa_lang/_utils.py +++ b/luisa_lang/_utils.py @@ -1,4 +1,7 @@ -from typing import Optional, TypeVar +import ast +import textwrap +from typing import Optional, Tuple, TypeVar +import sourceinspect T = TypeVar("T") @@ -7,3 +10,48 @@ def unwrap(opt: Optional[T]) -> T: if opt is None: raise ValueError("unwrap None") return opt + + +def increment_lineno_and_col_offset( + node: ast.AST, lineno: int, col_offset: int +) -> ast.AST: + """ + Increment the line number and end line number of each node in the tree + starting at *node* by *n*. This is useful to "move code" to a different + location in a file. + """ + for child in ast.walk(node): + if "lineno" in child._attributes: + child.lineno = getattr(child, "lineno", 0) + lineno + if ( + "end_lineno" in child._attributes + and (end_lineno := getattr(child, "end_lineno", 0)) is not None + ): + child.end_lineno = end_lineno + lineno + if "col_offset" in child._attributes: + child.col_offset = getattr(child, "col_offset", 0) + col_offset + if ( + "end_col_offset" in child._attributes + and (end_col_offset := getattr(child, "end_col_offset", 0)) is not None + ): + child.end_col_offset = end_col_offset + col_offset + return node + + +def dedent_and_retrieve_indentation(lines: str) -> Tuple[str, int]: + """ + Dedent the lines and return the indentation level of the first line. + """ + if not lines: + return "", 0 + return textwrap.dedent("".join(lines)), len(lines[0]) - len(lines[0].lstrip()) + + +def retrieve_ast_and_filename(f: object) -> Tuple[ast.AST, str]: + source_file = sourceinspect.getsourcefile(f) + if source_file is None: + source_file = "" + source_lines, lineno = sourceinspect.getsourcelines(f) + src, indent = dedent_and_retrieve_indentation(source_lines) + tree = increment_lineno_and_col_offset(ast.parse(src), lineno - 1, indent + 1) + return tree, source_file diff --git a/luisa_lang/hir.py b/luisa_lang/hir.py index a8c9154..40beb4a 100644 --- a/luisa_lang/hir.py +++ b/luisa_lang/hir.py @@ -105,6 +105,9 @@ def __eq__(self, value: object) -> bool: and value.signed == self.signed ) + def __repr__(self) -> str: + return f"IntType({self.bits}, {self.signed})" + class FloatType(ScalarType): bits: int @@ -121,6 +124,9 @@ def align(self) -> int: def __eq__(self, value: object) -> bool: return isinstance(value, FloatType) and value.bits == self.bits + def __repr__(self) -> str: + return f"FloatType({self.bits})" + # INT8: Final[IntType] = IntType(8, True) # INT16: Final[IntType] = IntType(16, True) @@ -171,6 +177,9 @@ def __eq__(self, value: object) -> bool: and value.count == self.count ) + def __repr__(self) -> str: + return f"VectorType({self.element}, {self.count})" + class ArrayType(Type): element: Type @@ -192,6 +201,9 @@ def __eq__(self, value: object) -> bool: and value.element == self.element and value.count == self.count ) + + def __repr__(self) -> str: + return f"ArrayType({self.element}, {self.count})" class StructType(Type): @@ -545,6 +557,7 @@ def __init__(self) -> None: self.types = {} self.functions = {} + class FuncMetadata: pass diff --git a/luisa_lang/lang.py b/luisa_lang/lang.py index 25752ec..3002e1f 100644 --- a/luisa_lang/lang.py +++ b/luisa_lang/lang.py @@ -16,6 +16,8 @@ ) from luisa_lang._math_type_exports import * from luisa_lang._markers import _builtin_type, _builtin, _intrinsic_impl +import luisa_lang.hir as hir +import luisa_lang.parse as parse import ast _T = TypeVar("_T") @@ -29,23 +31,40 @@ class _ObjKind(Enum): KERNEL = auto() -def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T: +def _dsl_func_impl(f: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T: import sourceinspect - obj_src = sourceinspect.getsource(obj) - obj_ast = ast.parse(obj_src) + from luisa_lang._utils import retrieve_ast_and_filename + import inspect + + assert inspect.isfunction(f), f"{f} is not a function" + obj_ast, obj_file = retrieve_ast_and_filename(f) + assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module" + + ctx = hir.GlobalContext.get() + + func_globals: Any = getattr(f, "__globals__", {}) + parsing_ctx = parse.ParsingContext(func_globals) + print(ast.dump(obj_ast)) + func_def = obj_ast.body[0] + if not isinstance(func_def, ast.FunctionDef): + raise RuntimeError("Function definition expected.") + func_parser = parse.FuncParser(func_def, parsing_ctx) + if kind == _ObjKind.FUNC: + + def dummy(*args, **kwargs): + raise RuntimeError("DSL function should only be called in DSL context.") + + return cast(_T, dummy) + else: + return cast(_T, f) + + +def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T: + if kind == _ObjKind.STRUCT: return obj elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL: - assert callable(obj) - func_globals = getattr(obj, "__globals__", None) - if kind == _ObjKind.FUNC: - - def dummy(*args, **kwargs): - raise RuntimeError("DSL function should only be called in DSL context.") - - return cast(_T, dummy) - else: - return cast(_T, obj) + return _dsl_func_impl(obj, kind, attrs) raise NotImplementedError() @@ -109,16 +128,35 @@ def swap(a: int, b: int): a, b = b, a ``` """ + + def impl(f: _F) -> _F: + return _dsl_decorator_impl(f, _ObjKind.FUNC, kwargs) + if len(args) == 1 and len(kwargs) == 0: f = args[0] - return f + return impl(f) def decorator(f): - return f + return impl(f) return decorator +@_builtin_type +class Array(Generic[_T]): + def __init__(self, size: u32 | u64) -> None: + return _intrinsic_impl() + + def __getitem__(self, index: int | u32 | u64) -> _T: + return _intrinsic_impl() + + def __setitem__(self, index: int | u32 | u64, value: _T) -> None: + return _intrinsic_impl() + + def __len__(self) -> u32 | u64: + return _intrinsic_impl() + + @_builtin_type class Buffer(Generic[_T]): def __getitem__(self, index: int | u32 | u64) -> _T: diff --git a/luisa_lang/math_types.py b/luisa_lang/math_types.py index 9e7884f..899bfeb 100644 --- a/luisa_lang/math_types.py +++ b/luisa_lang/math_types.py @@ -1,6 +1,9 @@ # fmt: off import typing as tp from luisa_lang._markers import _builtin, _builtin_type, _intrinsic_impl +import luisa_lang.hir as _hir +_ctx = _hir.GlobalContext.get() +_ctx.types[bool] = _hir.BoolType() FLOAT_TYPES: tp.Final[tp.List[str]] = ["f32", "f64", "float2", "double2", "float3", "double3", "float4", "double4"] FloatType = tp.Union["f32", "f64", "float2", "double2", "float3", "double3", "float4", "double4"] _F = tp.TypeVar("_F") @@ -95,6 +98,7 @@ def __rpow__(self, _other: tp.Union['f32', float]) -> 'f32': return _intrinsic_ def __ipow__(self, _other: tp.Union['f32', float]) -> 'f32': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['f32', float]) -> 'f32': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['f32', float]) -> 'f32': return _intrinsic_impl() +_ctx.types[f32] = _hir.FloatType(32) @_builtin_type class f64(FloatBuiltin['f64']): @@ -119,6 +123,7 @@ def __rpow__(self, _other: tp.Union['f64', float]) -> 'f64': return _intrinsic_ def __ipow__(self, _other: tp.Union['f64', float]) -> 'f64': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['f64', float]) -> 'f64': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['f64', float]) -> 'f64': return _intrinsic_impl() +_ctx.types[f64] = _hir.FloatType(64) @_builtin_type class i8: @@ -153,6 +158,7 @@ def __ior__(self, _other: tp.Union['i8', int]) -> 'i8': return _intrinsic_impl( def __xor__(self, _other: tp.Union['i8', int]) -> 'i8': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['i8', int]) -> 'i8': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['i8', int]) -> 'i8': return _intrinsic_impl() +_ctx.types[i8] = _hir.IntType(8, True) @_builtin_type class u8: @@ -187,6 +193,7 @@ def __ior__(self, _other: tp.Union['u8', int]) -> 'u8': return _intrinsic_impl( def __xor__(self, _other: tp.Union['u8', int]) -> 'u8': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['u8', int]) -> 'u8': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['u8', int]) -> 'u8': return _intrinsic_impl() +_ctx.types[u8] = _hir.IntType(8, False) @_builtin_type class i16: @@ -221,6 +228,7 @@ def __ior__(self, _other: tp.Union['i16', int]) -> 'i16': return _intrinsic_imp def __xor__(self, _other: tp.Union['i16', int]) -> 'i16': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['i16', int]) -> 'i16': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['i16', int]) -> 'i16': return _intrinsic_impl() +_ctx.types[i16] = _hir.IntType(16, True) @_builtin_type class u16: @@ -255,6 +263,7 @@ def __ior__(self, _other: tp.Union['u16', int]) -> 'u16': return _intrinsic_imp def __xor__(self, _other: tp.Union['u16', int]) -> 'u16': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['u16', int]) -> 'u16': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['u16', int]) -> 'u16': return _intrinsic_impl() +_ctx.types[u16] = _hir.IntType(16, False) @_builtin_type class i32: @@ -289,6 +298,7 @@ def __ior__(self, _other: tp.Union['i32', int]) -> 'i32': return _intrinsic_imp def __xor__(self, _other: tp.Union['i32', int]) -> 'i32': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['i32', int]) -> 'i32': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['i32', int]) -> 'i32': return _intrinsic_impl() +_ctx.types[i32] = _hir.IntType(32, True) @_builtin_type class u32: @@ -323,6 +333,7 @@ def __ior__(self, _other: tp.Union['u32', int]) -> 'u32': return _intrinsic_imp def __xor__(self, _other: tp.Union['u32', int]) -> 'u32': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['u32', int]) -> 'u32': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['u32', int]) -> 'u32': return _intrinsic_impl() +_ctx.types[u32] = _hir.IntType(32, False) @_builtin_type class i64: @@ -357,6 +368,7 @@ def __ior__(self, _other: tp.Union['i64', int]) -> 'i64': return _intrinsic_imp def __xor__(self, _other: tp.Union['i64', int]) -> 'i64': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['i64', int]) -> 'i64': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['i64', int]) -> 'i64': return _intrinsic_impl() +_ctx.types[i64] = _hir.IntType(64, True) @_builtin_type class u64: @@ -391,6 +403,7 @@ def __ior__(self, _other: tp.Union['u64', int]) -> 'u64': return _intrinsic_imp def __xor__(self, _other: tp.Union['u64', int]) -> 'u64': return _intrinsic_impl() def __rxor__(self, _other: tp.Union['u64', int]) -> 'u64': return _intrinsic_impl() def __ixor__(self, _other: tp.Union['u64', int]) -> 'u64': return _intrinsic_impl() +_ctx.types[u64] = _hir.IntType(64, False) @_builtin_type class bool2: @@ -416,6 +429,7 @@ def __rpow__(self, _other: tp.Union['bool2', bool, bool]) -> 'bool2': return _i def __ipow__(self, _other: tp.Union['bool2', bool, bool]) -> 'bool2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['bool2', bool, bool]) -> 'bool2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['bool2', bool, bool]) -> 'bool2': return _intrinsic_impl() +_ctx.types[bool2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[bool]), 2) @_builtin_type class float2: @@ -441,6 +455,7 @@ def __rpow__(self, _other: tp.Union['float2', f32, float]) -> 'float2': return def __ipow__(self, _other: tp.Union['float2', f32, float]) -> 'float2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['float2', f32, float]) -> 'float2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['float2', f32, float]) -> 'float2': return _intrinsic_impl() +_ctx.types[float2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f32]), 2) @_builtin_type class double2: @@ -466,6 +481,7 @@ def __rpow__(self, _other: tp.Union['double2', f64, float]) -> 'double2': retur def __ipow__(self, _other: tp.Union['double2', f64, float]) -> 'double2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['double2', f64, float]) -> 'double2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['double2', f64, float]) -> 'double2': return _intrinsic_impl() +_ctx.types[double2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f64]), 2) @_builtin_type class byte2: @@ -491,6 +507,7 @@ def __rpow__(self, _other: tp.Union['byte2', i8, int]) -> 'byte2': return _intr def __ipow__(self, _other: tp.Union['byte2', i8, int]) -> 'byte2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['byte2', i8, int]) -> 'byte2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['byte2', i8, int]) -> 'byte2': return _intrinsic_impl() +_ctx.types[byte2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i8]), 2) @_builtin_type class ubyte2: @@ -516,6 +533,7 @@ def __rpow__(self, _other: tp.Union['ubyte2', u8, int]) -> 'ubyte2': return _in def __ipow__(self, _other: tp.Union['ubyte2', u8, int]) -> 'ubyte2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ubyte2', u8, int]) -> 'ubyte2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ubyte2', u8, int]) -> 'ubyte2': return _intrinsic_impl() +_ctx.types[ubyte2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u8]), 2) @_builtin_type class short2: @@ -541,6 +559,7 @@ def __rpow__(self, _other: tp.Union['short2', i16, int]) -> 'short2': return _i def __ipow__(self, _other: tp.Union['short2', i16, int]) -> 'short2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['short2', i16, int]) -> 'short2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['short2', i16, int]) -> 'short2': return _intrinsic_impl() +_ctx.types[short2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i16]), 2) @_builtin_type class ushort2: @@ -566,6 +585,7 @@ def __rpow__(self, _other: tp.Union['ushort2', u16, int]) -> 'ushort2': return def __ipow__(self, _other: tp.Union['ushort2', u16, int]) -> 'ushort2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ushort2', u16, int]) -> 'ushort2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ushort2', u16, int]) -> 'ushort2': return _intrinsic_impl() +_ctx.types[ushort2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u16]), 2) @_builtin_type class int2: @@ -591,6 +611,7 @@ def __rpow__(self, _other: tp.Union['int2', i32, int]) -> 'int2': return _intri def __ipow__(self, _other: tp.Union['int2', i32, int]) -> 'int2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['int2', i32, int]) -> 'int2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['int2', i32, int]) -> 'int2': return _intrinsic_impl() +_ctx.types[int2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i32]), 2) @_builtin_type class uint2: @@ -616,6 +637,7 @@ def __rpow__(self, _other: tp.Union['uint2', u32, int]) -> 'uint2': return _int def __ipow__(self, _other: tp.Union['uint2', u32, int]) -> 'uint2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['uint2', u32, int]) -> 'uint2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['uint2', u32, int]) -> 'uint2': return _intrinsic_impl() +_ctx.types[uint2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u32]), 2) @_builtin_type class long2: @@ -641,6 +663,7 @@ def __rpow__(self, _other: tp.Union['long2', i64, int]) -> 'long2': return _int def __ipow__(self, _other: tp.Union['long2', i64, int]) -> 'long2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['long2', i64, int]) -> 'long2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['long2', i64, int]) -> 'long2': return _intrinsic_impl() +_ctx.types[long2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i64]), 2) @_builtin_type class ulong2: @@ -666,6 +689,7 @@ def __rpow__(self, _other: tp.Union['ulong2', u64, int]) -> 'ulong2': return _i def __ipow__(self, _other: tp.Union['ulong2', u64, int]) -> 'ulong2': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ulong2', u64, int]) -> 'ulong2': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ulong2', u64, int]) -> 'ulong2': return _intrinsic_impl() +_ctx.types[ulong2] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u64]), 2) @_builtin_type class bool3: @@ -692,6 +716,7 @@ def __rpow__(self, _other: tp.Union['bool3', bool, bool]) -> 'bool3': return _i def __ipow__(self, _other: tp.Union['bool3', bool, bool]) -> 'bool3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['bool3', bool, bool]) -> 'bool3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['bool3', bool, bool]) -> 'bool3': return _intrinsic_impl() +_ctx.types[bool3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[bool]), 3) @_builtin_type class float3: @@ -718,6 +743,7 @@ def __rpow__(self, _other: tp.Union['float3', f32, float]) -> 'float3': return def __ipow__(self, _other: tp.Union['float3', f32, float]) -> 'float3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['float3', f32, float]) -> 'float3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['float3', f32, float]) -> 'float3': return _intrinsic_impl() +_ctx.types[float3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f32]), 3) @_builtin_type class double3: @@ -744,6 +770,7 @@ def __rpow__(self, _other: tp.Union['double3', f64, float]) -> 'double3': retur def __ipow__(self, _other: tp.Union['double3', f64, float]) -> 'double3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['double3', f64, float]) -> 'double3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['double3', f64, float]) -> 'double3': return _intrinsic_impl() +_ctx.types[double3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f64]), 3) @_builtin_type class byte3: @@ -770,6 +797,7 @@ def __rpow__(self, _other: tp.Union['byte3', i8, int]) -> 'byte3': return _intr def __ipow__(self, _other: tp.Union['byte3', i8, int]) -> 'byte3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['byte3', i8, int]) -> 'byte3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['byte3', i8, int]) -> 'byte3': return _intrinsic_impl() +_ctx.types[byte3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i8]), 3) @_builtin_type class ubyte3: @@ -796,6 +824,7 @@ def __rpow__(self, _other: tp.Union['ubyte3', u8, int]) -> 'ubyte3': return _in def __ipow__(self, _other: tp.Union['ubyte3', u8, int]) -> 'ubyte3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ubyte3', u8, int]) -> 'ubyte3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ubyte3', u8, int]) -> 'ubyte3': return _intrinsic_impl() +_ctx.types[ubyte3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u8]), 3) @_builtin_type class short3: @@ -822,6 +851,7 @@ def __rpow__(self, _other: tp.Union['short3', i16, int]) -> 'short3': return _i def __ipow__(self, _other: tp.Union['short3', i16, int]) -> 'short3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['short3', i16, int]) -> 'short3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['short3', i16, int]) -> 'short3': return _intrinsic_impl() +_ctx.types[short3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i16]), 3) @_builtin_type class ushort3: @@ -848,6 +878,7 @@ def __rpow__(self, _other: tp.Union['ushort3', u16, int]) -> 'ushort3': return def __ipow__(self, _other: tp.Union['ushort3', u16, int]) -> 'ushort3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ushort3', u16, int]) -> 'ushort3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ushort3', u16, int]) -> 'ushort3': return _intrinsic_impl() +_ctx.types[ushort3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u16]), 3) @_builtin_type class int3: @@ -874,6 +905,7 @@ def __rpow__(self, _other: tp.Union['int3', i32, int]) -> 'int3': return _intri def __ipow__(self, _other: tp.Union['int3', i32, int]) -> 'int3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['int3', i32, int]) -> 'int3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['int3', i32, int]) -> 'int3': return _intrinsic_impl() +_ctx.types[int3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i32]), 3) @_builtin_type class uint3: @@ -900,6 +932,7 @@ def __rpow__(self, _other: tp.Union['uint3', u32, int]) -> 'uint3': return _int def __ipow__(self, _other: tp.Union['uint3', u32, int]) -> 'uint3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['uint3', u32, int]) -> 'uint3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['uint3', u32, int]) -> 'uint3': return _intrinsic_impl() +_ctx.types[uint3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u32]), 3) @_builtin_type class long3: @@ -926,6 +959,7 @@ def __rpow__(self, _other: tp.Union['long3', i64, int]) -> 'long3': return _int def __ipow__(self, _other: tp.Union['long3', i64, int]) -> 'long3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['long3', i64, int]) -> 'long3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['long3', i64, int]) -> 'long3': return _intrinsic_impl() +_ctx.types[long3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i64]), 3) @_builtin_type class ulong3: @@ -952,6 +986,7 @@ def __rpow__(self, _other: tp.Union['ulong3', u64, int]) -> 'ulong3': return _i def __ipow__(self, _other: tp.Union['ulong3', u64, int]) -> 'ulong3': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ulong3', u64, int]) -> 'ulong3': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ulong3', u64, int]) -> 'ulong3': return _intrinsic_impl() +_ctx.types[ulong3] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u64]), 3) @_builtin_type class bool4: @@ -979,6 +1014,7 @@ def __rpow__(self, _other: tp.Union['bool4', bool, bool]) -> 'bool4': return _i def __ipow__(self, _other: tp.Union['bool4', bool, bool]) -> 'bool4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['bool4', bool, bool]) -> 'bool4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['bool4', bool, bool]) -> 'bool4': return _intrinsic_impl() +_ctx.types[bool4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[bool]), 4) @_builtin_type class float4: @@ -1006,6 +1042,7 @@ def __rpow__(self, _other: tp.Union['float4', f32, float]) -> 'float4': return def __ipow__(self, _other: tp.Union['float4', f32, float]) -> 'float4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['float4', f32, float]) -> 'float4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['float4', f32, float]) -> 'float4': return _intrinsic_impl() +_ctx.types[float4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f32]), 4) @_builtin_type class double4: @@ -1033,6 +1070,7 @@ def __rpow__(self, _other: tp.Union['double4', f64, float]) -> 'double4': retur def __ipow__(self, _other: tp.Union['double4', f64, float]) -> 'double4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['double4', f64, float]) -> 'double4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['double4', f64, float]) -> 'double4': return _intrinsic_impl() +_ctx.types[double4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[f64]), 4) @_builtin_type class byte4: @@ -1060,6 +1098,7 @@ def __rpow__(self, _other: tp.Union['byte4', i8, int]) -> 'byte4': return _intr def __ipow__(self, _other: tp.Union['byte4', i8, int]) -> 'byte4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['byte4', i8, int]) -> 'byte4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['byte4', i8, int]) -> 'byte4': return _intrinsic_impl() +_ctx.types[byte4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i8]), 4) @_builtin_type class ubyte4: @@ -1087,6 +1126,7 @@ def __rpow__(self, _other: tp.Union['ubyte4', u8, int]) -> 'ubyte4': return _in def __ipow__(self, _other: tp.Union['ubyte4', u8, int]) -> 'ubyte4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ubyte4', u8, int]) -> 'ubyte4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ubyte4', u8, int]) -> 'ubyte4': return _intrinsic_impl() +_ctx.types[ubyte4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u8]), 4) @_builtin_type class short4: @@ -1114,6 +1154,7 @@ def __rpow__(self, _other: tp.Union['short4', i16, int]) -> 'short4': return _i def __ipow__(self, _other: tp.Union['short4', i16, int]) -> 'short4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['short4', i16, int]) -> 'short4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['short4', i16, int]) -> 'short4': return _intrinsic_impl() +_ctx.types[short4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i16]), 4) @_builtin_type class ushort4: @@ -1141,6 +1182,7 @@ def __rpow__(self, _other: tp.Union['ushort4', u16, int]) -> 'ushort4': return def __ipow__(self, _other: tp.Union['ushort4', u16, int]) -> 'ushort4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ushort4', u16, int]) -> 'ushort4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ushort4', u16, int]) -> 'ushort4': return _intrinsic_impl() +_ctx.types[ushort4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u16]), 4) @_builtin_type class int4: @@ -1168,6 +1210,7 @@ def __rpow__(self, _other: tp.Union['int4', i32, int]) -> 'int4': return _intri def __ipow__(self, _other: tp.Union['int4', i32, int]) -> 'int4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['int4', i32, int]) -> 'int4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['int4', i32, int]) -> 'int4': return _intrinsic_impl() +_ctx.types[int4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i32]), 4) @_builtin_type class uint4: @@ -1195,6 +1238,7 @@ def __rpow__(self, _other: tp.Union['uint4', u32, int]) -> 'uint4': return _int def __ipow__(self, _other: tp.Union['uint4', u32, int]) -> 'uint4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['uint4', u32, int]) -> 'uint4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['uint4', u32, int]) -> 'uint4': return _intrinsic_impl() +_ctx.types[uint4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u32]), 4) @_builtin_type class long4: @@ -1222,6 +1266,7 @@ def __rpow__(self, _other: tp.Union['long4', i64, int]) -> 'long4': return _int def __ipow__(self, _other: tp.Union['long4', i64, int]) -> 'long4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['long4', i64, int]) -> 'long4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['long4', i64, int]) -> 'long4': return _intrinsic_impl() +_ctx.types[long4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[i64]), 4) @_builtin_type class ulong4: @@ -1249,4 +1294,5 @@ def __rpow__(self, _other: tp.Union['ulong4', u64, int]) -> 'ulong4': return _i def __ipow__(self, _other: tp.Union['ulong4', u64, int]) -> 'ulong4': return _intrinsic_impl() def __floordiv__(self, _other: tp.Union['ulong4', u64, int]) -> 'ulong4': return _intrinsic_impl() def __rfloordiv__(self, _other: tp.Union['ulong4', u64, int]) -> 'ulong4': return _intrinsic_impl() +_ctx.types[ulong4] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[u64]), 4) diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 56c3095..2312be0 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -75,9 +75,100 @@ class AccessChain: parent: Any chain: List[Tuple[AccessKind, List[AccessKey]]] - def __init__(self, parent: Any): + def __init__( + self, + parent: Any, + chain: Optional[List[Tuple[AccessKind, List[AccessKey]]]] = None, + ): self.parent = parent - self.chain = [] + if chain is None: + self.chain = [] + else: + self.chain = chain + + def __repr__(self) -> str: + return f"AccessChain({self.parent}, {self.chain})" + + def resolve(self) -> Tuple[Any, Optional["AccessChain"]]: + """ + Attempt to resolve the access chain. + if the chain is fully resolved, return the final object and None. + otherwise, return most resolved object and the remaining chain. + """ + cur = self.parent + chain_idx = None + + def eval_keys( + cur: Any, access: Tuple[AccessKind, List[AccessKey]] + ) -> Tuple[bool, Any]: + kind, keys = access + if kind == AccessKind.ATTRIBUTE: + assert len(keys) == 1, f"expected single key" + assert isinstance(keys[0], str), f"expected string key" + return True, getattr(cur, keys[0]) + # kind == AccessKind.SUBSCRIPT + evaled_keys: List[Any] = [] + for key in keys: + if isinstance(key, str): + evaled_keys.append(key) + elif isinstance(key, ast.AST): + raise NotImplementedError("TODO: eval ast key") + else: + assert isinstance(key, AccessChain), f"expected AccessChain key" + resolved, remaining = key.resolve() + if remaining is not None: + return False, None + evaled_keys.append(resolved) + try: + return True, cur[tuple(evaled_keys)] + except KeyError: + return False, None + + def get_access_key() -> Optional[Tuple[AccessKind, List[AccessKey]]]: + if chain_idx is None: + if len(self.chain) == 0: + return None + return self.chain[0] + if chain_idx + 1 >= len(self.chain): + return None + return self.chain[chain_idx + 1] + + ctx = hir.GlobalContext.get() + while True: + access = get_access_key() + + # check type of cur to determine what to do + ## is cur a module? + if type(cur) == ModuleType: + if access is None: + break + success, cur = eval_keys(cur, access) + if not success: + break + # is cur a type? + elif type(cur) == type: + hir_ty = ctx.types.get(cur) + if hir_ty is None: + if access is None: + break + success, cur = eval_keys(cur, access) + if not success: + break + else: + if access is None: + return hir_ty, None + raise NotImplementedError() + if chain_idx is None: + if len(self.chain) == 0: + break + chain_idx = 0 + else: + chain_idx += 1 + if chain_idx >= len(self.chain): + break + if chain_idx is None: + return cur, None + return cur, AccessChain(cur, self.chain[chain_idx:]) class ParsingContext: @@ -85,9 +176,9 @@ class ParsingContext: global_ctx: hir.GlobalContext name_eval_cache: Dict[str, Optional[Any]] - def __init__(self, globals: Dict[str, Any], global_ctx: hir.GlobalContext): + def __init__(self, globals: Dict[str, Any]): self.globals = globals - self.global_ctx = global_ctx + self.global_ctx = hir.GlobalContext.get() self.name_eval_cache = {} def __eval_name(self, name: str) -> Optional[Any]: @@ -155,16 +246,14 @@ def parse_type(self, type_tree: ast.AST) -> Optional[Type]: acess_chain = self.__resolve_access_chain(type_tree) if acess_chain is None: return None - cur: Any = None - chain_idx = None - while True: - if isinstance(cur, type): - if cur in self.global_ctx.types: - if chain_idx is None or chain_idx >= len(acess_chain.chain): - return self.global_ctx.types[cur] - return self.global_ctx.types[cur] - break - raise NotImplementedError("TODO: parse type") + print(acess_chain) + resolved, remaining = acess_chain.resolve() + print(resolved, remaining) + if remaining is not None: + report_error(type_tree, f"failed to resolve type") + if isinstance(resolved, hir.Type): + return resolved + return None class FuncParser: @@ -181,6 +270,7 @@ def __init__(self, func: ast.FunctionDef, p_ctx: ParsingContext) -> None: self.arg_types = [] self.return_type = None self._init_signature() + print(self.arg_types, "->", self.return_type) def _init_signature( self, @@ -193,12 +283,11 @@ def _init_signature( report_error(args.vararg, f"vararg not supported") if args.kwarg is not None: report_error(args.kwarg, f"kwarg not supported") - arg_types: List[Type] = [] for arg in args.args: if arg.annotation is None: raise RuntimeError("TODO: infer type") if (arg_ty := p_ctx.parse_type(arg.annotation)) is not None: - arg_types.append(arg_ty) + self.arg_types.append(arg_ty) if func.returns is None: self.return_type = hir.UnitType() else: diff --git a/scripts/gen_math_types.py b/scripts/gen_math_types.py index 5efda00..aaad7a6 100644 --- a/scripts/gen_math_types.py +++ b/scripts/gen_math_types.py @@ -53,6 +53,9 @@ def print(s: str) -> None: print( "from luisa_lang._markers import _builtin, _builtin_type, _intrinsic_impl", ) + print("import luisa_lang.hir as _hir") + print("_ctx = _hir.GlobalContext.get()") + print("_ctx.types[bool] = _hir.BoolType()") def gen_float_builtins(): print("class FloatBuiltin(tp.Generic[_F]):") @@ -120,6 +123,12 @@ def __init__(self, _value: tp.Union['{ty}', {literal_ty}]) -> None: return _intr """ ) gen_common_binop(ty, f" tp.Union['{ty}', {literal_ty}]", kind) + if kind == Kind.FLOAT: + bits = 32 if ty == "f32" else 64 + print(f"_ctx.types[{ty}] = _hir.FloatType({bits})") + else: + signed = ty[0] == "i" + print(f"_ctx.types[{ty}] = _hir.IntType({int(ty[1:])}, {signed})") print("") def gen_vector_type(ty: str, scalar_ty: str, literal_scalar_ty: str, size: int): @@ -136,6 +145,7 @@ class {ty}: gen_common_binop( ty, f" tp.Union['{ty}', {scalar_ty}, {literal_scalar_ty}]", Kind.FLOAT ) + print(f"_ctx.types[{ty}] = _hir.VectorType(tp.cast(_hir.ScalarType, _ctx.types[{scalar_ty}]), {size})") print("") float_types = ["f32", "f64"]