diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 942cf4f48..ca93a8187 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -38,6 +38,7 @@ TypeAlias, TypeVar, cast, + overload, ) import numpy as np @@ -262,10 +263,36 @@ def rec_function_definition( assert method is not None return method(expr, *args, **kwargs) - def __call__(self, - expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: + @overload + def __call__( + self, + expr: ArrayOrNames, + *args: P.args, + **kwargs: P.kwargs) -> ResultT: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + *args: P.args, + **kwargs: P.kwargs) -> FunctionResultT: + ... + + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + *args: P.args, + **kwargs: P.kwargs) -> ResultT | FunctionResultT: """Handle the mapping of *expr*.""" - return self.rec(expr, *args, **kwargs) + if isinstance(expr, ArrayOrNames): + return self.rec(expr, *args, **kwargs) + elif isinstance(expr, FunctionDefinition): + return self.rec_function_definition(expr, *args, **kwargs) + else: + raise ForeignObjectError( + f"{type(self).__name__} encountered invalid foreign " + f"object: {expr!r}") from None # }}} @@ -1847,7 +1874,7 @@ def __init__(self) -> None: self.node_to_users: dict[ArrayOrNames, set[DistributedSend | ArrayOrNames]] = {} - def __call__(self, expr: ArrayOrNames) -> None: + def __call__(self, expr: ArrayOrNames) -> None: # type: ignore[override] # Root node has no predecessor self.node_to_users[expr] = set() self.rec(expr) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 205d578e7..0b91c716f 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -688,7 +688,8 @@ def handle_unsupported_array(self, expr: Array) -> Array: def rec(self, expr: Array) -> Array: # type: ignore[override] return expr - __call__ = Mapper.rec + def __call__(self, expr: Array) -> Array: # type: ignore[override] + return Mapper.rec(self, expr) def to_index_lambda(expr: Array) -> IndexLambda: