Skip to content

Commit

Permalink
Merge pull request #81 from btalamini/fix_gmres_errors
Browse files Browse the repository at this point in the history
Fix gmres errors in CI triggered by new Scipy version
  • Loading branch information
btalamini authored Feb 28, 2024
2 parents c1450ca + 6025ebd commit be8a2fe
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions optimism/NewtonSolver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as onp
from scipy.sparse.linalg import LinearOperator, gmres

from optimism.JaxConfig import *
Expand Down Expand Up @@ -27,11 +28,13 @@ def compute_min_p(ps, bounds):


def newton_step(residual, linear_op, x, settings=Settings(1e-2,100), precond=None):
sz = x.size
sz = x.size
# The call to onp.array copies the jax array output into a plain numpy
# array. The copy is necessary for safety, since as far as scipy knows,
# it is allowed to modify the output in place.
A = LinearOperator((sz,sz),
lambda v: linear_op(v))
r = residual(x)
rNorm = np.linalg.norm(r)
lambda v: onp.array(linear_op(v)))
r = onp.array(residual(x))

numIters = 0
def callback(xk):
Expand All @@ -41,13 +44,14 @@ def callback(xk):
relTol = settings.relative_gmres_tol
maxIters = settings.max_gmres_iters

if precond==None:
dx, exitcode = gmres(A, r, tol=relTol*rNorm, atol=0, callback_type='legacy', callback=callback, maxiter=maxIters)
else:
if precond is not None:
# Another copy to a plain numpy array, see comment for A above.
M = LinearOperator((sz,sz),
lambda v: precond(v))
dx, exitcode = gmres(A, r, tol=relTol*rNorm, atol=0, M=M, callback_type='legacy', callback=callback, maxiter=maxIters)
lambda v: onp.array(precond(v)))
else:
M = None

dx, exitcode = gmres(A, r, rtol=relTol, atol=0, M=M, callback_type='legacy', callback=callback, maxiter=maxIters)
print('Number of GMRES iters = ', numIters)

return -dx, exitcode
Expand Down

0 comments on commit be8a2fe

Please sign in to comment.