Skip to content

Commit

Permalink
feat: parallel scan extension for CPU (#17)
Browse files Browse the repository at this point in the history
* draft: scan extension on cpu

* include cpp extension in setup.py

* fix: remove extra arg

* fix: compile errors

* use dev versioning

* load library file when being imported

* test equivalence to numba version

* refactor: return tensor instead of void function

* refactor: rename RecurrenceCUDA to Recurrence to cover CPU device

* refactor: update functions to use Recurrence for CPU and CUDA devices

* refactor: remove contiguous check besides output tensor

* refactor: add warning for missing _C*.so file and check extension loading in Recurrence

* ci: add workflow step to build CPP extension and copy shared objects

* apply suggestions and remove comments

* refactor: apply google style format
  • Loading branch information
yoyolicoris authored Jan 24, 2025
1 parent eb7b778 commit 5801748
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 35 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ jobs:
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Build CPP extension
run: |
python setup.py build
find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \;
- name: Test with pytest
run: |
pytest
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import setuptools
from torch.utils import cpp_extension

NAME = "torchlpc"
VERSION = "0.6"
VERSION = "0.7.dev"
MAINTAINER = "Chin-Yun Yu"
EMAIL = "chin-yun.yu@qmul.ac.uk"

Expand All @@ -25,4 +26,8 @@
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
ext_modules=[
cpp_extension.CppExtension("torchlpc._C", ["torchlpc/csrc/scan_cpu.cpp"])
],
cmdclass={"build_ext": cpp_extension.BuildExtension},
)
35 changes: 35 additions & 0 deletions tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn.functional as F
import pytest
from torchlpc.core import lpc_np


from .test_grad import create_test_inputs


@pytest.mark.parametrize(
"samples",
[64, 4097],
)
@pytest.mark.parametrize(
"cmplx",
[True, False],
)
def test_scan_cpu_equiv(samples: int, cmplx: bool):
batch_size = 4
x = torch.randn(
batch_size, samples, dtype=torch.float32 if not cmplx else torch.complex64
)
A = torch.rand_like(x) * 1.8 - 0.9
zi = torch.randn(batch_size, dtype=x.dtype)

numba_y = torch.from_numpy(
lpc_np(
x.cpu().numpy(),
-A.cpu().unsqueeze(2).numpy(),
zi.cpu().unsqueeze(1).numpy(),
)
)
ext_y = torch.ops.torchlpc.scan_cpu(x, A, zi)

assert torch.allclose(numba_y, ext_y)
28 changes: 20 additions & 8 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.autograd.gradcheck import gradcheck, gradgradcheck
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA
from torchlpc.recurrence import Recurrence


def get_random_biquads(cmplx=False):
Expand Down Expand Up @@ -123,21 +123,33 @@ def test_float64_vs_32_cuda():
"zi_requires_grad",
[True, False],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan(
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
device: str,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device=device)

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(RecurrenceCUDA.apply, (A, x, zi), check_forward_ad=True)
assert gradgradcheck(RecurrenceCUDA.apply, (A, x, zi))
assert gradcheck(Recurrence.apply, (A, x, zi), check_forward_ad=True)
assert gradgradcheck(Recurrence.apply, (A, x, zi))
27 changes: 19 additions & 8 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.func import jacfwd
import pytest
from torchlpc.core import LPC
from torchlpc.recurrence import RecurrenceCUDA
from torchlpc.recurrence import Recurrence


from .test_grad import create_test_inputs
Expand Down Expand Up @@ -48,14 +48,25 @@ def func(x, A, zi):
assert torch.allclose(jac, arg.grad)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan_vmap():
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param(
"cuda",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
],
)
def test_parallel_scan_vmap(device: str):
batch_size = 3
samples = 255
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device="cuda")
y = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
x = torch.randn(batch_size, samples, dtype=torch.double, device=device)
A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1
zi = torch.randn(batch_size, dtype=torch.double, device=device)
y = torch.randn(batch_size, samples, dtype=torch.double, device=device)

A.requires_grad = True
x.requires_grad = True
Expand All @@ -64,7 +75,7 @@ def test_cuda_parallel_scan_vmap():
args = (x, A, zi)

def func(x, A, zi):
return F.mse_loss(RecurrenceCUDA.apply(A, x, zi), y)
return F.mse_loss(Recurrence.apply(A, x, zi), y)

jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args)

Expand Down
24 changes: 20 additions & 4 deletions torchlpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import torch
from typing import Optional
from pathlib import Path
import warnings

so_files = list(Path(__file__).parent.glob("_C*.so"))
# assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
if len(so_files) == 1:
torch.ops.load_library(so_files[0])
EXTENSION_LOADED = True
elif len(so_files) > 1:
raise ValueError(f"Expected one _C*.so file, found {len(so_files)}")
else:
warnings.warn("No _C*.so file found. Custom extension not loaded.")
EXTENSION_LOADED = False

from .core import LPC
from .parallel_scan import WARPSIZE
from .recurrence import RecurrenceCUDA

# from .parallel_scan import WARPSIZE
from .recurrence import Recurrence

__all__ = ["sample_wise_lpc"]

Expand Down Expand Up @@ -37,7 +51,9 @@ def sample_wise_lpc(
else:
assert zi.shape == (B, order)

if order == 1 and x.is_cuda and B * WARPSIZE < T:
return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
# if order == 1 and x.is_cuda and B * WARPSIZE < T:
# return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))
if order == 1:
return Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1))

