From 06b42959bd33d7b3eca5def055f0bbf71aa55ddc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 30 Jan 2025 16:04:35 -0600 Subject: [PATCH] add TagCountMapper (#326) --- pytato/analysis/__init__.py | 57 ++++++++++++++++++++++++++++++++++-- pytato/transform/__init__.py | 12 ++++++-- test/test_pytato.py | 37 +++++++++++++++++++++++ 3 files changed, 101 insertions(+), 5 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 2d91e51eb..6db21e863 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,12 +26,12 @@ THE SOFTWARE. """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Never from orderedsets import FrozenOrderedSet from typing_extensions import Self +import pytools from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper @@ -49,11 +49,11 @@ Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper +from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, Mapper, P if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Iterable, Mapping from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall @@ -74,6 +74,9 @@ .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter + +.. autoclass:: TagCountMapper +.. autofunction:: get_num_tags_of_type """ @@ -594,6 +597,54 @@ def get_num_call_sites(outputs: Array | DictOfNamedArrays) -> int: # }}} +# {{{ TagCountMapper + +class TagCountMapper(CombineMapper[int, Never]): + """ + Returns the number of nodes in a DAG that are tagged with all the tags in *tags*. + """ + + def __init__(self, tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> None: + super().__init__() + if isinstance(tags, pytools.tag.Tag): + tags = frozenset((tags,)) + elif not isinstance(tags, frozenset): + tags = frozenset(tags) + self._tags = tags + + def combine(self, *args: int) -> int: + return sum(args) + + def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> int: + key = self._cache.get_key(expr, *args, **kwargs) + try: + return self._cache.retrieve((expr, args, kwargs), key=key) + except KeyError: + s = super().rec(expr, *args, **kwargs) + if isinstance(expr, Array) and self._tags <= expr.tags: + result = 1 + s + else: + result = 0 + s + + self._cache.add((expr, args, kwargs), + 0, + key=key) + return result + + +def get_num_tags_of_type( + outputs: Array | DictOfNamedArrays, + tags: pytools.tag.Tag | Iterable[pytools.tag.Tag]) -> int: + """Returns the number of nodes in DAG *outputs* that are tagged with + all the tags in *tags*.""" + + tcm = TagCountMapper(tags) + + return tcm(outputs) + +# }}} + + # {{{ PytatoKeyBuilder class PytatoKeyBuilder(LoopyKeyBuilder): diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index dbab9148c..942cf4f48 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1809,13 +1809,21 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: ====== ======== ======= """ - from pytato.analysis import get_nusers + from pytato.analysis import get_num_nodes, get_num_tags_of_type, get_nusers materializer = MPMSMaterializer(get_nusers(expr)) new_data = {} for name, ary in expr.items(): new_data[name] = materializer(ary.expr).expr - return DictOfNamedArrays(new_data, tags=expr.tags) + res = DictOfNamedArrays(new_data, tags=expr.tags) + + from pytato import DEBUG_ENABLED + if DEBUG_ENABLED: + transform_logger.info("materialize_with_mpms: materialized " + f"{get_num_tags_of_type(res, ImplStored())} out of " + f"{get_num_nodes(res)} nodes") + + return res # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 40f1e7b40..3d16e28d9 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1139,6 +1139,43 @@ def test_adv_indexing_into_zero_long_axes(): # }}} +def test_tagcountmapper(): + from testlib import RandomDAGContext, make_random_dag + + from pytools.tag import Tag + + from pytato.analysis import get_num_nodes, get_num_tags_of_type + + class NonExistentTag(Tag): + pass + + class ExistentTag(Tag): + pass + + seed = 199 + axis_len = 3 + + rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed), + axis_len=axis_len, use_numpy=False) + + out = make_random_dag(rdagc_pt).tagged(ExistentTag()) + + dag = pt.make_dict_of_named_arrays({"out": out}) + + # get_num_nodes() returns an extra DictOfNamedArrays node + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) + + assert get_num_tags_of_type(dag, NonExistentTag()) == 0 + assert get_num_tags_of_type(dag, frozenset((ExistentTag(),))) == 1 + assert get_num_tags_of_type(dag, + frozenset((ExistentTag(), NonExistentTag()))) == 0 + + a = pt.make_data_wrapper(np.arange(27)) + dag = a+a+a+a+a+a+a+a + + assert get_num_tags_of_type(dag, frozenset()) == get_num_nodes(dag) + + def test_expand_dims_input_validate(): a = pt.make_placeholder("x", (10, 4))