Skip to content

Commit

Permalink
simplify elementwise reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Aug 18, 2022
1 parent 0c42040 commit 879ad35
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 45 deletions.
56 changes: 24 additions & 32 deletions pytential/symbolic/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,31 @@ def map_node_min(self, expr):
def _map_elementwise_reduction(self, reduction_name, expr):
import loopy as lp
from arraycontext import make_loopy_program
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
from meshmode.transform_metadata import ConcurrentElementInameTag
actx = self.array_context

@memoize_in(self.places, "elementwise_node_"+reduction_name)
@memoize_in(actx, (
EvaluationMapperBase._map_elementwise_reduction,
f"elementwise_node_{reduction_name}"))
def node_knl():
t_unit = make_loopy_program(
"""{[iel, idof, jdof]:
0<=iel<nelements and
0<=idof, jdof<ndofs}""",
"""
result[iel, idof] = %s(jdof, operand[iel, jdof])
<> el_result = %s(jdof, operand[iel, jdof])
result[iel, idof] = el_result
""" % reduction_name,
name="nodewise_reduce")
name=f"elementwise_node_{reduction_name}")

return lp.tag_inames(t_unit, {
"iel": ConcurrentElementInameTag(),
"idof": ConcurrentDOFInameTag(),
})

@memoize_in(self.places, "elementwise_"+reduction_name)
@memoize_in(actx, (
EvaluationMapperBase._map_elementwise_reduction,
f"elementwise_element_{reduction_name}"))
def element_knl():
# FIXME: This computes the reduction value redundantly for each
# output DOF.
t_unit = make_loopy_program(
"""{[iel, jdof]:
0<=iel<nelements and
Expand All @@ -155,37 +157,27 @@ def element_knl():
"""
result[iel, 0] = %s(jdof, operand[iel, jdof])
""" % reduction_name,
name="elementwise_reduce")
name=f"elementwise_element_{reduction_name}")

return lp.tag_inames(t_unit, {
"iel": ConcurrentElementInameTag(),
})

discr = self.places.get_discretization(
expr.dofdesc.geometry, expr.dofdesc.discr_stage)
dofdesc = expr.dofdesc
operand = self.rec(expr.operand)
assert operand.shape == (len(discr.groups),)

def _reduce(knl, result):
for g_operand, g_result in zip(operand, result):
self.array_context.call_loopy(
knl, operand=g_operand, result=g_result)

return result

dtype = operand.entry_dtype
granularity = expr.dofdesc.granularity
if granularity is sym.GRANULARITY_NODE:
return _reduce(node_knl(),
discr.empty(self.array_context, dtype=dtype))
elif granularity is sym.GRANULARITY_ELEMENT:
result = DOFArray(self.array_context, tuple([
self.array_context.empty((grp.nelements, 1), dtype=dtype)
for grp in discr.groups

if dofdesc.granularity is sym.GRANULARITY_NODE:
return type(operand)(actx, tuple([
actx.call_loopy(node_knl(), operand=operand_i)["result"]
for operand_i in operand
]))
elif dofdesc.granularity is sym.GRANULARITY_ELEMENT:
return type(operand)(actx, tuple([
actx.call_loopy(element_knl(), operand=operand_i)["result"]
for operand_i in operand
]))
return _reduce(element_knl(), result)
else:
raise ValueError(f"unsupported granularity: {granularity}")
raise ValueError(f"unsupported granularity: {dofdesc.granularity}")

def map_elementwise_sum(self, expr):
return self._map_elementwise_reduction("sum", expr)
Expand Down
58 changes: 45 additions & 13 deletions test/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,26 +306,58 @@ def test_node_reduction(actx_factory):

# {{{ test

# create a shuffled [1, nelements + 1] array
ary = []
el_nr_base = 0
for grp in discr.groups:
x = 1 + np.arange(el_nr_base, grp.nelements)
np.random.shuffle(x)
# create a shuffled [1, ndofs + 1] array
rng = np.random.default_rng(seed=42)

ary.append(actx.freeze(actx.from_numpy(x.reshape(-1, 1))))
el_nr_base += grp.nelements
def randrange_like(xi, offset):
x = offset + np.arange(1, xi.size + 1)
rng.shuffle(x)

return actx.from_numpy(x.reshape(xi.shape))

from meshmode.dof_array import DOFArray
ary = DOFArray(actx, tuple(ary))
base_node_nrs = np.cumsum([0] + [grp.ndofs for grp in discr.groups])
ary = DOFArray(actx, tuple([
randrange_like(xi, offset)
for xi, offset in zip(discr.nodes()[0], base_node_nrs)
]))

n = discr.ndofs
for func, expected in [
(sym.NodeSum, nelements * (nelements + 1) // 2),
(sym.NodeMax, nelements),
(sym.NodeSum, n * (n + 1) // 2),
(sym.NodeMax, n),
(sym.NodeMin, 1),
]:
r = bind(discr, func(sym.var("x")))(actx, x=ary)
assert abs(actx.to_numpy(r) - expected) < 1.0e-15, r
r = actx.to_numpy(
bind(discr, func(sym.var("x")))(actx, x=ary)
)
assert r == expected, (r, expected)

arys = tuple([rng.random(size=xi.shape) for xi in ary])
x = DOFArray(actx, tuple([actx.from_numpy(xi) for xi in arys]))

from meshmode.dof_array import flat_norm
for func, np_func in [
(sym.ElementwiseSum, np.sum),
(sym.ElementwiseMax, np.max)
]:
expected = DOFArray(actx, tuple([
actx.from_numpy(np.tile(np_func(xi, axis=1, keepdims=True), xi.shape[1]))
for xi in arys
]))
r = bind(
discr, func(sym.var("x"), dofdesc=sym.GRANULARITY_NODE)
)(actx, x=x)
assert actx.to_numpy(flat_norm(r - expected)) < 1.0e-15

expected = DOFArray(actx, tuple([
actx.from_numpy(np_func(xi, axis=1, keepdims=True))
for xi in arys
]))
r = bind(
discr, func(sym.var("x"), dofdesc=sym.GRANULARITY_ELEMENT)
)(actx, x=x)
assert actx.to_numpy(flat_norm(r - expected)) < 1.0e-15

# }}}

Expand Down

0 comments on commit 879ad35

Please sign in to comment.