diff --git a/optimism/NewtonSolver.py b/optimism/NewtonSolver.py index 22ad8fd..f7b6694 100644 --- a/optimism/NewtonSolver.py +++ b/optimism/NewtonSolver.py @@ -1,3 +1,4 @@ +import numpy as onp from scipy.sparse.linalg import LinearOperator, gmres from optimism.JaxConfig import * @@ -27,11 +28,13 @@ def compute_min_p(ps, bounds): def newton_step(residual, linear_op, x, settings=Settings(1e-2,100), precond=None): - sz = x.size + sz = x.size + # The call to onp.array copies the jax array output into a plain numpy + # array. The copy is necessary for safety, since as far as scipy knows, + # it is allowed to modify the output in place. A = LinearOperator((sz,sz), - lambda v: linear_op(v)) - r = residual(x) - rNorm = np.linalg.norm(r) + lambda v: onp.array(linear_op(v))) + r = onp.array(residual(x)) numIters = 0 def callback(xk): @@ -41,13 +44,14 @@ def callback(xk): relTol = settings.relative_gmres_tol maxIters = settings.max_gmres_iters - if precond==None: - dx, exitcode = gmres(A, r, tol=relTol*rNorm, atol=0, callback_type='legacy', callback=callback, maxiter=maxIters) - else: + if precond is not None: + # Another copy to a plain numpy array, see comment for A above. M = LinearOperator((sz,sz), - lambda v: precond(v)) - dx, exitcode = gmres(A, r, tol=relTol*rNorm, atol=0, M=M, callback_type='legacy', callback=callback, maxiter=maxIters) + lambda v: onp.array(precond(v))) + else: + M = None + dx, exitcode = gmres(A, r, rtol=relTol, atol=0, M=M, callback_type='legacy', callback=callback, maxiter=maxIters) print('Number of GMRES iters = ', numIters) return -dx, exitcode