Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Apr 18, 2024
1 parent 2170835 commit e8ac293
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@
from torchlpc.core import LPC


def get_random_biquads(cplx=False):
if cplx:
def get_random_biquads(cmplx=False):
if cmplx:
mag = torch.rand(2, dtype=torch.double)
phase = torch.rand(2, dtype=torch.double) * 2 * torch.pi
roots = mag * torch.exp(1j * phase)
return torch.tensor(
[-roots[0] - roots[1], roots[0] * roots[1]], dtype=torch.complex128
)
mag = torch.rand(1, dtype=torch.double)
phase = torch.rand(1, dtype=torch.double) * 2 * torch.pi
phase = torch.rand(1, dtype=torch.double) * torch.pi
return torch.tensor([-mag * torch.cos(phase) * 2, mag**2], dtype=torch.double)


def create_test_inputs(batch_size, samples, cplx=False):
start_coeffs = get_random_biquads(cplx)
end_coeffs = get_random_biquads(cplx)
dtype = torch.complex128 if cplx else torch.double
def create_test_inputs(batch_size, samples, cmplx=False):
start_coeffs = get_random_biquads(cmplx)
end_coeffs = get_random_biquads(cmplx)
dtype = torch.complex128 if cmplx else torch.double

A = (
torch.stack(
Expand Down Expand Up @@ -55,7 +55,7 @@ def create_test_inputs(batch_size, samples, cplx=False):
)
@pytest.mark.parametrize(
"cmplx",
[True],
[True, False],
)
@pytest.mark.parametrize(
"device",
Expand Down

0 comments on commit e8ac293

Please sign in to comment.