return LPC.apply(x, a, zi)
86 changes: 86 additions & 0 deletions torchlpc/csrc/scan_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include <torch/script.h>
#include <torch/torch.h>

#include <algorithm>
#include <utility>
#include <vector>

template <typename scalar_t>
void scan_cpu(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials, const at::Tensor &output) {
TORCH_CHECK(input.dim() == 2, "Input must be 2D");
TORCH_CHECK(initials.dim() == 1, "Initials must be 1D");
TORCH_CHECK(weights.sizes() == input.sizes(),
"Weights must have the same size as input");
TORCH_CHECK(output.sizes() == input.sizes(),
"Output must have the same size as input");
TORCH_CHECK(initials.size(0) == input.size(0),
"The first dimension of initials must be the same as the first "
"dimension of input");
TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU");
TORCH_INTERNAL_ASSERT(initials.device().is_cpu(),
"Initials must be on CPU");
TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU");
TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU");
TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous");

auto input_contiguous = input.contiguous();
auto weights_contiguous = weights.contiguous();
auto initials_contiguous = initials.contiguous();

auto n_batch = input.size(0);
auto T = input.size(1);
auto total_size = input.numel();

std::pair<scalar_t, scalar_t> buffer[total_size];

const scalar_t *input_ptr = input_contiguous.data_ptr<scalar_t>();
const scalar_t *initials_ptr = initials_contiguous.data_ptr<scalar_t>();
const scalar_t *weights_ptr = weights_contiguous.data_ptr<scalar_t>();
scalar_t *output_ptr = output.data_ptr<scalar_t>();

std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer,
[](const scalar_t &a, const scalar_t &b) {
return std::make_pair(a, b);
});

at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) {
for (auto b = start; b < end; b++) {
std::inclusive_scan(
buffer + b * T, buffer + (b + 1) * T, buffer + b * T,
[](const std::pair<scalar_t, scalar_t> &a,
const std::pair<scalar_t, scalar_t> &b) {
return std::make_pair(a.first * b.first,
a.second * b.first + b.second);
},
std::make_pair((scalar_t)1.0, initials_ptr[b]));
}
});

std::transform(
buffer, buffer + total_size, output_ptr,
[](const std::pair<scalar_t, scalar_t> &a) { return a.second; });
}

at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights,
const at::Tensor &initials) {
TORCH_CHECK(input.is_floating_point() || input.is_complex(),
"Input must be floating point or complex");
TORCH_CHECK(initials.scalar_type() == input.scalar_type(),
"Initials must have the same scalar type as input");
TORCH_CHECK(weights.scalar_type() == input.scalar_type(),
"Weights must have the same scalar type as input");

auto output = at::empty_like(input);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
input.scalar_type(), "scan_cpu",
[&] { scan_cpu<scalar_t>(input, weights, initials, output); });
return output;
}

TORCH_LIBRARY(torchlpc, m) {
m.def("torchlpc::scan_cpu(Tensor a, Tensor b, Tensor c) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { m.impl("scan_cpu", &scan_cpu_wrapper); }
50 changes: 36 additions & 14 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,45 @@
from numba import cuda
from typing import Tuple, Optional, Any, List

from .parallel_scan import compute_linear_recurrence
from .parallel_scan import compute_linear_recurrence, WARPSIZE
from .core import lpc_cuda, lpc_np
from . import EXTENSION_LOADED


class RecurrenceCUDA(Function):
class Recurrence(Function):
@staticmethod
def forward(
decay: torch.Tensor,
impulse: torch.Tensor,
initial_state: torch.Tensor,
) -> torch.Tensor:
n_dims, n_steps = decay.shape
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
if decay.is_cuda:
if n_dims * WARPSIZE < n_steps:
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
else:
out = lpc_cuda(impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1))
else:
num_threads = torch.get_num_threads()
# This is just a rough estimation of the computational cost
if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3:
out = torch.ops.torchlpc.scan_cpu(impulse, decay, initial_state)
else:
out = torch.from_numpy(
lpc_np(
impulse.detach().numpy(),
-decay.unsqueeze(2).detach().numpy(),
initial_state.unsqueeze(1).detach().numpy(),
)
)
return out

@staticmethod
Expand All @@ -48,7 +67,7 @@ def backward(
padded_decay = padded_decay[:, 1:]

init = padded_grad_out.new_zeros(n_dims)
flipped_grad_impulse = RecurrenceCUDA.apply(
flipped_grad_impulse = Recurrence.apply(
padded_decay.flip(1).conj_physical(),
padded_grad_out.flip(1),
init,
Expand Down Expand Up @@ -91,7 +110,7 @@ def jvp(
fwd_decay = concat_out * grad_decay
fwd_impulse = fwd_impulse + fwd_decay

return RecurrenceCUDA.apply(decay, fwd_impulse, fwd_initial_state)
return Recurrence.apply(decay, fwd_impulse, fwd_initial_state)

@staticmethod
def vmap(info, in_dims, *args):
Expand All @@ -107,5 +126,8 @@ def maybe_expand_bdim_at_front(x, x_bdim):
)
)

out = RecurrenceCUDA.apply(decay, impulse, initial_state)
out = Recurrence.apply(decay, impulse, initial_state)
return out.reshape(info.batch_size, -1, *out.shape[1:]), 0


RecurrenceCUDA = Recurrence

0 comments on commit 5801748

Please sign in to comment.