Skip to content

Commit

Permalink
Add softmax_csr implementation (pyg-team#264)
Browse files Browse the repository at this point in the history
This PR adds forward and backward implementation of sparse softmax
operation as defined
[here](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/softmax.py#L9).

In the `pytorch_geometric` implementation we cannot take advantage of
model compilation when groups are defined via `ptr`. `softmax_csr`
introduced here provides a well-performing kernel for such a scenario.

Performance boost (achieved on 28C, single socket machine):
~7x for forward pass
~8x for backward pass
Additionally, GAT training time was reduced by ~5%.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
DamianSzwichtenberg and pre-commit-ci[bot] committed Nov 17, 2023
1 parent 2b9af1c commit 0e787f1
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264))
- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280))
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
Expand Down
68 changes: 68 additions & 0 deletions benchmark/ops/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import argparse
from time import perf_counter as timestamp

import torch
from torch_geometric.utils import segment

import pyg_lib


def softmax_reference_ptr(src, ptr, dim=0):
dim = dim + src.dim() if dim < 0 else dim
size = ([1] * dim) + [-1]
count = ptr[1:] - ptr[:-1]
ptr = ptr.view(size)
src_max = segment(src.detach(), ptr, reduce='max')
src_max = src_max.repeat_interleave(count, dim=dim)
out = (src - src_max).exp()
out_sum = segment(out, ptr, reduce='sum') + 1e-16
out_sum = out_sum.repeat_interleave(count, dim=dim)

return out / out_sum


def measure_perf(impl_func, ptr, out_grad, num_warmups, num_steps, backward):
t_fwd = t_bwd = 0
for i in range(num_warmups + num_steps):
src = torch.randn(num_rows, num_heads)
src.requires_grad = backward

t_start = timestamp()
out = impl_func(src=src, ptr=ptr)
if i >= num_warmups:
t_fwd += timestamp() - t_start

if backward:
t_start = timestamp()
out.backward(out_grad)
if i >= num_warmups:
t_bwd += timestamp() - t_start

return t_fwd, t_bwd


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--backward', action='store_true')
parser.add_argument('--num-heads', type=int, default=4)
args = parser.parse_args()

num_rows, num_heads = 50000, args.num_heads
num_warmups, num_steps = 100, 500
group_size = 100

ptr = torch.arange(0, num_rows + 1, group_size)
out_grad = torch.randn(num_rows, num_heads)

func_args = [ptr, out_grad, num_warmups, num_steps, args.backward]

t_fwd, t_bwd = measure_perf(softmax_reference_ptr, *func_args)
print(f'Vanilla forward: {t_fwd:.4f}s')
if args.backward:
print(f'Vanilla backward: {t_bwd:.4f}s')
print('=========================')

t_fwd, t_bwd = measure_perf(pyg_lib.ops.softmax_csr, *func_args)
print(f'pyg_lib forward: {t_fwd:.4f}s')
if args.backward:
print(f'pyg_lib backward: {t_bwd:.4f}s')
258 changes: 258 additions & 0 deletions pyg_lib/csrc/ops/cpu/softmax_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#include <algorithm>
#include <cmath>
#include <limits>
#include <utility>
#include <vector>

