From 5861c8ed325b9198af93bf7e6be310a366ecfcbf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 21 Feb 2025 13:41:41 -0600 Subject: [PATCH 01/35] refactor deduplicate_data_wrappers to avoid dependence on erroneous super().rec usage in CachedMapAndCopyMapper Here is a sketch of what happens with super().rec vs Mapper.rec for the previous implementation of deduplicate_data_wrappers. Suppose we have 2 data wrappers a and b with the same data pointer. With super().rec: 1) map_fn maps a to itself, then mapper copies a to a'; mapper caches a -> a' (twice, once in super().rec and then again in rec), 2) map_fn maps b to a, then mapper maps (via cache in super().rec call) a to a'; mapper caches b -> a'. => Only a' in output DAG. With Mapper.rec: 1) map_fn maps a to itself, then mapper copies a to a'; caches a -> a', 2) map_fn maps b to a, then mapper copies a to a''; caches b -> a''. => Both a' and a'' in output DAG. --- pytato/transform/__init__.py | 122 ++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5b1ba02c4..32f5e31b5 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -2050,44 +2050,67 @@ def rec_get_user_nodes(expr: ArrayOrNames, # {{{ deduplicate_data_wrappers -def _get_data_dedup_cache_key(ary: DataInterface) -> CacheKeyT: - import sys - if "pyopencl" in sys.modules: - from pyopencl import MemoryObjectHolder - from pyopencl.array import Array as CLArray - try: - from pyopencl import SVMPointer - except ImportError: - SVMPointer = None # noqa: N806 - - if isinstance(ary, CLArray): - base_data = ary.base_data - if isinstance(ary.base_data, MemoryObjectHolder): - ptr = base_data.int_ptr - elif SVMPointer is not None and isinstance(base_data, SVMPointer): - ptr = base_data.svm_ptr - elif base_data is None: - # pyopencl represents 0-long arrays' base_data as None - ptr = None - else: - raise ValueError("base_data of array not understood") - +class DataWrapperDeduplicator(CopyMapper): + """ + Mapper to replace all :class:`pytato.array.DataWrapper` instances containing + identical data with a single instance. + """ + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} + self.data_wrappers_encountered = 0 + + def _get_data_dedup_cache_key(self, ary: DataInterface) -> CacheKeyT: + import sys + if "pyopencl" in sys.modules: + from pyopencl import MemoryObjectHolder + from pyopencl.array import Array as CLArray + try: + from pyopencl import SVMPointer + except ImportError: + SVMPointer = None # noqa: N806 + + if isinstance(ary, CLArray): + base_data = ary.base_data + if isinstance(ary.base_data, MemoryObjectHolder): + ptr = base_data.int_ptr + elif SVMPointer is not None and isinstance(base_data, SVMPointer): + ptr = base_data.svm_ptr + elif base_data is None: + # pyopencl represents 0-long arrays' base_data as None + ptr = None + else: + raise ValueError("base_data of array not understood") + + return ( + ptr, + ary.offset, + ary.shape, + ary.strides, + ary.dtype, + ) + if isinstance(ary, np.ndarray): return ( - ptr, - ary.offset, + ary.__array_interface__["data"], ary.shape, ary.strides, ary.dtype, ) - if isinstance(ary, np.ndarray): - return ( - ary.__array_interface__["data"], - ary.shape, - ary.strides, - ary.dtype, - ) - else: - raise NotImplementedError(str(type(ary))) + else: + raise NotImplementedError(str(type(ary))) + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + self.data_wrappers_encountered += 1 + cache_key = self._get_data_dedup_cache_key(expr.data) + try: + return self.data_wrapper_cache[cache_key] + except KeyError: + self.data_wrapper_cache[cache_key] = expr + return expr def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: @@ -2108,34 +2131,17 @@ def deduplicate_data_wrappers(array_or_names: ArrayOrNames) -> ArrayOrNames: this, but it must *also* tolerate this function doing a more thorough job of deduplication. """ + dedup = DataWrapperDeduplicator() + array_or_names = dedup(array_or_names) - data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} - data_wrappers_encountered = 0 - - def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: - nonlocal data_wrappers_encountered - - if isinstance(ary, DataWrapper): - data_wrappers_encountered += 1 - cache_key = _get_data_dedup_cache_key(ary.data) - - try: - return data_wrapper_cache[cache_key] - except KeyError: - result = ary - data_wrapper_cache[cache_key] = result - return result - else: - return ary - - array_or_names = map_and_copy(array_or_names, cached_data_wrapper_if_present) - - if data_wrappers_encountered: + if dedup.data_wrappers_encountered: transform_logger.debug("data wrapper de-duplication: " "%d encountered, %d kept, %d eliminated", - data_wrappers_encountered, - len(data_wrapper_cache), - data_wrappers_encountered - len(data_wrapper_cache)) + dedup.data_wrappers_encountered, + len(dedup.data_wrapper_cache), + ( + dedup.data_wrappers_encountered + - len(dedup.data_wrapper_cache))) return array_or_names From 4ff4017626e7205b1d5ed37bf4d1915fb7ccb251 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 20 Feb 2025 14:22:11 -0600 Subject: [PATCH 02/35] call Mapper.rec instead of super().rec to avoid double caching --- pytato/analysis/__init__.py | 2 +- pytato/distributed/partition.py | 2 ++ pytato/transform/__init__.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c100e7d31..fc71199d7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -626,7 +626,7 @@ def rec(self, expr: ArrayOrNames) -> int: try: return self._cache.retrieve(expr, key=key) except KeyError: - s = super().rec(expr) + s = Mapper.rec(self, expr) if ( isinstance(expr, Array) and ( diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 27b1e2cee..741e36548 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -309,6 +309,8 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if name is not None: return self._get_placeholder_for(name, expr) + # Calling super().rec instead of Mapper.rec is OK here, because we're not + # implementing cache insertion and thus are not double caching return cast("ArrayOrNames", super().rec(expr)) def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 32f5e31b5..2000e4513 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1532,7 +1532,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: return self._cache.retrieve(expr, key=key) except KeyError: return self._cache.add( - expr, super().rec(self.map_fn(expr)), key=key) + expr, Mapper.rec(self, self.map_fn(expr)), key=key) # }}} From 99b1d47dcdfaeeae92dbe1d21e88770569fdcd91 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 25 Feb 2025 07:56:37 -0600 Subject: [PATCH 03/35] call Mapper.rec from CachedMapper too just to avoid copy/paste errors --- pytato/transform/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 2000e4513..b6c3348bd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -418,7 +418,7 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: except KeyError: return self._cache.add( (expr, args, kwargs), - super().rec(expr, *args, **kwargs), + Mapper.rec(self, expr, *args, **kwargs), key=key) def rec_function_definition( @@ -430,7 +430,7 @@ def rec_function_definition( except KeyError: return self._function_cache.add( (expr, args, kwargs), - super().rec_function_definition(expr, *args, **kwargs), + Mapper.rec_function_definition(self, expr, *args, **kwargs), key=key) def clone_for_callee( From 9bc0a14180c438118fe882e378da0371a4955df7 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 17 Feb 2025 13:33:39 -0600 Subject: [PATCH 04/35] add assertion to check for double caching --- pytato/transform/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b6c3348bd..312b8e7cd 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -349,6 +349,9 @@ def add( else: key = self.get_key(key_inputs) + assert key not in self._expr_key_to_result, \ + f"Cache entry is already present for key '{key}'." + self._expr_key_to_result[key] = result return result From b7d9fdd7c5c5b8917d11bc662b6851f96f99c270 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:10 -0600 Subject: [PATCH 05/35] disable default implementation of get_cache_key and get_function_definition_cache_key for extra args case ambiguous due to the fact that any arg can be specified with/without keyword --- pytato/transform/__init__.py | 15 +++++++++++---- pytato/transform/einsum_distributive_law.py | 9 +++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 312b8e7cd..c44f6b8fc 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -406,13 +406,20 @@ def __init__( def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_cache_key if using extra inputs.") + return expr def get_function_definition_cache_key( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + ) -> CacheKeyT: + if args or kwargs: + raise NotImplementedError( + "Derived classes must override get_function_definition_cache_key if " + "using extra inputs.") + return expr def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: key = self._cache.get_key(expr, *args, **kwargs) diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 8cd635f61..694901b03 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -57,6 +57,8 @@ Stack, ) from pytato.transform import ( + ArrayOrNames, + CacheKeyT, MappedT, TransformMapperWithExtraArgs, _verify_is_array, @@ -160,6 +162,13 @@ def __init__(self, super().__init__() self.how_to_distribute = how_to_distribute + def get_cache_key( + self, + expr: ArrayOrNames, + ctx: _EinsumDistributiveLawMapperContext | None + ) -> CacheKeyT: + return (expr, ctx) + def _map_input_base(self, expr: InputArgumentBase, ctx: _EinsumDistributiveLawMapperContext | None, From 190cb68725582f6a3bdec71d914f3368d1550cea Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 16:16:55 -0600 Subject: [PATCH 06/35] add CacheInputs to simplify cache key handling logic --- pytato/analysis/__init__.py | 6 +- pytato/codegen.py | 4 +- pytato/distributed/partition.py | 10 +-- pytato/transform/__init__.py | 152 +++++++++++++++++--------------- pytato/transform/metadata.py | 10 +-- 5 files changed, 95 insertions(+), 87 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fc71199d7..c102f710e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -622,9 +622,9 @@ def combine(self, *args: int) -> int: return sum(args) def rec(self, expr: ArrayOrNames) -> int: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: s = Mapper.rec(self, expr) if ( @@ -636,7 +636,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(expr, 0, key=key) + self._cache.add(inputs, 0) return result diff --git a/pytato/codegen.py b/pytato/codegen.py index 86a328929..cb957f076 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -138,8 +138,8 @@ def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 741e36548..a022f8f8e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -240,9 +240,9 @@ def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None, + TransformMapperCache[FunctionDefinition, []] | None = None, ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -261,7 +261,7 @@ def clone_for_callee( return type(self)( {}, {}, {}, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def map_placeholder(self, expr: Placeholder) -> Placeholder: self.user_input_names.add(expr.name) @@ -294,9 +294,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c44f6b8fc..e2f5eccb4 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,6 +46,7 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -93,6 +94,7 @@ __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CacheInputs .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -304,12 +306,45 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CachedMapperCache(Generic[CacheExprT, CacheResultT]): +class CacheInputs(Generic[CacheExprT, P]): + """ + Data structure for inputs to :class:`CachedMapperCache`. + + .. attribute:: expr + + The input expression being mapped. + + .. attribute:: key + + The cache key corresponding to *expr* and any additional inputs that were + passed. + + """ + def __init__( + self, + expr: CacheExprT, + key_func: Callable[..., CacheKeyT], + *args: P.args, + **kwargs: P.kwargs): + self.expr: CacheExprT = expr + self._args: tuple[Any, ...] = args + self._kwargs: dict[str, Any] = kwargs + self._key_func = key_func + + @memoize_method + def _get_key(self) -> CacheKeyT: + return self._key_func(self.expr, *self._args, **self._kwargs) + + @property + def key(self) -> CacheKeyT: + return self._get_key() + + +class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ Cache for mappers. .. automethod:: __init__ - .. method:: get_key Compute the key for an input expression. @@ -317,37 +352,16 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT]): .. automethod:: retrieve .. automethod:: clear """ - def __init__( - self, - key_func: Callable[..., CacheKeyT]) -> None: - """ - Initialize the cache. - - :arg key_func: Function to compute a hashable cache key from an input - expression and any extra arguments. - """ - self.get_key = key_func - + def __init__(self) -> None: + """Initialize the cache.""" self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, - key_inputs: - CacheExprT - # Currently, Python's type system doesn't have a way to annotate - # containers of args/kwargs (ParamSpec won't work here). So we have - # to fall back to using Any. More details here: - # https://github.com/python/typing/issues/1252 - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - result: CacheResultT, - key: CacheKeyT | None = None) -> CacheResultT: + inputs: CacheInputs[CacheExprT, P], + result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) + key = inputs.key assert key not in self._expr_key_to_result, \ f"Cache entry is already present for key '{key}'." @@ -356,20 +370,9 @@ def add( return result - def retrieve( - self, - key_inputs: - CacheExprT - | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], - key: CacheKeyT | None = None) -> CacheResultT: + def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" - if key is None: - if isinstance(key_inputs, tuple): - expr, key_args, key_kwargs = key_inputs - key = self.get_key(expr, *key_args, **key_kwargs) - else: - key = self.get_key(key_inputs) - + key = inputs.key return self._expr_key_to_result[key] def clear(self) -> None: @@ -389,20 +392,20 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): def __init__( self, _cache: - CachedMapperCache[ArrayOrNames, ResultT] | None = None, + CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: - CachedMapperCache[FunctionDefinition, FunctionResultT] | None = None + CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: CachedMapperCache[ArrayOrNames, ResultT] = ( + self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache(self.get_cache_key)) + else CachedMapperCache()) self._function_cache: CachedMapperCache[ - FunctionDefinition, FunctionResultT] = ( + FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache(self.get_function_definition_cache_key)) + else CachedMapperCache()) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -421,27 +424,33 @@ def get_function_definition_cache_key( "using extra inputs.") return expr + def _make_cache_inputs( + self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[ArrayOrNames, P]: + return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + + def _make_function_definition_cache_inputs( + self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs + ) -> CacheInputs[FunctionDefinition, P]: + return CacheInputs( + expr, self.get_function_definition_cache_key, *args, **kwargs) + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self._cache.get_key(expr, *args, **kwargs) + inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve((expr, args, kwargs), key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - (expr, args, kwargs), - Mapper.rec(self, expr, *args, **kwargs), - key=key) + return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self._function_cache.get_key(expr, *args, **kwargs) + inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve((expr, args, kwargs), key=key) + return self._function_cache.retrieve(inputs) except KeyError: return self._function_cache.add( - (expr, args, kwargs), - Mapper.rec_function_definition(self, expr, *args, **kwargs), - key=key) + inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -457,7 +466,7 @@ def clone_for_callee( # {{{ TransformMapper -class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT]): +class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): pass @@ -471,8 +480,8 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -493,9 +502,9 @@ class TransformMapperWithExtraArgs( """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None + TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) @@ -1523,8 +1532,8 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1534,15 +1543,14 @@ def clone_for_callee( return type(self)( self.map_fn, _function_cache=cast( - "TransformMapperCache[FunctionDefinition]", self._function_cache)) + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: - return self._cache.add( - expr, Mapper.rec(self, self.map_fn(expr)), key=key) + return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} @@ -2067,8 +2075,8 @@ class DataWrapperDeduplicator(CopyMapper): """ def __init__( self, - _cache: TransformMapperCache[ArrayOrNames] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition] | None = None + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: super().__init__(_cache=_cache, _function_cache=_function_cache) self.data_wrapper_cache: dict[CacheKeyT, DataWrapper] = {} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index e654e8b51..485e2a76b 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -416,9 +416,9 @@ class AxisTagAttacher(CopyMapper): def __init__(self, axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, - _cache: TransformMapperCache[ArrayOrNames] | None = None, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: - TransformMapperCache[FunctionDefinition] | None = None): + TransformMapperCache[FunctionDefinition, []] | None = None): super().__init__(_cache=_cache, _function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]] = axis_to_tags @@ -465,9 +465,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self._cache.get_key(expr) + inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(expr, key=key) + return self._cache.retrieve(inputs) except KeyError: result = Mapper.rec(self, expr) if not isinstance( @@ -475,7 +475,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(expr, result, key=key) + return self._cache.add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From ea03645dd31433428709493eabd9aa0dfa7d1e61 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 7 Feb 2025 09:25:55 -0600 Subject: [PATCH 07/35] rename expr_key* to input_key* --- pytato/transform/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e2f5eccb4..4ad5e9751 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -354,7 +354,7 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): """ def __init__(self) -> None: """Initialize the cache.""" - self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} def add( self, @@ -363,21 +363,20 @@ def add( """Cache a mapping result.""" key = inputs.key - assert key not in self._expr_key_to_result, \ + assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - self._expr_key_to_result[key] = result - + self._input_key_to_result[key] = result return result def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._expr_key_to_result[key] + return self._input_key_to_result[key] def clear(self) -> None: """Reset the cache.""" - self._expr_key_to_result = {} + self._input_key_to_result = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): From 13d5ab00e1e842ad32f5114005d658b3dd920e8d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:22:24 -0600 Subject: [PATCH 08/35] refactor to avoid performance drop --- pytato/transform/__init__.py | 45 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 4ad5e9751..35fb499eb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -46,7 +46,6 @@ from typing_extensions import Self from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method from pytato.array import ( AbstractResultWithNamedArrays, @@ -94,7 +93,7 @@ __doc__ = """ .. autoclass:: Mapper -.. autoclass:: CacheInputs +.. autoclass:: CacheInputsWithKey .. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapperCache @@ -306,7 +305,7 @@ def __call__( CacheKeyT: TypeAlias = Hashable -class CacheInputs(Generic[CacheExprT, P]): +class CacheInputsWithKey(Generic[CacheExprT, P]): """ Data structure for inputs to :class:`CachedMapperCache`. @@ -314,6 +313,14 @@ class CacheInputs(Generic[CacheExprT, P]): The input expression being mapped. + .. attribute:: args + + A :class:`tuple` of extra positional arguments. + + .. attribute:: kwargs + + A :class:`dict` of extra keyword arguments. + .. attribute:: key The cache key corresponding to *expr* and any additional inputs that were @@ -323,21 +330,13 @@ class CacheInputs(Generic[CacheExprT, P]): def __init__( self, expr: CacheExprT, - key_func: Callable[..., CacheKeyT], + key: CacheKeyT, *args: P.args, **kwargs: P.kwargs): self.expr: CacheExprT = expr - self._args: tuple[Any, ...] = args - self._kwargs: dict[str, Any] = kwargs - self._key_func = key_func - - @memoize_method - def _get_key(self) -> CacheKeyT: - return self._key_func(self.expr, *self._args, **self._kwargs) - - @property - def key(self) -> CacheKeyT: - return self._get_key() + self.args: tuple[Any, ...] = args + self.kwargs: dict[str, Any] = kwargs + self.key: CacheKeyT = key class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): @@ -358,7 +357,7 @@ def __init__(self) -> None: def add( self, - inputs: CacheInputs[CacheExprT, P], + inputs: CacheInputsWithKey[CacheExprT, P], result: CacheResultT) -> CacheResultT: """Cache a mapping result.""" key = inputs.key @@ -369,7 +368,7 @@ def add( self._input_key_to_result[key] = result return result - def retrieve(self, inputs: CacheInputs[CacheExprT, P]) -> CacheResultT: + def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key return self._input_key_to_result[key] @@ -425,14 +424,16 @@ def get_function_definition_cache_key( def _make_cache_inputs( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[ArrayOrNames, P]: - return CacheInputs(expr, self.get_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[ArrayOrNames, P]: + return CacheInputsWithKey( + expr, self.get_cache_key(expr, *args, **kwargs), *args, **kwargs) def _make_function_definition_cache_inputs( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> CacheInputs[FunctionDefinition, P]: - return CacheInputs( - expr, self.get_function_definition_cache_key, *args, **kwargs) + ) -> CacheInputsWithKey[FunctionDefinition, P]: + return CacheInputsWithKey( + expr, self.get_function_definition_cache_key(expr, *args, **kwargs), + *args, **kwargs) def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) From 3297f4c8a53461d9c7b5c4825904feae9a3f2bf7 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 19 Sep 2024 19:31:13 -0500 Subject: [PATCH 09/35] add map_dict_of_named_arrays to DirectPredecessorsGetter --- pytato/analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c102f710e..d579f5cf0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -337,6 +337,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr._data.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) From 6c43a5597d21a92e554ad1f4f404d440d770f918 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 24 Sep 2024 14:42:38 -0500 Subject: [PATCH 10/35] support functions as inputs and outputs in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d579f5cf0..3f6d3f165 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): +class DirectPredecessorsGetter( + Mapper[ + FrozenOrderedSet[ArrayOrNames | FunctionDefinition], + FrozenOrderedSet[ArrayOrNames], + []]): """ Mapper to get the `direct predecessors @@ -334,6 +338,10 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ + def __init__(self, *, include_functions: bool = False) -> None: + super().__init__() + self.include_functions = include_functions + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) @@ -401,8 +409,17 @@ def map_distributed_send_ref_holder(self, ) -> FrozenOrderedSet[ArrayOrNames]: return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_call( + self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \ + FrozenOrderedSet(expr.bindings.values()) + if self.include_functions: + result = result | FrozenOrderedSet([expr.function]) + return result + + def map_function_definition( + self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.returns.values()) def map_named_call_result( self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: From c1f0681d98ac38965424d788b6ef0bc35ccf266e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 29 Aug 2024 16:57:13 -0500 Subject: [PATCH 11/35] add collision/duplication checks to CachedMapper/TransformMapper/TransformMapperWithExtraArgs --- pytato/analysis/__init__.py | 4 +- pytato/distributed/partition.py | 2 +- pytato/transform/__init__.py | 270 ++++++++++++++++++++++++++++++-- pytato/transform/metadata.py | 4 +- 4 files changed, 258 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3f6d3f165..ba23254df 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -645,7 +645,7 @@ def combine(self, *args: int) -> int: def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: s = Mapper.rec(self, expr) if ( @@ -657,7 +657,7 @@ def rec(self, expr: ArrayOrNames) -> int: else: result = 0 + s - self._cache.add(inputs, 0) + self._cache_add(inputs, 0) return result diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a022f8f8e..73eec2745 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -296,7 +296,7 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 35fb499eb..be7905576 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -188,6 +188,14 @@ class ForeignObjectError(ValueError): pass +class CacheCollisionError(ValueError): + pass + + +class CacheNoOpDuplicationError(ValueError): + pass + + # {{{ mapper base class ResultT = TypeVar("ResultT") @@ -300,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT") +CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -351,9 +359,18 @@ class CachedMapperCache(Generic[CacheExprT, CacheResultT, P]): .. automethod:: retrieve .. automethod:: clear """ - def __init__(self) -> None: - """Initialize the cache.""" + def __init__(self, err_on_collision: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + """ + self.err_on_collision = err_on_collision + self._input_key_to_result: dict[CacheKeyT, CacheResultT] = {} + if self.err_on_collision: + self._input_key_to_expr: dict[CacheKeyT, CacheExprT] = {} def add( self, @@ -366,16 +383,27 @@ def add( f"Cache entry is already present for key '{key}'." self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + return result def retrieve(self, inputs: CacheInputsWithKey[CacheExprT, P]) -> CacheResultT: """Retrieve the cached mapping result.""" key = inputs.key - return self._input_key_to_result[key] + + result = self._input_key_to_result[key] + + if self.err_on_collision and inputs.expr is not self._input_key_to_expr[key]: + raise CacheCollisionError + + return result def clear(self) -> None: """Reset the cache.""" self._input_key_to_result = {} + if self.err_on_collision: + self._input_key_to_expr = {} class CachedMapper(Mapper[ResultT, FunctionResultT, P]): @@ -389,6 +417,7 @@ class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """ def __init__( self, + err_on_collision: bool = False, _cache: CachedMapperCache[ArrayOrNames, ResultT, P] | None = None, _function_cache: @@ -398,12 +427,12 @@ def __init__( self._cache: CachedMapperCache[ArrayOrNames, ResultT, P] = ( _cache if _cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) self._function_cache: CachedMapperCache[ FunctionDefinition, FunctionResultT, P] = ( _function_cache if _function_cache is not None - else CachedMapperCache()) + else CachedMapperCache(err_on_collision=err_on_collision)) def get_cache_key( self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs @@ -435,21 +464,50 @@ def _make_function_definition_cache_inputs( expr, self.get_function_definition_cache_key(expr, *args, **kwargs), *args, **kwargs) + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ResultT) -> ResultT: + return self._cache.add(inputs, result) + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionResultT) -> FunctionResultT: + return self._function_cache.add(inputs, result) + + def _cache_retrieve(self, inputs: CacheInputsWithKey[ArrayOrNames, P]) -> ResultT: + try: + return self._cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_retrieve( + self, inputs: CacheInputsWithKey[FunctionDefinition, P]) -> FunctionResultT: + try: + return self._function_cache.retrieve(inputs) + except CacheCollisionError as e: + raise ValueError( + f"cache collision detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: inputs = self._make_cache_inputs(expr, *args, **kwargs) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add(inputs, Mapper.rec(self, expr, *args, **kwargs)) + return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs)) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: inputs = self._make_function_definition_cache_inputs(expr, *args, **kwargs) try: - return self._function_cache.retrieve(inputs) + return self._function_cache_retrieve(inputs) except KeyError: - return self._function_cache.add( + return self._function_cache_add( inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs)) def clone_for_callee( @@ -458,8 +516,10 @@ def clone_for_callee( Called to clone *self* before starting traversal of a :class:`pytato.function.FunctionDefinition`. """ - # Functions are cached globally, but arrays aren't - return type(self)(_function_cache=self._function_cache) + return type(self)( + err_on_collision=self._cache.err_on_collision, + # Functions are cached globally, but arrays aren't + _function_cache=self._function_cache) # }}} @@ -467,7 +527,67 @@ def clone_for_callee( # {{{ TransformMapper class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): - pass + """ + Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_no_op_duplication: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_no_op_duplication = err_on_no_op_duplication + + def add( + self, + inputs: CacheInputsWithKey[CacheExprT, P], + result: CacheExprT) -> CacheExprT: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + if self.err_on_no_op_duplication: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Need this check in order to handle input DAGs that have existing + # duplicates. Deduplication will potentially replace predecessors + # of `expr` with cached versions, producing a new `result` that has + # the same cache key as `expr`. + and all( + result_pred is pred + for pred, result_pred in zip( + pred_getter(inputs.expr), + pred_getter(result), + strict=True))): + raise CacheNoOpDuplicationError from None + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): @@ -477,13 +597,71 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see :class:`CopyMapper`. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, []], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -499,14 +677,72 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. + + .. automethod:: __init__ + .. automethod:: clone_for_callee """ def __init__( self, + err_on_collision: bool = False, + err_on_no_op_duplication: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + """ + :arg err_on_collision: Raise an exception if two distinct input array + instances have the same key. + :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + if _cache is None: + _cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + if _function_cache is None: + _function_cache = TransformMapperCache( + err_on_collision=err_on_collision, + err_on_no_op_duplication=err_on_no_op_duplication) + + super().__init__( + err_on_collision=err_on_collision, + _cache=_cache, + _function_cache=_function_cache) + + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, P], + result: ArrayOrNames) -> ArrayOrNames: + try: + return self._cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def _function_cache_add( + self, + inputs: CacheInputsWithKey[FunctionDefinition, P], + result: FunctionDefinition) -> FunctionDefinition: + try: + return self._function_cache.add(inputs, result) + except CacheNoOpDuplicationError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + function_cache = cast( + "TransformMapperCache[FunctionDefinition, P]", self._function_cache) + return type(self)( + err_on_collision=function_cache.err_on_collision, + err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + _function_cache=function_cache) # }}} @@ -1548,9 +1784,9 @@ def clone_for_callee( def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: - return self._cache.add(inputs, Mapper.rec(self, self.map_fn(expr))) + return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr))) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 485e2a76b..0b4f08990 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -467,7 +467,7 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: def rec(self, expr: ArrayOrNames) -> ArrayOrNames: inputs = self._make_cache_inputs(expr) try: - return self._cache.retrieve(inputs) + return self._cache_retrieve(inputs) except KeyError: result = Mapper.rec(self, expr) if not isinstance( @@ -475,7 +475,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - return self._cache.add(inputs, result) + return self._cache_add(inputs, result) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( From b3b86a0973bddb554d6a26368325c92d0797b9f9 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:41:56 -0600 Subject: [PATCH 12/35] fix doc --- pytato/transform/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index be7905576..c12a69133 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -556,8 +556,7 @@ def add( """ Cache a mapping result. - Returns the cached result (which may not be identical to *result* if a - result was already cached with the same result key). + Returns *result*. """ key = inputs.key From 1feea92b835a214ad12ca8638068527569fe1973 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 18 Feb 2025 14:49:23 -0600 Subject: [PATCH 13/35] change terminology from 'no-op duplication' to 'mapper-created duplicate' --- pytato/transform/__init__.py | 48 ++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index c12a69133..d3c4fe410 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -192,7 +192,7 @@ class CacheCollisionError(ValueError): pass -class CacheNoOpDuplicationError(ValueError): +class MapperCreatedDuplicateError(ValueError): pass @@ -536,18 +536,18 @@ class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): def __init__( self, err_on_collision: bool, - err_on_no_op_duplication: bool) -> None: + err_on_created_duplicate: bool) -> None: """ Initialize the cache. :arg err_on_collision: Raise an exception if two distinct input expression instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ super().__init__(err_on_collision=err_on_collision) - self.err_on_no_op_duplication = err_on_no_op_duplication + self.err_on_created_duplicate = err_on_created_duplicate def add( self, @@ -563,7 +563,7 @@ def add( assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - if self.err_on_no_op_duplication: + if self.err_on_created_duplicate: from pytato.analysis import DirectPredecessorsGetter pred_getter = DirectPredecessorsGetter(include_functions=True) if ( @@ -580,7 +580,7 @@ def add( pred_getter(inputs.expr), pred_getter(result), strict=True))): - raise CacheNoOpDuplicationError from None + raise MapperCreatedDuplicateError from None self._input_key_to_result[key] = result if self.err_on_collision: @@ -603,25 +603,25 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): def __init__( self, err_on_collision: bool = False, - err_on_no_op_duplication: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: """ :arg err_on_collision: Raise an exception if two distinct input array instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ if _cache is None: _cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) if _function_cache is None: _function_cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) super().__init__( err_on_collision=err_on_collision, @@ -634,9 +634,9 @@ def _cache_add( result: ArrayOrNames) -> ArrayOrNames: try: return self._cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def _function_cache_add( @@ -645,9 +645,9 @@ def _function_cache_add( result: FunctionDefinition) -> FunctionDefinition: try: return self._function_cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def clone_for_callee(self, function: FunctionDefinition) -> Self: @@ -659,7 +659,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: "TransformMapperCache[FunctionDefinition, []]", self._function_cache) return type(self)( err_on_collision=function_cache.err_on_collision, - err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + err_on_created_duplicate=function_cache.err_on_created_duplicate, _function_cache=function_cache) # }}} @@ -683,7 +683,7 @@ class TransformMapperWithExtraArgs( def __init__( self, err_on_collision: bool = False, - err_on_no_op_duplication: bool = False, + err_on_created_duplicate: bool = False, _cache: TransformMapperCache[ArrayOrNames, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None @@ -691,18 +691,18 @@ def __init__( """ :arg err_on_collision: Raise an exception if two distinct input array instances have the same key. - :arg err_on_no_op_duplication: Raise an exception if mapping produces a new + :arg err_on_created_duplicate: Raise an exception if mapping produces a new array instance that has the same key as the input array. """ if _cache is None: _cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) if _function_cache is None: _function_cache = TransformMapperCache( err_on_collision=err_on_collision, - err_on_no_op_duplication=err_on_no_op_duplication) + err_on_created_duplicate=err_on_created_duplicate) super().__init__( err_on_collision=err_on_collision, @@ -715,9 +715,9 @@ def _cache_add( result: ArrayOrNames) -> ArrayOrNames: try: return self._cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def _function_cache_add( @@ -726,9 +726,9 @@ def _function_cache_add( result: FunctionDefinition) -> FunctionDefinition: try: return self._function_cache.add(inputs, result) - except CacheNoOpDuplicationError as e: + except MapperCreatedDuplicateError as e: raise ValueError( - f"no-op duplication detected on {type(inputs.expr)} in " + f"mapper-created duplicate detected on {type(inputs.expr)} in " f"{type(self)}.") from e def clone_for_callee(self, function: FunctionDefinition) -> Self: @@ -740,7 +740,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: "TransformMapperCache[FunctionDefinition, P]", self._function_cache) return type(self)( err_on_collision=function_cache.err_on_collision, - err_on_no_op_duplication=function_cache.err_on_no_op_duplication, + err_on_created_duplicate=function_cache.err_on_created_duplicate, _function_cache=function_cache) # }}} From b97edc6146768bd46fc2dceda407aa28edc07f98 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Feb 2025 13:04:20 -0600 Subject: [PATCH 14/35] reword explanation of predecessor check in duplication check --- pytato/transform/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index d3c4fe410..6b5418861 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -570,10 +570,14 @@ def add( hash(result) == hash(inputs.expr) and result == inputs.expr and result is not inputs.expr - # Need this check in order to handle input DAGs that have existing - # duplicates. Deduplication will potentially replace predecessors - # of `expr` with cached versions, producing a new `result` that has - # the same cache key as `expr`. + # Only consider "direct" duplication, not duplication resulting + # from equality-preserving changes to predecessors. Assume that + # such changes are OK, otherwise they would have been detected + # at the point at which they originated. (For example, consider + # a DAG containing pre-existing duplicates. If a subexpression + # of *expr* is a duplicate and is replaced with a previously + # encountered version from the cache, a new instance of *expr* + # must be created. This should not trigger an error.) and all( result_pred is pred for pred, result_pred in zip( From cbd9a62136515d4df64da24f5c26258bebd4c7c0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 26 Feb 2025 15:25:19 -0600 Subject: [PATCH 15/35] change CacheExprT constraint to use bound= apparently TypeVar(..., ) doesn't include subclasses of --- pytato/transform/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6b5418861..7d141903e 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -308,7 +308,7 @@ def __call__( # {{{ CachedMapper -CacheExprT = TypeVar("CacheExprT", ArrayOrNames, FunctionDefinition) +CacheExprT = TypeVar("CacheExprT", bound=ArrayOrNames | FunctionDefinition) CacheResultT = TypeVar("CacheResultT") CacheKeyT: TypeAlias = Hashable @@ -581,8 +581,10 @@ def add( and all( result_pred is pred for pred, result_pred in zip( - pred_getter(inputs.expr), - pred_getter(result), + # type-ignore-reason: mypy doesn't seem to recognize + # overloaded Mapper.__call__ here + pred_getter(inputs.expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] strict=True))): raise MapperCreatedDuplicateError from None From 7e572d5b3bd939b0f5dc03ad1e6f6ec88e297ed8 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Sep 2024 20:45:47 -0500 Subject: [PATCH 16/35] add result deduplication to transform mappers --- pytato/transform/__init__.py | 59 +++++++++++++++++++++--------------- test/test_codegen.py | 3 +- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 7d141903e..5b8358804 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -549,6 +549,8 @@ def __init__( self.err_on_created_duplicate = err_on_created_duplicate + self._result_to_cached_result: dict[CacheExprT, CacheExprT] = {} + def add( self, inputs: CacheInputsWithKey[CacheExprT, P], @@ -556,37 +558,44 @@ def add( """ Cache a mapping result. - Returns *result*. + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). """ key = inputs.key assert key not in self._input_key_to_result, \ f"Cache entry is already present for key '{key}'." - if self.err_on_created_duplicate: - from pytato.analysis import DirectPredecessorsGetter - pred_getter = DirectPredecessorsGetter(include_functions=True) - if ( - hash(result) == hash(inputs.expr) - and result == inputs.expr - and result is not inputs.expr - # Only consider "direct" duplication, not duplication resulting - # from equality-preserving changes to predecessors. Assume that - # such changes are OK, otherwise they would have been detected - # at the point at which they originated. (For example, consider - # a DAG containing pre-existing duplicates. If a subexpression - # of *expr* is a duplicate and is replaced with a previously - # encountered version from the cache, a new instance of *expr* - # must be created. This should not trigger an error.) - and all( - result_pred is pred - for pred, result_pred in zip( - # type-ignore-reason: mypy doesn't seem to recognize - # overloaded Mapper.__call__ here - pred_getter(inputs.expr), # type: ignore[arg-type] - pred_getter(result), # type: ignore[arg-type] - strict=True))): - raise MapperCreatedDuplicateError from None + try: + result = self._result_to_cached_result[result] + except KeyError: + if self.err_on_created_duplicate: + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + if ( + hash(result) == hash(inputs.expr) + and result == inputs.expr + and result is not inputs.expr + # Only consider "direct" duplication, not duplication + # resulting from equality-preserving changes to predecessors. + # Assume that such changes are OK, otherwise they would have + # been detected at the point at which they originated. (For + # example, consider a DAG containing pre-existing duplicates. + # If a subexpression of *expr* is a duplicate and is replaced + # with a previously encountered version from the cache, a + # new instance of *expr* must be created. This should not + # trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + # type-ignore-reason: mypy doesn't seem to recognize + # overloaded Mapper.__call__ here + pred_getter(inputs.expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] + strict=True))): + raise MapperCreatedDuplicateError from None + + self._result_to_cached_result[result] = result self._input_key_to_result[key] = result if self.err_on_collision: diff --git a/test/test_codegen.py b/test/test_codegen.py index 0c6972cf6..2193b7fc9 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1621,7 +1621,8 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data - assert num_nodes_new == (num_nodes_old - 2) + # '2*x2' would be merged with '2*x1' as they are identical expressions + assert num_nodes_new == (num_nodes_old - 3) # {{{ test_deterministic_codegen From e1ab346c5dd57a0373e31fee502238c0df2e0608 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 4 Sep 2024 21:35:04 -0500 Subject: [PATCH 17/35] add FIXME --- pytato/transform/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 5b8358804..e4c421bb4 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1781,6 +1781,7 @@ class CachedMapAndCopyMapper(CopyMapper): def __init__( self, + # FIXME: Should map_fn be applied to functions too? map_fn: Callable[[ArrayOrNames], ArrayOrNames], _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None From 8b15622273c3e67cf96259f6787eb6bdef34fd97 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Jun 2024 13:32:42 -0500 Subject: [PATCH 18/35] avoid unnecessary duplication in CopyMapper/CopyMapperWithExtraArgs --- pytato/transform/__init__.py | 622 +++++++++++++++++++++++------------ 1 file changed, 403 insertions(+), 219 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e4c421bb4..2af55c24a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -776,63 +776,104 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars # here. - return tuple(self.rec(s) if isinstance(s, Array) else s # type: ignore[misc] - for s in situp) + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: - arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(_verify_is_array(self.rec(arr)) for arr in expr.arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: - return AxisPermutation(array=_verify_is_array(self.rec(expr.array)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: - return type(expr)(_verify_is_array(self.rec(expr.array)), - indices=self.rec_idx_or_size_tuple(expr.indices), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_indices = self.rec_idx_or_size_tuple(expr.indices) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -848,91 +889,131 @@ def map_non_contiguous_advanced_index(self, return self._map_index_base(expr) def map_data_wrapper(self, expr: DataWrapper) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None - return SizeParam( - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + return expr def map_einsum(self, expr: Einsum) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array(self.rec(arg)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, expr: NamedArray) -> Array: - container = self.rec(expr._container) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args = tuple(_verify_is_array(self.rec(arg)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, expr: NamedArray) -> Array: + new_container = self.rec(expr._container) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array(self.rec(val.expr)) - for key, val in expr.items()}, - tags=expr.tags - ) + new_data = { + key: _verify_is_array(self.rec(val.expr)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: - rec_container = self.rec(expr._container) - assert isinstance(rec_container, LoopyCall) - return LoopyCallResult( - _container=rec_container, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array)), - newshape=self.rec_idx_or_size_tuple(expr.newshape), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array(self.rec(expr.passthrough_data)), - ) + new_send_data = _verify_is_array(self.rec(expr.send.data)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array(self.rec(expr.passthrough_data)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition) -> FunctionDefinition: @@ -941,19 +1022,37 @@ def map_function_definition(self, new_mapper = self.clone_for_callee(expr) new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} - return dataclasses.replace(expr, returns=immutabledict(new_returns)) + if all( + new_ret is ret + for ret, new_ret in zip( + expr.returns.values(), + new_returns.values(), + strict=True)): + return expr + else: + return dataclasses.replace(expr, returns=immutabledict(new_returns)) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function), - immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function) + new_bindings = { + name: _verify_is_array(self.rec(bnd)) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult) -> Array: - call = self.rec(expr._container) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container) + assert isinstance(new_call, Call) + return new_call[expr.name] class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs[P]): @@ -977,70 +1076,102 @@ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], def map_index_lambda(self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> Array: - bindings: Mapping[str, Array] = immutabledict({ + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + new_bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr, *args, **kwargs) for name, subexpr in sorted(expr.bindings.items())}) - return IndexLambda(expr=expr.expr, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - bindings=bindings, - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return Placeholder(name=expr.name, - shape=self.rec_idx_or_size_tuple(expr.shape, - *args, **kwargs), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return Placeholder(name=expr.name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Stack(arrays=new_arrays, axis=expr.axis, axes=expr.axes, + tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> Array: - arrays = tuple( + new_arrays: tuple[Array, ...] = tuple( _verify_is_array(self.rec(arr, *args, **kwargs)) for arr in expr.arrays) - return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + return expr + else: + return Concatenate(arrays=new_arrays, axis=expr.axis, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> Array: - return Roll(array=_verify_is_array(self.rec(expr.array, *args, **kwargs)), - shift=expr.shift, - axis=expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return Roll(array=new_ary, + shift=expr.shift, + axis=expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> Array: - return AxisPermutation(array=_verify_is_array( - self.rec(expr.array, *args, **kwargs)), - axis_permutation=expr.axis_permutation, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + if new_ary is expr.array: + return expr + else: + return AxisPermutation(array=new_ary, + axis_permutation=expr.axis_permutation, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: P.args, **kwargs: P.kwargs) -> Array: assert isinstance(expr, _SuppliedAxesAndTagsMixin) - return type(expr)(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - indices=self.rec_idx_or_size_tuple(expr.indices, - *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_indices = self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs) + if new_ary is expr.array and new_indices is expr.indices: + return expr + else: + return type(expr)(new_ary, + indices=new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> Array: @@ -1061,98 +1192,141 @@ def map_non_contiguous_advanced_index(self, def map_data_wrapper(self, expr: DataWrapper, *args: P.args, **kwargs: P.kwargs) -> Array: - return DataWrapper( - data=expr.data, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DataWrapper( + data=expr.data, + shape=new_shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: P.args, **kwargs: P.kwargs) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return expr def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> Array: - return Einsum(expr.access_descriptors, - tuple(_verify_is_array( - self.rec(arg, *args, **kwargs)) for arg in expr.args), - axes=expr.axes, - redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) - - def map_named_array(self, - expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: - container = self.rec(expr._container, *args, **kwargs) - assert isinstance(container, AbstractResultWithNamedArrays) - return type(expr)(container, - expr.name, + new_args: tuple[Array, ...] = tuple( + _verify_is_array(self.rec(arg, *args, **kwargs)) for arg in expr.args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + return expr + else: + return Einsum(expr.access_descriptors, + new_args, axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) + def map_named_array(self, + expr: NamedArray, *args: P.args, **kwargs: P.kwargs) -> Array: + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, AbstractResultWithNamedArrays) + if new_container is expr._container: + return expr + else: + return type(expr)(new_container, + expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: P.args, **kwargs: P.kwargs ) -> DictOfNamedArrays: - return DictOfNamedArrays({key: _verify_is_array( - self.rec(val.expr, *args, **kwargs)) - for key, val in expr.items()}, - tags=expr.tags, - ) + new_data = { + key: _verify_is_array(self.rec(val.expr, *args, **kwargs)) + for key, val in expr.items()} + if all( + new_data_val is val.expr + for val, new_data_val in zip( + expr.values(), + new_data.values(), + strict=True)): + return expr + else: + return DictOfNamedArrays(new_data, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall, *args: P.args, **kwargs: P.kwargs) -> LoopyCall: - bindings: Mapping[Any, Any] = immutabledict( + new_bindings: Mapping[Any, Any] = immutabledict( {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - - return LoopyCall(translation_unit=expr.translation_unit, - bindings=bindings, - entrypoint=expr.entrypoint, - tags=expr.tags, - ) + if ( + frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return LoopyCall(translation_unit=expr.translation_unit, + bindings=new_bindings, + entrypoint=expr.entrypoint, + tags=expr.tags, + ) def map_loopy_call_result(self, expr: LoopyCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - rec_loopy_call = self.rec(expr._container, *args, **kwargs) - assert isinstance(rec_loopy_call, LoopyCall) - return LoopyCallResult( - _container=rec_loopy_call, - name=expr.name, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_container = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_container, LoopyCall) + if new_container is expr._container: + return expr + else: + return LoopyCallResult( + _container=new_container, + name=expr.name, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> Array: - return Reshape(_verify_is_array(self.rec(expr.array, *args, **kwargs)), - newshape=self.rec_idx_or_size_tuple(expr.newshape, - *args, **kwargs), - order=expr.order, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_ary = _verify_is_array(self.rec(expr.array, *args, **kwargs)) + new_newshape = self.rec_idx_or_size_tuple(expr.newshape, *args, **kwargs) + if new_ary is expr.array and new_newshape is expr.newshape: + return expr + else: + return Reshape(new_ary, + newshape=new_newshape, + order=expr.order, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedSendRefHolder( - send=DistributedSend( - data=_verify_is_array(self.rec(expr.send.data, *args, **kwargs)), - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag), - passthrough_data=_verify_is_array( - self.rec(expr.passthrough_data, *args, **kwargs))) + new_send_data = _verify_is_array(self.rec(expr.send.data, *args, **kwargs)) + if new_send_data is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + data=new_send_data, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag) + new_passthrough = _verify_is_array( + self.rec(expr.passthrough_data, *args, **kwargs)) + if new_send is expr.send and new_passthrough is expr.passthrough_data: + return expr + else: + return DistributedSendRefHolder(new_send, new_passthrough) def map_distributed_recv(self, expr: DistributedRecv, *args: P.args, **kwargs: P.kwargs) -> Array: - return DistributedRecv( - src_rank=expr.src_rank, comm_tag=expr.comm_tag, - shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes, - non_equality_tags=expr.non_equality_tags) + new_shape = self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs) + if new_shape is expr.shape: + return expr + else: + return DistributedRecv( + src_rank=expr.src_rank, comm_tag=expr.comm_tag, + shape=new_shape, dtype=expr.dtype, tags=expr.tags, + axes=expr.axes, non_equality_tags=expr.non_equality_tags) def map_function_definition( self, expr: FunctionDefinition, @@ -1164,17 +1338,27 @@ def map_function_definition( def map_call(self, expr: Call, *args: P.args, **kwargs: P.kwargs) -> AbstractResultWithNamedArrays: - return Call(self.rec_function_definition(expr.function, *args, **kwargs), - immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, - ) + new_function = self.rec_function_definition(expr.function, *args, **kwargs) + new_bindings = { + name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()} + if ( + new_function is expr.function + and all( + new_bnd is bnd + for bnd, new_bnd in zip( + expr.bindings.values(), + new_bindings.values(), + strict=True))): + return expr + else: + return Call(new_function, immutabledict(new_bindings), tags=expr.tags) def map_named_call_result(self, expr: NamedCallResult, *args: P.args, **kwargs: P.kwargs) -> Array: - call = self.rec(expr._container, *args, **kwargs) - assert isinstance(call, Call) - return call[expr.name] + new_call = self.rec(expr._container, *args, **kwargs) + assert isinstance(new_call, Call) + return new_call[expr.name] # }}} From f3674c5a62046812f5a2eb66ad8474ec01a25aff Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 20 Sep 2024 12:03:14 -0500 Subject: [PATCH 19/35] add Deduplicator --- pytato/transform/__init__.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 2af55c24a..16496063b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -101,6 +101,7 @@ .. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper .. autoclass:: CopyMapperWithExtraArgs +.. autoclass:: Deduplicator .. autoclass:: CombineMapper .. autoclass:: DependencyMapper .. autoclass:: InputGatherer @@ -1363,6 +1364,28 @@ def map_named_call_result(self, expr: NamedCallResult, # }}} +# {{{ Deduplicator + +class Deduplicator(CopyMapper): + """Removes duplicate nodes from an expression.""" + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__( + err_on_collision=False, err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + +# }}} + + # {{{ CombineMapper class CombineMapper(CachedMapper[ResultT, FunctionResultT, []]): From 5ac33bb49dc95bf8477c19747b5bb5bdea526292 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 10:58:24 -0500 Subject: [PATCH 20/35] avoid unnecessary duplication in InlineMarker --- pytato/transform/calls.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 34f89cbc1..298b5351b 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -104,7 +104,11 @@ class InlineMarker(CopyMapper): Primary mapper for :func:`tag_all_calls_to_be_inlined`. """ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - return super().map_call(expr).tagged(InlineCallTag()) + rec_expr = super().map_call(expr) + if rec_expr.tags_of_type(InlineCallTag): + return rec_expr + else: + return rec_expr.tagged(InlineCallTag()) def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: From f0d48ce566a7b5810d039ede31f228aa957d2b73 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 15:02:08 -0500 Subject: [PATCH 21/35] avoid duplication in tagged() for Axis/ReductionDescriptor/_SuppliedAxesAndTagsMixin --- pytato/array.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 8f92e5118..73d970a2e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -467,8 +467,11 @@ class Axis(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> Axis: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True) @@ -480,8 +483,11 @@ class ReductionDescriptor(Taggable): tags: frozenset[Tag] def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor: - from dataclasses import replace - return replace(self, tags=tags) + if tags != self.tags: + from dataclasses import replace + return replace(self, tags=tags) + else: + return self @array_dataclass() @@ -865,7 +871,10 @@ class _SuppliedAxesAndTagsMixin(Taggable): default=frozenset()) def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self: - return dataclasses.replace(self, tags=tags) + if tags != self.tags: + return dataclasses.replace(self, tags=tags) + else: + return self @dataclasses.dataclass(frozen=True, eq=False, repr=False) From 3e56653e3a08f8d3c477ca28fee33ea9ce803843 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 12:48:20 -0500 Subject: [PATCH 22/35] avoid duplication in Array.with_tagged_axis --- pytato/array.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 73d970a2e..c246a92a6 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -839,10 +839,14 @@ def with_tagged_axis(self, iaxis: int, """ Returns a copy of *self* with *iaxis*-th axis tagged with *tags*. """ - new_axes = (self.axes[:iaxis] - + (self.axes[iaxis].tagged(tags),) - + self.axes[iaxis+1:]) - return self.copy(axes=new_axes) + new_axis = self.axes[iaxis].tagged(tags) + if new_axis is not self.axes[iaxis]: + new_axes = (self.axes[:iaxis] + + (self.axes[iaxis].tagged(tags),) + + self.axes[iaxis+1:]) + return self.copy(axes=new_axes) + else: + return self @memoize_method def __repr__(self) -> str: From 5cd2204dcf04c415201e685bf430383f8930b6b0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 11 Jun 2024 12:48:54 -0500 Subject: [PATCH 23/35] avoid duplication in with_tagged_reduction for IndexLambda/Einsum --- pytato/array.py | 58 ++++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c246a92a6..53d5215b8 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1127,20 +1127,22 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, immutabledict) - new_var_to_redn_descr = dict(self.var_to_reduction_descr) - new_var_to_redn_descr[reduction_variable] = \ - self.var_to_reduction_descr[reduction_variable].tagged(tags) - - return type(self)(expr=self.expr, - shape=self.shape, - dtype=self.dtype, - bindings=self.bindings, - axes=self.axes, - var_to_reduction_descr=immutabledict - (new_var_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags) + new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags) + if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]: + assert isinstance(self.var_to_reduction_descr, immutabledict) + new_var_to_redn_descr = dict(self.var_to_reduction_descr) + new_var_to_redn_descr[reduction_variable] = new_redn_descr + return type(self)(expr=self.expr, + shape=self.shape, + dtype=self.dtype, + bindings=self.bindings, + axes=self.axes, + var_to_reduction_descr=immutabledict + (new_var_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags) + else: + return self # }}} @@ -1291,19 +1293,21 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, immutabledict) - new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) - new_redn_axis_to_redn_descr[redn_axis] = \ - self.redn_axis_to_redn_descr[redn_axis].tagged(tags) - - return type(self)(access_descriptors=self.access_descriptors, - args=self.args, - axes=self.axes, - redn_axis_to_redn_descr=immutabledict - (new_redn_axis_to_redn_descr), - tags=self.tags, - non_equality_tags=self.non_equality_tags, - ) + new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags) + if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]: + assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) + new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr + return type(self)(access_descriptors=self.access_descriptors, + args=self.args, + axes=self.axes, + redn_axis_to_redn_descr=immutabledict + (new_redn_axis_to_redn_descr), + tags=self.tags, + non_equality_tags=self.non_equality_tags, + ) + else: + return self EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") From e413cc9c287b129948adbb6ae2939e373b88ccd8 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Jun 2024 15:35:47 -0500 Subject: [PATCH 24/35] attempt to avoid duplication in CodeGenPreprocessor --- pytato/array.py | 52 ++++--- pytato/codegen.py | 106 +++++++++----- pytato/transform/lower_to_index_lambda.py | 169 ++++++++++++++-------- 3 files changed, 206 insertions(+), 121 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 53d5215b8..1bcfbe696 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1177,6 +1177,34 @@ class EinsumReductionAxis(EinsumAxisDescriptor): dim: int +def _get_einsum_access_descr_to_axis_len( + access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...], + args: tuple[Array, ...], + ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: + from pytato.utils import are_shape_components_equal + descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} + + for access_descrs, arg in zip(access_descriptors, + args, strict=True): + assert arg.ndim == len(access_descrs) + for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): + if descr in descr_to_axis_len: + seen_axis_len = descr_to_axis_len[descr] + + if not are_shape_components_equal(seen_axis_len, + arg_axis_len): + if are_shape_components_equal(arg_axis_len, 1): + # this axis would be broadcasted + pass + else: + assert are_shape_components_equal(seen_axis_len, 1) + descr_to_axis_len[descr] = arg_axis_len + else: + descr_to_axis_len[descr] = arg_axis_len + + return immutabledict(descr_to_axis_len) + + @array_dataclass() class Einsum(_SuppliedAxesAndTagsMixin, Array): """ @@ -1224,28 +1252,8 @@ def __post_init__(self) -> None: @memoize_method def _access_descr_to_axis_len(self ) -> Mapping[EinsumAxisDescriptor, ShapeComponent]: - from pytato.utils import are_shape_components_equal - descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {} - - for access_descrs, arg in zip(self.access_descriptors, - self.args, strict=True): - assert arg.ndim == len(access_descrs) - for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True): - if descr in descr_to_axis_len: - seen_axis_len = descr_to_axis_len[descr] - - if not are_shape_components_equal(seen_axis_len, - arg_axis_len): - if are_shape_components_equal(arg_axis_len, 1): - # this axis would be broadcasted - pass - else: - assert are_shape_components_equal(seen_axis_len, 1) - descr_to_axis_len[descr] = arg_axis_len - else: - descr_to_axis_len[descr] = arg_axis_len - - return immutabledict(descr_to_axis_len) + return _get_einsum_access_descr_to_axis_len( + self.access_descriptors, self.args) @cached_property def shape(self) -> ShapeType: diff --git a/pytato/codegen.py b/pytato/codegen.py index cb957f076..f01296e1b 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -141,54 +141,72 @@ def __init__( _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) + super().__init__( + # ToIndexLambdaMixin operates on certain array types for which `shape` + # is a derived property (e.g. BasicIndex). For these types, `shape` + # is an expression that may contain duplicate nodes. Mappers do not + # traverse properties, so these expressions are not subject to any prior + # deduplication. Once transformed into an IndexLambda, however, `shape` + # becomes a field and is subject to traversal and duplication checks. + # Without `err_on_collision=False`, these duplicates would lead to + # collision errors. + err_on_collision=False, + _cache=_cache, _function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target self.kernels_seen: dict[str, lp.LoopKernel] = kernels_seen or {} def map_size_param(self, expr: SizeParam) -> Array: - name = expr.name - assert name is not None - return SizeParam( # pylint: disable=missing-kwoa - name=name, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + assert expr.name is not None + return expr def map_placeholder(self, expr: Placeholder) -> Array: - name = expr.name - if name is None: - name = self.var_name_gen("_pt_in") - return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), - dtype=expr.dtype, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_name = expr.name + if new_name is None: + new_name = self.var_name_gen("_pt_in") + new_shape = self.rec_idx_or_size_tuple(expr.shape) + if ( + new_name is expr.name + and new_shape is expr.shape): + return expr + else: + return Placeholder(name=new_name, + shape=new_shape, + dtype=expr.dtype, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") - translation_unit = expr.translation_unit.copy( - target=self.target.get_loopy_target()) + new_target = self.target.get_loopy_target() + + # FIXME: Can't use "is" here because targets aren't unique. Is it OK to + # use the existing target if it's equal to self.target.get_loopy_target()? + # If not, may have to set err_on_created_duplicate=False + if new_target == expr.translation_unit.target: + new_translation_unit = expr.translation_unit + else: + new_translation_unit = expr.translation_unit.copy(target=new_target) namegen = UniqueNameGenerator(set(self.kernels_seen)) - entrypoint = expr.entrypoint + new_entrypoint = expr.entrypoint # {{{ eliminate callable name collision - for name, clbl in translation_unit.callables_table.items(): + for name, clbl in new_translation_unit.callables_table.items(): if isinstance(clbl, lp.CallableKernel): assert isinstance(name, str) if name in self.kernels_seen and ( - translation_unit[name] != self.kernels_seen[name]): + new_translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): - if other_knl.copy(name=name) == translation_unit[name]: + if other_knl.copy(name=name) == new_translation_unit[name]: new_name = other_knl.name break else: @@ -198,37 +216,55 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - if name == entrypoint: + if name == new_entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. - entrypoint = new_name + new_entrypoint = new_name - translation_unit = lp.rename_callable( - translation_unit, name, new_name) + new_translation_unit = lp.rename_callable( + new_translation_unit, name, new_name) name = new_name self.kernels_seen[name] = clbl.subkernel # }}} - bindings: Mapping[str, Any] = immutabledict( + new_bindings: Mapping[str, Any] = immutabledict( {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in sorted(expr.bindings.items())}) - return LoopyCall(translation_unit=translation_unit, - bindings=bindings, - entrypoint=entrypoint, - tags=expr.tags - ) + assert ( + new_entrypoint is expr.entrypoint + or new_entrypoint != expr.entrypoint) + for bnd, new_bnd in zip( + expr.bindings.values(), new_bindings.values(), strict=True): + assert new_bnd is bnd or new_bnd != bnd + + if ( + new_translation_unit == expr.translation_unit + and ( + frozenset(new_bindings.keys()) + == frozenset(expr.bindings.keys())) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings) + and new_entrypoint is expr.entrypoint): + return expr + else: + return LoopyCall(translation_unit=new_translation_unit, + bindings=new_bindings, + entrypoint=new_entrypoint, + tags=expr.tags + ) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = _generate_name_for_temp(expr, self.var_name_gen, "_pt_data") + shape = self.rec_idx_or_size_tuple(expr.shape) self.bound_arguments[name] = expr.data return Placeholder(name=name, - shape=tuple(self.rec(s) if isinstance(s, Array) else s - for s in expr.shape), + shape=shape, dtype=expr.dtype, axes=expr.axes, tags=expr.tags, diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 507a450cd..f71ccc010 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -53,15 +53,18 @@ ShapeComponent, ShapeType, Stack, + _get_einsum_access_descr_to_axis_len, ) from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.scalar_expr import INT_CLASSES, ScalarExpression from pytato.tags import AssumeNonNegative -from pytato.transform import Mapper +from pytato.transform import IndexOrShapeExpr, Mapper from pytato.utils import normalized_slice_does_not_change_axis if TYPE_CHECKING: + from collections.abc import Mapping + import numpy as np @@ -126,16 +129,14 @@ def _generate_index_expressions( for old_size_till, old_stride in zip(old_size_tills, old_strides, strict=True)) -def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: +def _get_reshaped_indices( + order: str, old_shape: ShapeType, new_shape: ShapeType + ) -> tuple[ScalarExpression, ...]: - if expr.order.upper() not in ["C", "F"]: + if order.upper() not in ["C", "F"]: raise NotImplementedError("Order expected to be 'C' or 'F'", " (case insensitive). Found order = ", - f"{expr.order}") - - order = expr.order - old_shape = expr.array.shape - new_shape = expr.shape + f"{order}") # index variables need to be unique and depend on the new shape length index_vars = [prim.Variable(f"_{i}") for i in range(len(new_shape))] @@ -143,7 +144,8 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: # {{{ check for scalars if old_shape == (): - assert expr.size == 1 + from pytools import product + assert product(new_shape) == 1 return () if new_shape == (): @@ -256,10 +258,17 @@ def _get_reshaped_indices(expr: Reshape) -> tuple[ScalarExpression, ...]: class ToIndexLambdaMixin: - def _rec_shape(self, shape: ShapeType) -> ShapeType: - return tuple(self.rec(s) if isinstance(s, Array) - else s - for s in shape) + def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] + ) -> tuple[IndexOrShapeExpr, ...]: + # type-ignore-reason: apparently mypy cannot substitute typevars + # here. + new_situp = tuple( + self.rec(s) if isinstance(s, Array) else s + for s in situp) + if all(new_s is s for s, new_s in zip(situp, new_situp, strict=True)): + return situp + else: + return new_situp # type: ignore[return-value] if TYPE_CHECKING: def rec( @@ -270,17 +279,27 @@ def rec( return super().rec( # type: ignore[no-any-return,misc] expr, *args, **kwargs) - def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: - return IndexLambda(expr=expr.expr, - shape=self._rec_shape(expr.shape), - dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd - in sorted(expr.bindings.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + def map_index_lambda(self, expr: IndexLambda) -> Array: + new_shape = self.rec_idx_or_size_tuple(expr.shape) + new_bindings: Mapping[str, Array] = immutabledict({ + name: self.rec(subexpr) + for name, subexpr in sorted(expr.bindings.items())}) + if ( + new_shape is expr.shape + and frozenset(new_bindings.keys()) == frozenset(expr.bindings.keys()) + and all( + new_bindings[name] is expr.bindings[name] + for name in expr.bindings)): + return expr + else: + return IndexLambda(expr=expr.expr, + shape=new_shape, + dtype=expr.dtype, + bindings=new_bindings, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> IndexLambda: subscript = tuple(prim.Variable(f"_{i}") @@ -305,11 +324,11 @@ def map_stack(self, expr: Stack) -> IndexLambda: subarray_expr, stack_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": self.rec(ary) + for i, ary in enumerate(expr.arrays)} return IndexLambda(expr=stack_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, bindings=immutabledict(bindings), @@ -328,10 +347,12 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: for i in range(len(expr.shape))] return Subscript(aggregate, tuple(index)) + rec_arrays: tuple[Array, ...] = tuple(self.rec(ary) for ary in expr.arrays) + lbounds: list[Any] = [0] - ubounds: list[Any] = [expr.arrays[0].shape[expr.axis]] + ubounds: list[Any] = [rec_arrays[0].shape[expr.axis]] - for i, array in enumerate(expr.arrays[1:], start=1): + for i, array in enumerate(rec_arrays[1:], start=1): ubounds.append(ubounds[i-1]+array.shape[expr.axis]) lbounds.append(ubounds[i-1]) @@ -354,11 +375,11 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: subarray_expr, concat_expr) - bindings = {f"_in{i}": self.rec(array) - for i, array in enumerate(expr.arrays)} + bindings = {f"_in{i}": ary + for i, ary in enumerate(rec_arrays)} return IndexLambda(expr=concat_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -377,7 +398,9 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: dim_to_index_lambda_components, ) - bindings = {f"_in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} + rec_args: tuple[Array, ...] = tuple(self.rec(arg) for arg in expr.args) + + bindings = {f"_in{k}": arg for k, arg in enumerate(rec_args)} redn_bounds: dict[str, tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: list[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) @@ -385,13 +408,16 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: # {{{ add bindings coming from the shape expressions + access_descr_to_axis_len = _get_einsum_access_descr_to_axis_len( + expr.access_descriptors, rec_args) + for access_descr, (iarg, arg) in zip(expr.access_descriptors, - enumerate(expr.args), strict=True): + enumerate(rec_args), strict=True): subscript_indices: list[ArithmeticExpression] = [] for iaxis, axis in enumerate(access_descr): if not are_shape_components_equal( arg.shape[iaxis], - expr._access_descr_to_axis_len()[axis]): + access_descr_to_axis_len[axis]): # axis is broadcasted assert are_shape_components_equal(arg.shape[iaxis], 1) subscript_indices.append(0) @@ -432,7 +458,7 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: immutabledict(redn_bounds)) return IndexLambda(expr=inner_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, bindings=immutabledict(bindings), axes=expr.axes, @@ -443,12 +469,14 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: def map_roll(self, expr: Roll) -> IndexLambda: from pytato.utils import dim_to_index_lambda_components + rec_array = self.rec(expr.array) + index_expr: prim.ExpressionNode = prim.Variable("_in0") indices: list[ArithmeticExpression] = [ prim.Variable(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( - expr.shape[axis], + rec_array.shape[axis], UniqueNameGenerator({"_in0"})) # Mypy has a point: the type system does not prove that the operands are @@ -459,13 +487,12 @@ def map_roll(self, expr: Roll) -> IndexLambda: index_expr = index_expr[tuple(indices)] # type-ignore-reason: `bindings` was returned as Dict[str, SizeParam] - bindings["_in0"] = expr.array # type: ignore[assignment] + bindings["_in0"] = rec_array # type: ignore[assignment] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({name: self.rec(bnd) - for name, bnd in bindings.items()}), + bindings=immutabledict(bindings), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, @@ -476,27 +503,30 @@ def map_contiguous_advanced_index(self, ) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 for i_idx, (idx, axis_len) in enumerate( - zip(expr.indices, expr.array.shape, strict=True)): + zip(rec_indices, rec_array.shape, strict=True)): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -508,7 +538,7 @@ def map_contiguous_advanced_index(self, elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), get_indexing_expression( @@ -536,7 +566,7 @@ def map_contiguous_advanced_index(self, return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -547,28 +577,32 @@ def map_contiguous_advanced_index(self, def map_non_contiguous_advanced_index( self, expr: AdvancedIndexInNoncontiguousAxes) -> IndexLambda: from pytato.utils import get_indexing_expression, get_shape_after_broadcasting + + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + i_adv_indices = tuple(i - for i, idx_expr in enumerate(expr.indices) + for i, idx_expr in enumerate(rec_indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast("Array | int | np.integer[Any]", expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", rec_indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = len(adv_idx_shape) - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -580,7 +614,7 @@ def map_non_contiguous_advanced_index( elif isinstance(idx, Array): if isinstance(axis_len, INT_CLASSES): bnd_name = vng("in") - bindings[bnd_name] = self.rec(idx) + bindings[bnd_name] = idx indirect_idx_expr: ArithmeticExpression = prim.Subscript( prim.Variable(bnd_name), @@ -605,7 +639,7 @@ def map_non_contiguous_advanced_index( return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -614,20 +648,23 @@ def map_non_contiguous_advanced_index( ) def map_basic_index(self, expr: BasicIndex) -> IndexLambda: + rec_array = self.rec(expr.array) + rec_indices = self.rec_idx_or_size_tuple(expr.indices) + vng = UniqueNameGenerator() indices: list[ArithmeticExpression] = [] in_ary = vng("in") - bindings = {in_ary: self.rec(expr.array)} + bindings = {in_ary: rec_array} islice_idx = 0 - for idx, axis_len in zip(expr.indices, expr.array.shape, strict=True): + for idx, axis_len in zip(rec_indices, rec_array.shape, strict=True): if isinstance(idx, INT_CLASSES): if isinstance(axis_len, INT_CLASSES): indices.append(idx % axis_len) else: bnd_name = vng("in") - bindings[bnd_name] = self.rec(axis_len) + bindings[bnd_name] = axis_len indices.append(idx % prim.Variable(bnd_name)) elif isinstance(idx, NormalizedSlice): if normalized_slice_does_not_change_axis(idx, axis_len): @@ -642,7 +679,7 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), bindings=immutabledict(bindings), - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, var_to_reduction_descr=immutabledict(), @@ -651,18 +688,22 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: ) def map_reshape(self, expr: Reshape) -> IndexLambda: - indices = _get_reshaped_indices(expr) + rec_array = self.rec(expr.array) + rec_newshape = self.rec_idx_or_size_tuple(expr.shape) + indices = _get_reshaped_indices(expr.order, rec_array.shape, rec_newshape) index_expr = prim.Variable("_in0")[tuple(indices)] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=rec_newshape, dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: + rec_array = self.rec(expr.array) + indices: list[ArithmeticExpression | None] = [None] * expr.ndim for from_index, to_index in enumerate(expr.axis_permutation): indices[to_index] = prim.Variable(f"_{from_index}") @@ -671,9 +712,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: cast("tuple[ArithmeticExpression]", tuple(indices))] return IndexLambda(expr=index_expr, - shape=self._rec_shape(expr.shape), + shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, - bindings=immutabledict({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": rec_array}), axes=expr.axes, var_to_reduction_descr=immutabledict(), tags=expr.tags, From 0a805edbb35b3880d62412149a9f36d0e228b10d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 3 Jul 2024 15:54:52 -0500 Subject: [PATCH 25/35] limit PlaceholderSubstitutor to one call stack frame --- pytato/transform/calls.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 298b5351b..d9bc078fd 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -65,9 +65,10 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: def map_placeholder(self, expr: Placeholder) -> Array: return self.substitutions[expr.name] - def map_named_call_result(self, expr: NamedCallResult) -> NamedCallResult: - raise NotImplementedError( - "PlaceholderSubstitutor does not support functions.") + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + # Only operates within the current stack frame + return expr class Inliner(CopyMapper): From ffa4fcc14dd6cc8654eb354a280dab360de784e4 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 11 Jul 2024 23:02:13 -0500 Subject: [PATCH 26/35] tweak Inliner/PlaceholderSubstitutor implementations --- pytato/transform/calls.py | 55 ++++++++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 12 deletions(-) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index d9bc078fd..2cf8c6479 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -30,7 +30,9 @@ """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast + +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, @@ -38,9 +40,14 @@ DictOfNamedArrays, Placeholder, ) -from pytato.function import Call, NamedCallResult +from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.tags import InlineCallTag -from pytato.transform import ArrayOrNames, CopyMapper, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CopyMapper, + TransformMapperCache, + _verify_is_array, +) if TYPE_CHECKING: @@ -55,6 +62,12 @@ class PlaceholderSubstitutor(CopyMapper): A mapping from the placeholder name to the array that it is to be substituted with. + + .. note:: + + This mapper does not deduplicate subexpressions that occur in both the mapped + expression and the substitutions. Must follow up with a + :class:`pytato.transform.Deduplicator` if duplicates need to be removed. """ def __init__(self, substitutions: Mapping[str, Array]) -> None: @@ -63,6 +76,9 @@ def __init__(self, substitutions: Mapping[str, Array]) -> None: self.substitutions = substitutions def map_placeholder(self, expr: Placeholder) -> Array: + # Can't call rec() to remove duplicates here, because the substituted-in + # expression may potentially contain unrelated placeholders whose names + # collide with the ones being replaced return self.substitutions[expr.name] def map_function_definition( @@ -75,21 +91,36 @@ class Inliner(CopyMapper): """ Primary mapper for :func:`inline_calls`. """ - def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: - # inline call sites within the callee. - new_expr = super().map_call(expr) - assert isinstance(new_expr, Call) + def __init__( + self, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + # Must disable collision/duplication checking because we're combining + # expressions that were previously in two different call stack frames + # (and were thus cached separately) + super().__init__( + err_on_collision=False, + err_on_created_duplicate=False, + _cache=_cache, + _function_cache=_function_cache) + + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: if expr.tags_of_type(InlineCallTag): - substitutor = PlaceholderSubstitutor(new_expr.bindings) + substitutor = PlaceholderSubstitutor(expr.bindings) return DictOfNamedArrays( - {name: _verify_is_array(substitutor.rec(ret)) - for name, ret in new_expr.function.returns.items()}, - tags=new_expr.tags + {name: _verify_is_array(self.rec(substitutor(ret))) + for name, ret in expr.function.returns.items()}, + tags=expr.tags ) else: - return new_expr + return super().map_call(expr) def map_named_call_result(self, expr: NamedCallResult) -> Array: new_call_or_inlined_expr = self.rec(expr._container) From e58a283195a86ea9f671e9f5bb0ede1c6f8f1d5a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 16 Jul 2024 13:31:48 -0500 Subject: [PATCH 27/35] use context manager to avoid leaking traceback tag setting in test --- test/test_pytato.py | 139 +++++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 66 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index da176f124..e0e1f90d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -29,6 +29,7 @@ import dataclasses import sys +from contextlib import contextmanager import numpy as np import pytest @@ -932,111 +933,117 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim -def test_created_at(): - pt.set_traceback_tag_enabled() +@contextmanager +def enable_traceback_tag(): + try: + pt.set_traceback_tag_enabled(True) + yield + finally: + pt.set_traceback_tag_enabled(False) - a = pt.make_placeholder("a", (10, 10), "float64") - b = pt.make_placeholder("b", (10, 10), "float64") - # res1 and res2 are defined on different lines and should have different - # CreatedAt tags. - res1 = a+b - res2 = a+b +def test_created_at(): + with enable_traceback_tag(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") - # res3 and res4 are defined on the same line and should have the same - # CreatedAt tags. - res3 = a+b; res4 = a+b # noqa: E702 + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b + res2 = a+b - # {{{ Check that CreatedAt tags are handled correctly for equality/hashing + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 - assert res1 == res2 == res3 == res4 - assert hash(res1) == hash(res2) == hash(res3) == hash(res4) + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing - assert res1.non_equality_tags != res2.non_equality_tags - assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) - assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) - assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) - # }}} + assert res1.tags == res2.tags == res3.tags == res4.tags + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) - from pytato.tags import CreatedAt + # }}} - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + from pytato.tags import CreatedAt - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - # {{{ Make sure the function name appears in the traceback + assert len(created_tag) == 1 - tag, = created_tag + # {{{ Make sure the function name appears in the traceback - found = False + tag, = created_tag - stacksummary = tag.traceback.to_stacksummary() - assert len(stacksummary) > 10 + found = False - for frame in tag.traceback.frames: - if frame.name == "test_created_at" and "a+b" in frame.line: - found = True - break + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 - assert found + for frame in tag.traceback.frames: + if frame.name == "test_created_at" and "a+b" in frame.line: + found = True + break - # }}} + assert found - # {{{ Make sure that CreatedAt tags are in the visualization + # }}} - from pytato.visualization import get_dot_graph - s = get_dot_graph(res1) - assert "test_created_at" in s - assert "a+b" in s + # {{{ Make sure that CreatedAt tags are in the visualization - # }}} + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s - # {{{ Make sure only a single CreatedAt tag is created + # }}} - old_tag = tag + # {{{ Make sure only a single CreatedAt tag is created - res1 = res1 + res2 + old_tag = tag - created_tag = frozenset({tag - for tag in res1.non_equality_tags - if isinstance(tag, CreatedAt)}) + res1 = res1 + res2 - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) - tag, = created_tag + assert len(created_tag) == 1 - # Tag should be recreated - assert tag != old_tag + tag, = created_tag - # }}} + # Tag should be recreated + assert tag != old_tag - # {{{ Make sure that copying preserves the tag + # }}} - old_tag = tag + # {{{ Make sure that copying preserves the tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + old_tag = tag - created_tag = frozenset({tag - for tag in res1_new.non_equality_tags - if isinstance(tag, CreatedAt)}) + res1_new = pt.transform.map_and_copy(res1, lambda x: x) - assert len(created_tag) == 1 + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) - tag, = created_tag + assert len(created_tag) == 1 - assert old_tag == tag + tag, = created_tag - # }}} + assert old_tag == tag - # {{{ Test disabling traceback creation + # }}} - pt.set_traceback_tag_enabled(False) + # {{{ Test disabling traceback creation a = pt.make_placeholder("a", (10, 10), "float64") From 6aa44ebae06b6d82c4f8c4da5d8261743c35a59f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 16 Jul 2024 16:57:54 -0500 Subject: [PATCH 28/35] refactor FFTRealizationMapper to avoid resetting cache in __init__ --- test/test_apps.py | 71 +++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/test/test_apps.py b/test/test_apps.py index f39be848c..bdb3afc14 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -39,7 +39,7 @@ from pytools.tag import Tag, tag_dataclass import pytato as pt -from pytato.transform import CopyMapper, WalkMapper +from pytato.transform import CopyMapper, Deduplicator, WalkMapper # {{{ Trace an FFT @@ -78,40 +78,21 @@ def map_constant(self, expr): class FFTRealizationMapper(CopyMapper): - def __init__(self, fft_vec_gatherer): - super().__init__() - - self.fft_vec_gatherer = fft_vec_gatherer - - self.old_array_to_new_array = {} - levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) - - lev = 0 - arrays = fft_vec_gatherer.level_to_arrays[lev] - self.finalized = False - - for lev in levels: - arrays = fft_vec_gatherer.level_to_arrays[lev] - rec_arrays = [self.rec(ary) for ary in arrays] - # reset cache so that the partial subs are not stored - self._cache.clear() - lev_array = pt.concatenate(rec_arrays, axis=0) - assert lev_array.shape == (fft_vec_gatherer.n,) - - startidx = 0 - for array in arrays: - size = array.shape[0] - sub_array = lev_array[startidx:startidx+size] - startidx += size - self.old_array_to_new_array[array] = sub_array - - assert startidx == fft_vec_gatherer.n - self.finalized = True + def __init__(self, old_array_to_new_array): + # Must use err_on_created_duplicate=False, because the use of ConstantSizer + # in map_index_lambda creates IndexLambdas that differ only in the type of + # their contained constants, which changes their identity but not their + # equality + super().__init__(err_on_created_duplicate=False) + self.old_array_to_new_array = old_array_to_new_array def map_index_lambda(self, expr): tags = expr.tags_of_type(FFTIntermediate) - if tags and (self.finalized or expr in self.old_array_to_new_array): - return self.old_array_to_new_array[expr] + if tags: + try: + return self.old_array_to_new_array[expr] + except KeyError: + pass return super().map_index_lambda( expr.copy(expr=ConstantSizer()(expr.expr))) @@ -122,6 +103,29 @@ def map_concatenate(self, expr): (ImplStored(), PrefixNamed("concat"))) +def make_fft_realization_mapper(fft_vec_gatherer): + old_array_to_new_array = {} + levels = sorted(fft_vec_gatherer.level_to_arrays, reverse=True) + + for lev in levels: + lev_mapper = FFTRealizationMapper(old_array_to_new_array) + arrays = fft_vec_gatherer.level_to_arrays[lev] + rec_arrays = [lev_mapper(ary) for ary in arrays] + lev_array = pt.concatenate(rec_arrays, axis=0) + assert lev_array.shape == (fft_vec_gatherer.n,) + + startidx = 0 + for array in arrays: + size = array.shape[0] + sub_array = lev_array[startidx:startidx+size] + startidx += size + old_array_to_new_array[array] = sub_array + + assert startidx == fft_vec_gatherer.n + + return FFTRealizationMapper(old_array_to_new_array) + + def test_trace_fft(ctx_factory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -134,10 +138,11 @@ def test_trace_fft(ctx_factory): wrap_intermediate_with_level=( lambda level, ary: ary.tagged(FFTIntermediate(level)))) + result = Deduplicator()(result) fft_vec_gatherer = FFTVectorGatherer(n) fft_vec_gatherer(result) - mapper = FFTRealizationMapper(fft_vec_gatherer) + mapper = make_fft_realization_mapper(fft_vec_gatherer) result = mapper(result) From 9c01112e93fa0e98247d62097320f45cf703af84 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 16:10:14 -0500 Subject: [PATCH 29/35] add allow_duplicate_nodes option to RandomDAGContext in tests --- test/testlib.py | 47 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/test/testlib.py b/test/testlib.py index a28dec67e..7d58df480 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -101,6 +101,7 @@ def __init__( rng: np.random.Generator, axis_len: int, use_numpy: bool, + allow_duplicate_nodes: bool = False, additional_generators: ( Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None @@ -115,6 +116,7 @@ def __init__( self.axis_len = axis_len self.past_results: list[Array] = [] self.use_numpy = use_numpy + self.allow_duplicate_nodes = allow_duplicate_nodes if additional_generators is None: additional_generators = [] @@ -156,6 +158,14 @@ def make_random_reshape( def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng max_prob_hardcoded = 1500 @@ -166,7 +176,7 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: v = rng.integers(0, max_prob_hardcoded + additional_prob) if v < 600: - return make_random_constant(rdagc, naxes=rng.integers(1, 3)) + return dedup(make_random_constant(rdagc, naxes=rng.integers(1, 3))) elif v < 1000: op1 = make_random_dag(rdagc) @@ -189,9 +199,9 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: # just inserted a few new 1-long axes. Those need to go before we # return. if which_op in ["maximum", "minimum"]: - return rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2)) + return dedup(rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2))) else: - return rdagc.np.squeeze(which_op(op1, op2)) + return dedup(rdagc.np.squeeze(which_op(op1, op2))) elif v < 1075: op1 = make_random_dag(rdagc) @@ -199,24 +209,26 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any: if op1.ndim <= 1 and op2.ndim <= 1: continue - return op1 @ op2 + return dedup(op1 @ op2) elif v < 1275: if not rdagc.past_results: continue - return rdagc.past_results[rng.integers(0, len(rdagc.past_results))] + return dedup( + rdagc.past_results[rng.integers(0, len(rdagc.past_results))]) elif v < max_prob_hardcoded: result = make_random_dag(rdagc) - return rdagc.np.transpose( + return dedup( + rdagc.np.transpose( result, - tuple(rng.permuted(list(range(result.ndim))))) + tuple(rng.permuted(list(range(result.ndim)))))) else: base_prob = max_prob_hardcoded for fake_prob, gen_func in rdagc.additional_generators: if base_prob <= v < base_prob + fake_prob: - return gen_func(rdagc) + return dedup(gen_func(rdagc)) base_prob += fake_prob @@ -237,6 +249,14 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: of the array are of length :attr:`RandomDAGContext.axis_len` (there is at least one axis, but arbitrarily more may be present). """ + if not rdagc.use_numpy and not rdagc.allow_duplicate_nodes: + def dedup(expr: Array) -> Array: + return pt.transform._verify_is_array(pt.transform.Deduplicator()(expr)) + + else: + def dedup(expr: Array) -> Array: + return expr + rng = rdagc.rng result = make_random_dag_inner(rdagc) @@ -248,14 +268,15 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any: subscript[rng.integers(0, result.ndim)] = int( rng.integers(0, rdagc.axis_len)) - return result[tuple(subscript)] + return dedup(result[tuple(subscript)]) elif v == 1: # reduce away an axis # FIXME do reductions other than sum? - return rdagc.np.sum( - result, axis=int(rng.integers(0, result.ndim))) + return dedup( + rdagc.np.sum( + result, axis=int(rng.integers(0, result.ndim)))) else: raise AssertionError() @@ -275,7 +296,8 @@ def get_random_pt_dag(seed: int, Sequence[tuple[int, Callable[[RandomDAGContext], Array]]] | None) = None, axis_len: int = 4, - convert_dws_to_placeholders: bool = False + convert_dws_to_placeholders: bool = False, + allow_duplicate_nodes: bool = False ) -> pt.DictOfNamedArrays: if additional_generators is None: additional_generators = [] @@ -286,6 +308,7 @@ def get_random_pt_dag(seed: int, rdagc_comm = RandomDAGContext(np.random.default_rng(seed=seed), axis_len=axis_len, use_numpy=False, + allow_duplicate_nodes=allow_duplicate_nodes, additional_generators=additional_generators) dag = pt.make_dict_of_named_arrays({"result": make_random_dag(rdagc_comm)}) From ccb4188a41d12ba8818ee834b8a1698034cf841d Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 27 Aug 2024 16:10:44 -0500 Subject: [PATCH 30/35] fix some more tests --- pytato/utils.py | 2 ++ test/test_codegen.py | 17 ++++++++++------- test/test_distributed.py | 5 +++-- test/test_pytato.py | 10 ++++++---- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index 31247897d..77cecc3bd 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -340,8 +340,10 @@ def are_shape_components_equal( if isinstance(dim1, INT_CLASSES) and isinstance(dim2, INT_CLASSES): return dim1 == dim2 + from pytato.transform import Deduplicator dim1_minus_dim2 = dim1 - dim2 assert isinstance(dim1_minus_dim2, Array) + dim1_minus_dim2 = Deduplicator()(dim1_minus_dim2) from pytato.transform import InputGatherer inputs = InputGatherer()(dim1_minus_dim2) diff --git a/test/test_codegen.py b/test/test_codegen.py index 2193b7fc9..a4ca538de 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -926,7 +926,7 @@ def _get_x_shape(_m, n_): x = pt.make_data_wrapper(x_in, shape=_get_x_shape(m, n)) np_out = np.einsum("ij, j -> i", A_in, x_in) - pt_expr = pt.einsum("ij, j -> i", A, x) + pt_expr = pt.transform.Deduplicator()(pt.einsum("ij, j -> i", A, x)) _, (pt_out,) = pt.generate_loopy(pt_expr)(cq, m=m_in, n=n_in) @@ -1582,8 +1582,9 @@ def get_np_input_args(): np_inputs = get_np_input_args() np_result = kernel(np, **np_inputs) - pt_dag = kernel(pt, **{kw: pt.make_data_wrapper(arg) - for kw, arg in np_inputs.items()}) + pt_dag = pt.transform.Deduplicator()( + kernel(pt, **{kw: pt.make_data_wrapper(arg) + for kw, arg in np_inputs.items()})) knl = pt.generate_loopy(pt_dag, options=lp.Options(write_code=True)) @@ -1939,10 +1940,12 @@ def build_expression(tracer): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result_with_functions = pt.tag_all_calls_to_be_inlined( - pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result_without_functions = pt.make_dict_of_named_arrays( - build_expression(lambda fn, *args: fn(*args))) + expr = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + + result_with_functions = pt.tag_all_calls_to_be_inlined(expr) + result_without_functions = pt.transform.Deduplicator()( + pt.make_dict_of_named_arrays(build_expression(lambda fn, *args: fn(*args)))) # test that visualizing graphs with functions works dot = pt.get_dot_graph(result_with_functions) diff --git a/test/test_distributed.py b/test/test_distributed.py index d78479e08..65214c4b0 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -555,12 +555,13 @@ def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): x_np = rng.random((10, 4)) x = pt.make_data_wrapper(cla.to_device(queue, x_np)) y = 2 * x + ones = pt.ones(10) send1 = pt.staple_distributed_send( y, dest_rank=1, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) send2 = pt.staple_distributed_send( y, dest_rank=2, comm_tag=42, - stapled_to=pt.ones(10)) + stapled_to=ones) z = 4 * y dag = pt.make_dict_of_named_arrays({"z": z, "send1": send1, "send2": send2}) else: diff --git a/test/test_pytato.py b/test/test_pytato.py index e0e1f90d5..7fe5a3b49 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -724,7 +724,7 @@ def test_small_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -761,7 +761,7 @@ def test_large_dag_with_duplicates_count(): # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + pt.transform.DependencyMapper(err_on_collision=False)(dag)) assert node_count - num_duplicates == get_num_nodes( dag, count_duplicates=False) @@ -806,6 +806,8 @@ def post_visit(self, expr): assert expr.name == "x" expr, inp = construct_intestine_graph() + expr = pt.transform.Deduplicator()(expr) + result = pt.transform.rec_get_user_nodes(expr, inp) SubexprRecorder()(expr) @@ -1029,7 +1031,7 @@ def test_created_at(): old_tag = tag - res1_new = pt.transform.map_and_copy(res1, lambda x: x) + res1_new = pt.transform.Deduplicator()(res1) created_tag = frozenset({tag for tag in res1_new.non_equality_tags @@ -1167,7 +1169,7 @@ class ExistentTag(Tag): out = make_random_dag(rdagc_pt).tagged(ExistentTag()) - dag = pt.make_dict_of_named_arrays({"out": out}) + dag = pt.transform.Deduplicator()(pt.make_dict_of_named_arrays({"out": out})) # get_num_nodes() returns an extra DictOfNamedArrays node assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) From 7258ccf17350c3988c27522977bd5b7a88cf2216 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Sep 2024 22:51:56 -0500 Subject: [PATCH 31/35] don't check for collisions in ArrayToDotNodeInfoMapper --- pytato/visualization/dot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index c0c3e7945..7420d1708 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -178,9 +178,10 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" +# FIXME: Make this inherit from CachedWalkMapper instead? class ArrayToDotNodeInfoMapper(CachedMapper[None, None, []]): def __init__(self) -> None: - super().__init__() + super().__init__(err_on_collision=False) self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {} self.functions: set[FunctionDefinition] = set() From 177385cc6c92e7ee9246e7fc712025eb66714b64 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 18 Sep 2024 15:40:25 -0500 Subject: [PATCH 32/35] avoid duplication in MPMSMaterializer now inherits from CachedMapper --- pytato/transform/__init__.py | 319 +++++++++++++++++++++++++---------- 1 file changed, 231 insertions(+), 88 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 16496063b..209c065b5 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -527,6 +527,31 @@ def clone_for_callee( # {{{ TransformMapper +def _is_mapper_created_duplicate(expr: CacheExprT, result: CacheExprT) -> bool: + """Returns *True* if *result* is not identical to *expr* when it ought to be.""" + from pytato.analysis import DirectPredecessorsGetter + pred_getter = DirectPredecessorsGetter(include_functions=True) + return ( + hash(result) == hash(expr) + and result == expr + and result is not expr + # Only consider "direct" duplication, not duplication resulting from + # equality-preserving changes to predecessors. Assume that such changes are + # OK, otherwise they would have been detected at the point at which they + # originated. (For example, consider a DAG containing pre-existing + # duplicates. If a subexpression of *expr* is a duplicate and is replaced + # with a previously encountered version from the cache, a new instance of + # *expr* must be created. This should not trigger an error.) + and all( + result_pred is pred + for pred, result_pred in zip( + # type-ignore-reason: mypy doesn't seem to recognize overloaded + # Mapper.__call__ here + pred_getter(expr), # type: ignore[arg-type] + pred_getter(result), # type: ignore[arg-type] + strict=True))) + + class TransformMapperCache(CachedMapperCache[CacheExprT, CacheExprT, P]): """ Cache for :class:`TransformMapper` and :class:`TransformMapperWithExtraArgs`. @@ -570,31 +595,10 @@ def add( try: result = self._result_to_cached_result[result] except KeyError: - if self.err_on_created_duplicate: - from pytato.analysis import DirectPredecessorsGetter - pred_getter = DirectPredecessorsGetter(include_functions=True) - if ( - hash(result) == hash(inputs.expr) - and result == inputs.expr - and result is not inputs.expr - # Only consider "direct" duplication, not duplication - # resulting from equality-preserving changes to predecessors. - # Assume that such changes are OK, otherwise they would have - # been detected at the point at which they originated. (For - # example, consider a DAG containing pre-existing duplicates. - # If a subexpression of *expr* is a duplicate and is replaced - # with a previously encountered version from the cache, a - # new instance of *expr* must be created. This should not - # trigger an error.) - and all( - result_pred is pred - for pred, result_pred in zip( - # type-ignore-reason: mypy doesn't seem to recognize - # overloaded Mapper.__call__ here - pred_getter(inputs.expr), # type: ignore[arg-type] - pred_getter(result), # type: ignore[arg-type] - strict=True))): - raise MapperCreatedDuplicateError from None + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result)): + raise MapperCreatedDuplicateError from None self._result_to_cached_result[result] = result @@ -2025,6 +2029,65 @@ class MPMSMaterializerAccumulator: expr: Array +class MPMSMaterializerCache( + CachedMapperCache[ArrayOrNames, MPMSMaterializerAccumulator, []]): + """ + Cache for :class:`MPMSMaterializer`. + + .. automethod:: __init__ + .. automethod:: add + """ + def __init__( + self, + err_on_collision: bool, + err_on_created_duplicate: bool) -> None: + """ + Initialize the cache. + + :arg err_on_collision: Raise an exception if two distinct input expression + instances have the same key. + :arg err_on_created_duplicate: Raise an exception if mapping produces a new + array instance that has the same key as the input array. + """ + super().__init__(err_on_collision=err_on_collision) + + self.err_on_created_duplicate = err_on_created_duplicate + + self._result_key_to_result: dict[ + ArrayOrNames, MPMSMaterializerAccumulator] = {} + + def add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + """ + Cache a mapping result. + + Returns the cached result (which may not be identical to *result* if a + result was already cached with the same result key). + """ + key = inputs.key + + assert key not in self._input_key_to_result, \ + f"Cache entry is already present for key '{key}'." + + try: + result = self._result_key_to_result[result.expr] + except KeyError: + if ( + self.err_on_created_duplicate + and _is_mapper_created_duplicate(inputs.expr, result.expr)): + raise MapperCreatedDuplicateError from None + + self._result_key_to_result[result.expr] = result + + self._input_key_to_result[key] = result + if self.err_on_collision: + self._input_key_to_expr[key] = inputs.expr + + return result + + def _materialize_if_mpms(expr: Array, nsuccessors: int, predecessors: Iterable[MPMSMaterializerAccumulator] @@ -2042,13 +2105,16 @@ def _materialize_if_mpms(expr: Array, for pred in predecessors), frozenset()) if nsuccessors > 1 and len(materialized_predecessors) > 1: - new_expr = expr.tagged(ImplStored()) + if not expr.tags_of_type(ImplStored): + new_expr = expr.tagged(ImplStored()) + else: + new_expr = expr return MPMSMaterializerAccumulator(frozenset([new_expr]), new_expr) else: return MPMSMaterializerAccumulator(materialized_predecessors, expr) -class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): +class MPMSMaterializer(CachedMapper[MPMSMaterializerAccumulator, Never, []]): """ See :func:`materialize_with_mpms` for an explanation. @@ -2057,17 +2123,41 @@ class MPMSMaterializer(Mapper[MPMSMaterializerAccumulator, Never, []]): A mapping from a node in the expression graph (i.e. an :class:`~pytato.Array`) to its number of successors. """ - def __init__(self, nsuccessors: Mapping[Array, int]): - super().__init__() + def __init__( + self, + nsuccessors: Mapping[Array, int], + _cache: MPMSMaterializerCache | None = None): + err_on_collision = False + err_on_created_duplicate = False + + if _cache is None: + _cache = MPMSMaterializerCache( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate) + + # Does not support functions, so function_cache is ignored + super().__init__(err_on_collision=err_on_collision, _cache=_cache) + self.nsuccessors = nsuccessors - self.cache: dict[ArrayOrNames, MPMSMaterializerAccumulator] = {} - def rec(self, expr: ArrayOrNames) -> MPMSMaterializerAccumulator: - if expr in self.cache: - return self.cache[expr] - result: MPMSMaterializerAccumulator = super().rec(expr) - self.cache[expr] = result - return result + def _cache_add( + self, + inputs: CacheInputsWithKey[ArrayOrNames, []], + result: MPMSMaterializerAccumulator) -> MPMSMaterializerAccumulator: + try: + return self._cache.add(inputs, result) + except MapperCreatedDuplicateError as e: + raise ValueError( + f"no-op duplication detected on {type(inputs.expr)} in " + f"{type(self)}.") from e + + def clone_for_callee( + self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + raise AssertionError("Control shouldn't reach this point.") def _map_input_base(self, expr: InputArgumentBase ) -> MPMSMaterializerAccumulator: @@ -2084,24 +2174,40 @@ def map_named_array(self, expr: NamedArray) -> MPMSMaterializerAccumulator: def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: children_rec = {bnd_name: self.rec(bnd) for bnd_name, bnd in sorted(expr.bindings.items())} + new_children: Mapping[str, Array] = immutabledict({ + bnd_name: bnd.expr + for bnd_name, bnd in sorted(children_rec.items())}) + + if ( + frozenset(new_children.keys()) == frozenset(expr.bindings.keys()) + and all( + new_children[name] is expr.bindings[name] + for name in expr.bindings)): + new_expr = expr + else: + new_expr = IndexLambda( + expr=expr.expr, + shape=expr.shape, + dtype=expr.dtype, + bindings=new_children, + axes=expr.axes, + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - new_expr = IndexLambda(expr=expr.expr, - shape=expr.shape, - dtype=expr.dtype, - bindings=immutabledict({bnd_name: bnd.expr - for bnd_name, bnd in sorted(children_rec.items())}), - axes=expr.axes, - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Stack(new_arrays, expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -2109,29 +2215,44 @@ def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] - new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), - expr.axis, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_arrays = tuple(ary.expr for ary in rec_arrays) + if all( + new_ary is ary + for ary, new_ary in zip(expr.arrays, new_arrays, strict=True)): + new_expr = expr + else: + new_expr = Concatenate(new_arrays, + expr.axis, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -2141,16 +2262,23 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: rec_indices = {i: self.rec(idx) for i, idx in enumerate(expr.indices) if isinstance(idx, Array)} - - new_expr = type(expr)(rec_array.expr, - tuple(rec_indices[i].expr - if i in rec_indices - else expr.indices[i] - for i in range( - len(expr.indices))), - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + new_indices = tuple(rec_indices[i].expr + if i in rec_indices + else expr.indices[i] + for i in range( + len(expr.indices))) + if ( + rec_array.expr is expr.array + and all( + new_idx is idx + for idx, new_idx in zip(expr.indices, new_indices, strict=True))): + new_expr = expr + else: + new_expr = type(expr)(rec_array.expr, + new_indices, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -2163,26 +2291,35 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + if rec_array.expr is expr.array: + new_expr = expr + else: + new_expr = Reshape(rec_array.expr, expr.newshape, + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: - rec_arrays = [self.rec(ary) for ary in expr.args] - new_expr = Einsum(expr.access_descriptors, - tuple(ary.expr for ary in rec_arrays), - expr.redn_axis_to_redn_descr, - axes=expr.axes, - tags=expr.tags, - non_equality_tags=expr.non_equality_tags) + rec_args = [self.rec(ary) for ary in expr.args] + new_args = tuple(ary.expr for ary in rec_args) + if all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)): + new_expr = expr + else: + new_expr = Einsum(expr.access_descriptors, + new_args, + expr.redn_axis_to_redn_descr, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], - rec_arrays) + rec_args) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays ) -> MPMSMaterializerAccumulator: @@ -2195,15 +2332,21 @@ def map_loopy_call_result(self, expr: NamedArray) -> MPMSMaterializerAccumulator def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder ) -> MPMSMaterializerAccumulator: - rec_passthrough = self.rec(expr.passthrough_data) rec_send_data = self.rec(expr.send.data) - new_expr = DistributedSendRefHolder( - send=DistributedSend(rec_send_data.expr, - dest_rank=expr.send.dest_rank, - comm_tag=expr.send.comm_tag, - tags=expr.send.tags), - passthrough_data=rec_passthrough.expr, - ) + if rec_send_data.expr is expr.send.data: + new_send = expr.send + else: + new_send = DistributedSend( + rec_send_data.expr, + dest_rank=expr.send.dest_rank, + comm_tag=expr.send.comm_tag, + tags=expr.send.tags) + rec_passthrough = self.rec(expr.passthrough_data) + if new_send is expr.send and rec_passthrough.expr is expr.passthrough_data: + new_expr = expr + else: + new_expr = DistributedSendRefHolder(new_send, rec_passthrough.expr) + return MPMSMaterializerAccumulator( rec_passthrough.materialized_predecessors, new_expr) From 14cc4d695e55802440bdadf9df796987796bf25b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Feb 2025 21:43:35 -0600 Subject: [PATCH 33/35] avoid duplicates in EinsumWithNoBroadcastsRewriter --- pytato/transform/remove_broadcasts_einsum.py | 110 +++++++++++++++---- 1 file changed, 91 insertions(+), 19 deletions(-) diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index 2d8f7e0f0..50ee4967c 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -28,46 +28,118 @@ THE SOFTWARE. """ -from typing import cast +from typing import TYPE_CHECKING, cast from pytato.array import Array, Einsum, EinsumAxisDescriptor -from pytato.transform import CopyMapper, MappedT, _verify_is_array +from pytato.transform import ( + ArrayOrNames, + CacheKeyT, + CopyMapperWithExtraArgs, + MappedT, + Mapper, + _verify_is_array, +) from pytato.utils import are_shape_components_equal -class EinsumWithNoBroadcastsRewriter(CopyMapper): - def map_einsum(self, expr: Einsum) -> Array: +if TYPE_CHECKING: + from pytato.function import FunctionDefinition + + +class EinsumWithNoBroadcastsRewriter(CopyMapperWithExtraArgs[[tuple[int, ...] | None]]): + def get_cache_key( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + return (expr, axes_to_squeeze) + + def get_function_definition_cache_key( + self, + expr: FunctionDefinition, + axes_to_squeeze: tuple[int, ...] | None = None + ) -> CacheKeyT: + assert axes_to_squeeze is None + return expr + + def _squeeze_axes( + self, + expr: Array, + axes_to_squeeze: tuple[int, ...] | None = None) -> Array: + result = ( + expr[ + tuple( + slice(None) if idim not in axes_to_squeeze else 0 + for idim in range(expr.ndim))] + if axes_to_squeeze else expr) + return result + + def rec( + self, + expr: ArrayOrNames, + axes_to_squeeze: tuple[int, ...] | None = None) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr, axes_to_squeeze) + try: + return self._cache_retrieve(inputs) + except KeyError: + rec_result: ArrayOrNames = Mapper.rec(self, expr, None) + result: ArrayOrNames + if isinstance(expr, Array): + result = self._squeeze_axes( + _verify_is_array(rec_result), + axes_to_squeeze) + else: + result = rec_result + return self._cache_add(inputs, result) + + def map_einsum( + self, expr: Einsum, axes_to_squeeze: tuple[int, ...] | None) -> Array: new_args: list[Array] = [] new_access_descriptors: list[tuple[EinsumAxisDescriptor, ...]] = [] descr_to_axis_len = expr._access_descr_to_axis_len() - for acc_descrs, arg in zip(expr.access_descriptors, expr.args, strict=True): - arg = _verify_is_array(self.rec(arg)) - axes_to_squeeze: list[int] = [] + for arg, acc_descrs in zip(expr.args, expr.access_descriptors, strict=True): + axes_to_squeeze_list: list[int] = [] for idim, acc_descr in enumerate(acc_descrs): if not are_shape_components_equal(arg.shape[idim], descr_to_axis_len[acc_descr]): assert are_shape_components_equal(arg.shape[idim], 1) - axes_to_squeeze.append(idim) + axes_to_squeeze_list.append(idim) + axes_to_squeeze = tuple(axes_to_squeeze_list) if axes_to_squeeze: - arg = arg[tuple(slice(None) if idim not in axes_to_squeeze else 0 - for idim in range(arg.ndim))] - acc_descrs = tuple(acc_descr + new_arg = _verify_is_array(self.rec(arg, axes_to_squeeze)) + new_acc_descrs = tuple(acc_descr for idim, acc_descr in enumerate(acc_descrs) if idim not in axes_to_squeeze) + else: + new_arg = _verify_is_array(self.rec(arg)) + new_acc_descrs = acc_descrs - new_args.append(arg) - new_access_descriptors.append(acc_descrs) + new_args.append(new_arg) + new_access_descriptors.append(new_acc_descrs) assert len(new_args) == len(expr.args) assert len(new_access_descriptors) == len(expr.access_descriptors) - return Einsum(tuple(new_access_descriptors), - tuple(new_args), - expr.redn_axis_to_redn_descr, - tags=expr.tags, - axes=expr.axes,) + if ( + all( + new_arg is arg + for arg, new_arg in zip(expr.args, new_args, strict=True)) + and all( + new_acc_descr is acc_descr + for acc_descr, new_acc_descr in zip( + expr.access_descriptors, + new_access_descriptors, + strict=True))): + return expr + else: + return Einsum(tuple(new_access_descriptors), + tuple(new_args), + axes=expr.axes, + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: @@ -97,6 +169,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - return cast("MappedT", mapper(expr)) + return cast("MappedT", mapper(expr, None)) # vim:fdm=marker From 0cd78a1fcb2c6c6e14c8099906eb14947ea48c43 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 6 Feb 2025 22:01:19 -0600 Subject: [PATCH 34/35] forbid DependencyMapper from being called on functions --- pytato/transform/__init__.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 209c065b5..9f6e57db5 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1507,7 +1507,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> ResultT: # {{{ DependencyMapper -class DependencyMapper(CombineMapper[R, R]): +class DependencyMapper(CombineMapper[R, Never]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. @@ -1569,14 +1569,10 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) - def map_function_definition(self, expr: FunctionDefinition) -> R: + def map_call(self, expr: Call) -> R: # do not include arrays from the function's body as it would involve # putting arrays from different namespaces into the same collection. - return frozenset() - - def map_call(self, expr: Call) -> R: - return self.combine(self.rec_function_definition(expr.function), - *[self.rec(bnd) for bnd in expr.bindings.values()]) + return self.combine(*[self.rec(bnd) for bnd in expr.bindings.values()]) def map_named_call_result(self, expr: NamedCallResult) -> R: return self.rec(expr._container) From 2ca7a65971e3806e4c7952c77086156ebdecc335 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Thu, 27 Feb 2025 17:59:51 -0600 Subject: [PATCH 35/35] deduplicate in advection example --- examples/advection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/advection.py b/examples/advection.py index 339ff80a8..fd308ae50 100755 --- a/examples/advection.py +++ b/examples/advection.py @@ -156,6 +156,7 @@ def test_advection_convergence(order, flux_type): op = AdvectionOperator(discr, c=1, flux_type=flux_type, dg_ops=dg_ops) result = op.apply(u) + result = pt.transform.Deduplicator()(result) prog = pt.generate_loopy(result, cl_device=queue.device)