Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parallel scan extension for CPU #17

Merged
merged 15 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Build CPP extension
run: |
python setup.py build
find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \;
- name: Test with pytest
run: |
pytest
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import setuptools
from torch.utils import cpp_extension

NAME = "torchlpc"
VERSION = "0.6"
VERSION = "0.7.dev"
MAINTAINER = "Chin-Yun Yu"
EMAIL = "[email protected]"

Expand All @@ -25,4 +26,8 @@
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
ext_modules=[
cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"])
],
cmdclass={"build_ext": cpp_extension.BuildExtension},
)
35 changes: 35 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn.functional as F
import pytest
from torchlpc.core import lpc_np


from .test_grad import create_test_inputs


@pytest.mark.parametrize(
"samples",
[64, 4097],
)
@pytest.mark.parametrize(
"cmplx",
[True, False],
)
def test_scan_cpu_equiv(samples: int, cmplx: bool):
batch_size = 4
x = torch.randn(
batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64
)
A = torch.rand_like(x) * 1.8 - 0.9
zi = torch.randn(batch_size, dtype=x.dtype)

numba_y = torch.from_numpy(
lpc_np(
x.cpu().numpy(),
-A.cpu().unsqueeze(2).numpy(),
zi.cpu().unsqueeze(1).numpy(),
)
)
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)

assert torch.allclose(numba_y, ext_y)
28 changes: 20 additions & 8 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA
from torchlpc.recurrence import Recurrence


def get_random_biquads(cmplx=False):
Expand Down Expand Up @@ -123,21 +123,33 @@ def test_float64_vs_32_cuda():
"zi_requires_grad",
[True, False],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan(
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
device: str,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device=device)

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True)
assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi))
assert gradcheck(Recurrence.apply, (A, x, zi), check_forward_ad=True)
assert gradgradcheck(Recurrence.apply, (A, x, zi))
27 changes: 19 additions & 8 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.func import jacfwd
import pytest
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA
from torchlpc.recurrence import Recurrence


from .test_grad import create_test_inputs
Expand Down Expand Up @@ -48,14 +48,25 @@ def func(x, A, zi):
assert torch.allclose(jac, arg.grad)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan_vmap():
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_parallel_scan_vmap(device: str):
batch_size = 3
samples = 255
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")
y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device=device)
y = torch.randn(batch_size, samples, dtype=torch.double, device=device)

A.requires_grad = True
x.requires_grad = True
Expand All @@ -64,7 +75,7 @@ def test_cuda_parallel_scan_vmap():
args = (x, A, zi)

def func(x, A, zi):
return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y)
return F.mse_loss(Recurrence.apply(A, x, zi), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

Expand Down
24 changes: 20 additions & 4 deletions torchlpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import torch
from typing import Optional
from pathlib import Path
import warnings

so_files = list(Path(__file__).parent.glob("_C*.so"))
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
if len(so_files) == 1:
torch.ops.load_library(so_files[0])
EXTENSION_LOADED = True
elif len(so_files) > 1:
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
else:
warnings.warn("No _C*.so file found. Custom extension not loaded.")
EXTENSION_LOADED = False

from .core import LPC
from .parallel_scan import WARPSIZE
from .recurrence import RecurrenceCUDA

# from .parallel_scan import WARPSIZE
from .recurrence import Recurrence

__all__ = ["sample_wise_lpc"]

Expand Down Expand Up @@ -37,7 +51,9 @@ def sample_wise_lpc(
else:
assert zi.shape == (B, order)

if order == 1 and x.is_cuda and B * WARPSIZE < T:
return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
# if order == 1 and x.is_cuda and B * WARPSIZE < T:
# return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
if order == 1:
return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1))

return LPC.apply(x, a, zi)
86 changes: 86 additions & 0 deletions torchlpc/csrc/scan_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <torch/script.h>
#include <torch/torch.h>

#include <algorithm>
#include <utility>
#include <vector>

