From b123123996595c9e4b79f6ca5aedfd693ca9512b Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 30 May 2023 07:25:32 -0500 Subject: [PATCH 1/4] fuse indirections in meshmode based code --- .../pytato_indirection_transforms.py | 727 ++++++++++++++++++ 1 file changed, 727 insertions(+) create mode 100644 grudge/pytato_transforms/pytato_indirection_transforms.py diff --git a/grudge/pytato_transforms/pytato_indirection_transforms.py b/grudge/pytato_transforms/pytato_indirection_transforms.py new file mode 100644 index 000000000..a5f05c648 --- /dev/null +++ b/grudge/pytato_transforms/pytato_indirection_transforms.py @@ -0,0 +1,727 @@ +__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pytato as pt +from typing import Callable, Dict, FrozenSet, List, Mapping, Optional, Tuple +from pytato.array import (InputArgumentBase, IndexLambda, + Stack, Concatenate, AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, BasicIndex, + Einsum, Roll, Array, Reshape, DictOfNamedArrays, + DataWrapper, Placeholder, IndexBase) +from pytato.transform import (ArrayOrNames, CombineMapper, Mapper, + MappedT) +from immutables import Map + + +# {{{ fuse_dof_pick_lists + +def _is_materialized(expr: Array) -> bool: + """ + Returns true if an array is materialized. An array is considered to be + materialized if it is either a :class:`pytato.array.InputArgumentBase` or + is tagged with :class:`pytato.tags.ImplStored`. + """ + from pytato.tags import ImplStored + return (isinstance(expr, InputArgumentBase) + or bool(expr.tags_of_type(ImplStored))) + + +def _can_index_lambda_propagate_indirections_without_changing_axes( + expr: IndexLambda) -> bool: + + from pytato.utils import are_shapes_equal + from pytato.raising import (index_lambda_to_high_level_op, + BinaryOp) + hlo = index_lambda_to_high_level_op(expr) + return (isinstance(hlo, BinaryOp) + and ((not isinstance(hlo.x1, pt.Array)) + or (not isinstance(hlo.x2, pt.Array)) + or are_shapes_equal(hlo.x1.shape, hlo.x2.shape))) + + +def _is_advanced_indexing_from_resample_by_picking( + expr: AdvancedIndexInContiguousAxes +) -> bool: + + from pytato.utils import are_shapes_equal + + if expr.ndim != 2 or expr.array.ndim != 2: + # only worry about dofs-to-dofs like resamplings. + return False + + idx1, idx2 = expr.indices + + if (not isinstance(idx1, Array)) or (not isinstance(idx2, Array)): + # only worry about resamplings of the form + # `u[from_el_indices, dof_pick_list]`. + return False + + if (idx1.ndim != 2) or (idx2.ndim != 2): + return False + + if not are_shapes_equal(idx1.shape, (idx2.shape[0], 1)): + return False + + return True + + +# **Note: implementation of _CanPickIndirectionsBePropagated is restrictive on +# purpose.** +# Although this could be generalized to get a tighter condition on when indirections +# can legally be propagated, (for now) we are only interested at patterns commonly +# seen in meshmode-based expression graphs. +class _CanPickIndirectionsBePropagated(Mapper): + """ + Mapper to test whether the dof pick lists and element pick lists can be + propagated towards the operands. + """ + def __init__(self) -> None: + self._cache: Dict[Tuple[ArrayOrNames, int, int], bool] = {} + super().__init__() + + # type-ignore-reason: incompatible function signature with Mapper.rec + def rec(self, expr: ArrayOrNames, # type: ignore[override] + iel_axis: int, idof_axis: int) -> bool: + if isinstance(expr, Array): + assert 0 <= iel_axis < expr.ndim + assert 0 <= idof_axis < expr.ndim + # the condition below ensures that we are only dealing with indirections + # appearing at contiguous locations. + assert abs(iel_axis-idof_axis) == 1 + + if isinstance(expr, Array) and _is_materialized(expr): + return True + + key = (expr, iel_axis, idof_axis) + try: + return self._cache[key] + except KeyError: + result = super().rec(expr, iel_axis, idof_axis) + self._cache[key] = result + return result + + def _map_input_base(self, expr: InputArgumentBase, + iel_axis: int, idof_axis: int) -> bool: + return True + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + + def map_index_lambda(self, + expr: IndexLambda, + iel_axis: int, + idof_axis: int) -> bool: + if _can_index_lambda_propagate_indirections_without_changing_axes(expr): + return all(self.rec(bnd, iel_axis, idof_axis) + for bnd in expr.bindings.values()) + else: + return False + + def map_stack(self, expr: Stack, iel_axis: int, idof_axis: int) -> bool: + if expr.axis in {iel_axis, idof_axis}: + return False + else: + if iel_axis < expr.axis: + assert idof_axis < expr.axis + return all(self.rec(ary, iel_axis, idof_axis) for ary in expr.arrays) + else: + assert idof_axis > expr.axis + return all(self.rec(ary, iel_axis-1, idof_axis-1) + for ary in expr.arrays) + + def map_concatenate(self, + expr: Concatenate, + iel_axis: int, + idof_axis: int) -> bool: + if expr.axis in {iel_axis, idof_axis}: + return False + else: + return all(self.rec(ary, iel_axis, idof_axis) for ary in expr.arrays) + + def map_einsum(self, expr: Einsum, iel_axis: int, idof_axis: int) -> bool: + from pytato.array import EinsumElementwiseAxis + + for arg, acc_descrs in zip(expr.args, expr.access_descriptors): + try: + arg_iel_axis = acc_descrs.index(EinsumElementwiseAxis(iel_axis)) + arg_idof_axis = acc_descrs.index(EinsumElementwiseAxis(idof_axis)) + except ValueError: + return False + else: + if abs(arg_iel_axis - arg_idof_axis) != 1: + return False + + if not self.rec(arg, arg_iel_axis, arg_idof_axis): + return False + + return True + + def map_roll(self, expr: Roll, iel_axis: int, idof_axis: int) -> bool: + return False + + def map_non_contiguous_advanced_index(self, + expr: AdvancedIndexInNoncontiguousAxes, + iel_axis: int, + idof_axis: int) -> bool: + # TODO: In meshmode based codes non-contiguous advanced indices are rare + # i.e. not a first order concern to optimize across these nodes. + return False + + def map_contiguous_advanced_index(self, + expr: AdvancedIndexInContiguousAxes, + iel_axis: int, + idof_axis: int) -> bool: + + return (_is_advanced_indexing_from_resample_by_picking(expr) + and iel_axis == 0 and idof_axis == 1 + and self.rec(expr.array, iel_axis, idof_axis)) + + def map_basic_index(self, expr: BasicIndex, + iel_axis: int, idof_axis: int) -> bool: + # TODO: In meshmode based codes slices are rare i.e. not a first order + # concern to optimize across these nodes. + return False + + def map_reshape(self, expr: Reshape, + iel_axis: int, idof_axis: int) -> bool: + # TODO: In meshmode based codes reshapes in flux computations on sub-domains + # are rare i.e. not a first order concern to optimize across these nodes. + return False + + +def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]): + assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices) + assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices) + + result = from_element_indices[-1] + for from_el_idx in from_element_indices[-2::-1]: + result = result[from_el_idx, 0] + + return result + + +def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], from_element_indices: + Tuple[Array, ...]): + assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices) + assert all(dof_pick_list.ndim == 2 for dof_pick_list in dof_pick_lists) + assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices) + + result = dof_pick_lists[-1] + for from_el_idx, dof_pick_list in zip(from_element_indices[-2::-1], + dof_pick_lists[-2::-1]): + result = result[from_el_idx, dof_pick_list] + + return result + + +def _pick_list_fusers_map_materialized_node(rec_expr: Array, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...] + ) -> Array: + + if iel_axis is not None: + assert idof_axis is not None + assert len(from_element_indices) != 0 + assert len(from_element_indices) == len(dof_pick_lists) + + fused_from_element_indices = _fuse_from_element_indices(from_element_indices) + fused_dof_pick_lists = _fuse_dof_pick_lists(dof_pick_lists, + from_element_indices) + if iel_axis < idof_axis: + assert idof_axis == (iel_axis+1) + indices = (slice(None),)*iel_axis + (fused_from_element_indices, + fused_dof_pick_lists) + else: + assert iel_axis == (idof_axis+1) + indices = (slice(None),)*iel_axis + (fused_dof_pick_lists, + fused_from_element_indices) + + return rec_expr[indices] + else: + assert idof_axis is None + return rec_expr + + +class PickListFusers(Mapper): + def __init__(self) -> None: + self.can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated() + self._cache: Dict[Tuple[Array, Optional[int], Optional[int], + Tuple[Array, ...], Tuple[Array, ...]], Array] = {} + super().__init__() + + # type-ignore-reason: incompatible signature with Mapper.rec + def rec(self, # type: ignore[override] + expr: Array, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...], + ) -> Array: + if not isinstance(expr, Array): + raise ValueError("Mapping AbstractResultWithNamedArrays" + " is illegal for PickListFusers. Pass arrays" + " instead.") + + if iel_axis is not None: + assert idof_axis is not None + assert 0 <= iel_axis < expr.ndim + assert 0 <= idof_axis < expr.ndim + # the condition below ensures that we are only dealing with indirections + # appearing at contiguous locations. + assert abs(iel_axis-idof_axis) == 1 + else: + assert idof_axis is None + assert len(from_element_indices) == 0 + + assert len(dof_pick_lists) == len(from_element_indices) + + key = (expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + try: + return self._cache[key] + except KeyError: + result = super().rec(expr, iel_axis, idof_axis, + from_element_indices, dof_pick_lists) + self._cache[key] = result + return result + + # type-ignore-reason: incompatible signature with Mapper.__call__ + def __call__(self, # type: ignore[override] + expr: Array, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...], + ) -> Array: + return self.rec(expr, iel_axis, idof_axis, + from_element_indices, dof_pick_lists) + + def _map_input_base(self, + expr: InputArgumentBase, + iel_axis: int, + idof_axis: int, + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + return _pick_list_fusers_map_materialized_node( + expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + + def map_index_lambda(self, + expr: IndexLambda, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + if _is_materialized(expr): + # Stop propagating indirections and return indirections collected till + # this point. + rec_expr = IndexLambda( + expr.expr, + expr.shape, + expr.dtype, + Map({name: self.rec(bnd, None, None, (), ()) + for name, bnd in expr.bindings.items()}), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes + ) + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + if iel_axis is not None: + assert idof_axis is not None + assert _can_index_lambda_propagate_indirections_without_changing_axes( + expr) + from pytato.utils import are_shapes_equal + new_el_dim, new_dofs_dim = dof_pick_lists[0].shape + assert are_shapes_equal(from_element_indices[0].shape, (new_el_dim, 1)) + + new_shape = tuple( + new_el_dim if idim == iel_axis else ( + new_dofs_dim if idim == idof_axis else dim) + for idim, dim in enumerate(expr.shape)) + + return IndexLambda( + expr.expr, + new_shape, + expr.dtype, + Map({name: self.rec(bnd, iel_axis, idof_axis, + from_element_indices, + dof_pick_lists) + for name, bnd in expr.bindings.items()}), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes + ) + else: + return IndexLambda( + expr.expr, + expr.shape, + expr.dtype, + Map({name: self.rec(bnd, None, None, (), ()) + for name, bnd in expr.bindings.items()}), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes + ) + + def map_contiguous_advanced_index(self, + expr: AdvancedIndexInContiguousAxes, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...], + ) -> Array: + if _is_materialized(expr): + # Stop propagating indirections and return indirections collected till + # this point. + rec_expr = AdvancedIndexInContiguousAxes( + self.rec(expr.array, None, None, (), ()), + expr.indices, + tags=expr.tags, + axes=expr.axes) + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + if self.can_pick_indirections_be_propagated(expr, + iel_axis or 0, + idof_axis or 1): + idx1, idx2 = expr.indices + assert isinstance(idx1, Array) and isinstance(idx2, Array) + return self.rec(expr.array, 0, 1, + from_element_indices + (idx1,), + dof_pick_lists + (idx2,)) + else: + assert iel_axis is None and idof_axis is None + return AdvancedIndexInContiguousAxes( + self.rec(expr.array, iel_axis, idof_axis, + from_element_indices, dof_pick_lists), + expr.indices, + tags=expr.tags, + axes=expr.axes + ) + + def map_einsum(self, + expr: Einsum, + iel_axis: int, + idof_axis: int, + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + from pytato.array import EinsumElementwiseAxis + + if _is_materialized(expr): + # Stop propagating indirections and return indirections collected till + # this point. + rec_expr = Einsum(expr.access_descriptors, + args=tuple(self.rec(arg, None, None, (), ()) + for arg in expr.args), + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + index_to_access_descr=expr.index_to_access_descr, + tags=expr.tags, + axes=expr.axes) + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + if iel_axis is not None: + assert idof_axis is not None + new_args: List[Array] = [] + for arg, acc_descrs in zip(expr.args, expr.access_descriptors): + arg_iel_axis = acc_descrs.index(EinsumElementwiseAxis(iel_axis)) + arg_idof_axis = acc_descrs.index(EinsumElementwiseAxis(idof_axis)) + new_args.append( + self.rec(arg, arg_iel_axis, arg_idof_axis, + from_element_indices, dof_pick_lists) + ) + return Einsum(expr.access_descriptors, + args=tuple(new_args), + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + index_to_access_descr=expr.index_to_access_descr, + tags=expr.tags, + axes=expr.axes) + else: + assert idof_axis is None + return Einsum(expr.access_descriptors, + args=tuple(self.rec(arg, None, None, (), ()) + for arg in expr.args), + redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, + index_to_access_descr=expr.index_to_access_descr, + tags=expr.tags, + axes=expr.axes) + + def map_stack(self, expr: Stack, iel_axis: int, idof_axis: int, + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...], + ) -> Array: + + if _is_materialized(expr): + # Stop propagating indirections and return indirections collected till + # this point. + rec_expr = Stack(tuple(self.rec(ary, None, None, (), ()) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + if iel_axis is not None: + assert idof_axis is not None + if iel_axis < expr.axis: + assert idof_axis < expr.axis + return Stack(tuple(self.rec(ary, iel_axis, idof_axis, + from_element_indices, dof_pick_lists) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + else: + assert idof_axis > expr.axis + return Stack(tuple(self.rec(ary, iel_axis-1, idof_axis-1, + from_element_indices, dof_pick_lists) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + return self.rec(expr.array, iel_axis-1, idof_axis-1, + from_element_indices, dof_pick_lists) + else: + assert idof_axis is None + return Stack(tuple(self.rec(ary, iel_axis, idof_axis, + from_element_indices, dof_pick_lists) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + + def map_concatenate(self, + expr: Concatenate, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...], + ) -> Array: + if _is_materialized(expr): + # Stop propagating indirections and return indirections collected till + # this point. + rec_expr = Concatenate(tuple(self.rec(ary, None, None, (), ()) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + return Concatenate(tuple(self.rec(ary, iel_axis, idof_axis, + from_element_indices, dof_pick_lists) + for ary in expr.arrays), + expr.axis, + tags=expr.tags, + axes=expr.axes) + + def map_roll(self, + expr: Roll, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + + rec_expr = Roll(self.rec(expr.array, None, None, (), ()), + expr.shift, + expr.axis, + tags=expr.tags, + axes=expr.axes) + if _is_materialized(expr): + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + assert iel_axis is None and idof_axis is None + return rec_expr + + def map_non_contiguous_advanced_index(self, + expr: AdvancedIndexInNoncontiguousAxes, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...] + ) -> Array: + rec_expr = AdvancedIndexInNoncontiguousAxes( + self.rec(expr.array, None, None, (), ()), + expr.indices, + tags=expr.tags, + axes=expr.axes + ) + + if _is_materialized(expr): + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + assert iel_axis is None and idof_axis is None + return rec_expr + + def map_basic_index(self, expr: BasicIndex, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + + rec_expr = BasicIndex( + self.rec(expr.array, None, None, (), ()), + expr.indices, + tags=expr.tags, + axes=expr.axes + ) + + if _is_materialized(expr): + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + assert iel_axis is None and idof_axis is None + return rec_expr + + def map_reshape(self, + expr: Reshape, + iel_axis: Optional[int], + idof_axis: Optional[int], + from_element_indices: Tuple[Array, ...], + dof_pick_lists: Tuple[Array, ...]) -> Array: + rec_expr = Reshape( + self.rec(expr.array, None, None, (), ()), + expr.newshape, + expr.order, + tags=expr.tags, + axes=expr.axes + ) + + if _is_materialized(expr): + return _pick_list_fusers_map_materialized_node( + rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) + + assert iel_axis is None and idof_axis is None + return rec_expr + + +def fuse_dof_pick_lists(expr: DictOfNamedArrays) -> DictOfNamedArrays: + mapper = PickListFusers() + + return DictOfNamedArrays( + {name: mapper(subexpr, None, None, (), ()) + for name, subexpr in sorted(expr._data.items(), key=lambda x: x[0])} + ) + +# }}} + + +# {{{ fold indirection constants + +class _ConstantIndirectionArrayCollector(CombineMapper[FrozenSet[Array]]): + def __init__(self) -> None: + from pytato.transform import InputGatherer + super().__init__() + self.get_inputs = InputGatherer() + + def combine(self, *args: FrozenSet[Array]) -> FrozenSet[Array]: + from functools import reduce + return reduce(frozenset.union, args, frozenset()) + + def _map_input_base(self, expr: InputArgumentBase) -> FrozenSet[Array]: + return frozenset() + + map_placeholder = _map_input_base + map_data_wrapper = _map_input_base + map_size_param = _map_input_base + + def _map_index_base(self, expr: IndexBase) -> FrozenSet[Array]: + rec_results: List[FrozenSet[Array]] = [] + + rec_results.append(self.rec(expr.array)) + + for idx in expr.indices: + if isinstance(idx, Array): + input_deps = self.get_inputs(idx) + + if input_deps: + if any(isinstance(inp, Placeholder) for inp in input_deps): + rec_results.append(self.rec(idx)) + else: + assert all(isinstance(inp, DataWrapper) + for inp in input_deps) + rec_results.append(frozenset([idx])) + + return self.combine(*rec_results) + + +def fold_constant_indirections( + expr: MappedT, + evaluator: Callable[[DictOfNamedArrays], Mapping[str, DataWrapper]] +) -> MappedT: + """ + Returns a copy of *expr* with constant indirection expressions frozen. + + :arg evaluator: A callable that takes in a + :class:`~pytato.array.DictOfNamedArrays` and returns a mapping from the + name of every named array to it's corresponding evaluated array as an + instance of :class:`~pytato.array.DataWrapper`. + """ + from pytools import UniqueNameGenerator + from pytato.array import make_dict_of_named_arrays + import collections.abc as abc + from pytato.transform import map_and_copy + + vng = UniqueNameGenerator() + arys_to_evaluate = _ConstantIndirectionArrayCollector()(expr) + dict_of_named_arrays = make_dict_of_named_arrays( + {vng("_pt_folded_cnst"): ary for ary in arys_to_evaluate} + ) + del arys_to_evaluate + evaluated_arys = evaluator(dict_of_named_arrays) + + if not isinstance(evaluated_arys, abc.Mapping): + raise TypeError("evaluator did not return a mapping") + + if set(evaluated_arys.keys()) != set(dict_of_named_arrays.keys()): + raise ValueError("evaluator must return a mapping with " + f"the keys: '{set(dict_of_named_arrays.keys())}'.") + + for key, ary in evaluated_arys.items(): + if not isinstance(ary, DataWrapper): + raise TypeError(f"evaluated array for '{key}' not a DataWrapper") + + before_to_after_subst = { + dict_of_named_arrays._data[name]: evaluated_ary + for name, evaluated_ary in evaluated_arys.items() + } + + def _replace_with_folded_constants(subexpr: ArrayOrNames) -> ArrayOrNames: + if isinstance(subexpr, Array): + return before_to_after_subst.get(subexpr, subexpr) + else: + return subexpr + + return map_and_copy(expr, _replace_with_folded_constants) + +# }}} + + +# vim: fdm=marker From 7359d59fe26f01bd9ab8e5899f7668ff13a116bb Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 30 May 2023 07:26:36 -0500 Subject: [PATCH 2/4] tests pick list fusion --- test/test_pytato_transforms.py | 275 +++++++++++++++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 test/test_pytato_transforms.py diff --git a/test/test_pytato_transforms.py b/test/test_pytato_transforms.py new file mode 100644 index 000000000..430d644c0 --- /dev/null +++ b/test/test_pytato_transforms.py @@ -0,0 +1,275 @@ +import numpy as np # noqa: F401 +import pyopencl as cl +from typing import Union +from meshmode.mesh import BTAG_ALL +from meshmode.mesh.generation import generate_regular_rect_mesh +from arraycontext.metadata import NameHint +from meshmode.array_context import (PytatoPyOpenCLArrayContext, + PyOpenCLArrayContext) +from pytato.transform import CombineMapper +from pytato.array import (Placeholder, DataWrapper, SizeParam, IndexBase, + Array, DictOfNamedArrays, BasicIndex) +from meshmode.discretization.connection import (FACE_RESTR_INTERIOR, + FACE_RESTR_ALL) +from pytools.obj_array import make_obj_array +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +import grudge +import grudge.op as op + + +# {{{ utilities for test_push_indirections_* + +class _IndexeeArraysMaterializedChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + from grudge.pytato_transforms.pytato_indirection_transforms import ( + _is_materialized) + return self.combine( + _is_materialized(expr.array) or isinstance(expr.array, BasicIndex), + self.rec(expr.array) + ) + + +def are_all_indexees_materialized_nodes( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexee arrays are either materialized nodes, + OR, other indexing nodes that have materialized indexees. + """ + return _IndexeeArraysMaterializedChecker()(expr) + + +class _IndexerArrayDatawrapperChecker(CombineMapper[bool]): + def combine(self, *args: bool) -> bool: + return all(args) + + def map_placeholder(self, expr: Placeholder) -> bool: + return True + + def map_data_wrapper(self, expr: DataWrapper) -> bool: + return True + + def map_size_param(self, expr: SizeParam) -> bool: + return True + + def _map_index_base(self, expr: IndexBase) -> bool: + return self.combine( + *[isinstance(idx, DataWrapper) + for idx in expr.indices + if isinstance(idx, Array)], + super()._map_index_base(expr), + ) + + +def are_all_indexer_arrays_datawrappers( + expr: Union[Array, DictOfNamedArrays]) -> bool: + """ + Returns *True* only if all indexer arrays are instances of + :class:`~pytato.array.DataWrapper`. + """ + return _IndexerArrayDatawrapperChecker()(expr) + +# }}} + + +def _evaluate_dict_of_named_arrays(actx, dict_of_named_arrays): + container = make_obj_array([dict_of_named_arrays._data[name] + for name in sorted(dict_of_named_arrays.keys())]) + + evaluated_container = actx.thaw(actx.freeze(container)) + + return {name: evaluated_container[i] + for i, name in enumerate(sorted(dict_of_named_arrays.keys()))} + + +class FluxOptimizerActx(PytatoPyOpenCLArrayContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.check_completed = False + + def transform_dag(self, dag): + from grudge.pytato_transforms.pytato_indirection_transforms import ( + fuse_dof_pick_lists, fold_constant_indirections) + from pytato.tags import PrefixNamed + + if all(PrefixNamed("flux_container") in v.tags for v in dag._data.values()): + assert not are_all_indexer_arrays_datawrappers(dag) + assert not are_all_indexees_materialized_nodes(dag) + self.check_completed = True + + dag = fuse_dof_pick_lists(dag) + dag = fold_constant_indirections( + dag, lambda x: _evaluate_dict_of_named_arrays(self, x)) + + if all(PrefixNamed("flux_container") in v.tags for v in dag._data.values()): + assert are_all_indexer_arrays_datawrappers(dag) + assert are_all_indexees_materialized_nodes(dag) + self.check_completed = True + + return dag + + +# {{{ test_resampling_indirections_are_fused_0 + +def _compute_flux_0(dcoll, actx, u): + u_interior_tpair, = op.interior_trace_pairs(dcoll, u) + flux_on_interior_faces = u_interior_tpair.avg + flux_on_all_faces = op.project( + dcoll, FACE_RESTR_INTERIOR, FACE_RESTR_ALL, flux_on_interior_faces) + + flux_on_all_faces = actx.tag(NameHint("flux_container"), flux_on_all_faces) + return flux_on_all_faces + + +def test_resampling_indirections_are_fused_0(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + ref_actx = PyOpenCLArrayContext(cq) + actx = FluxOptimizerActx(cq) + + dim = 3 + nel_1d = 4 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, + b=(0.5,)*dim, + nelements_per_axis=(nel_1d,)*dim, + boundary_tag_to_face={"bdry": ["-x", "+x", + "-y", "+y", + "-z", "+z"]} + ) + dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) + + x, _, _ = dcoll.nodes() + compiled_flux_0 = actx.compile(lambda ary: _compute_flux_0(dcoll, actx, ary)) + + ref_output = ref_actx.to_numpy( + _compute_flux_0(dcoll, ref_actx, ref_actx.thaw(x))) + output = actx.to_numpy( + compiled_flux_0(actx.thaw(x))) + + np.testing.assert_allclose(ref_output[0], output[0]) + assert actx.check_completed + +# }}} + + +# {{{ test_resampling_indirections_are_fused_1 + +def _compute_flux_1(dcoll, actx, u): + u_interior_tpair, = op.interior_trace_pairs(dcoll, u) + flux_on_interior_faces = u_interior_tpair.avg + flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u) + flux_on_all_faces = ( + op.project(dcoll, + FACE_RESTR_INTERIOR, + FACE_RESTR_ALL, + flux_on_interior_faces) + + op.project(dcoll, BTAG_ALL, FACE_RESTR_ALL, flux_on_bdry) + ) + + result = op.inverse_mass(dcoll, op.face_mass(dcoll, flux_on_all_faces)) + + result = actx.tag(NameHint("flux_container"), result) + return result + + +def test_resampling_indirections_are_fused_1(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + ref_actx = PyOpenCLArrayContext(cq) + actx = FluxOptimizerActx(cq) + + dim = 3 + nel_1d = 4 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, + b=(0.5,)*dim, + nelements_per_axis=(nel_1d,)*dim, + boundary_tag_to_face={"bdry": ["-x", "+x", + "-y", "+y", + "-z", "+z"]} + ) + dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) + + x, _, _ = dcoll.nodes() + compiled_flux_1 = actx.compile(lambda ary: _compute_flux_1(dcoll, actx, ary)) + + ref_output = ref_actx.to_numpy( + _compute_flux_1(dcoll, ref_actx, ref_actx.thaw(x))) + output = actx.to_numpy( + compiled_flux_1(actx.thaw(x))) + + np.testing.assert_allclose(ref_output[0], output[0]) + assert actx.check_completed + +# }}} + + +# {{{ test_resampling_indirections_are_fused_2 + +def _compute_flux_2(dcoll, actx, u): + u_interior_tpair, = op.interior_trace_pairs(dcoll, u) + normal_on_interior_faces = actx.thaw(dcoll.normal(u_interior_tpair.dd)) + normal_on_bdry_faces = actx.thaw(dcoll.normal(BTAG_ALL)) + flux_on_interior_faces = u_interior_tpair.avg * normal_on_interior_faces + flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u) * normal_on_bdry_faces + flux_on_all_faces = ( + op.project(dcoll, + FACE_RESTR_INTERIOR, + FACE_RESTR_ALL, + flux_on_interior_faces) + + op.project(dcoll, BTAG_ALL, FACE_RESTR_ALL, flux_on_bdry) + ) + + result = op.inverse_mass(dcoll, op.face_mass(dcoll, flux_on_all_faces)) + + result = actx.tag(NameHint("flux_container"), result) + return result + + +def test_resampling_indirections_are_fused_2(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + ref_actx = PyOpenCLArrayContext(cq) + actx = FluxOptimizerActx(cq) + + dim = 2 + nel_1d = 4 + mesh = generate_regular_rect_mesh( + a=(-0.5,)*dim, + b=(0.5,)*dim, + nelements_per_axis=(nel_1d,)*dim, + boundary_tag_to_face={"bdry": ["-x", "+x", + "-y", "+y"]} + ) + dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) + + x, _ = dcoll.nodes() + compiled_flux_2 = actx.compile(lambda ary: _compute_flux_2(dcoll, actx, ary)) + + ref_output = ref_actx.to_numpy( + _compute_flux_2(dcoll, ref_actx, ref_actx.thaw(x))) + output = actx.to_numpy( + compiled_flux_2(actx.thaw(x))) + + np.testing.assert_allclose(ref_output[0], output[0]) + assert actx.check_completed + +# }}} + +# vim: fdm=marker From ce9b9fae8d3a9767552b1988ec5b3871f6610428 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 10 Aug 2023 10:41:30 -0500 Subject: [PATCH 3/4] use lazy eval array context for setup actx --- test/test_pytato_transforms.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_pytato_transforms.py b/test/test_pytato_transforms.py index 430d644c0..e0d3b5c9f 100644 --- a/test/test_pytato_transforms.py +++ b/test/test_pytato_transforms.py @@ -37,7 +37,7 @@ def _map_index_base(self, expr: IndexBase) -> bool: from grudge.pytato_transforms.pytato_indirection_transforms import ( _is_materialized) return self.combine( - _is_materialized(expr.array) or isinstance(expr.array, BasicIndex), + _is_materialized(expr.array) or isinstance(expr, BasicIndex), self.rec(expr.array) ) @@ -227,6 +227,7 @@ def _compute_flux_2(dcoll, actx, u): normal_on_bdry_faces = actx.thaw(dcoll.normal(BTAG_ALL)) flux_on_interior_faces = u_interior_tpair.avg * normal_on_interior_faces flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u) * normal_on_bdry_faces + flux_on_all_faces = ( op.project(dcoll, FACE_RESTR_INTERIOR, @@ -245,21 +246,23 @@ def test_resampling_indirections_are_fused_2(ctx_factory): cl_ctx = ctx_factory() cq = cl.CommandQueue(cl_ctx) - ref_actx = PyOpenCLArrayContext(cq) + from grudge.array_context import get_reasonable_array_context_class + + ref_actx = get_reasonable_array_context_class(lazy=True, distributed=False)(cq) actx = FluxOptimizerActx(cq) - dim = 2 - nel_1d = 4 + dim = 3 + nel_1d = 16 + order = 4 mesh = generate_regular_rect_mesh( a=(-0.5,)*dim, b=(0.5,)*dim, nelements_per_axis=(nel_1d,)*dim, - boundary_tag_to_face={"bdry": ["-x", "+x", - "-y", "+y"]} ) - dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2) - - x, _ = dcoll.nodes() + dcoll = grudge.make_discretization_collection( + ref_actx, mesh, + order=order) + x, _, _ = dcoll.nodes() compiled_flux_2 = actx.compile(lambda ary: _compute_flux_2(dcoll, actx, ary)) ref_output = ref_actx.to_numpy( From 3a767c1b24748e90d8a3e8791ff8fcb9d67b27fd Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 6 Mar 2024 10:59:09 -0600 Subject: [PATCH 4/4] wip: save work --- .../pytato_indirection_transforms.py | 156 +++++++++++++----- 1 file changed, 113 insertions(+), 43 deletions(-) diff --git a/grudge/pytato_transforms/pytato_indirection_transforms.py b/grudge/pytato_transforms/pytato_indirection_transforms.py index a5f05c648..9d4961a52 100644 --- a/grudge/pytato_transforms/pytato_indirection_transforms.py +++ b/grudge/pytato_transforms/pytato_indirection_transforms.py @@ -46,8 +46,12 @@ def _is_materialized(expr: Array) -> bool: def _can_index_lambda_propagate_indirections_without_changing_axes( - expr: IndexLambda) -> bool: - + expr: IndexLambda, iel_axis: Optional[int], idof_axis: Optional[int] +) -> bool: + """ + Returns *True* only if the axes being reindexed appear at the same + positions in the bindings' indexing locations. + """ from pytato.utils import are_shapes_equal from pytato.raising import (index_lambda_to_high_level_op, BinaryOp) @@ -219,8 +223,8 @@ def _fuse_from_element_indices(from_element_indices: Tuple[Array, ...]): return result -def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], from_element_indices: - Tuple[Array, ...]): +def _fuse_dof_pick_lists(dof_pick_lists: Tuple[Array, ...], + from_element_indices: Tuple[Array, ...]): assert all(from_el_idx.ndim == 2 for from_el_idx in from_element_indices) assert all(dof_pick_list.ndim == 2 for dof_pick_list in dof_pick_lists) assert all(from_el_idx.shape[1] == 1 for from_el_idx in from_element_indices) @@ -239,7 +243,10 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array, from_element_indices: Tuple[Array, ...], dof_pick_lists: Tuple[Array, ...] ) -> Array: - + raise NotImplementedError("We still need to port this from" + " the previous version, where only" + " indirections only along the element" + " axes.") if iel_axis is not None: assert idof_axis is not None assert len(from_element_indices) != 0 @@ -263,6 +270,56 @@ def _pick_list_fusers_map_materialized_node(rec_expr: Array, return rec_expr +def _is_iel_idof_picking(expr: AdvancedIndexInContiguousAxes, + iel_axis: Optional[int], + idof_axis: Optional[int], + ) -> bool: + if expr.ndim != 2: + return False + + if expr.array.ndim != 2: + return False + + if not ((iel_axis is None and idof_axis is None) + or (iel_axis == 0 and idof_axis == 1)): + return False + + if (isinstance(expr.indices[0], Array) + and isinstance(expr.indices[1], Array)): + from pytato.utils import are_shape_components_equal + from_el_indices, dof_pick_lists = expr.indices + assert isinstance(from_el_indices, Array) + assert isinstance(dof_pick_lists, Array) + + if dof_pick_lists.ndim != 1: + return False + if from_el_indices.ndim != 2: + return False + if are_shape_components_equal(from_el_indices.shape[1], 1): + return False + + return True + else: + return False + + +def _is_iel_only_picking(expr: AdvancedIndexInContiguousAxes, + iel_axis: Optional[int]) -> bool: + if expr.ndim != 1: + return False + + if expr.array.ndim != 1: + return False + + if not isinstance(expr.indices[0], Array): + return False + + if iel_axis not in [0, None]: + return False + + return True + + class PickListFusers(Mapper): def __init__(self) -> None: self.can_pick_indirections_be_propagated = _CanPickIndirectionsBePropagated() @@ -283,18 +340,22 @@ def rec(self, # type: ignore[override] " is illegal for PickListFusers. Pass arrays" " instead.") - if iel_axis is not None: - assert idof_axis is not None + if idof_axis is not None: + assert iel_axis is not None assert 0 <= iel_axis < expr.ndim assert 0 <= idof_axis < expr.ndim # the condition below ensures that we are only dealing with indirections # appearing at contiguous locations. assert abs(iel_axis-idof_axis) == 1 - else: + assert len(dof_pick_lists) == len(from_element_indices) + elif iel_axis is not None: assert idof_axis is None + assert len(dof_pick_lists) == 0 + assert len(from_element_indices) > 0 + else: + assert iel_axis is None assert len(from_element_indices) == 0 - - assert len(dof_pick_lists) == len(from_element_indices) + assert len(dof_pick_lists) == 0 key = (expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) try: @@ -318,8 +379,8 @@ def __call__(self, # type: ignore[override] def _map_input_base(self, expr: InputArgumentBase, - iel_axis: int, - idof_axis: int, + iel_axis: Optional[int], + idof_axis: Optional[int], from_element_indices: Tuple[Array, ...], dof_pick_lists: Tuple[Array, ...]) -> Array: return _pick_list_fusers_map_materialized_node( @@ -351,30 +412,36 @@ def map_index_lambda(self, rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) if iel_axis is not None: - assert idof_axis is not None assert _can_index_lambda_propagate_indirections_without_changing_axes( - expr) - from pytato.utils import are_shapes_equal - new_el_dim, new_dofs_dim = dof_pick_lists[0].shape - assert are_shapes_equal(from_element_indices[0].shape, (new_el_dim, 1)) - - new_shape = tuple( - new_el_dim if idim == iel_axis else ( - new_dofs_dim if idim == idof_axis else dim) - for idim, dim in enumerate(expr.shape)) - - return IndexLambda( - expr.expr, - new_shape, - expr.dtype, - Map({name: self.rec(bnd, iel_axis, idof_axis, - from_element_indices, - dof_pick_lists) - for name, bnd in expr.bindings.items()}), - var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags, - axes=expr.axes - ) + expr, iel_axis, idof_axis) + if idof_axis is None: + # TODO: Not encountered any practical DAGs that take this code path. + # Implement this branch only if seen in any practical applications. + raise NotImplementedError + else: + assert idof_axis is not None + from pytato.utils import are_shapes_equal + new_el_dim, new_dofs_dim = dof_pick_lists[0].shape + assert are_shapes_equal(from_element_indices[0].shape, + (new_el_dim, 1)) + + new_shape = tuple( + new_el_dim if idim == iel_axis else ( + new_dofs_dim if idim == idof_axis else dim) + for idim, dim in enumerate(expr.shape)) + + return IndexLambda( + expr.expr, + new_shape, + expr.dtype, + Map({name: self.rec(bnd, iel_axis, idof_axis, + from_element_indices, + dof_pick_lists) + for name, bnd in expr.bindings.items()}), + var_to_reduction_descr=expr.var_to_reduction_descr, + tags=expr.tags, + axes=expr.axes + ) else: return IndexLambda( expr.expr, @@ -405,14 +472,17 @@ def map_contiguous_advanced_index(self, return _pick_list_fusers_map_materialized_node( rec_expr, iel_axis, idof_axis, from_element_indices, dof_pick_lists) - if self.can_pick_indirections_be_propagated(expr, - iel_axis or 0, - idof_axis or 1): - idx1, idx2 = expr.indices - assert isinstance(idx1, Array) and isinstance(idx2, Array) - return self.rec(expr.array, 0, 1, - from_element_indices + (idx1,), - dof_pick_lists + (idx2,)) + if (_is_iel_idof_picking(expr, iel_axis, idof_axis) + and self.can_pick_indirections_be_propagated(expr, + iel_axis or 0, + idof_axis or 1)): + raise NotImplementedError + elif (_is_iel_only_picking(expr, iel_axis) + and self.can_pick_indirections_be_propagated(expr, + iel_axis or 0, + None)): + assert idof_axis is None + raise NotImplementedError else: assert iel_axis is None and idof_axis is None return AdvancedIndexInContiguousAxes(