Skip to content

Commit

Permalink
port connections away from is_array_container
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Oct 24, 2021
1 parent 1ef686c commit dd24703
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 23 deletions.
30 changes: 24 additions & 6 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
ConcurrentElementInameTag, ConcurrentDOFInameTag)
from pytools import memoize_in, keyed_memoize_method
from arraycontext import (
ArrayContext, make_loopy_program,
is_array_container_type, map_array_container)
ArrayContext, NotAnArrayContainerError,
serialize_container, deserialize_container, make_loopy_program)


# {{{ interpolation batch
Expand Down Expand Up @@ -318,12 +318,30 @@ def full_resample_matrix(self, actx):
return make_direct_full_resample_matrix(actx, self)

def __call__(self, ary, _force_no_inplace_updates=False):
# _force_no_inplace_updates: Only used to ensure test coverage
# of both code paths.
"""
:arg ary: a :class:`~meshmode.dof_array.DOFArray`, or an
:class:`arraycontext.ArrayContainer` of them, containing nodal
coefficient data on :attr:`from_discr`.
:arg _force_no_inplace_updates: private argument only used to ensure
test coverge of all code paths.
"""

# {{{ recurse into array containers

from meshmode.dof_array import DOFArray
if is_array_container_type(ary) and not isinstance(ary, DOFArray):
return map_array_container(self, ary)
if not isinstance(ary, DOFArray):
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return deserialize_container(ary, [
(key, self(subary, _force_no_inplace_updates))
for key, subary in iterable
])

# }}}

if __debug__:
from meshmode.dof_array import check_dofarray_against_discr
Expand Down
43 changes: 34 additions & 9 deletions meshmode/discretization/connection/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
import numpy.linalg as la
import modepy as mp

from arraycontext import (
NotAnArrayContainerError, serialize_container, deserialize_container)
from meshmode.transform_metadata import FirstAxisIsElementsTag
from arraycontext import is_array_container_type, map_array_container
from meshmode.discretization import InterpolatoryElementGroupBase
from meshmode.discretization.poly_element import QuadratureSimplexElementGroup
from meshmode.discretization.connection.direct import DiscretizationConnection
Expand Down Expand Up @@ -170,12 +171,24 @@ def __call__(self, ary):
"""Computes modal coefficients data from a functions
nodal coefficients.
:arg ary: a :class:`meshmode.dof_array.DOFArray` containing
nodal coefficient data.
:arg ary: a :class:`~meshmode.dof_array.DOFArray`, or an
:class:`arraycontext.ArrayContainer` of them, containing nodal
coefficient data.
"""
# {{{ recurse into array containers

from meshmode.dof_array import DOFArray
if is_array_container_type(ary) and not isinstance(ary, DOFArray):
return map_array_container(self, ary)
if not isinstance(ary, DOFArray):
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return deserialize_container(ary, [
(key, self(subary)) for key, subary in iterable
])

# }}}

if not isinstance(ary, DOFArray):
raise TypeError("Non-array passed to discretization connection")
Expand Down Expand Up @@ -291,12 +304,24 @@ def __init__(self, from_discr, to_discr):
def __call__(self, ary):
"""Computes nodal coefficients from modal data.
:arg ary: a :class:`meshmode.dof_array.DOFArray` containing
modal coefficient data.
:arg ary: a :class:`~meshmode.dof_array.DOFArray`, or an
:class:`arraycontext.ArrayContainer` of them, containing modal
coefficient data.
"""
# {{{ recurse into array containers

from meshmode.dof_array import DOFArray
if is_array_container_type(ary) and not isinstance(ary, DOFArray):
return map_array_container(self, ary)
if not isinstance(ary, DOFArray):
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return deserialize_container(ary, [
(key, self(subary)) for key, subary in iterable
])

# }}}

if not isinstance(ary, DOFArray):
raise TypeError("Non-array passed to discretization connection")
Expand Down
24 changes: 21 additions & 3 deletions meshmode/discretization/connection/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
import loopy as lp

from arraycontext import (
make_loopy_program, is_array_container_type, map_array_container)
NotAnArrayContainerError,
make_loopy_program, serialize_container, deserialize_container)
from meshmode.transform_metadata import FirstAxisIsElementsTag
from meshmode.discretization.connection.direct import (
DiscretizationConnection,
Expand Down Expand Up @@ -119,9 +120,26 @@ def det(v):
return weights

def __call__(self, ary):
"""
:arg ary: a :class:`~meshmode.dof_array.DOFArray`, or an
:class:`arraycontext.ArrayContainer` of them, containing nodal
coefficient data on :attr:`from_discr`.
"""

# {{{ recurse into array containers

from meshmode.dof_array import DOFArray
if is_array_container_type(ary) and not isinstance(ary, DOFArray):
return map_array_container(self, ary)
if not isinstance(ary, DOFArray):
try:
iterable = serialize_container(ary)
except NotAnArrayContainerError:
pass
else:
return deserialize_container(ary, [
(key, self(subary)) for key, subary in iterable
])

# }}}

if __debug__:
from meshmode.dof_array import check_dofarray_against_discr
Expand Down
22 changes: 17 additions & 5 deletions meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
from arraycontext import (
ArrayContext, ArrayContainerTypeError,
ArrayContext, NotAnArrayContainerError,
make_loopy_program, with_container_arithmetic,
serialize_container, deserialize_container,
thaw as _thaw, freeze as _freeze,
Expand Down Expand Up @@ -361,8 +361,20 @@ def rec_map_dof_array_container(f: Callable[[Any], Any], ary):
Similar to :func:`~arraycontext.map_array_container`, but
does not further recurse on :class:`DOFArray`\ s.
"""
from arraycontext.container.traversal import _map_array_container_impl
return _map_array_container_impl(f, ary, leaf_cls=DOFArray, recursive=True)
def rec(_ary):
if isinstance(_ary, DOFArray):
return f(_ary)

try:
iterable = serialize_container(_ary)
except NotAnArrayContainerError:
return f(_ary)
else:
return deserialize_container(_ary, [
(key, rec(subary)) for key, subary in iterable
])

return rec(ary)


def mapped_over_dof_arrays(f):
Expand Down Expand Up @@ -613,7 +625,7 @@ def _unflatten_like(_ary, _prototype):
iterable = zip(
serialize_container(_ary),
serialize_container(_prototype))
except ArrayContainerTypeError:
except NotAnArrayContainerError:
if strict:
raise ValueError("cannot unflatten array "
f"with prototype '{type(_prototype).__name__}'; "
Expand Down Expand Up @@ -734,7 +746,7 @@ def _rec(_ary):

try:
iterable = serialize_container(_ary)
except ArrayContainerTypeError:
except NotAnArrayContainerError:
raise TypeError(f"unsupported array type: '{type(_ary).__name__}'")
else:
arys = [_rec(subary) for _, subary in iterable]
Expand Down

0 comments on commit dd24703

Please sign in to comment.