Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 8, 2023
1 parent b8a441d commit 5062dd1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 31 deletions.
51 changes: 22 additions & 29 deletions pytential/linalg/hmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
"""

from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Optional, Union
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union

import numpy as np
import numpy.linalg as la
from scipy.sparse.linalg import LinearOperator

from arraycontext import PyOpenCLArrayContext, ArrayOrContainerT, flatten, unflatten
from meshmode.dof_array import DOFArray
Expand All @@ -35,13 +36,6 @@
from pytential.linalg.skeletonization import (
SkeletonizationWrangler, SkeletonizationResult)

try:
from scipy.sparse.linalg import LinearOperator
except ImportError:
# NOTE: scipy should be available (for interp_decomp), but just in case
class LinearOperator:
pass

import logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,7 +118,7 @@ def _update_skeleton_diagonal(
targets, sources = parent.skel_tgt_src_index

# FIXME: nicer way to do this?
mat = np.empty(skeleton.nclusters, dtype=object)
mat: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
for k in range(skeleton.nclusters):
D = skeleton.D[k].copy()

Expand All @@ -146,9 +140,9 @@ def _update_skeleton_diagonal(

def _update_skeletons_diagonal(
wrangler: "ProxyHierarchicalMatrixWrangler",
func: Callable[[SkeletonizationResult], np.ndarray],
func: Callable[[SkeletonizationResult], Optional[np.ndarray]],
) -> np.ndarray:
skeletons = np.empty(wrangler.skeletons.shape, dtype=object)
skeletons: np.ndarray = np.empty(wrangler.skeletons.shape, dtype=object)
skeletons[0] = wrangler.skeletons[0]

for i in range(1, wrangler.ctree.nlevels):
Expand Down Expand Up @@ -263,11 +257,14 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
else:
raise TypeError(f"unsupported input type: {type(x)}")

assert actx is None or isinstance(actx, PyOpenCLArrayContext)
result = apply_skeleton_forward_matvec(self, ary)

if isinstance(x, DOFArray):
assert actx is not None
result = unflatten(x, actx.from_numpy(result), actx)

return result
return result # type: ignore[return-value]


def apply_skeleton_forward_matvec(
Expand All @@ -276,7 +273,7 @@ def apply_skeleton_forward_matvec(
) -> ArrayOrContainerT:
from pytential.linalg.cluster import split_array
targets, sources = hmat.skeletons[0].tgt_src_index
x = split_array(ary, sources)
x = split_array(ary, sources) # type: ignore[arg-type]

# NOTE: this computes a telescoping product of the form
#
Expand All @@ -297,7 +294,7 @@ def apply_skeleton_forward_matvec(
#
# which gives back the desired product when we reach the leaf level again.

d_dot_x = np.empty(hmat.nlevels, dtype=object)
d_dot_x: np.ndarray = np.empty(hmat.nlevels, dtype=object)

# {{{ recurse down

Expand All @@ -307,8 +304,8 @@ def apply_skeleton_forward_matvec(
assert x.shape == (skeleton.nclusters,)
assert skeleton.tgt_src_index.shape[1] == sum([xi.size for xi in x])

d_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
r_dot_x_k = np.empty(skeleton.nclusters, dtype=object)
d_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
r_dot_x_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)

for i in range(skeleton.nclusters):
r_dot_x_k[i] = skeleton.R[i] @ x[i]
Expand Down Expand Up @@ -366,23 +363,26 @@ def _matvec(self, x: ArrayOrContainerT) -> ArrayOrContainerT:
else:
raise TypeError(f"unsupported input type: {type(x)}")

assert actx is None or isinstance(actx, PyOpenCLArrayContext)
result = apply_skeleton_backward_matvec(actx, self, ary)

if isinstance(x, DOFArray):
assert actx is not None
result = unflatten(x, actx.from_numpy(result), actx)

return result
return result # type: ignore[return-value]


def apply_skeleton_backward_matvec(
actx: PyOpenCLArrayContext,
actx: Optional[PyOpenCLArrayContext],
hmat: ProxyHierarchicalMatrix,
ary: ArrayOrContainerT,
) -> ArrayOrContainerT:
from pytential.linalg.cluster import split_array
targets, sources = hmat.skeletons[0].tgt_src_index

b = split_array(ary, targets)
r_dot_b = np.empty(hmat.nlevels, dtype=object)
b = split_array(ary, targets) # type: ignore[arg-type]
r_dot_b: np.ndarray = np.empty(hmat.nlevels, dtype=object)

# {{{ recurse down

Expand Down Expand Up @@ -412,7 +412,7 @@ def apply_skeleton_backward_matvec(
assert b.shape == (skeleton.nclusters,)
assert skeleton.tgt_src_index.shape[0] == sum([bi.size for bi in b])

dhat_dot_b_k = np.empty(skeleton.nclusters, dtype=object)
dhat_dot_b_k: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
for i in range(skeleton.nclusters):
dhat_dot_b_k[i] = (
skeleton.Dhat[i] @ (skeleton.R[i] @ (skeleton.invD[i] @ b[i]))
Expand Down Expand Up @@ -467,7 +467,7 @@ def build_hmatrix_by_proxy(
exprs: Union[sym.Expression, Iterable[sym.Expression]],
input_exprs: Union[sym.Expression, Iterable[sym.Expression]], *,
auto_where: Optional[sym.DOFDescriptorLike] = None,
domains: Optional[Iterable[sym.DOFDescriptorLike]] = None,
domains: Optional[Sequence[sym.DOFDescriptorLike]] = None,
context: Optional[Dict[str, Any]] = None,
id_eps: float = 1.0e-8,

Expand All @@ -483,13 +483,6 @@ def build_hmatrix_by_proxy(
_approx_nproxy: Optional[int] = None,
_proxy_radius_factor: Optional[float] = None,
) -> ProxyHierarchicalMatrixWrangler:
try:
import scipy # noqa: F401
except ImportError:
raise ImportError(
"The direct solver requires 'scipy' for the interpolative "
"decomposition used in skeletonization")

from pytential.symbolic.matrix import P2PClusterMatrixBuilder
from pytential.linalg.skeletonization import make_skeletonization_wrangler

Expand Down
4 changes: 2 additions & 2 deletions pytential/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ def mnorm(x: np.ndarray, y: np.ndarray) -> "np.floating[Any]":
def skeletonization_matrix(
mat: np.ndarray, skeleton: "SkeletonizationResult",
) -> Tuple[np.ndarray, np.ndarray]:
D = np.empty(skeleton.nclusters, dtype=object)
S = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)
D: np.ndarray = np.empty(skeleton.nclusters, dtype=object)
S: np.ndarray = np.empty((skeleton.nclusters, skeleton.nclusters), dtype=object)

from itertools import product
for i, j in product(range(skeleton.nclusters), repeat=2):
Expand Down

0 comments on commit 5062dd1

Please sign in to comment.