diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 22315f9..f349a77 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections.abc import Mapping __copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner" @@ -24,11 +25,30 @@ """ from abc import ABC, abstractmethod -from typing import Any +from typing import ( + TYPE_CHECKING, + AbstractSet, + Callable, + Generic, + Hashable, + Iterable, + TypeVar, + cast, +) +from warnings import warn from immutabledict import immutabledict +from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeIs import pymbolic.primitives as primitives +from pymbolic.typing import ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.rational import Rational + from pymbolic.geometric_algebra import MultiVector __doc__ = """ @@ -96,14 +116,20 @@ """ -try: - import numpy +if TYPE_CHECKING: + import numpy as np + + def is_numpy_array(val) -> TypeIs[np.ndarray]: + return isinstance(val, np.ndarray) +else: + try: + import numpy as np - def is_numpy_array(val): - return isinstance(val, numpy.ndarray) -except ImportError: - def is_numpy_array(ary): - return False + def is_numpy_array(val): + return isinstance(val, np.ndarray) + except ImportError: + def is_numpy_array(ary): + return False class UnsupportedExpressionError(ValueError): @@ -112,7 +138,11 @@ class UnsupportedExpressionError(ValueError): # {{{ mapper base -class Mapper: +ResultT = TypeVar("ResultT") +P = ParamSpec("P") + + +class Mapper(Generic[ResultT, P]): """A visitor for trees of :class:`pymbolic.Expression` subclasses. Each expression-derived object is dispatched to the method named by the :attr:`pymbolic.Expression.mapper_method` @@ -120,7 +150,8 @@ class Mapper: *mapper_method* in the method resolution order of the object. """ - def handle_unsupported_expression(self, expr, *args, **kwargs): + def handle_unsupported_expression(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for :class:`pymbolic.Expression` subclasses for which a mapper method does not exist in this mapper. @@ -130,7 +161,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs): "{} cannot handle expressions of type {}".format( type(self), type(expr))) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Dispatch *expr* to its corresponding mapper method. Pass on ``*args`` and ``**kwargs`` unmodified. @@ -162,7 +194,8 @@ def __call__(self, expr, *args, **kwargs): rec = __call__ - def rec_fallback(self, expr, *args, **kwargs): + def rec_fallback(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: if isinstance(expr, primitives.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) @@ -175,76 +208,135 @@ def rec_fallback(self, expr, *args, **kwargs): else: return self.map_foreign(expr, *args, **kwargs) - def map_algebraic_leaf(self, expr, *args, **kwargs): + def map_algebraic_leaf(self, + expr: primitives.AlgebraicLeaf, + *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: primitives.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: primitives.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_lookup(self, + expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_rational(self, expr, *args, **kwargs): - return self.map_quotient(expr, *args, **kwargs) + def map_if(self, + expr: primitives.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError - def map_quotient(self, expr, *args, **kwargs): + def map_rational(self, + expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_constant(self, expr, *args, **kwargs): + def map_quotient(self, + expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_list(self, expr, *args, **kwargs): + def map_floor_div(self, + expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_tuple(self, expr, *args, **kwargs): + def map_remainder(self, + expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_numpy_array(self, expr, *args, **kwargs): + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_nan(self, expr, *args, **kwargs): + def map_comparison(self, + expr: primitives.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_min(self, + expr: primitives.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_max(self, + expr: primitives.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_tuple(self, + expr: tuple[ExpressionT, ...], + *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_nan(self, + expr: primitives.NaN, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_foreign(self, expr, *args, **kwargs): + def map_foreign(self, + expr: object, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: """Mapper method dispatch for non-:mod:`pymbolic` objects.""" if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): return self.map_constant(expr, *args, **kwargs) elif is_numpy_array(expr): return self.map_numpy_array(expr, *args, **kwargs) - elif isinstance(expr, list): - return self.map_list(expr, *args, **kwargs) elif isinstance(expr, tuple): return self.map_tuple(expr, *args, **kwargs) + elif isinstance(expr, list): + warn("List found in expression graph. " + "This is deprecated and will stop working in 2025. " + "Use tuples instead.", DeprecationWarning, stacklevel=2 + ) + return self.map_list(expr, *args, **kwargs) else: raise ValueError( "{} encountered invalid foreign object: {}".format( self.__class__, repr(expr))) -_NOT_IN_CACHE = object() +class _NotInCache: + pass + + +CacheKeyT: TypeAlias = Hashable -class CachedMapper(Mapper): +class CachedMapper(Mapper[ResultT, P]): """ A mapper that memoizes the mapped result for the expressions traversed. .. automethod:: get_cache_key """ - def __init__(self): - self._cache: dict[Any, Any] = {} + def __init__(self) -> None: + self._cache: dict[CacheKeyT, ResultT] = {} Mapper.__init__(self) - def get_cache_key(self, expr, *args, **kwargs): + def get_cache_key(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> CacheKeyT: """ Returns the key corresponding to which the result of a mapper method is stored in the cache. @@ -260,16 +352,23 @@ def get_cache_key(self, expr, *args, **kwargs): # and "4 == 4.0", but their traversal results cannot be re-used. return (type(expr), expr, args, immutabledict(kwargs)) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: result = self._cache.get( (cache_key := self.get_cache_key(expr, *args, **kwargs)), - _NOT_IN_CACHE) - if result is not _NOT_IN_CACHE: + _NotInCache) + if not isinstance(result, type): return result method_name = getattr(expr, "mapper_method", None) if method_name is not None: - method = getattr(self, method_name, None) + method = cast( + Callable[Concatenate[ExpressionT, P], ResultT], + getattr(self, method_name, None) + ) if method is not None: result = method(expr, *args, **kwargs) self._cache[cache_key] = result @@ -286,7 +385,10 @@ def __call__(self, expr, *args, **kwargs): # {{{ combine mapper -class CombineMapper(Mapper): +CombineArgT = TypeVar("CombineArgT") + + +class CombineMapper(Mapper[ResultT, P]): """A mapper whose goal it is to *combine* all branches of the expression tree into one final result. The default implementation of all mapper methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from @@ -304,16 +406,19 @@ class CombineMapper(Mapper): :class:`pymbolic.mapper.dependency.DependencyMapper` is another example. """ - def combine(self, values): + def combine(self, values: Iterable[ResultT]) -> ResultT: raise NotImplementedError - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters] )) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: primitives.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters], @@ -321,87 +426,140 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): for child in expr.kw_parameters.values()] )) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( [self.rec(expr.aggregate, *args, **kwargs), self.rec(expr.index, *args, **kwargs)]) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: primitives.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) - map_product = map_sum + def map_product(self, + expr: primitives.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, + expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.numerator, *args, **kwargs), self.rec(expr.denominator, *args, **kwargs))) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) - def map_power(self, expr, *args, **kwargs): + def map_remainder(self, + expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) + + def map_power(self, + expr: primitives.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.base, *args, **kwargs), self.rec(expr.exponent, *args, **kwargs))) - def map_polynomial(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: primitives.LeftShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( - self.rec(expr.base, *args, **kwargs), - *[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data] - )) + self.rec(expr.shiftee, *args, **kwargs), + self.rec(expr.shift, *args, **kwargs))) - def map_left_shift(self, expr, *args, **kwargs): + def map_right_shift(self, + expr: primitives.RightShift, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.shiftee, *args, **kwargs), self.rec(expr.shift, *args, **kwargs))) - map_right_shift = map_left_shift + def map_bitwise_not(self, + expr: primitives.BitwiseNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.rec(expr.child, *args, **kwargs) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: primitives.BitwiseOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_and(self, + expr: primitives.BitwiseAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_bitwise_xor(self, + expr: primitives.BitwiseXor, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_logical_not(self, + expr: primitives.LogicalNot, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.child, *args, **kwargs) - map_bitwise_or = map_sum - map_bitwise_xor = map_sum - map_bitwise_and = map_sum - map_logical_not = map_bitwise_not - map_logical_and = map_sum - map_logical_or = map_sum + def map_logical_or(self, + expr: primitives.LogicalOr, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_comparison(self, expr, *args, **kwargs): + def map_logical_and(self, + expr: primitives.LogicalAnd, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_comparison(self, + expr: primitives.Comparison, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.left, *args, **kwargs), self.rec(expr.right, *args, **kwargs))) - map_max = map_sum - map_min = map_sum + def map_max(self, + expr: primitives.Max, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_list(self, expr, *args, **kwargs): + def map_min(self, + expr: primitives.Min, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) + + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr) - map_tuple = map_list + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) for child in expr) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine(self.rec(el, *args, **kwargs) for el in expr.flat) - def map_multivector(self, expr, *args, **kwargs): + def map_multivector(self, + expr: MultiVector, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.combine( self.rec(coeff, *args, **kwargs) for bits, coeff in expr.data.items()) - def map_common_subexpression(self, expr, *args, **kwargs): + def map_common_subexpression(self, + expr: primitives.CommonSubexpression, *args: P.args, **kwargs: P.kwargs + ) -> ResultT: return self.rec(expr.child, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): - return self.combine([ - self.rec(expr.criterion, *args, **kwargs), - self.rec(expr.then, *args, **kwargs), - self.rec(expr.else_, *args, **kwargs)]) - - def map_if(self, expr, *args, **kwargs): + def map_if(self, + expr: primitives.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine([ self.rec(expr.condition, *args, **kwargs), self.rec(expr.then, *args, **kwargs), @@ -416,7 +574,10 @@ class CachedCombineMapper(CachedMapper, CombineMapper): # {{{ collector -class Collector(CombineMapper): +CollectedT = TypeVar("CollectedT") + + +class Collector(CombineMapper[AbstractSet[CollectedT], P]): """A subclass of :class:`CombineMapper` for the common purpose of collecting data derived from an expression in a set that gets 'unioned' across children at each non-leaf node in the expression tree. @@ -426,19 +587,36 @@ class Collector(CombineMapper): .. versionadded:: 2014.3 """ - def combine(self, values): + def combine(self, + values: Iterable[AbstractSet[CollectedT]] + ) -> AbstractSet[CollectedT]: import operator from functools import reduce return reduce(operator.or_, values, set()) - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, expr: object, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: return set() - map_variable = map_constant - map_wildcard = map_constant - map_dot_wildcard = map_constant - map_star_wildcard = map_constant - map_function_symbol = map_constant + def map_variable(self, expr: primitives.Variable, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_wildcard(self, expr: primitives.Wildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_dot_wildcard(self, expr: primitives.DotWildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_star_wildcard(self, expr: primitives.StarWildcard, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() + + def map_function_symbol(self, expr: primitives.FunctionSymbol, + *args: P.args, **kwargs: P.kwargs) -> AbstractSet[CollectedT]: + return set() class CachedCollector(CachedMapper, Collector): @@ -449,34 +627,49 @@ class CachedCollector(CachedMapper, Collector): # {{{ identity mapper -class IdentityMapper(Mapper): +class IdentityMapper(Mapper[ExpressionT, P]): """A :class:`Mapper` whose default mapper methods make a deep copy of each subexpression. See :ref:`custom-manipulation` for an example of the manipulations that can be implemented this way. """ - def map_constant(self, expr, *args, **kwargs): + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild + assert primitives.is_valid_operand(expr) return expr - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: primitives.Variable, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: # leaf -- no need to rebuild return expr - def map_wildcard(self, expr, *args, **kwargs): + def map_wildcard(self, + expr: primitives.Wildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_dot_wildcard(self, expr, *args, **kwargs): + def map_dot_wildcard(self, + expr: primitives.DotWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_star_wildcard(self, expr, *args, **kwargs): + def map_star_wildcard(self, + expr: primitives.StarWildcard, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_function_symbol(self, expr, *args, **kwargs): + def map_function_symbol(self, + expr: primitives.FunctionSymbol, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: primitives.Call, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters @@ -488,12 +681,14 @@ def map_call(self, expr, *args, **kwargs): return type(expr)(function, parameters) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: primitives.CallWithKwargs, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: function = self.rec(expr.function, *args, **kwargs) parameters = tuple([ self.rec(child, *args, **kwargs) for child in expr.parameters ]) - kw_parameters = immutabledict({ + kw_parameters: Mapping[str, ExpressionT] = immutabledict({ key: self.rec(val, *args, **kwargs) for key, val in expr.kw_parameters.items()}) @@ -505,20 +700,26 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): return expr return type(expr)(function, parameters, kw_parameters) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) index = self.rec(expr.index, *args, **kwargs) if aggregate is expr.aggregate and index is expr.index: return expr return type(expr)(aggregate, index) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: aggregate = self.rec(expr.aggregate, *args, **kwargs) if aggregate is expr.aggregate: return expr return type(expr)(aggregate, expr.name) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: primitives.Sum, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children)): @@ -526,41 +727,81 @@ def map_sum(self, expr, *args, **kwargs): return type(expr)(tuple(children)) - map_product = map_sum + def map_product(self, + expr: primitives.Product, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr - def map_quotient(self, expr, *args, **kwargs): + return type(expr)(tuple(children)) + + def map_quotient(self, + expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: numerator = self.rec(expr.numerator, *args, **kwargs) denominator = self.rec(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr return expr.__class__(numerator, denominator) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec(expr.numerator, *args, **kwargs) + denominator = self.rec(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) - def map_power(self, expr, *args, **kwargs): + def map_remainder(self, + expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + numerator = self.rec(expr.numerator, *args, **kwargs) + denominator = self.rec(expr.denominator, *args, **kwargs) + if numerator is expr.numerator and denominator is expr.denominator: + return expr + return expr.__class__(numerator, denominator) + + def map_power(self, + expr: primitives.Power, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: base = self.rec(expr.base, *args, **kwargs) exponent = self.rec(expr.exponent, *args, **kwargs) if base is expr.base and exponent is expr.exponent: return expr return expr.__class__(base, exponent) - def map_left_shift(self, expr, *args, **kwargs): + def map_left_shift(self, + expr: primitives.LeftShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: shiftee = self.rec(expr.shiftee, *args, **kwargs) shift = self.rec(expr.shift, *args, **kwargs) if shiftee is expr.shiftee and shift is expr.shift: return expr return type(expr)(shiftee, shift) - map_right_shift = map_left_shift + def map_right_shift(self, + expr: primitives.RightShift, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + shiftee = self.rec(expr.shiftee, *args, **kwargs) + shift = self.rec(expr.shift, *args, **kwargs) + if shiftee is expr.shiftee and shift is expr.shift: + return expr + return type(expr)(shiftee, shift) - def map_bitwise_not(self, expr, *args, **kwargs): + def map_bitwise_not(self, + expr: primitives.BitwiseNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: child = self.rec(expr.child, *args, **kwargs) if child is expr.child: return expr return type(expr)(child) - def map_bitwise_or(self, expr, *args, **kwargs): + def map_bitwise_or(self, + expr: primitives.BitwiseOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr.children] if all(child is orig_child for child, orig_child in zip(children, expr.children)): @@ -568,14 +809,57 @@ def map_bitwise_or(self, expr, *args, **kwargs): return type(expr)(tuple(children)) - map_bitwise_xor = map_bitwise_or - map_bitwise_and = map_bitwise_or + def map_bitwise_and(self, + expr: primitives.BitwiseAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr - map_logical_not = map_bitwise_not - map_logical_or = map_bitwise_or - map_logical_and = map_bitwise_or + return type(expr)(tuple(children)) - def map_comparison(self, expr, *args, **kwargs): + def map_bitwise_xor(self, + expr: primitives.BitwiseXor, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_logical_not(self, + expr: primitives.LogicalNot, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + child = self.rec(expr.child, *args, **kwargs) + if child is expr.child: + return expr + return type(expr)(child) + + def map_logical_or(self, + expr: primitives.LogicalOr, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_logical_and(self, + expr: primitives.LogicalAnd, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + children = [self.rec(child, *args, **kwargs) for child in expr.children] + if all(child is orig_child + for child, orig_child in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) + + def map_comparison(self, + expr: primitives.Comparison, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: left = self.rec(expr.left, *args, **kwargs) right = self.rec(expr.right, *args, **kwargs) if left is expr.left and right is expr.right: @@ -583,10 +867,16 @@ def map_comparison(self, expr, *args, **kwargs): return type(expr)(left, expr.operator, right) - def map_list(self, expr, *args, **kwargs): - return [self.rec(child, *args, **kwargs) for child in expr] + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: - def map_tuple(self, expr, *args, **kwargs): + # True fact: lists aren't expressions + return [self.rec(child, *args, **kwargs) for child in expr] # type: ignore[return-value] + + def map_tuple(self, + expr: tuple[ExpressionT, ...], *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: children = [self.rec(child, *args, **kwargs) for child in expr] if all(child is orig_child for child, orig_child in zip(children, expr)): @@ -594,14 +884,21 @@ def map_tuple(self, expr, *args, **kwargs): return tuple(children) - def map_numpy_array(self, expr, *args, **kwargs): + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: + import numpy result = numpy.empty(expr.shape, dtype=object) for i in numpy.ndindex(expr.shape): result[i] = self.rec(expr[i], *args, **kwargs) - return result - def map_multivector(self, expr, *args, **kwargs): + # True fact: ndarrays aren't expressions + return result # type: ignore[return-value] + + def map_multivector(self, + expr: MultiVector, *args: P.args, **kwargs: P.kwargs + ) -> ExpressionT: return expr.map(lambda ch: self.rec(ch, *args, **kwargs)) def map_common_subexpression(self, expr, *args, **kwargs): diff --git a/pymbolic/mapper/constant_folder.py b/pymbolic/mapper/constant_folder.py index 62b3edd..8932743 100644 --- a/pymbolic/mapper/constant_folder.py +++ b/pymbolic/mapper/constant_folder.py @@ -100,7 +100,7 @@ class ConstantFoldingMapper( # Yes, map_product incompatible: missing *args, **kwargs -class CommutativeConstantFoldingMapper( # type: ignore[misc] +class CommutativeConstantFoldingMapper( CSECachingMapperMixin, CommutativeConstantFoldingMapperBase, IdentityMapper):