Skip to content

Commit

Permalink
various fix
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 4, 2024
1 parent baf2020 commit 78733cb
Show file tree
Hide file tree
Showing 6 changed files with 424 additions and 306 deletions.
10 changes: 5 additions & 5 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
assert isinstance(p, hir.SymbolicType)
implicit_generic_params.add(p.param)

def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
type_var_ns: Dict[TypeVar, hir.Type |
hir.ComptimeValue] = foreign_type_var_ns.copy()
mapped_implicit_type_params: Dict[str,
Expand Down Expand Up @@ -127,7 +127,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
ret = func_parser.parse_body()
ret.inline_hint = props.inline
ret.export = props.export
ret.export = props.export
return ret
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig_converted.generic_params) > 0
Expand Down Expand Up @@ -271,9 +271,9 @@ def volume(self) -> float:
return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {})


def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[type[_TT]], _TT]:
def decorator(cls: type[_TT]) -> _TT:
return typing.cast(_TT, _dsl_struct_impl(cls, {}, ir_ty_override=ty))
def builtin_type(ty: hir.Type, *args, **kwargs) -> Callable[[type[_TT]], type[_TT]]:
def decorator(cls: type[_TT]) -> type[_TT]:
return typing.cast(type[_TT], _dsl_struct_impl(cls, {}, ir_ty_override=ty))
return decorator


Expand Down
15 changes: 10 additions & 5 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,20 @@ def mangle_name(name: str) -> str:


class Mangling:
cache: Dict[hir.Type | hir.FunctionLike, str]
cache: Dict[hir.Type | hir.Function, str]

def __init__(self) -> None:
self.cache = {}

def mangle(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
def mangle(self, obj: Union[hir.Type, hir.Function]) -> str:
if obj in self.cache:
return self.cache[obj]
else:
res = self.mangle_impl(obj)
self.cache[obj] = res
return res

def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:

match obj:
case hir.UnitType():
Expand Down Expand Up @@ -266,7 +266,7 @@ def gen_ref(self, ref: hir.Ref) -> str:
case _:
raise NotImplementedError(f"unsupported reference: {ref}")

def gen_func(self, f: hir.FunctionLike) -> str:
def gen_func(self, f: hir.Function) -> str:
if isinstance(f, hir.Function):
return self.base.gen_function(f)
else:
Expand Down Expand Up @@ -346,7 +346,12 @@ def do():
comps = intrin_name.split('.')
gened_args = [self.gen_value_or_ref(
arg) for arg in intrin.args]
if comps[0] == 'cmp':
if comps[0] == 'init':
assert expr.type
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(
f"{ty} v{vid}{{ {','.join(gened_args)} }};")
elif comps[0] == 'cmp':
cmp_dict = {
'__eq__': '==',
'__ne__': '!=',
Expand Down
136 changes: 117 additions & 19 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@

PATH_PREFIX = "luisa_lang"

FunctionLike = Union["Function"]


# @dataclass
# class FunctionTemplateResolveResult:
# func: Optional[FunctionLike]
# func: Optional[Function]
# matched: bool


Expand All @@ -40,17 +38,18 @@


FunctionTemplateResolvingFunc = Callable[[
FunctionTemplateResolvingArgs], Union[FunctionLike, 'TemplateMatchingError']]
FunctionTemplateResolvingArgs], Union['Function', 'TemplateMatchingError']]


class FuncProperties:
inline: bool | Literal["always"]
inline: bool | Literal["never", "always"]
export: bool
byref: Set[str]

def __init__(self):
self.inline = False
self.export = False
self.byref = set()
self.byref = set()


class FunctionTemplate:
Expand All @@ -63,7 +62,7 @@ class FunctionTemplate:
"""
parsing_func: FunctionTemplateResolvingFunc
__resolved: Dict[Tuple[Tuple[str,
Union['Type', Any]], ...], FunctionLike]
Union['Type', Any]], ...], "Function"]
is_generic: bool
name: str
params: List[str]
Expand All @@ -78,7 +77,7 @@ def __init__(self, name: str, params: List[str], parsing_func: FunctionTemplateR
self.name = name
self.props = None

def resolve(self, args: FunctionTemplateResolvingArgs | None) -> Union[FunctionLike, 'TemplateMatchingError']:
def resolve(self, args: FunctionTemplateResolvingArgs | None) -> Union["Function", 'TemplateMatchingError']:
args = args or []
if not self.is_generic:
key = tuple(args)
Expand All @@ -101,7 +100,7 @@ class DynamicIndex:


class Type(ABC):
methods: Dict[str, Union[FunctionLike]]
methods: Dict[str, Union["Function", FunctionTemplate]]
is_builtin: bool

def __init__(self):
Expand Down Expand Up @@ -132,7 +131,7 @@ def member(self, field: Any) -> Optional['Type']:
return FunctionType(m, None)
return None

def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:
def method(self, name: str) -> Optional[Union["Function", FunctionTemplate]]:
m = self.methods.get(name)
if m:
return m
Expand Down Expand Up @@ -738,7 +737,7 @@ def member(self, field) -> Optional['Type']:
raise RuntimeError("member access on uninstantiated BoundType")

@override
def method(self, name) -> Optional[FunctionLike | FunctionTemplate]:
def method(self, name) -> Optional[Union["Function", FunctionTemplate]]:
if self.instantiated is not None:
return self.instantiated.method(name)
else:
Expand Down Expand Up @@ -766,10 +765,10 @@ def __hash__(self) -> int:


class FunctionType(Type):
func_like: FunctionLike | FunctionTemplate
func_like: Union["Function", FunctionTemplate]
bound_object: Optional['Ref']

def __init__(self, func_like: FunctionLike | FunctionTemplate, bound_object: Optional['Ref']) -> None:
def __init__(self, func_like: Union["Function", FunctionTemplate], bound_object: Optional['Ref']) -> None:
super().__init__()
self.func_like = func_like
self.bound_object = bound_object
Expand Down Expand Up @@ -950,6 +949,8 @@ def __eq__(self, value: object) -> bool:

def __hash__(self) -> int:
return hash(self.value)


class TypeValue(Value):
def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
super().__init__(TypeConstructorType(ty), span)
Expand All @@ -958,10 +959,12 @@ def inner_type(self) -> Type:
assert isinstance(self.type, TypeConstructorType)
return self.type.inner


class FunctionValue(Value):
def __init__(self, ty:FunctionType, span: Optional[Span] = None) -> None:
def __init__(self, ty: FunctionType, span: Optional[Span] = None) -> None:
super().__init__(ty, span)


class Alloca(Ref):
"""
A temporary variable
Expand Down Expand Up @@ -1003,14 +1006,14 @@ def __repr__(self) -> str:


