-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
352 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
__copyright__ = "Copyright (C) 2022 Alexandru Fikl" | ||
|
||
__license__ = """ | ||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
The above copyright notice and this permission notice shall be included in | ||
all copies or substantial portions of the Software. | ||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
THE SOFTWARE. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Any, Dict, Iterable, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from arraycontext import PyOpenCLArrayContext, ArrayOrContainerT | ||
from meshmode.dof_array import DOFArray | ||
|
||
from pytential import GeometryCollection, sym | ||
from pytential.linalg.cluster import ClusterTree, cluster | ||
|
||
__doc__ = """ | ||
Hierarical Matrix Construction | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
""" | ||
|
||
|
||
# {{{ ProxyHierarchicalMatrix | ||
|
||
@dataclass(frozen=True) | ||
class ProxyHierarchicalMatrix: | ||
""" | ||
.. attribute:: ctree | ||
A :class:`~pytential.linalg.cluster.ClusterTree`. | ||
.. attribute:: skeletons | ||
An :class:`~numpy.ndarray` containing skeletonization information | ||
for each level of the hierarchy. For additional details, see | ||
:class:`~pytential.linalg.skeletonization.SkeletonizationResult`. | ||
This class implements the :class:`scipy.sparse.linalg.LinearOperator` | ||
interface. In particular, the following attributes and methods: | ||
.. attribute:: shape | ||
A :class:`tuple` that gives the matrix size ``(m, n)``. | ||
.. attribute:: dtype | ||
The data type of the matrix entries. | ||
.. automethod:: matvec | ||
.. automethod:: __matmul__ | ||
""" | ||
|
||
ctree: ClusterTree | ||
skeletons: np.ndarray | ||
|
||
@property | ||
def shape(self): | ||
return self.skeletons[0].tgt_src_index.shape | ||
|
||
@property | ||
def dtype(self): | ||
# FIXME: assert that everyone has this dtype? | ||
return self.skeletons[0].R[0].dtype | ||
|
||
@property | ||
def nlevels(self): | ||
return self.skeletons.size | ||
|
||
def matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT: | ||
"""Implements a matrix-vector multiplication :math:`H x`.""" | ||
from arraycontext import get_container_context_recursively_opt | ||
actx = get_container_context_recursively_opt(x) | ||
if actx is None: | ||
raise ValueError("input array is frozen") | ||
|
||
return apply_skeleton_matvec(actx, self, x) | ||
|
||
def __matmul__(self, x: ArrayOrContainerT) -> ArrayOrContainerT: | ||
"""Same as :meth:`matvec`.""" | ||
return self.matvec(x) | ||
|
||
def rmatvec(self, x): | ||
raise NotImplementedError | ||
|
||
def matmat(self, mat): | ||
raise NotImplementedError | ||
|
||
def rmatmat(self, mat): | ||
raise NotImplementedError | ||
|
||
|
||
def apply_skeleton_matvec( | ||
actx: PyOpenCLArrayContext, | ||
hmat: ProxyHierarchicalMatrix, | ||
x: ArrayOrContainerT, | ||
) -> ArrayOrContainerT: | ||
from arraycontext import flatten | ||
x = actx.to_numpy(flatten(x, actx, leaf_class=DOFArray)) | ||
|
||
from pytential.linalg.utils import split_array | ||
y = split_array(x, hmat.skeletons[0].tgt_src_index.sources) | ||
|
||
assert x.dtype == hmat.dtype | ||
assert x.shape == (hmat.shape[1],) | ||
|
||
d_dot_y = np.empty(hmat.nlevels, dtype=object) | ||
r_dot_y = np.empty(hmat.nlevels, dtype=object) | ||
|
||
# recurse down | ||
for k, clevel in enumerate(hmat.ctree.levels(root=True)): | ||
skeleton = hmat.skeletons[k] | ||
assert skeleton.tgt_src_index.shape[1] == sum(xi.size for xi in y) | ||
|
||
d_dot_y_k = np.empty(skeleton.nclusters, dtype=object) | ||
r_dot_y_k = np.empty(skeleton.nclusters, dtype=object) | ||
|
||
for i in range(skeleton.nclusters): | ||
r_dot_y_k[i] = skeleton.R[i] @ y[i] | ||
d_dot_y_k[i] = skeleton.D[i] @ y[i] | ||
|
||
r_dot_y[k] = r_dot_y_k | ||
d_dot_y[k] = d_dot_y_k | ||
y = cluster(r_dot_y_k, clevel) | ||
|
||
# recurse up | ||
for k, skeleton in reversed(list(enumerate(hmat.skeletons))): | ||
r_dot_y_k = r_dot_y[k] | ||
d_dot_y_k = d_dot_y[k] | ||
|
||
result = np.empty(skeleton.nclusters, dtype=object) | ||
for i in range(skeleton.nclusters): | ||
result[i] = skeleton.L[i] @ r_dot_y_k[i] + d_dot_y_k[i] | ||
|
||
from arraycontext import unflatten | ||
return unflatten( | ||
x, | ||
actx.from_numpy(np.concatenate(result)), | ||
actx) | ||
|
||
# }}} | ||
|
||
|
||
# {{{ build_hmatrix_matvec_by_proxy | ||
|
||
def build_hmatrix_matvec_by_proxy( | ||
actx: PyOpenCLArrayContext, | ||
places: GeometryCollection, | ||
exprs: Union[sym.Expression, Iterable[sym.Expression]], | ||
input_exprs: Union[sym.Expression, Iterable[sym.Expression]], *, | ||
domains: Optional[Iterable[sym.DOFDescriptorLike]] = None, | ||
context: Optional[Dict[str, Any]] = None, | ||
id_eps: float = 1.0e-8, | ||
|
||
# NOTE: these are dev variables and can disappear at any time! | ||
# TODO: plugin in error model to get an estimate for: | ||
# * how many points we want per cluster? | ||
# * how many proxy points we want? | ||
# * how far away should the proxy points be? | ||
# based on id_eps. How many of these should be user tunable? | ||
_tree_kind: Optional[str] = "adaptive-level-restricted", | ||
_max_particles_in_box: Optional[int] = None, | ||
|
||
_id_rank: Optional[int] = None, | ||
|
||
_approx_nproxy: Optional[int] = None, | ||
_proxy_radius_factor: Optional[float] = None, | ||
_proxy_cls: Optional[type] = None, | ||
): | ||
from pytential.linalg.cluster import partition_by_nodes | ||
cluster_index, ctree = partition_by_nodes( | ||
actx, places, | ||
tree_kind=_tree_kind, | ||
max_particles_in_box=_max_particles_in_box) | ||
|
||
from pytential.linalg.utils import TargetAndSourceClusterList | ||
tgt_src_index = TargetAndSourceClusterList( | ||
targets=cluster_index, sources=cluster_index) | ||
|
||
from pytential.linalg.skeletonization import rec_skeletonize_by_proxy | ||
skeletons = rec_skeletonize_by_proxy( | ||
actx, places, ctree, tgt_src_index, exprs, input_exprs, | ||
domains=domains, | ||
context=context, | ||
id_eps=id_eps, | ||
id_rank=_id_rank, | ||
approx_nproxy=_approx_nproxy, | ||
proxy_radius_factor=_proxy_radius_factor, | ||
max_particles_in_box=_max_particles_in_box, | ||
_proxy_cls=_proxy_cls, | ||
) | ||
|
||
return ProxyHierarchicalMatrix(ctree=ctree, skeletons=skeletons) | ||
|
||
# }}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.