diff --git a/doc/conf.py b/doc/conf.py index 6966a6062..19cfe8581 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -30,6 +30,7 @@ "jax": ("https://jax.readthedocs.io/en/latest/", None), "mpi4py": ("https://mpi4py.readthedocs.io/en/latest", None), "immutabledict": ("https://immutabledict.corenting.fr/", None), + "orderedsets": ("https://matthiasdiener.github.io/orderedsets", None), } # Some modules need to import things just so that sphinx can resolve symbols in diff --git a/pyproject.toml b/pyproject.toml index b91a3568e..38019888c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pytools>=2024.1.21", "pymbolic>=2024.2", "typing_extensions>=4", + "orderedsets", ] [project.urls] diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fc2435db9..ce21160dd 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,8 +26,10 @@ THE SOFTWARE. """ +from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Never +from orderedsets import FrozenOrderedSet from typing_extensions import Self from loopy.tools import LoopyKeyBuilder @@ -329,37 +331,37 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[ArrayOrNames]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[ArrayOrNames]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]: + return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[ArrayOrNames]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> FrozenOrderedSet[ArrayOrNames]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[ArrayOrNames]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> FrozenOrderedSet[ArrayOrNames]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[ArrayOrNames]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> FrozenOrderedSet[ArrayOrNames]: + return (FrozenOrderedSet(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[ArrayOrNames]: + def map_loopy_call_result(self, expr: NamedArray) -> FrozenOrderedSet[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (FrozenOrderedSet(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> FrozenOrderedSet[ArrayOrNames]: + return (FrozenOrderedSet([expr.array]) + | FrozenOrderedSet(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -368,34 +370,36 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[ArrayOrNames]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[ArrayOrNames]: - return frozenset([expr.array]) + ) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) \ + -> FrozenOrderedSet[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[ArrayOrNames]: + def map_distributed_recv(self, + expr: DistributedRecv) -> FrozenOrderedSet[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[ArrayOrNames]: - return frozenset([expr.passthrough_data]) + ) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet([expr.passthrough_data]) - def map_call(self, expr: Call) -> frozenset[ArrayOrNames]: - return frozenset(expr.bindings.values()) + def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[ArrayOrNames]: - return frozenset([expr._container]) + self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]: + return FrozenOrderedSet([expr._container]) # }}} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 216403191..7662c7205 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,20 +62,19 @@ THE SOFTWARE. """ -import collections import dataclasses -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set +from collections.abc import Hashable, Mapping, Sequence, Set from functools import reduce from typing import ( TYPE_CHECKING, Any, - Generic, Never, TypeVar, cast, ) from immutabledict import immutabledict +from orderedsets import FrozenOrderedSet, OrderedSet from pymbolic.mapper.optimize import optimize_mapper from pytools import UniqueNameGenerator, memoize_method @@ -134,61 +133,6 @@ class CommunicationOpIdentifier: _ValueT = TypeVar("_ValueT") -# {{{ crude ordered set - - -class _OrderedSet(Generic[_ValueT], collections.abc.MutableSet[_ValueT]): - def __init__(self, items: Iterable[_ValueT] | None = None): - # Could probably also use a valueless dictionary; not sure if it matters - self._items: set[_ValueT] = set() - self._items_ordered: list[_ValueT] = [] - if items is not None: - for item in items: - self.add(item) - - def add(self, item: _ValueT) -> None: - if item not in self._items: - self._items.add(item) - self._items_ordered.append(item) - - def discard(self, item: _ValueT) -> None: - # Not currently needed - raise NotImplementedError - - def __len__(self) -> int: - return len(self._items) - - def __iter__(self) -> Iterator[_ValueT]: - return iter(self._items_ordered) - - def __contains__(self, item: Any) -> bool: - return item in self._items - - def __and__(self, other: Set[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item in other: - result.add(item) - return result - - # Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution - # according to mypy. *shrug* - def __or__(self, other: Set[Any]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered) - for item in other: - result.add(item) - return result - - def __sub__(self, other: Set[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item not in other: - result.add(item) - return result - -# }}} - - # {{{ distributed graph part PartId = Hashable @@ -377,8 +321,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: frozenset[CommunicationOpIdentifier] - send_ids: frozenset[CommunicationOpIdentifier] + recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] + send_ids: FrozenOrderedSet[CommunicationOpIdentifier] # {{{ _make_distributed_partition @@ -464,12 +408,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[frozenset[CommunicationOpIdentifier], Never]): + CombineMapper[FrozenOrderedSet[CommunicationOpIdentifier], Never]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - frozenset[CommunicationOpIdentifier]] = {} + FrozenOrderedSet[CommunicationOpIdentifier]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -479,13 +423,14 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: FrozenOrderedSet[CommunicationOpIdentifier] + ) -> FrozenOrderedSet[CommunicationOpIdentifier]: + return reduce(FrozenOrderedSet.union, args, FrozenOrderedSet()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[CommunicationOpIdentifier]: + ) \ + -> FrozenOrderedSet[CommunicationOpIdentifier]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -499,8 +444,9 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) \ + -> FrozenOrderedSet[CommunicationOpIdentifier]: + return FrozenOrderedSet() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -508,21 +454,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> FrozenOrderedSet[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = FrozenOrderedSet() self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return FrozenOrderedSet({recv_id}) - def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + def map_named_call_result(self, expr: NamedCallResult) \ + -> FrozenOrderedSet[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -536,7 +482,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ - -> Sequence[Set[TaskType]]: + -> Sequence[OrderedSet[TaskType]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -551,7 +497,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, Set[TaskType]]) \ - -> tuple[Sequence[Set[TaskType]], int]: + -> tuple[Sequence[OrderedSet[TaskType]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -560,7 +506,8 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[OrderedSet[TaskType]] = \ + [OrderedSet() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): task_batches[dep_level].add(task_id) @@ -626,7 +573,7 @@ class _MaterializedArrayCollector(CachedWalkMapper[[]]): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: OrderedSet[Array] = OrderedSet() def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -654,13 +601,14 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, FrozenOrderedSet[_ValueT]], + dict_b: Mapping[_KeyT, FrozenOrderedSet[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype | None) \ + -> Mapping[_KeyT, FrozenOrderedSet[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, FrozenOrderedSet()) | values return result # }}} @@ -828,9 +776,7 @@ def find_distributed_partition( if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - comm_batches = cast( - "Sequence[Set[CommunicationOpIdentifier]]", - comm_batches_or_exc) + comm_batches = comm_batches_or_exc # }}} @@ -838,9 +784,9 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: FrozenOrderedSet[CommunicationOpIdentifier] = FrozenOrderedSet() for batch in comm_batches: - send_ids = frozenset( + send_ids = FrozenOrderedSet( comm_id for comm_id in batch if comm_id.src_rank == local_rank) if recv_ids or send_ids: @@ -849,19 +795,19 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( + recv_ids = FrozenOrderedSet( comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=FrozenOrderedSet())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=FrozenOrderedSet(), + send_ids=FrozenOrderedSet())) nparts = len(part_comm_ids) @@ -891,10 +837,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = FrozenOrderedSet( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = FrozenOrderedSet(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -908,7 +854,7 @@ def find_distributed_partition( - sent_arrays) # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) + output_arrays = FrozenOrderedSet(outputs._data.values()) mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to @@ -973,15 +919,15 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = FrozenOrderedSet(stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: - materialized_preds: _OrderedSet[Array] = _OrderedSet() + def get_materialized_predecessors(ary: Array) -> OrderedSet[Array]: + materialized_preds: OrderedSet[Array] = OrderedSet() for pred in direct_preds_getter(ary): assert isinstance(pred, Array) if pred in materialized_arrays: @@ -990,13 +936,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: materialized_preds |= get_materialized_predecessors(pred) return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs = FrozenOrderedSet( stored_pred for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + ) # }}} diff --git a/test/test_distributed.py b/test/test_distributed.py index ac7ca1389..1554a024b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -899,13 +899,11 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - (_distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) assert next_tag == 4244 - # FIXME: For the next assertion, find_distributed_partition needs to be - # deterministic too (https://github.com/inducer/pytato/pull/465). - # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 + assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # }}}