Skip to content

Commit

Permalink
add test to make sure a linear operator can be passed.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Oct 12, 2024
1 parent efaf7a8 commit bae9724
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/test_Scipy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pymatsolver import Solver, Diagonal, SolverCG, SolverLU
import scipy.sparse as sp
from scipy.sparse.linalg import aslinearoperator
import numpy as np
import numpy.testing as npt
import pytest
Expand Down Expand Up @@ -57,6 +58,16 @@ def test_solver(a_matrix, n_rhs, solver):

npt.assert_allclose(x, b, atol=tol)

def test_iterative_solver_linear_op():
n = 10
A = aslinearoperator(sp.eye(n))

Ainv = SolverCG(A)

rhs = np.linspace(0.9, 1.1, n)

npt.assert_allclose(Ainv @ rhs, rhs)

@pytest.mark.parametrize('n_rhs', [1, 5])
def test_diag_solver(n_rhs):
n = 10
Expand Down
5 changes: 4 additions & 1 deletion tests/test_Wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def test_wrapper_unused_kwargs(solver_class):
with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"):
solver_class(A, not_a_keyword_arg=True)


def test_good_arg_iterative():
# Ensure this doesn't throw a warning!
with warnings.catch_warnings():
warnings.simplefilter("error")
SolverCG(sp.eye(10), rtol=1e-4)


def test_good_arg_direct():
# Ensure this doesn't throw a warning!
with warnings.catch_warnings():
Expand All @@ -40,7 +42,6 @@ def __init__(self, A):
WrappedClass(sp.eye(2))



def test_direct_clean_function():
def direct_func(A):
class Empty():
Expand All @@ -67,6 +68,7 @@ def clean(self):
Ainv.clean()
assert Ainv.solver.A is None


def test_iterative_deprecations():

with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
Expand All @@ -75,6 +77,7 @@ def test_iterative_deprecations():
with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
wrap_iterative(lambda a, x: x, accuracy_tol=1E-3)


def test_non_scipy_iterative():
def iterative_solver(A, x):
return x
Expand Down

0 comments on commit bae9724

Please sign in to comment.