template <typename scalar_t>
void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials, const at::Tensor &output) {
TORCH_CHECK(input.dim() == 2, "Input must be 2D");
TORCH_CHECK(initials.dim() == 1, "Initials must be 1D");
TORCH_CHECK(weights.sizes() == input.sizes(),
"Weights must have the same size as input");
TORCH_CHECK(output.sizes() == input.sizes(),
"Output must have the same size as input");
TORCH_CHECK(initials.size(0) == input.size(0),
"The first dimension of initials must be the same as the first "
"dimension of input");
TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU");
TORCH_INTERNAL_ASSERT(initials.device().is_cpu(),
"Initials must be on CPU");
TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU");
TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU");
TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous");

auto input_contiguous = input.contiguous();
auto weights_contiguous = weights.contiguous();
auto initials_contiguous = initials.contiguous();

auto n_batch = input.size(0);
auto T = input.size(1);
auto total_size = input.numel();

std::pair<scalar_t, scalar_t> buffer[total_size];

const scalar_t *input_ptr = input_contiguous.data_ptr<scalar_t>();
const scalar_t *initials_ptr = initials_contiguous.data_ptr<scalar_t>();
const scalar_t *weights_ptr = weights_contiguous.data_ptr<scalar_t>();
scalar_t *output_ptr = output.data_ptr<scalar_t>();

std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
[](const scalar_t &a, const scalar_t &b) {
return std::make_pair(a, b);
});

at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++) {
std::inclusive_scan(
buffer + b * T, buffer + (b + 1) * T, buffer + b * T,
[](const std::pair<scalar_t, scalar_t> &a,
const std::pair<scalar_t, scalar_t> &b) {
return std::make_pair(a.first * b.first,
a.second * b.first + b.second);
},
std::make_pair((scalar_t)1.0, initials_ptr[b]));
}
});

std::transform(
buffer, buffer + total_size, output_ptr,
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
}

at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials) {
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(initials.scalar_type() == input.scalar_type(),
"Initials must have the same scalar type as input");
TORCH_CHECK(weights.scalar_type() == input.scalar_type(),
"Weights must have the same scalar type as input");

auto output = at::empty_like(input);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
input.scalar_type(), "scan_cpu",
[&] { scan_cpu<scalar_t>(input, weights, initials, output); });
return output;
}

TORCH_LIBRARY(torchlpc, m) {
m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); }
50 changes: 36 additions & 14 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,45 @@
from numba import cuda
from typing import Tuple, Optional, Any, List

from .parallel_scan import compute_linear_recurrence
from .parallel_scan import compute_linear_recurrence, WARPSIZE
from .core import lpc_cuda, lpc_np
from . import EXTENSION_LOADED


class RecurrenceCUDA(Function):
class Recurrence(Function):
@staticmethod
def forward(
decay: torch.Tensor,
impulse: torch.Tensor,
initial_state: torch.Tensor,
) -> torch.Tensor:
n_dims, n_steps = decay.shape
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
if decay.is_cuda:
if n_dims * WARPSIZE < n_steps:
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
else:
out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1))
else:
num_threads = torch.get_num_threads()
# This is just a rough estimation of the computational cost
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state)
else:
out = torch.from_numpy(
lpc_np(
impulse.detach().numpy(),
-decay.unsqueeze(2).detach().numpy(),
initial_state.unsqueeze(1).detach().numpy(),
)
)
return out

@staticmethod
Expand All @@ -48,7 +67,7 @@ def backward(
padded_decay = padded_decay[:, 1:]

init = padded_grad_out.new_zeros(n_dims)
flipped_grad_impulse = RecurrenceCUDA.apply(
flipped_grad_impulse = Recurrence.apply(
padded_decay.flip(1).conj_physical(),
padded_grad_out.flip(1),
init,
Expand Down Expand Up @@ -91,7 +110,7 @@ def jvp(
fwd_decay = concat_out * grad_decay
fwd_impulse = fwd_impulse + fwd_decay

return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state)
return Recurrence.apply(decay, fwd_impulse, fwd_initial_state)

@staticmethod
def vmap(info, in_dims, *args):
Expand All @@ -107,5 +126,8 @@ def maybe_expand_bdim_at_front(x, x_bdim):
)
)

out = RecurrenceCUDA.apply(decay, impulse, initial_state)
out = Recurrence.apply(decay, impulse, initial_state)
return out.reshape(info.batch_size, -1, *out.shape[1:]), 0


RecurrenceCUDA = Recurrence
Loading