class Call(Value):
op: FunctionLike
op: "Function"
"""After type inference, op should be a Value."""

args: List[Value | Ref]

def __init__(
self,
op: FunctionLike,
op: "Function",
args: List[Value | Ref],
type: Type,
span: Optional[Span] = None,
Expand Down Expand Up @@ -1077,7 +1080,7 @@ class Assign(Node):
value: Value

def __init__(self, ref: Ref, value: Value, span: Optional[Span] = None) -> None:
assert not isinstance(value.type, (FunctionType, TypeConstructorType))
assert not isinstance(value.type, (FunctionType, TypeConstructorType))
super().__init__(span)
self.ref = ref
self.value = value
Expand Down Expand Up @@ -1206,7 +1209,7 @@ class Function:
locals: List[Var]
complete: bool
is_method: bool
inline_hint: Literal[True, 'always', 'never'] | None
inline_hint: bool | Literal['always', 'never']

def __init__(
self,
Expand All @@ -1223,7 +1226,7 @@ def __init__(
self.locals = []
self.complete = False
self.is_method = is_method
self.inline_hint = None
self.inline_hint = False


def match_template_args(
Expand Down Expand Up @@ -1408,3 +1411,98 @@ def is_type_compatible_to(ty: Type, target: Type) -> bool:
if isinstance(target, IntType):
return isinstance(ty, GenericIntType)
return False


class FunctionInliner:
mapping: Dict[Ref | Value, Ref | Value]
ret: Value | None

def __init__(self, func: Function, args: List[Value | Ref], body: BasicBlock, span: Optional[Span] = None) -> None:
self.mapping = {}
for param, arg in zip(func.params, args):
self.mapping[param] = arg
assert func.body
self.do_inline(func.body, body)

def do_inline(self, func_body: BasicBlock, body: BasicBlock) -> None:
for node in func_body.nodes:
assert node not in self.mapping

match node:
case Var():
assert node.type
assert node.semantic == ParameterSemantic.BYVAL
self.mapping[node] = Alloca(node.type, node.span)
case Load():
mapped_var = self.mapping[node.ref]
assert isinstance(mapped_var, Ref)
body.append(Load(mapped_var))
case Index():
base = self.mapping.get(node.base)
assert isinstance(base, Value)
index = self.mapping.get(node.index)
assert isinstance(index, Value)
assert node.type
self.mapping[node] = body.append(
Index(base, index, node.type, node.span))
case IndexRef():
base = self.mapping.get(node.base)
index = self.mapping.get(node.index)
assert isinstance(base, Ref)
assert isinstance(index, Value)
assert node.type
self.mapping[node] = body.append(IndexRef(
base, index, node.type, node.span))
case Member():
base = self.mapping.get(node.base)
assert isinstance(base, Value)
assert node.type
self.mapping[node] = body.append(Member(
base, node.field, node.type, node.span))
case MemberRef():
base = self.mapping.get(node.base)
assert isinstance(base, Ref)
assert node.type
self.mapping[node] = body.append(MemberRef(
base, node.field, node.type, node.span))
case Call() as call:
def do():
args: List[Ref | Value] = []
for arg in call.args:
mapped_arg = self.mapping.get(arg)
if mapped_arg is None:
raise ParsingError(node, "unable to inline call")
args.append(mapped_arg)
assert call.type
self.mapping[call] = body.append(
Call(call.op, args, call.type, node.span))
do()
case Intrinsic() as intrin:
def do():
args: List[Ref | Value] = []
for arg in intrin.args:
mapped_arg = self.mapping.get(arg)
if mapped_arg is None:
raise ParsingError(
node, "unable to inline intrinsic")
args.append(mapped_arg)
assert intrin.type
self.mapping[intrin] = body.append(
Intrinsic(intrin.name, args, intrin.type, node.span))
do()
case Return():
if self.ret is not None:
raise ParsingError(node, "multiple return statement")
assert node.value is not None
mapped_value = self.mapping.get(node.value)
if mapped_value is None or isinstance(mapped_value, Ref):
raise ParsingError(node, "unable to inline return")
self.ret = mapped_value
case _:
raise ParsingError(node, "invalid node for inlining")

@staticmethod
def inline(func: Function, args: List[Value | Ref], body: BasicBlock, span: Optional[Span] = None) -> Value:
inliner = FunctionInliner(func, args, body, span)
assert inliner.ret
return inliner.ret
Loading

0 comments on commit 78733cb

Please sign in to comment.