Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid duplicating arrays #515

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5861c8e
refactor deduplicate_data_wrappers to avoid dependence on erroneous s…
majosm Feb 21, 2025
4ff4017
call Mapper.rec instead of super().rec to avoid double caching
majosm Feb 20, 2025
99b1d47
call Mapper.rec from CachedMapper too just to avoid copy/paste errors
majosm Feb 25, 2025
9bc0a14
add assertion to check for double caching
majosm Feb 17, 2025
b7d9fdd
disable default implementation of get_cache_key and get_function_defi…
majosm Feb 7, 2025
190cb68
add CacheInputs to simplify cache key handling logic
majosm Feb 7, 2025
ea03645
rename expr_key* to input_key*
majosm Feb 7, 2025
13d5ab0
refactor to avoid performance drop
majosm Feb 18, 2025
3297f4c
add map_dict_of_named_arrays to DirectPredecessorsGetter
majosm Sep 20, 2024
6c43a55
support functions as inputs and outputs in DirectPredecessorsGetter
majosm Sep 24, 2024
c1f0681
add collision/duplication checks to CachedMapper/TransformMapper/Tran…
majosm Aug 29, 2024
b3b86a0
fix doc
majosm Feb 18, 2025
1feea92
change terminology from 'no-op duplication' to 'mapper-created duplic…
majosm Feb 18, 2025
b97edc6
reword explanation of predecessor check in duplication check
majosm Feb 19, 2025
cbd9a62
change CacheExprT constraint to use bound=
majosm Feb 26, 2025
7e572d5
add result deduplication to transform mappers
majosm Sep 24, 2024
e1ab346
add FIXME
majosm Sep 5, 2024
8b15622
avoid unnecessary duplication in CopyMapper/CopyMapperWithExtraArgs
majosm Jun 10, 2024
f3674c5
add Deduplicator
majosm Sep 20, 2024
5ac33bb
avoid unnecessary duplication in InlineMarker
majosm Jun 11, 2024
f0d48ce
avoid duplication in tagged() for Axis/ReductionDescriptor/_SuppliedA…
majosm Aug 27, 2024
3e56653
avoid duplication in Array.with_tagged_axis
majosm Jun 11, 2024
5cd2204
avoid duplication in with_tagged_reduction for IndexLambda/Einsum
majosm Jun 11, 2024
e413cc9
attempt to avoid duplication in CodeGenPreprocessor
majosm Jun 10, 2024
0a805ed
limit PlaceholderSubstitutor to one call stack frame
majosm Jul 3, 2024
ffa4fcc
tweak Inliner/PlaceholderSubstitutor implementations
majosm Jul 12, 2024
e58a283
use context manager to avoid leaking traceback tag setting in test
majosm Jul 16, 2024
6aa44eb
refactor FFTRealizationMapper to avoid resetting cache in __init__
majosm Jul 16, 2024
9c01112
add allow_duplicate_nodes option to RandomDAGContext in tests
majosm Aug 27, 2024
ccb4188
fix some more tests
majosm Aug 27, 2024
7258ccf
don't check for collisions in ArrayToDotNodeInfoMapper
majosm Sep 24, 2024
177385c
avoid duplication in MPMSMaterializer
majosm Sep 18, 2024
14cc4d6
avoid duplicates in EinsumWithNoBroadcastsRewriter
majosm Feb 7, 2025
0cd78a1
forbid DependencyMapper from being called on functions
majosm Feb 7, 2025
2ca7a65
deduplicate in advection example
majosm Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def test_advection_convergence(order, flux_type):
op = AdvectionOperator(discr, c=1, flux_type=flux_type,
dg_ops=dg_ops)
result = op.apply(u)
result = pt.transform.Deduplicator()(result)

prog = pt.generate_loopy(result, cl_device=queue.device)

Expand Down
35 changes: 28 additions & 7 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

# {{{ DirectPredecessorsGetter

class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):
class DirectPredecessorsGetter(
Mapper[
FrozenOrderedSet[ArrayOrNames | FunctionDefinition],
FrozenOrderedSet[ArrayOrNames],
[]]):
"""
Mapper to get the
`direct predecessors
Expand All @@ -334,9 +338,17 @@ class DirectPredecessorsGetter(Mapper[frozenset[ArrayOrNames], Never, []]):

We only consider the predecessors of a nodes in a data-flow sense.
"""
def __init__(self, *, include_functions: bool = False) -> None:
super().__init__()
self.include_functions = include_functions

def _get_preds_from_shape(self, shape: ShapeType) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(dim for dim in shape if isinstance(dim, Array))

def map_dict_of_named_arrays(
self, expr: DictOfNamedArrays) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr._data.values())

def map_index_lambda(self, expr: IndexLambda) -> FrozenOrderedSet[ArrayOrNames]:
return (FrozenOrderedSet(expr.bindings.values())
| self._get_preds_from_shape(expr.shape))
Expand Down Expand Up @@ -397,8 +409,17 @@ def map_distributed_send_ref_holder(self,
) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet([expr.passthrough_data])

def map_call(self, expr: Call) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.bindings.values())
def map_call(
self, expr: Call) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]:
result: FrozenOrderedSet[ArrayOrNames | FunctionDefinition] = \
FrozenOrderedSet(expr.bindings.values())
if self.include_functions:
result = result | FrozenOrderedSet([expr.function])
return result

def map_function_definition(
self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]:
return FrozenOrderedSet(expr.returns.values())

def map_named_call_result(
self, expr: NamedCallResult) -> FrozenOrderedSet[ArrayOrNames]:
Expand Down Expand Up @@ -622,11 +643,11 @@ def combine(self, *args: int) -> int:
return sum(args)

