diff --git a/pytential/symbolic/execution.py b/pytential/symbolic/execution.py index 3dee394e4..eaa718152 100644 --- a/pytential/symbolic/execution.py +++ b/pytential/symbolic/execution.py @@ -124,29 +124,31 @@ def map_node_min(self, expr): def _map_elementwise_reduction(self, reduction_name, expr): import loopy as lp from arraycontext import make_loopy_program - from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) + from meshmode.transform_metadata import ConcurrentElementInameTag + actx = self.array_context - @memoize_in(self.places, "elementwise_node_"+reduction_name) + @memoize_in(actx, ( + EvaluationMapperBase._map_elementwise_reduction, + f"elementwise_node_{reduction_name}")) def node_knl(): t_unit = make_loopy_program( """{[iel, idof, jdof]: 0<=iel el_result = %s(jdof, operand[iel, jdof]) + result[iel, idof] = el_result """ % reduction_name, - name="nodewise_reduce") + name=f"elementwise_node_{reduction_name}") return lp.tag_inames(t_unit, { "iel": ConcurrentElementInameTag(), - "idof": ConcurrentDOFInameTag(), }) - @memoize_in(self.places, "elementwise_"+reduction_name) + @memoize_in(actx, ( + EvaluationMapperBase._map_elementwise_reduction, + f"elementwise_element_{reduction_name}")) def element_knl(): - # FIXME: This computes the reduction value redundantly for each - # output DOF. t_unit = make_loopy_program( """{[iel, jdof]: 0<=iel