Skip to content

Commit

Permalink
eliminate _OrderedSets
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Jul 25, 2024
1 parent 817b255 commit 26492ce
Showing 1 changed file with 6 additions and 63 deletions.
69 changes: 6 additions & 63 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@
THE SOFTWARE.
"""

import collections
from functools import reduce
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
FrozenSet,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
TypeVar,
Expand Down Expand Up @@ -131,61 +128,6 @@ class CommunicationOpIdentifier:
_ValueT = TypeVar("_ValueT")


# {{{ crude ordered set


class _OrderedSet(collections.abc.MutableSet[_ValueT]):
def __init__(self, items: Iterable[_ValueT] | None = None):
# Could probably also use a valueless dictionary; not sure if it matters
self._items: set[_ValueT] = set()
self._items_ordered: list[_ValueT] = []
if items is not None:
for item in items:
self.add(item)

def add(self, item: _ValueT) -> None:
if item not in self._items:
self._items.add(item)
self._items_ordered.append(item)

def discard(self, item: _ValueT) -> None:
# Not currently needed
raise NotImplementedError

def __len__(self) -> int:
return len(self._items)

def __iter__(self) -> Iterator[_ValueT]:
return iter(self._items_ordered)

def __contains__(self, item: Any) -> bool:
return item in self._items

def __and__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]:
result: _OrderedSet[_ValueT] = _OrderedSet()
for item in self._items_ordered:
if item in other:
result.add(item)
return result

# Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution
# according to mypy. *shrug*
def __or__(self, other: AbstractSet[Any]) -> _OrderedSet[_ValueT]:
result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered)
for item in other:
result.add(item)
return result

def __sub__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]:
result: _OrderedSet[_ValueT] = _OrderedSet()
for item in self._items_ordered:
if item not in other:
result.add(item)
return result

# }}}


# {{{ distributed graph part

PartId = Hashable
Expand Down Expand Up @@ -986,14 +928,15 @@ def find_distributed_partition(

direct_preds_getter = DirectPredecessorsGetter()

def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]:
materialized_preds: _OrderedSet[Array] = _OrderedSet()
def get_materialized_predecessors(ary: Array) -> tuple[Array]:
materialized_preds: dict[Array, None] = {}
for pred in direct_preds_getter(ary):
if pred in materialized_arrays:
materialized_preds.add(pred)
materialized_preds[pred] = None
else:
materialized_preds |= get_materialized_predecessors(pred)
return materialized_preds
for p in get_materialized_predecessors(pred):
materialized_preds[p] = None
return tuple(materialized_preds.keys())

stored_arrays_promoted_to_part_outputs = tuple(unique(
stored_pred
Expand Down

0 comments on commit 26492ce

Please sign in to comment.