From 0e787f1bc8725ef3130738c6f0c7ac7bdaba3903 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Fri, 17 Nov 2023 14:24:32 +0100 Subject: [PATCH] Add `softmax_csr` implementation (#264) This PR adds forward and backward implementation of sparse softmax operation as defined [here](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/softmax.py#L9). In the `pytorch_geometric` implementation we cannot take advantage of model compilation when groups are defined via `ptr`. `softmax_csr` introduced here provides a well-performing kernel for such a scenario. Performance boost (achieved on 28C, single socket machine): ~7x for forward pass ~8x for backward pass Additionally, GAT training time was reduced by ~5%. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 1 + benchmark/ops/softmax.py | 68 +++++++ pyg_lib/csrc/ops/cpu/softmax_kernel.cpp | 258 ++++++++++++++++++++++++ pyg_lib/csrc/ops/softmax.cpp | 58 ++++++ pyg_lib/csrc/ops/softmax.h | 21 ++ pyg_lib/ops/__init__.py | 62 +++++- test/csrc/ops/test_softmax.cpp | 61 ++++++ test/ops/test_softmax.py | 51 +++++ 8 files changed, 577 insertions(+), 3 deletions(-) create mode 100644 benchmark/ops/softmax.py create mode 100644 pyg_lib/csrc/ops/cpu/softmax_kernel.cpp create mode 100644 pyg_lib/csrc/ops/softmax.cpp create mode 100644 pyg_lib/csrc/ops/softmax.h create mode 100644 test/csrc/ops/test_softmax.cpp create mode 100644 test/ops/test_softmax.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b7404fe5a..4876414fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.4.0] - 2023-MM-DD ### Added +- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264)) - Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280)) - Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272)) ### Changed diff --git a/benchmark/ops/softmax.py b/benchmark/ops/softmax.py new file mode 100644 index 000000000..a438f397c --- /dev/null +++ b/benchmark/ops/softmax.py @@ -0,0 +1,68 @@ +import argparse +from time import perf_counter as timestamp + +import torch +from torch_geometric.utils import segment + +import pyg_lib + + +def softmax_reference_ptr(src, ptr, dim=0): + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + count = ptr[1:] - ptr[:-1] + ptr = ptr.view(size) + src_max = segment(src.detach(), ptr, reduce='max') + src_max = src_max.repeat_interleave(count, dim=dim) + out = (src - src_max).exp() + out_sum = segment(out, ptr, reduce='sum') + 1e-16 + out_sum = out_sum.repeat_interleave(count, dim=dim) + + return out / out_sum + + +def measure_perf(impl_func, ptr, out_grad, num_warmups, num_steps, backward): + t_fwd = t_bwd = 0 + for i in range(num_warmups + num_steps): + src = torch.randn(num_rows, num_heads) + src.requires_grad = backward + + t_start = timestamp() + out = impl_func(src=src, ptr=ptr) + if i >= num_warmups: + t_fwd += timestamp() - t_start + + if backward: + t_start = timestamp() + out.backward(out_grad) + if i >= num_warmups: + t_bwd += timestamp() - t_start + + return t_fwd, t_bwd + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--backward', action='store_true') + parser.add_argument('--num-heads', type=int, default=4) + args = parser.parse_args() + + num_rows, num_heads = 50000, args.num_heads + num_warmups, num_steps = 100, 500 + group_size = 100 + + ptr = torch.arange(0, num_rows + 1, group_size) + out_grad = torch.randn(num_rows, num_heads) + + func_args = [ptr, out_grad, num_warmups, num_steps, args.backward] + + t_fwd, t_bwd = measure_perf(softmax_reference_ptr, *func_args) + print(f'Vanilla forward: {t_fwd:.4f}s') + if args.backward: + print(f'Vanilla backward: {t_bwd:.4f}s') + print('=========================') + + t_fwd, t_bwd = measure_perf(pyg_lib.ops.softmax_csr, *func_args) + print(f'pyg_lib forward: {t_fwd:.4f}s') + if args.backward: + print(f'pyg_lib backward: {t_bwd:.4f}s') diff --git a/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp b/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp new file mode 100644 index 000000000..88575af13 --- /dev/null +++ b/pyg_lib/csrc/ops/cpu/softmax_kernel.cpp @@ -0,0 +1,258 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace pyg { +namespace ops { + +namespace { + +std::vector create_per_thread_groups(const int64_t* groups_ptr, + const int64_t n_groups, + const int64_t dim_size) { + std::vector new_groups = {0}; + const auto avg_work_per_thread = at::divup(dim_size, at::get_num_threads()); + int64_t cur_work = 0; + for (int64_t i = 0; i < n_groups; ++i) { + cur_work += groups_ptr[i + 1] - groups_ptr[i]; + if (cur_work >= avg_work_per_thread) { + new_groups.push_back(i + 1); + cur_work = 0; + } + } + new_groups.push_back(n_groups); + + return new_groups; +} + +std::pair, std::vector> +precompute_data_access_patterns(const int64_t outer_range, + const int64_t inner_range, + const int64_t global_dim_size, + const int64_t dim_stride) { + std::vector data_ids(outer_range * inner_range); + std::vector aux_ids(outer_range * inner_range); + for (int64_t i = 0; i < outer_range; ++i) { + const auto contiguous_offset = i * global_dim_size * dim_stride; + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = j + contiguous_offset; + const auto aux_id = k % dim_stride + (dim_stride * (k / inner_range)); + data_ids[k] = data_id; + aux_ids[k] = aux_id; + } + } + + return {std::move(data_ids), std::move(aux_ids)}; +} + +at::Tensor softmax_csr_forward_kernel_impl(const at::Tensor& src, + const at::Tensor& groups, + const int64_t dim) { + auto out = at::zeros_like(src); + + AT_DISPATCH_FLOATING_TYPES( + src.scalar_type(), "softmax_csr_forward_kernel_impl", [&] { + const auto n_groups = groups.size(0) - 1; + const auto n_heads = src.numel() / src.size(dim); + auto max = at::full({n_groups, n_heads}, + std::numeric_limits::lowest()); + auto sum = at::zeros({n_groups, n_heads}); + const auto groups_ptr = groups.data_ptr(); + const auto src_base_ptr = src.data_ptr(); + auto out_base_ptr = out.data_ptr(); + auto max_base_ptr = max.data_ptr(); + auto sum_base_ptr = sum.data_ptr(); + const auto global_dim_size = src.size(dim); + const auto new_groups = std::move( + create_per_thread_groups(groups_ptr, n_groups, global_dim_size)); + + at::parallel_for( + 0, new_groups.size() - 1, 1, [&](int64_t beg, int64_t end) { + // each thread may cover several groups + for (auto group_id = new_groups[beg]; group_id < new_groups[end]; + ++group_id) { + const auto dim_beg = groups_ptr[group_id]; + const auto dim_end = groups_ptr[group_id + 1]; + const auto local_dim_size = dim_end - dim_beg; + const auto dim_stride = src.stride(dim); + // outer_range says how many data jumps we need to make + const auto outer_range = [&src, dim]() { + int64_t range = 1; + for (int64_t i = 0; i < dim; ++i) + range *= src.size(i); + return range; + }(); + // inner_range says how many contigous elements we can visit + const auto inner_range = local_dim_size * dim_stride; + const auto inout_offset = dim_beg * dim_stride; + const auto aux_offset = group_id * n_heads; + + const auto src_ptr = src_base_ptr + inout_offset; + auto out_ptr = out_base_ptr + inout_offset; + auto max_ptr = max_base_ptr + aux_offset; + auto sum_ptr = sum_base_ptr + aux_offset; + + const auto indices = precompute_data_access_patterns( + outer_range, inner_range, global_dim_size, dim_stride); + const auto& data_ids = indices.first; + const auto& aux_ids = indices.second; + + if (local_dim_size == 1) { + for (int64_t i = 0; i < outer_range; ++i) { + const auto k = i * inner_range; + const auto data_id = data_ids[k]; + std::fill(out_ptr + data_id, + out_ptr + data_id + inner_range, + static_cast(1.0)); + } + } else { + // calculate max + for (int64_t i = 0; i < outer_range; ++i) { + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = data_ids[k]; + const auto aux_id = aux_ids[k]; + max_ptr[aux_id] = + std::max(max_ptr[aux_id], src_ptr[data_id]); + } + } + + // calculate sum + for (int64_t i = 0; i < outer_range; ++i) { + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = data_ids[k]; + const auto aux_id = aux_ids[k]; + const auto value = + std::exp(src_ptr[data_id] - max_ptr[aux_id]); + sum_ptr[aux_id] += value; + out_ptr[data_id] = value; + } + } + + // unify + for (int64_t i = 0; i < outer_range; ++i) { + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = data_ids[k]; + const auto aux_id = aux_ids[k]; + out_ptr[data_id] /= sum_ptr[aux_id]; + } + } + } + } + }); + }); + + return out; +} + +at::Tensor softmax_csr_backward_kernel_impl(const at::Tensor& out, + const at::Tensor& out_grad, + const at::Tensor& groups, + const int64_t dim) { + auto in_grad = at::zeros_like(out); + + AT_DISPATCH_FLOATING_TYPES( + out.scalar_type(), "softmax_csr_backward_kernel_impl", [&] { + const auto n_groups = groups.size(0) - 1; + const auto n_heads = out.numel() / out.size(dim); + auto sum = at::zeros({n_groups, n_heads}); + const auto groups_ptr = groups.data_ptr(); + const auto out_base_ptr = out.data_ptr(); + const auto out_grad_base_ptr = out_grad.data_ptr(); + auto in_grad_base_ptr = in_grad.data_ptr(); + auto sum_base_ptr = sum.data_ptr(); + const auto global_dim_size = out.size(dim); + const auto new_groups = std::move( + create_per_thread_groups(groups_ptr, n_groups, global_dim_size)); + + at::parallel_for( + 0, new_groups.size() - 1, 1, [&](int64_t beg, int64_t end) { + for (auto group_id = new_groups[beg]; group_id < new_groups[end]; + ++group_id) { + const auto dim_beg = groups_ptr[group_id]; + const auto dim_end = groups_ptr[group_id + 1]; + const auto local_dim_size = dim_end - dim_beg; + const auto dim_stride = out.stride(dim); + // outer_range says how many data jumps we need to make + const auto outer_range = [&out, dim]() { + int64_t range = 1; + for (int64_t i = 0; i < dim; ++i) + range *= out.size(i); + return range; + }(); + // inner_range says how many contigous elements we can visit + const auto inner_range = local_dim_size * dim_stride; + const auto inout_offset = dim_beg * dim_stride; + const auto sum_offset = group_id * n_heads; + + const auto out_ptr = out_base_ptr + inout_offset; + const auto out_grad_ptr = out_grad_base_ptr + inout_offset; + auto in_grad_ptr = in_grad_base_ptr + inout_offset; + auto sum_ptr = sum_base_ptr + sum_offset; + + const auto indices = precompute_data_access_patterns( + outer_range, inner_range, global_dim_size, dim_stride); + const auto& data_ids = indices.first; + const auto& aux_ids = indices.second; + + // calculate sum of out * out_grad + for (int64_t i = 0; i < outer_range; ++i) { + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = data_ids[k]; + const auto aux_id = aux_ids[k]; + sum_ptr[aux_id] += out_ptr[data_id] * out_grad_ptr[data_id]; + } + } + + // calculate out * (out_grad - sum) + for (int64_t i = 0; i < outer_range; ++i) { + for (int64_t j = 0; j < inner_range; ++j) { + const auto k = i * inner_range + j; + const auto data_id = data_ids[k]; + const auto aux_id = aux_ids[k]; + in_grad_ptr[data_id] = + out_ptr[data_id] * + (out_grad_ptr[data_id] - sum_ptr[aux_id]); + } + } + } + }); + }); + + return in_grad; +} + +at::Tensor softmax_csr_forward_kernel(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim) { + return softmax_csr_forward_kernel_impl(src, ptr, dim); +} + +at::Tensor softmax_csr_backward_kernel(const at::Tensor& out, + const at::Tensor& out_grad, + const at::Tensor& ptr, + const int64_t dim) { + return softmax_csr_backward_kernel_impl(out, out_grad, ptr, dim); +} + +} // namespace + +TORCH_LIBRARY_IMPL(pyg, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"), + TORCH_FN(softmax_csr_forward_kernel)); + m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"), + TORCH_FN(softmax_csr_backward_kernel)); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/ops/softmax.cpp b/pyg_lib/csrc/ops/softmax.cpp new file mode 100644 index 000000000..92512cfae --- /dev/null +++ b/pyg_lib/csrc/ops/softmax.cpp @@ -0,0 +1,58 @@ +#include "softmax.h" + +#include +#include + +namespace pyg { +namespace ops { + +// Performs softmax operations for each group. +PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim) { + at::TensorArg src_arg{src, "src", 0}; + at::TensorArg ptr_arg{ptr, "ptr", 1}; + at::CheckedFrom c{"softmax_forward"}; + + at::checkAllDefined(c, {src_arg, ptr_arg}); + at::checkContiguous(c, src_arg); + at::checkContiguous(c, ptr_arg); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::softmax_csr_forward", "") + .typed(); + return op.call(src, ptr, dim); +} + +// Computes gradient for grouped softmax operation. +PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out, + const at::Tensor& out_grad, + const at::Tensor& ptr, + const int64_t dim) { + at::TensorArg out_arg{out, "out", 0}; + at::TensorArg out_grad_arg{out_grad, "out_grad", 1}; + at::TensorArg ptr_arg{ptr, "ptr", 2}; + at::CheckedFrom c{"softmax_backward"}; + + at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg}); + at::checkContiguous(c, out_arg); + at::checkContiguous(c, out_grad_arg); + at::checkContiguous(c, ptr_arg); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::softmax_csr_backward", "") + .typed(); + return op.call(out, out_grad, ptr, dim); +} + +TORCH_LIBRARY_FRAGMENT(pyg, m) { + m.def( + TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, " + "int dim=0) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::softmax_csr_backward(Tensor out, Tensor out_grad, " + "Tensor ptr, int dim=0) -> Tensor")); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/ops/softmax.h b/pyg_lib/csrc/ops/softmax.h new file mode 100644 index 000000000..f381ae825 --- /dev/null +++ b/pyg_lib/csrc/ops/softmax.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include "pyg_lib/csrc/macros.h" + +namespace pyg { +namespace ops { + +// Performs softmax operations for each group. +PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim = 0); + +// Computes gradient for grouped softmax operations. +PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out, + const at::Tensor& out_grad, + const at::Tensor& ptr, + const int64_t dim = 0); + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index 7d8404d9e..81b66c3f1 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -1,9 +1,8 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch -from torch import Tensor - import torch.utils._pytree as pytree +from torch import Tensor def pytreeify(cls): @@ -332,6 +331,62 @@ def index_sort( return torch.ops.pyg.index_sort(inputs, max_value) +class Softmax(torch.autograd.Function): + @staticmethod + def forward( + ctx, + src: Tensor, + ptr: Tensor, + dim: int = 0, + ) -> Tensor: + out = torch.ops.pyg.softmax_csr_forward(src, ptr, dim) + ctx.save_for_backward(out, ptr) + ctx.dim = dim + + return out + + @staticmethod + def backward(ctx, out_grad: Tensor) -> Tuple[Union[Tensor, int]]: + out, ptr = ctx.saved_tensors + in_grad = torch.ops.pyg.softmax_csr_backward(out, out_grad, ptr, + ctx.dim) + + return in_grad, None, None + + +def softmax_csr( + src: Tensor, + ptr: Tensor, + dim: int = 0, +) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the given dimension :attr:`dim`, based on the indices specified via + :attr:`ptr`, and then proceeds to compute the softmax individually for + each group. + + Args: + src (Tensor): The source tensor. + ptr (LongTensor): Groups defined by CSR representation. + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + + :rtype: :class:`Tensor` + + Examples: + + >>> src = torch.randn(4, 4) + >>> ptr = torch.tensor([0, 4]) + >>> softmax(src, ptr) + tensor([[0.0157, 0.0984, 0.1250, 0.4523], + [0.1453, 0.2591, 0.5907, 0.2410], + [0.0598, 0.2923, 0.1206, 0.0921], + [0.7792, 0.3502, 0.1638, 0.2145]]) + """ + dim = dim + src.dim() if dim < 0 else dim + return Softmax.apply(src, ptr, dim) + + __all__ = [ 'grouped_matmul', 'segment_matmul', @@ -340,4 +395,5 @@ def index_sort( 'sampled_mul', 'sampled_div', 'index_sort', + 'softmax_csr', ] diff --git a/test/csrc/ops/test_softmax.cpp b/test/csrc/ops/test_softmax.cpp new file mode 100644 index 000000000..476cd51ee --- /dev/null +++ b/test/csrc/ops/test_softmax.cpp @@ -0,0 +1,61 @@ +#include +#include + +#include + +#include "pyg_lib/csrc/ops/softmax.h" + +using namespace at::indexing; + +at::Tensor softmax2D_ref_impl(const at::Tensor& src, + const at::Tensor& ptr, + const int64_t dim) { + auto out = at::zeros_like(src); + + for (int64_t i = 0; i < src.size(1 - dim); ++i) { + for (int64_t j = 0; j < ptr.size(0) - 1; ++j) { + const auto beg = ptr[j].item(); + const auto end = ptr[j + 1].item(); + const auto row_slice = (dim == 0) ? Slice(beg, end) : Slice(i, i + 1); + const auto col_slice = (dim == 0) ? Slice(i, i + 1) : Slice(beg, end); + out.index_put_({row_slice, col_slice}, + src.index({row_slice, col_slice}).softmax(dim)); + } + } + + return out; +} + +class CPUTest : public testing::TestWithParam {}; + +TEST_P(CPUTest, SoftmaxCSRForward) { + const auto dim = ::testing::TestWithParam::GetParam(); + const auto src = at::rand({8, 8}); + const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong); + const auto expected_out = softmax2D_ref_impl(src, ptr, dim); + + const auto out = pyg::ops::softmax_csr_forward(src, ptr, dim); + EXPECT_EQ(expected_out.size(0), out.size(0)); + EXPECT_EQ(expected_out.size(1), out.size(1)); + EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04)); +} + +TEST_P(CPUTest, SoftmaxCSRBackward) { + const auto dim = ::testing::TestWithParam::GetParam(); + const auto src = at::rand({8, 8}); + src.set_requires_grad(true); + const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong); + const auto out = softmax2D_ref_impl(src, ptr, dim); + const auto out_grad = at::rand({8, 8}); + + const auto in_grad = pyg::ops::softmax_csr_backward(out, out_grad, ptr, dim); + out.backward(out_grad); + EXPECT_EQ(src.grad().size(0), in_grad.size(0)); + EXPECT_EQ(src.grad().size(1), in_grad.size(1)); + EXPECT_TRUE(at::allclose(src.grad(), in_grad, 1e-04, 1e-04)); +} + +INSTANTIATE_TEST_SUITE_P(OpsTest, + CPUTest, + // dim + testing::Values(0, 1)); diff --git a/test/ops/test_softmax.py b/test/ops/test_softmax.py new file mode 100644 index 000000000..f2aa97f12 --- /dev/null +++ b/test/ops/test_softmax.py @@ -0,0 +1,51 @@ +import pytest +import torch + +import pyg_lib + + +def ptr2index(ptr): + group_sizes = ptr[1:] - ptr[:-1] + return torch.repeat_interleave( + torch.arange(0, group_sizes.numel(), dtype=group_sizes.dtype), + group_sizes) + + +def broadcast(src, ref, dim): + size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1)) + return src.view(size).expand_as(ref) + + +def softmax_reference(src, ptr, dim): + index = ptr2index(ptr) + N = int(index.max()) + 1 if index.numel() > 0 else 0 + size = src.size()[:dim] + (N, ) + src.size()[dim + 1:] + src_max = src.detach().new_zeros(size).scatter_reduce_( + dim, broadcast(index, src, dim), src, reduce='amax', + include_self=False) + out = src - src_max.index_select(dim, index) + out = out.exp() + out_sum = out.new_zeros(size).scatter_add_(dim, broadcast(index, out, dim), + out) + out_sum = out_sum.index_select(dim, index) + + return out / out_sum + + +@pytest.mark.parametrize('dim', list(range(3))) +def test_softmax_csr_autograd(dim): + sizes = (16, 32, 64) + src1 = torch.rand(sizes, requires_grad=True) + src2 = src1.detach().clone() + src2.requires_grad = True + dim_size = sizes[dim] + ptr = torch.tensor([0, 1, 4, 5, dim_size - 1, dim_size]) + out_grad = torch.randn(sizes) + + expected_out = softmax_reference(src1, ptr, dim) + out = pyg_lib.ops.softmax_csr(src2, ptr, dim) + assert torch.allclose(expected_out, out, atol=1e-6) + + expected_out.backward(out_grad) + out.backward(out_grad) + assert torch.allclose(src1.grad, src2.grad, atol=1e-6)