Skip to content

Commit

Permalink
Merge pull request #607 from AxelBreuer/solve_banded_branch
Browse files Browse the repository at this point in the history
add support of scipy.linalg.solve_banded()
  • Loading branch information
j-towns authored Nov 16, 2023
2 parents e18f656 + 526051c commit 9a90bd6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
57 changes: 57 additions & 0 deletions autograd/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import division
from functools import partial
import scipy.linalg

import autograd.numpy as anp
Expand Down Expand Up @@ -35,6 +36,62 @@ def vjp(g):
lambda ans, a, b, trans=0, lower=False, **kwargs:
lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower))

def grad_solve_banded(argnum, ans, l_and_u, a, b):

updim = lambda x: x if x.ndim == a.ndim else x[...,None]

def transpose_banded(l_and_u, a):

# Compute the transpose of a banded matrix.
# The transpose is itself a banded matrix.

num_rows = a.shape[0]

shifts = anp.arange(-l_and_u[1], l_and_u[0]+1)

T_a = anp.roll(a[:1, :], shifts[0])
for rr in range(1, num_rows):
T_a = anp.vstack([T_a, anp.flipud(anp.roll(a[rr:rr+1, :], shifts[rr]))])
T_a = anp.flipud(T_a)

T_l_and_u = anp.flip(l_and_u)

return T_l_and_u, T_a

def banded_dot(l_and_u, uu, vv):

# Compute tensor product of vectors uu and vv.
# Tensor product elements are resticted to the bands specified by l_and_u.

# TODO: replace the brute-force ravel() by smarter dimension handeling of uu and vv

# main diagonal
banded_uv = anp.ravel(uu)*anp.ravel(vv)

# stack below the sub-diagonals
for rr in range(1, l_and_u[0]+1):
banded_uv_rr = anp.hstack([anp.ravel(uu)[rr:]*anp.ravel(vv)[:-rr], anp.zeros(rr)])
banded_uv = anp.vstack([banded_uv, banded_uv_rr])

# stack above the sup-diagonals
for rr in range(1, l_and_u[1]+1):
banded_uv_rr = anp.hstack([anp.zeros(rr), anp.ravel(uu)[:-rr]*anp.ravel(vv)[rr:]])
banded_uv = anp.vstack([banded_uv_rr, banded_uv])

return(banded_uv)

T_l_and_u, T_a = transpose_banded(l_and_u, a)

if argnum == 1:
return lambda g: -banded_dot(l_and_u, updim(solve_banded(T_l_and_u, T_a, g)), anp.transpose(updim(ans)))
elif argnum == 2:
return lambda g: solve_banded(T_l_and_u, T_a, g)

defvjp(solve_banded,
partial(grad_solve_banded, 1),
partial(grad_solve_banded, 2),
argnums=[1, 2])

def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
assert disp, "sqrtm jvp not implemented for disp=False"
return solve_sylvester(ans, ans, dA)
Expand Down
1 change: 1 addition & 0 deletions tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,4 @@ def test_odeint():
def test_sqrtm(): combo_check(spla.sqrtm, modes=['fwd'], order=2)([R(3, 3)])
def test_sqrtm(): combo_check(symmetrize_matrix_arg(spla.sqrtm, 0), modes=['fwd', 'rev'], order=2)([R(3, 3)])
def test_solve_sylvester(): combo_check(spla.solve_sylvester, [0, 1, 2], modes=['rev', 'fwd'], order=2)([R(3, 3)], [R(3, 3)], [R(3, 3)])
def test_solve_banded(): combo_check(spla.solve_banded, [1, 2], modes=['rev'], order=1)([(1, 1)], [R(3,5)], [R(5)])

0 comments on commit 9a90bd6

Please sign in to comment.