Skip to content

Commit

Permalink
added comptime and reveal_type
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 3, 2024
1 parent d4c1044 commit 0d624e5
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 99 deletions.
3 changes: 3 additions & 0 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,9 @@ class Ref(TypedNode):
class Value(TypedNode):
pass

class Unit(Value):
def __init__(self) -> None:
super().__init__(UnitType())

class SymbolicConstant(Value):
generic: GenericParameter
Expand Down
78 changes: 1 addition & 77 deletions luisa_lang/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from luisa_lang.utils import get_full_name, unique_hash
from luisa_lang.math_types import *
from luisa_lang._builtin_decor import _builtin_type, _builtin, _intrinsic_impl
from luisa_lang.lang_builtins import *
import luisa_lang.hir as hir
import luisa_lang.classinfo as classinfo
import luisa_lang.parse as parse
Expand Down Expand Up @@ -224,80 +225,3 @@ def decorator(f):
return decorator


def type_of_opt(value: Any) -> Optional[hir.Type]:
if isinstance(value, hir.Type):
return value
if isinstance(value, type):
return hir.GlobalContext.get().types[value]
return hir.GlobalContext.get().types.get(type(value))


def typeof(value: Any) -> hir.Type:
ty = type_of_opt(value)
if ty is None:
raise TypeError(f"Cannot determine type of {value}")
return ty


_t = hir.SymbolicType(hir.GenericParameter("_T", "luisa_lang.lang"))
_n = hir.SymbolicConstant(hir.GenericParameter(
"_N", "luisa_lang.lang")), typeof(u32)


# @_builtin_type(
# hir.ParametricType(
# "Array", [hir.TypeParameter(_t, bound=[])], hir.ArrayType(_t, _n)
# )
# )
class Array(Generic[_T, _N]):
def __init__(self) -> None:
return _intrinsic_impl()