def rec(self, expr: ArrayOrNames) -> int:
key = self._cache.get_key(expr)
inputs = self._make_cache_inputs(expr)
try:
return self._cache.retrieve(expr, key=key)
return self._cache_retrieve(inputs)
except KeyError:
s = super().rec(expr)
s = Mapper.rec(self, expr)
if (
isinstance(expr, Array)
and (
Expand All @@ -636,7 +657,7 @@ def rec(self, expr: ArrayOrNames) -> int:
else:
result = 0 + s

self._cache.add(expr, 0, key=key)
self._cache_add(inputs, 0)
return result


Expand Down
141 changes: 83 additions & 58 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,11 @@ class Axis(Taggable):
tags: frozenset[Tag]

def _with_new_tags(self, tags: frozenset[Tag]) -> Axis:
from dataclasses import replace
return replace(self, tags=tags)
if tags != self.tags:
from dataclasses import replace
return replace(self, tags=tags)
else:
return self


@dataclasses.dataclass(frozen=True)
Expand All @@ -480,8 +483,11 @@ class ReductionDescriptor(Taggable):
tags: frozenset[Tag]

def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor:
from dataclasses import replace
return replace(self, tags=tags)
if tags != self.tags:
from dataclasses import replace
return replace(self, tags=tags)
else:
return self


@array_dataclass()
Expand Down Expand Up @@ -833,10 +839,14 @@ def with_tagged_axis(self, iaxis: int,
"""
Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
"""
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return self.copy(axes=new_axes)
new_axis = self.axes[iaxis].tagged(tags)
if new_axis is not self.axes[iaxis]:
new_axes = (self.axes[:iaxis]
+ (self.axes[iaxis].tagged(tags),)
+ self.axes[iaxis+1:])
return self.copy(axes=new_axes)
else:
return self

@memoize_method
def __repr__(self) -> str:
Expand Down Expand Up @@ -865,7 +875,10 @@ class _SuppliedAxesAndTagsMixin(Taggable):
default=frozenset())

def _with_new_tags(self: Self, tags: frozenset[Tag]) -> Self:
return dataclasses.replace(self, tags=tags)
if tags != self.tags:
return dataclasses.replace(self, tags=tags)
else:
return self


@dataclasses.dataclass(frozen=True, eq=False, repr=False)
Expand Down Expand Up @@ -1114,20 +1127,22 @@ def with_tagged_reduction(self,
f" '{self.var_to_reduction_descr.keys()}',"
f" got '{reduction_variable}'.")

assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = \
self.var_to_reduction_descr[reduction_variable].tagged(tags)

return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags)
new_redn_descr = self.var_to_reduction_descr[reduction_variable].tagged(tags)
if new_redn_descr is not self.var_to_reduction_descr[reduction_variable]:
assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = new_redn_descr
return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags)
else:
return self

# }}}

Expand Down Expand Up @@ -1162,6 +1177,34 @@ class EinsumReductionAxis(EinsumAxisDescriptor):
dim: int


def _get_einsum_access_descr_to_axis_len(
access_descriptors: tuple[tuple[EinsumAxisDescriptor, ...], ...],
args: tuple[Array, ...],
) -> Mapping[EinsumAxisDescriptor, ShapeComponent]:
from pytato.utils import are_shape_components_equal
descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {}

for access_descrs, arg in zip(access_descriptors,
args, strict=True):
assert arg.ndim == len(access_descrs)
for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True):
if descr in descr_to_axis_len:
seen_axis_len = descr_to_axis_len[descr]

if not are_shape_components_equal(seen_axis_len,
arg_axis_len):
if are_shape_components_equal(arg_axis_len, 1):
# this axis would be broadcasted
pass
else:
assert are_shape_components_equal(seen_axis_len, 1)
descr_to_axis_len[descr] = arg_axis_len
else:
descr_to_axis_len[descr] = arg_axis_len

return immutabledict(descr_to_axis_len)


@array_dataclass()
class Einsum(_SuppliedAxesAndTagsMixin, Array):
"""
Expand Down Expand Up @@ -1209,28 +1252,8 @@ def __post_init__(self) -> None:
@memoize_method
def _access_descr_to_axis_len(self
) -> Mapping[EinsumAxisDescriptor, ShapeComponent]:
from pytato.utils import are_shape_components_equal
descr_to_axis_len: dict[EinsumAxisDescriptor, ShapeComponent] = {}

for access_descrs, arg in zip(self.access_descriptors,
self.args, strict=True):
assert arg.ndim == len(access_descrs)
for arg_axis_len, descr in zip(arg.shape, access_descrs, strict=True):
if descr in descr_to_axis_len:
seen_axis_len = descr_to_axis_len[descr]

if not are_shape_components_equal(seen_axis_len,
arg_axis_len):
if are_shape_components_equal(arg_axis_len, 1):
# this axis would be broadcasted
pass
else:
assert are_shape_components_equal(seen_axis_len, 1)
descr_to_axis_len[descr] = arg_axis_len
else:
descr_to_axis_len[descr] = arg_axis_len

return immutabledict(descr_to_axis_len)
return _get_einsum_access_descr_to_axis_len(
self.access_descriptors, self.args)

@cached_property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1278,19 +1301,21 @@ def with_tagged_reduction(self,

# }}}

assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = \
self.redn_axis_to_redn_descr[redn_axis].tagged(tags)

return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags,
)
new_redn_descr = self.redn_axis_to_redn_descr[redn_axis].tagged(tags)
if new_redn_descr is not self.redn_axis_to_redn_descr[redn_axis]:
assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = new_redn_descr
return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
non_equality_tags=self.non_equality_tags,
)
else:
return self


EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")
Expand Down
Loading
Loading