namespace pyg {
namespace ops {

namespace {

std::vector<int64_t> create_per_thread_groups(const int64_t* groups_ptr,
const int64_t n_groups,
const int64_t dim_size) {
std::vector<int64_t> new_groups = {0};
const auto avg_work_per_thread = at::divup(dim_size, at::get_num_threads());
int64_t cur_work = 0;
for (int64_t i = 0; i < n_groups; ++i) {
cur_work += groups_ptr[i + 1] - groups_ptr[i];
if (cur_work >= avg_work_per_thread) {
new_groups.push_back(i + 1);
cur_work = 0;
}
}
new_groups.push_back(n_groups);

return new_groups;
}

std::pair<std::vector<int64_t>, std::vector<int64_t>>
precompute_data_access_patterns(const int64_t outer_range,
const int64_t inner_range,
const int64_t global_dim_size,
const int64_t dim_stride) {
std::vector<int64_t> data_ids(outer_range * inner_range);
std::vector<int64_t> aux_ids(outer_range * inner_range);
for (int64_t i = 0; i < outer_range; ++i) {
const auto contiguous_offset = i * global_dim_size * dim_stride;
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = j + contiguous_offset;
const auto aux_id = k % dim_stride + (dim_stride * (k / inner_range));
data_ids[k] = data_id;
aux_ids[k] = aux_id;
}
}

return {std::move(data_ids), std::move(aux_ids)};
}

at::Tensor softmax_csr_forward_kernel_impl(const at::Tensor& src,
const at::Tensor& groups,
const int64_t dim) {
auto out = at::zeros_like(src);

AT_DISPATCH_FLOATING_TYPES(
src.scalar_type(), "softmax_csr_forward_kernel_impl", [&] {
const auto n_groups = groups.size(0) - 1;
const auto n_heads = src.numel() / src.size(dim);
auto max = at::full({n_groups, n_heads},
std::numeric_limits<scalar_t>::lowest());
auto sum = at::zeros({n_groups, n_heads});
const auto groups_ptr = groups.data_ptr<int64_t>();
const auto src_base_ptr = src.data_ptr<scalar_t>();
auto out_base_ptr = out.data_ptr<scalar_t>();
auto max_base_ptr = max.data_ptr<scalar_t>();
auto sum_base_ptr = sum.data_ptr<scalar_t>();
const auto global_dim_size = src.size(dim);
const auto new_groups = std::move(
create_per_thread_groups(groups_ptr, n_groups, global_dim_size));

at::parallel_for(
0, new_groups.size() - 1, 1, [&](int64_t beg, int64_t end) {
// each thread may cover several groups
for (auto group_id = new_groups[beg]; group_id < new_groups[end];
++group_id) {
const auto dim_beg = groups_ptr[group_id];
const auto dim_end = groups_ptr[group_id + 1];
const auto local_dim_size = dim_end - dim_beg;
const auto dim_stride = src.stride(dim);
// outer_range says how many data jumps we need to make
const auto outer_range = [&src, dim]() {
int64_t range = 1;
for (int64_t i = 0; i < dim; ++i)
range *= src.size(i);
return range;
}();
// inner_range says how many contigous elements we can visit
const auto inner_range = local_dim_size * dim_stride;
const auto inout_offset = dim_beg * dim_stride;
const auto aux_offset = group_id * n_heads;

const auto src_ptr = src_base_ptr + inout_offset;
auto out_ptr = out_base_ptr + inout_offset;
auto max_ptr = max_base_ptr + aux_offset;
auto sum_ptr = sum_base_ptr + aux_offset;

const auto indices = precompute_data_access_patterns(
outer_range, inner_range, global_dim_size, dim_stride);
const auto& data_ids = indices.first;
const auto& aux_ids = indices.second;

if (local_dim_size == 1) {
for (int64_t i = 0; i < outer_range; ++i) {
const auto k = i * inner_range;
const auto data_id = data_ids[k];
std::fill(out_ptr + data_id,
out_ptr + data_id + inner_range,
static_cast<scalar_t>(1.0));
}
} else {
// calculate max
for (int64_t i = 0; i < outer_range; ++i) {
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = data_ids[k];
const auto aux_id = aux_ids[k];
max_ptr[aux_id] =
std::max(max_ptr[aux_id], src_ptr[data_id]);
}
}

// calculate sum
for (int64_t i = 0; i < outer_range; ++i) {
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = data_ids[k];
const auto aux_id = aux_ids[k];
const auto value =
std::exp(src_ptr[data_id] - max_ptr[aux_id]);
sum_ptr[aux_id] += value;
out_ptr[data_id] = value;
}
}

// unify
for (int64_t i = 0; i < outer_range; ++i) {
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = data_ids[k];
const auto aux_id = aux_ids[k];
out_ptr[data_id] /= sum_ptr[aux_id];
}
}
}
}
});
});

return out;
}