def __getitem__(self, index: int | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

def __len__(self) -> u32 | u64:
return _intrinsic_impl()


# @_builtin_type(
# hir.ParametricType(
# "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
# )
# )
class Buffer(Generic[_T]):
def __getitem__(self, index: int | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

def __len__(self) -> u32 | u64:
return _intrinsic_impl()


# @_builtin_type(
# hir.ParametricType(
# "Pointer", [hir.TypeParameter(_t, bound=[])], hir.PointerType(_t)
# )
# )
class Pointer(Generic[_T]):
def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

@property
def value(self) -> _T:
return _intrinsic_impl()

@value.setter
def value(self, value: _T) -> None:
return _intrinsic_impl()


# hir.GlobalContext.get().flush()
98 changes: 95 additions & 3 deletions luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Generic, List, Optional, Sequence, TypeVar, overload
from typing_extensions import TypeAliasType
from luisa_lang.lang import _builtin, _intrinsic_impl
from luisa_lang.lang import *
from luisa_lang.math_types import *
import luisa_lang.hir as hir

_T = TypeVar("_T")

_N = TypeVar("_N")

@_builtin
def dispatch_id() -> uint3:
Expand Down Expand Up @@ -89,8 +91,98 @@ def unroll(range_: Sequence[int]) -> Sequence[int]:


@_builtin
def address_of(a: _T) -> Pointer[_T]:
def address_of(a: _T) -> 'Pointer[_T]':
return _intrinsic_impl()

# class StaticEval:
#


def type_of_opt(value: Any) -> Optional[hir.Type]:
if isinstance(value, hir.Type):
return value
if isinstance(value, type):
return hir.GlobalContext.get().types[value]
return hir.GlobalContext.get().types.get(type(value))


def typeof(value: Any) -> hir.Type:
ty = type_of_opt(value)
if ty is None:
raise TypeError(f"Cannot determine type of {value}")
return ty


_t = hir.SymbolicType(hir.GenericParameter("_T", "luisa_lang.lang"))
_n = hir.SymbolicConstant(hir.GenericParameter(
"_N", "luisa_lang.lang")), typeof(u32)


# @_builtin_type(
# hir.ParametricType(
# "Array", [hir.TypeParameter(_t, bound=[])], hir.ArrayType(_t, _n)
# )
# )
class Array(Generic[_T, _N]):
def __init__(self) -> None:
return _intrinsic_impl()

def __getitem__(self, index: int | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

def __len__(self) -> u32 | u64:
return _intrinsic_impl()


# @_builtin_type(
# hir.ParametricType(
# "Buffer", [hir.TypeParameter(_t, bound=[])], hir.OpaqueType("Buffer")
# )
# )
class Buffer(Generic[_T]):
def __getitem__(self, index: int | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

def __len__(self) -> u32 | u64:
return _intrinsic_impl()


# @_builtin_type(
# hir.ParametricType(
# "Pointer", [hir.TypeParameter(_t, bound=[])], hir.PointerType(_t)
# )
# )
class Pointer(Generic[_T]):
def __getitem__(self, index: int | i32 | i64 | u32 | u64) -> _T:
return _intrinsic_impl()

def __setitem__(self, index: int | i32 | i64 | u32 | u64, value: _T) -> None:
return _intrinsic_impl()

@property
def value(self) -> _T:
return _intrinsic_impl()

@value.setter
def value(self, value: _T) -> None:
return _intrinsic_impl()


__all__: List[str] = [
'Pointer',
'Buffer',
'Array',
'comptime',
'device_log',
'address_of',
'unroll',
'static_assert',
'type_of_opt',
'typeof',
]
97 changes: 78 additions & 19 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, overload
import typing
import luisa_lang
from luisa_lang.lang_builtins import comptime
from luisa_lang.utils import get_typevar_constrains_and_bounds, report_error
import luisa_lang.hir as hir
import sys
Expand Down Expand Up @@ -114,6 +115,12 @@ def convert_func_signature(signature: classinfo.MethodType,
return hir.FunctionSignature(type_parser.generic_params, params, return_type), type_parser


SPECIAL_FUNCTIONS: Set[Callable[..., Any]] = {
comptime,
reveal_type,
}


class FuncParser:
name: str
func: object
Expand All @@ -138,7 +145,7 @@ def __init__(self, name: str,
self.signature = signature
self.globalns = globalns
obj_ast, _obj_file = retrieve_ast_and_filename(func)
print(ast.dump(obj_ast))
# print(ast.dump(obj_ast))
assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module"
if not isinstance(obj_ast.body[0], ast.FunctionDef):
raise RuntimeError("Function definition expected.")
Expand Down Expand Up @@ -205,6 +212,18 @@ def parse_const(self, const: ast.Constant) -> hir.Value:
report_error(
const, f"unsupported constant type {type(value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")

def convert_any_to_value(self, a: Any, span: hir.Span | None) -> hir.Value | ComptimeValue:
if not isinstance(a, ComptimeValue):
a = ComptimeValue(a, None)
if a.value in SPECIAL_FUNCTIONS:
return a
if (converted := self.convert_constexpr(a, span)) is not None:
return converted
if is_valid_comptime_value_in_dsl_code(a.value):
return a
report_error(
span, f"unsupported constant type {type(a.value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression")

def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value | ComptimeValue:
span = hir.Span.from_ast(name)
var = self.vars.get(name.id)
Expand All @@ -218,13 +237,9 @@ def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value
# look up in global namespace
if name.id in self.globalns:
resolved = self.globalns[name.id]
return self.convert_any_to_value(resolved, span)
# assert isinstance(resolved, ComptimeValue), type(resolved)
if not isinstance(resolved, ComptimeValue):
resolved = ComptimeValue(resolved, None)
if (converted := self.convert_constexpr(resolved, span)) is not None:
return converted
if is_valid_comptime_value_in_dsl_code(resolved.value):
return resolved

report_error(name, f"unknown variable {name.id}")

def try_convert_comptime_value(self, value: ComptimeValue, span: hir.Span | None = None) -> hir.Value:
Expand Down Expand Up @@ -346,12 +361,49 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
return self.cur_bb().append(hir.Call(resolved_f, args, type=ty, span=span))
raise NotImplementedError() # unreachable

def parse_call(self, expr: ast.Call) -> hir.Value:
def handle_special_functions(self, f: Callable[..., Any], expr: ast.Call) -> hir.Value | ComptimeValue:
match f:
case _ if f == comptime:
if len(expr.args) != 1:
report_error(
expr, f"when used in expressions, lc.comptime function expects exactly one argument")
arg = expr.args[0]
# print(ast.dump(arg))
if isinstance(arg, ast.Constant) and isinstance(arg.value, str):
evaled = self.eval_expr(arg.value)
else:
evaled = self.eval_expr(arg)
# print(evaled)
v = self.convert_any_to_value(evaled, hir.Span.from_ast(expr))
return v
case _ if f == reveal_type:
if len(expr.args) != 1:
report_error(
expr, f"lc.reveal_type expects exactly one argument")
arg = expr.args[0]
cur_bb = self.cur_bb()
cur_bb_len = len(cur_bb.nodes)
value = self.parse_expr(arg)
assert cur_bb is self.cur_bb()
del self.cur_bb().nodes[cur_bb_len:]
unparsed_arg = ast.unparse(arg)
if isinstance(value, ComptimeValue):
print(
f"Type of {unparsed_arg} is ComptimeValue({type(value.value)})")
else:
print(f"Type of {unparsed_arg} is {value.type}")
return hir.Unit()
case _:
raise RuntimeError(f"Unsupported special function {f}")

def parse_call(self, expr: ast.Call) -> hir.Value | ComptimeValue:
func = self.parse_expr(expr.func)

if isinstance(func, hir.Ref):
report_error(expr, f"function expected")
elif isinstance(func, ComptimeValue):
if func.value in SPECIAL_FUNCTIONS:
return self.handle_special_functions(func.value, expr)
func = self.try_convert_comptime_value(
func, hir.Span.from_ast(expr))

Expand Down Expand Up @@ -471,9 +523,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
case _:
raise RuntimeError(f"Unsupported expression: {ast.dump(expr)}")

def eval_expr(self, tree: ast.Expression | ast.expr):
def eval_expr(self, tree: str | ast.Expression | ast.expr):
if isinstance(tree, ast.expr):
tree = ast.Expression(tree)
# print(tree)
code_object = compile(tree, "<string>", "eval")
localns = {}
for name, v in self.vars.items():
Expand Down Expand Up @@ -531,18 +584,19 @@ def check_return_type(ty: hir.Type):
report_error(
stmt, f"expected {var.type}, got {value.type}")
else:
if not value.type.is_concrete():
report_error(
stmt, "only concrete type can be assigned, please annotate the variable with type hint")
var.type = value.type
self.cur_bb().append(hir.Assign(var, value, span))
case ast.AnnAssign():
var = self.parse_ref(stmt.target, maybe_new_var=True)
if isinstance(var, hir.Value):
report_error(stmt, f"value cannot be assigned")
elif isinstance(var, hir.Ref):
type_annotation = self.eval_expr(stmt.annotation)
type_hint = classinfo.parse_type_hint(type_annotation)
ty = self.parse_type(type_hint)
assert ty
var.type = ty

type_annotation = self.eval_expr(stmt.annotation)
type_hint = classinfo.parse_type_hint(type_annotation)
ty = self.parse_type(type_hint)
assert ty
var.type = ty

if stmt.value:
value = self.parse_expr(stmt.value)
Expand All @@ -560,14 +614,19 @@ def check_return_type(ty: hir.Type):
value = hir.Load(value)
assert value.type
assert ty
if not var.type.is_concrete():
report_error(
stmt, "only concrete type can be assigned, please annotate the variable with concrete types")
if not hir.is_type_compatible_to(value.type, ty):
report_error(
stmt, f"expected {ty}, got {value.type}")
if not value.type.is_concrete():
value.type = var.type
self.cur_bb().append(hir.Assign(var, value, span))
else:
assert isinstance(var, hir.Var)
case ast.Expression():
self.parse_expr(stmt.body)
case ast.Expr():
self.parse_expr(stmt.value)
case ast.Pass():
return
case _:
Expand Down

0 comments on commit 0d624e5

Please sign in to comment.