Skip to content

Commit

Permalink
added if and while
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 4, 2024
1 parent 0d624e5 commit da68f83
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 129 deletions.
50 changes: 43 additions & 7 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
# matched: bool


FunctionTemplateResolvingArgs = List[Tuple[str, Union['Type', 'ComptimeValue']]]
FunctionTemplateResolvingArgs = List[Tuple[str,
Union['Type', 'ComptimeValue']]]
"""
[Function parameter name, Type or Value].
The reason for using parameter name instead of GenericParameter is that python supports passing type[T] as a parameter,
Expand Down Expand Up @@ -727,13 +728,24 @@ class Ref(TypedNode):
pass


class LocalRef(Ref):
value: 'Value'

def __init__(self, value: 'Value') -> None:
super().__init__(value.type)
self.value = value
self.span = value.span


class Value(TypedNode):
pass


class Unit(Value):
def __init__(self) -> None:
super().__init__(UnitType())


class SymbolicConstant(Value):
generic: GenericParameter

Expand Down Expand Up @@ -884,20 +896,32 @@ def __str__(self) -> str:
return f"Template matching error at {self.span}:\n\t{self.message}"


class TypeInferenceError(Exception):
class SpannedError(Exception):
span: Span | None
message: str

def __init__(self, node: Node | Span | None, message: str) -> None:
def __init__(self, node: Node | Span | ast.AST | None, message: str) -> None:
if node is not None:
if isinstance(node, Node):
self.span = node.span
else:
self.span = node
match node:
case Node():
self.span = node.span
case Span():
self.span = node
case ast.AST():
self.span = Span.from_ast(node)
else:
self.span = None
self.message = message


class ParsingError(SpannedError):
def __str__(self) -> str:
if self.span is None:
return f"Parsing error:\n\t{self.message}"
return f"Parsing error at {self.span}:\n\t{self.message}"


class TypeInferenceError(SpannedError):
def __str__(self) -> str:
if self.span is None:
return f"Type inference error:\n\t{self.message}"
Expand Down Expand Up @@ -998,6 +1022,18 @@ def __init__(self, value: Optional[Value], span: Optional[Span] = None) -> None:
self.value = value


class Range(Value):
start: Value
step: Optional[Value]
stop: Optional[Value]

def __init__(self, start: Value, stop: Optional[Value] = None, step: Optional[Value] = None, span: Optional[Span] = None) -> None:
super().__init__(None, span)
self.start = start
self.stop = stop
self.step = step


class ComptimeValue:
value: Any
update_func: Optional[Callable[[Any], None]]
Expand Down
37 changes: 25 additions & 12 deletions luisa_lang/lang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
List,
Optional,
Sequence,
Set,
Tuple,
TypeAlias,
TypeVar,
Expand Down Expand Up @@ -50,34 +51,48 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dic

func_sig = classinfo.parse_func_signature(f, func_globals, [])
func_sig_converted, sig_parser = parse.convert_func_signature(
func_sig, func_name, func_globals, {}, [], self_type)
func_sig, func_name, func_globals, {}, {}, self_type)
implicit_type_params = sig_parser.implicit_type_params
implicit_generic_params: Set[hir.GenericParameter] = set()
for p in implicit_type_params.values():
assert isinstance(p, hir.SymbolicType)
implicit_generic_params.add(p.param)

def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue] = {}
any_param_types: List[hir.Type] = []
mapped_implicit_type_params: Dict[str,
hir.Type] = dict()
if is_generic:
mapping = hir.match_func_template_args(func_sig_converted, args)
if isinstance(mapping, hir.TypeInferenceError):
raise mapping
if len(mapping) != len(func_sig_converted.generic_params):
# print(mapping, func_sig_converted.generic_params)
raise hir.TypeInferenceError(
None, "not all type parameters are resolved")
for p in func_sig_converted.generic_params:
if p not in mapping:
raise hir.TypeInferenceError(
None, f"type parameter {p} is not resolved")
type_var_ns[sig_parser.generic_param_to_type_var[p]
] = mapping[p]
# print(f'binding {p.name} = {mapping[p]}, tv: {sig_parser.generic_param_to_type_var[p]} @{id(sig_parser.generic_param_to_type_var[p])}')
# print('parsing instantiated signature')
func_sig_instantiated, _ = parse.convert_func_signature(
func_sig, func_name, func_globals, type_var_ns, any_param_types, self_type)
if p not in implicit_generic_params:
type_var_ns[sig_parser.generic_param_to_type_var[p]
] = mapping[p]

for name, itp, in implicit_type_params.items():
assert isinstance(itp, hir.SymbolicType)
gp = itp.param
mapped_type = mapping[gp]
assert isinstance(mapped_type, hir.Type)
mapped_implicit_type_params[name] = mapped_type
func_sig_instantiated, _p = parse.convert_func_signature(
func_sig, func_name, func_globals, type_var_ns, mapped_implicit_type_params, self_type, mode='instantiate')
assert len(
func_sig_instantiated.generic_params) == 0, f"generic params should be resolved but found {func_sig_instantiated.generic_params}"
func_parser = FuncParser(
func_name, f, func_sig_instantiated, func_globals, type_var_ns, self_type)
return func_parser.parse_body()
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig.type_vars) > 0
is_generic = len(func_sig_converted.generic_params) > 0
# print(f"func {func_name} is_generic: {is_generic}")
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)


Expand Down Expand Up @@ -223,5 +238,3 @@ def decorator(f):
return impl(f)

return decorator


Loading

0 comments on commit da68f83

Please sign in to comment.