Skip to content

Commit

Permalink
rename test
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Sep 9, 2024
1 parent bdb72be commit d826e14
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tests/test_hessian.py → tests/test_vmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn.functional as F
from torch.func import hessian
from torch.func import hessian, jacfwd
import pytest
from torchlpc.core import LPC

Expand All @@ -20,9 +20,9 @@
),
],
)
def test_hessian(device: str):
batch_size = 1
samples = 10
def test_vmap(device: str):
batch_size = 4
samples = 40
x, A, zi = tuple(
x.to(device) for x in create_test_inputs(batch_size, samples, False)
)
Expand All @@ -32,13 +32,16 @@ def test_hessian(device: str):

A.requires_grad = True
zi.requires_grad = True
x.requires_grad = True

args = (A, zi)
args = (x, A, zi)

def func(A, zi):
def func(x, A, zi):
return F.mse_loss(LPC.apply(x, A[:, None, :].expand(-1, samples, -1), zi), y)

h = hessian(func, 0)(*args)
assert torch.any(h != 0)
jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

h_inv = torch.linalg.inv(h.squeeze())
loss = func(*args)
loss.backward()
for jac, arg in zip(jacs, args):
assert torch.allclose(jac, arg.grad)

0 comments on commit d826e14

Please sign in to comment.