at::Tensor softmax_csr_backward_kernel_impl(const at::Tensor& out,
const at::Tensor& out_grad,
const at::Tensor& groups,
const int64_t dim) {
auto in_grad = at::zeros_like(out);

AT_DISPATCH_FLOATING_TYPES(
out.scalar_type(), "softmax_csr_backward_kernel_impl", [&] {
const auto n_groups = groups.size(0) - 1;
const auto n_heads = out.numel() / out.size(dim);
auto sum = at::zeros({n_groups, n_heads});
const auto groups_ptr = groups.data_ptr<int64_t>();
const auto out_base_ptr = out.data_ptr<scalar_t>();
const auto out_grad_base_ptr = out_grad.data_ptr<scalar_t>();
auto in_grad_base_ptr = in_grad.data_ptr<scalar_t>();
auto sum_base_ptr = sum.data_ptr<scalar_t>();
const auto global_dim_size = out.size(dim);
const auto new_groups = std::move(
create_per_thread_groups(groups_ptr, n_groups, global_dim_size));

at::parallel_for(
0, new_groups.size() - 1, 1, [&](int64_t beg, int64_t end) {
for (auto group_id = new_groups[beg]; group_id < new_groups[end];
++group_id) {
const auto dim_beg = groups_ptr[group_id];
const auto dim_end = groups_ptr[group_id + 1];
const auto local_dim_size = dim_end - dim_beg;
const auto dim_stride = out.stride(dim);
// outer_range says how many data jumps we need to make
const auto outer_range = [&out, dim]() {
int64_t range = 1;
for (int64_t i = 0; i < dim; ++i)
range *= out.size(i);
return range;
}();
// inner_range says how many contigous elements we can visit
const auto inner_range = local_dim_size * dim_stride;
const auto inout_offset = dim_beg * dim_stride;
const auto sum_offset = group_id * n_heads;

const auto out_ptr = out_base_ptr + inout_offset;
const auto out_grad_ptr = out_grad_base_ptr + inout_offset;
auto in_grad_ptr = in_grad_base_ptr + inout_offset;
auto sum_ptr = sum_base_ptr + sum_offset;

const auto indices = precompute_data_access_patterns(
outer_range, inner_range, global_dim_size, dim_stride);
const auto& data_ids = indices.first;
const auto& aux_ids = indices.second;

// calculate sum of out * out_grad
for (int64_t i = 0; i < outer_range; ++i) {
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = data_ids[k];
const auto aux_id = aux_ids[k];
sum_ptr[aux_id] += out_ptr[data_id] * out_grad_ptr[data_id];
}
}

// calculate out * (out_grad - sum)
for (int64_t i = 0; i < outer_range; ++i) {
for (int64_t j = 0; j < inner_range; ++j) {
const auto k = i * inner_range + j;
const auto data_id = data_ids[k];
const auto aux_id = aux_ids[k];
in_grad_ptr[data_id] =
out_ptr[data_id] *
(out_grad_ptr[data_id] - sum_ptr[aux_id]);
}
}
}
});
});

return in_grad;
}

at::Tensor softmax_csr_forward_kernel(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
return softmax_csr_forward_kernel_impl(src, ptr, dim);
}

at::Tensor softmax_csr_backward_kernel(const at::Tensor& out,
const at::Tensor& out_grad,
const at::Tensor& ptr,
const int64_t dim) {
return softmax_csr_backward_kernel_impl(out, out_grad, ptr, dim);
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
TORCH_FN(softmax_csr_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"),
TORCH_FN(softmax_csr_backward_kernel));
}

} // namespace ops
} // namespace pyg
58 changes: 58 additions & 0 deletions pyg_lib/csrc/ops/softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "softmax.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
at::TensorArg src_arg{src, "src", 0};
at::TensorArg ptr_arg{ptr, "ptr", 1};
at::CheckedFrom c{"softmax_forward"};

at::checkAllDefined(c, {src_arg, ptr_arg});
at::checkContiguous(c, src_arg);
at::checkContiguous(c, ptr_arg);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::softmax_csr_forward", "")
.typed<decltype(softmax_csr_forward)>();
return op.call(src, ptr, dim);
}

// Computes gradient for grouped softmax operation.
PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
const at::Tensor& out_grad,
const at::Tensor& ptr,
const int64_t dim) {
at::TensorArg out_arg{out, "out", 0};
at::TensorArg out_grad_arg{out_grad, "out_grad", 1};
at::TensorArg ptr_arg{ptr, "ptr", 2};
at::CheckedFrom c{"softmax_backward"};

at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg});
at::checkContiguous(c, out_arg);
at::checkContiguous(c, out_grad_arg);
at::checkContiguous(c, ptr_arg);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::softmax_csr_backward", "")
.typed<decltype(softmax_csr_backward)>();
return op.call(out, out_grad, ptr, dim);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, "
"int dim=0) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::softmax_csr_backward(Tensor out, Tensor out_grad, "
"Tensor ptr, int dim=0) -> Tensor"));
}

} // namespace ops
} // namespace pyg
Loading

0 comments on commit 0e787f1

Please sign in to comment.