Skip to content

Commit

Permalink
thinking about refactor parsiing & type inferencing for better metapr…
Browse files Browse the repository at this point in the history
…ogramming support
  • Loading branch information
shiinamiyuki committed Oct 24, 2024
1 parent 85cdbb1 commit 059f91e
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 13 deletions.
2 changes: 2 additions & 0 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
case hir.VectorType(element=element, count=count):
return f"V{count}{self.mangle(element)}"
case hir.Function(name=name, params=params, return_type=ret):
assert ret
name = mangle_name(name)
return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}")
case hir.BuiltinFunction(name=name):
Expand Down Expand Up @@ -186,6 +187,7 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
self.name = base.mangling.mangle(func)
self.func = func
params = ",".join(self.gen_var(p) for p in func.params)
assert func.return_type
self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
self.body = ScratchBuffer()
self.params = set(p.name for p in func.params)
Expand Down
7 changes: 5 additions & 2 deletions luisa_lang/hir/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __eq__(self, value: object) -> bool:

def __hash__(self) -> int:
return hash(UnitType)

def __str__(self) -> str:
return "NoneType"


class ScalarType(Type):
Expand Down Expand Up @@ -1066,7 +1069,7 @@ class Function:
name: str
generic_params: Dict[str, GenericParameter]
params: List[Var]
return_type: Type
return_type: Type | None
body: List[Stmt]
builtin: bool
export: bool
Expand All @@ -1079,7 +1082,7 @@ def __init__(
name: str,
generic_params: Dict[str, GenericParameter],
params: List[Var],
return_type: Type,
return_type: Type | None,
body: List[Stmt],
locals: List[Var],
builtin: bool = False,
Expand Down
8 changes: 6 additions & 2 deletions luisa_lang/hir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def wrapper(inferencer: 'FuncTypeInferencer', node: hir.TypedNode, *args) -> Opt


def is_function_fully_typed(func: hir.Function) -> bool:
if not func.return_type:
return False
for stmt in func.body:
if not is_stmt_fully_typed(stmt):
return False
Expand Down Expand Up @@ -112,7 +114,9 @@ def infer_stmt(self, stmt: hir.Stmt) -> None:
case hir.Return(value=value):
if value:
ty = self.infer_expr(value)
if self.func.return_type != ty:
if not self.func.return_type:
self.func.return_type = ty
elif self.func.return_type != ty:
report_error(
stmt.span,
f"Return type mismatch: expected {self.func.return_type}, got {ty}",
Expand Down Expand Up @@ -247,7 +251,7 @@ def _infer_call_helper(
# traceback.print_exc()
raise hir.TypeInferenceError(
node,
f"Error during instantiating function template {f.name}: {e}")
f"Error during instantiating function template {f.name}: {e}") from e
else:
resolved_f = f.resolve(None)
node.op = hir.Constant(resolved_f)
Expand Down
3 changes: 2 additions & 1 deletion luisa_lang/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
if is_generic:
mapping = hir.match_func_template_args(func_sig, args)
if len(mapping) != len(func_sig.generic_params):
print(mapping, func_sig.generic_params)
raise hir.TypeInferenceError(
None, "not all type parameters are resolved")
for p in func_sig.generic_params.values():
if p not in mapping:
raise hir.TypeInferenceError(
None, f"type parameter {p} is not resolved")
parsing_ctx.bound_type_vars[p.name] = mapping[p]
# print(f'binding {p.name} = {mapping[p]}')
print(f'binding {p.name} = {mapping[p]}')
func_parser = parse.FuncParser(func_name, f, parsing_ctx, self_type)
func_ir = func_parser.parse_body()
hir.run_inference_on_function(func_ir)
Expand Down
48 changes: 41 additions & 7 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ def get_access_key() -> Optional[Tuple[AccessKind, List[AccessKey]]]:
else:
raise RuntimeError(
"Associated type not supported by Python")
elif cur is typing.Any:
# generic
if access is None:
return cur, None
else:
raise RuntimeError(
"Associated type not supported by Python")
if chain_idx is None:
if len(self.chain) == 0:
break
Expand All @@ -175,6 +182,7 @@ class ParsingContext:
bound_type_vars: Dict[str, Union[hir.Type, hir.Value]]
type_vars: Dict[typing.TypeVar,
Tuple[hir.GenericParameter, Union[hir.Type, hir.Value]]]
any_cnt: int

def __init__(self, ctx_name: str, globals: Dict[str, Any]):
self.globals = globals
Expand All @@ -183,7 +191,8 @@ def __init__(self, ctx_name: str, globals: Dict[str, Any]):
self.ctx_name = ctx_name
self.type_vars = {}
self.bound_type_vars = {}

self.any_cnt = 0

def __eval_name(self, name: str) -> Optional[Any]:
try:
if name in self.name_eval_cache:
Expand Down Expand Up @@ -264,17 +273,38 @@ def check_is_access(tree: ast.AST) -> bool:
return None
# report_error(tree, f"unsupported access chain {tree}")

def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Optional[Type]:
def parse_type(self, type_tree: ast.AST, is_sig_params: bool = False) -> Optional[Type]:
"""
`is_sig_params` should be set if is parsing function arguments
"""
acess_chain: AccessChain | None = self._parse_access_chain(
type_tree, True)
if acess_chain is None:
return None
# print(acess_chain)
resolved, remaining = acess_chain.resolve()
if remaining is not None:
report_error(type_tree, f"failed to resolve type. {resolved},{remaining}")
if remaining is not None and remaining != []:
report_error(
type_tree, f"failed to resolve type. {resolved},{remaining}")
if isinstance(resolved, hir.Type):
return resolved
if resolved == typing.Any:
if is_sig_params:
# create a new generic parameter
param = hir.GenericParameter(
f"Any#{self.any_cnt}", self.ctx_name, None)
type_var = typing.TypeVar( # type: ignore
f"Any#{self.any_cnt}", bound=Any) # type: ignore
self.any_cnt += 1
if param.name in self.bound_type_vars:
any_ty = self.bound_type_vars[param.name]
assert isinstance(any_ty, hir.Type)
return any_ty
generic_ty = hir.SymbolicType(param)
self.type_vars[type_var] = (param, generic_ty)
return generic_ty
else:
return None
if isinstance(resolved, typing.TypeVar):
# if resolved.__name__ in self.bound_type_vars:
# ty_or_val = self.bound_type_vars[resolved.__name__]
Expand All @@ -289,7 +319,7 @@ def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Opt
else:
report_error(
type_tree, f"expected generic parameter {resolved} to be a type but got a value: {ty_or_val}")
elif allow_new_typevar:
elif is_sig_params:
ty_bound: hir.TypeBound | None = None
# create new type var
constraints, bound = get_typevar_constrains_and_bounds(
Expand All @@ -316,6 +346,7 @@ def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Opt
type_tree, f"undefined type parameter {resolved}. type parameter must be included in the function signature or class definition")
return None


class FuncParser:
p_ctx: ParsingContext
vars: Dict[str, hir.Var]
Expand All @@ -330,6 +361,7 @@ class FuncParser:

def __init__(self, name: str, func: object, p_ctx: ParsingContext, self_type: Optional[Type] = None) -> None:
obj_ast, _obj_file = retrieve_ast_and_filename(func)
print(ast.dump(obj_ast))
assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module"
if not isinstance(obj_ast.body[0], ast.FunctionDef):
raise RuntimeError("Function definition expected.")
Expand All @@ -346,7 +378,7 @@ def __init__(self, name: str, func: object, p_ctx: ParsingContext, self_type: Op
self.signature_initialized = True
# print(self.arg_types, "->", self.return_type)

assert self.return_type is not None
# assert self.return_type is not None
generic_params: Dict[str, hir.GenericParameter] = {}
for tv in self.p_ctx.type_vars:
param, _ = self.p_ctx.type_vars[tv]
Expand Down Expand Up @@ -401,9 +433,11 @@ def _init_signature(
self.return_type = self.self_type
else:
if func.returns is None:
self.return_type = None
elif isinstance(func.returns,ast.Constant) and func.returns.value is None:
self.return_type = hir.UnitType()
else:
self.return_type = p_ctx.parse_type(func.returns, True)
self.return_type = p_ctx.parse_type(func.returns, False)

def parse_const(self, const: ast.Constant) -> hir.Value:
span = hir.Span.from_ast(const)
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from setuptools import setup, find_packages

setup(
name="luisa_lang",
name="luisa-python-lang",
description="A New DSL Frontend for LuisaCompute",
version="0.1",
packages=find_packages(),
package_data={"luisa_lang": ["py.typed"]},
Expand Down

0 comments on commit 059f91e

Please sign in to comment.