Skip to content

Commit

Permalink
replace orderedsets with unique tuples in DirectPredecessorsGetter
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Aug 13, 2024
1 parent bd70620 commit 3c87726
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 38 deletions.
70 changes: 33 additions & 37 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any, Mapping

from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method
from pytools import memoize_method, unique

from pytato.array import (
Array,
Expand Down Expand Up @@ -314,11 +314,6 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

from collections.abc import Set as abc_Set

from orderedsets import FrozenOrderedSet


class DirectPredecessorsGetter(Mapper):
"""
Mapper to get the
Expand All @@ -327,74 +322,75 @@ class DirectPredecessorsGetter(Mapper):
of a node.
.. note::
We only consider the predecessors of a nodes in a data-flow sense.
"""
def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[ArrayOrNames]:
return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)])
def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]:
return tuple(unique(dim for dim in shape if isinstance(dim, Array)))

def map_index_lambda(self, expr: IndexLambda) -> abc_Set[ArrayOrNames]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))
def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]:
return tuple(unique(tuple(expr.bindings.values())
+ self._get_preds_from_shape(expr.shape)))

def map_stack(self, expr: Stack) -> abc_Set[ArrayOrNames]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))
def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]:
return tuple(unique(tuple(expr.arrays)
+ self._get_preds_from_shape(expr.shape)))

def map_concatenate(self, expr: Concatenate) -> abc_Set[ArrayOrNames]:
return (FrozenOrderedSet(expr.arrays)
| self._get_preds_from_shape(expr.shape))
map_concatenate = map_stack

def map_einsum(self, expr: Einsum) -> abc_Set[ArrayOrNames]:
return (FrozenOrderedSet(expr.args)
| self._get_preds_from_shape(expr.shape))
def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]:
return tuple(unique((tuple(expr.args)
+ self._get_preds_from_shape(expr.shape))))

def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]:
def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]:
from pytato.loopy import LoopyCall, LoopyCallResult
assert isinstance(expr, LoopyCallResult)
assert isinstance(expr._container, LoopyCall)
return (FrozenOrderedSet(ary
return tuple(unique(tuple(ary
for ary in expr._container.bindings.values()
if isinstance(ary, Array))
| self._get_preds_from_shape(expr.shape))
+ self._get_preds_from_shape(expr.shape)))

def _map_index_base(self, expr: IndexBase) -> abc_Set[ArrayOrNames]:
return (FrozenOrderedSet([expr.array])
| FrozenOrderedSet(idx for idx in expr.indices
def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]:
return tuple(unique(tuple([expr.array])
+ tuple(idx for idx in expr.indices
if isinstance(idx, Array))
| self._get_preds_from_shape(expr.shape))
+ self._get_preds_from_shape(expr.shape)))

map_basic_index = _map_index_base
map_contiguous_advanced_index = _map_index_base
map_non_contiguous_advanced_index = _map_index_base

def _map_index_remapping_base(self, expr: IndexRemappingBase
) -> abc_Set[ArrayOrNames]:
return FrozenOrderedSet([expr.array])
) -> tuple[ArrayOrNames]:
return (expr.array,)

map_roll = _map_index_remapping_base
map_axis_permutation = _map_index_remapping_base
map_reshape = _map_index_remapping_base

def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[ArrayOrNames]:
def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

map_placeholder = _map_input_base
map_data_wrapper = _map_input_base
map_size_param = _map_input_base

def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[ArrayOrNames]:
def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]:
return self._get_preds_from_shape(expr.shape)

def map_distributed_send_ref_holder(self,
expr: DistributedSendRefHolder
) -> abc_Set[ArrayOrNames]:
return FrozenOrderedSet([expr.passthrough_data])
) -> tuple[ArrayOrNames]:
return (expr.passthrough_data,)

def map_call(self, expr: Call) -> tuple[ArrayOrNames]:
return tuple(unique(expr.bindings.values()))

def map_call(self, expr: Call) -> abc_Set[ArrayOrNames]:
return FrozenOrderedSet(expr.bindings.values())
def map_named_call_result(
self, expr: NamedCallResult) -> tuple[ArrayOrNames]:
return (expr._container,)

def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[ArrayOrNames]:
return FrozenOrderedSet([expr._container])

# }}}

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
"immutabledict",
"attrs",
"bidict",
"orderedsets",
],
package_data={"pytato": ["py.typed"]},
author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei",
Expand Down

0 comments on commit 3c87726

Please sign in to comment.