Skip to content

Commit

Permalink
Redefine FBGEMM targets with gpu_cpp_library [18/N] (#3190)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3190

X-link: facebookresearch/FBGEMM#285

- Redefine sparse ops targets using `gpu_cpp_library`

Reviewed By: spcyppt

Differential Revision: D63424455

fbshipit-source-id: 928d92da85f5ae69c8d1f0ee2479ada66bf26380
  • Loading branch information
q10 authored and facebook-github-bot committed Sep 30, 2024
1 parent 93dcc07 commit 4d16892
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 137 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ set(fbgemm_gpu_sources_static_cpu
src/layout_transform_ops/layout_transform_ops_cpu.cpp
src/quantize_ops/quantize_ops_cpu.cpp
src/quantize_ops/quantize_ops_meta.cpp
src/sparse_ops/sparse_async_cumsum.cpp
src/sparse_ops/sparse_ops_cpu.cpp
src/sparse_ops/sparse_ops_meta.cpp
src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp
Expand Down
28 changes: 28 additions & 0 deletions fbgemm_gpu/src/sparse_ops/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>

using Tensor = at::Tensor;

namespace fbgemm_gpu {

namespace {
inline Tensor native_empty_like(const Tensor& self) {
return at::native::empty_like(
self,
c10::optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt(),
c10::nullopt);
}

} // namespace

}; // namespace fbgemm_gpu
151 changes: 151 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_async_cumsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/Parallel.h>

#include "common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
#include "fbgemm_gpu/utils/tensor_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {

// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3]
// Used as a helper to several functions below.
template <class T, class U>
U exclusive_scan_ptrs_cpu(
const int64_t N,
const T* const input,
U* const output) {
U cumsum = 0;
for (const auto i : c10::irange(N)) {
output[i] = cumsum;
cumsum += input[i];
}
return cumsum;
}

void asynchronous_exclusive_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TENSOR_ON_CPU(t_out);

const auto t_in_contig = t_in.expect_contiguous();
at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt);

FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_exclusive_cumsum_cpu_kernel",
[&] {
exclusive_scan_ptrs_cpu(
t_in_contig->numel(),
t_in_contig->data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>());
});
}

Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig);
return output;
}

Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_inclusive_cumsum_cpu_kernel",
[&] {
scalar_t cumsum = 0;
const auto* input_ptr = t_in_contig->data_ptr<scalar_t>();
const auto N = t_in_contig->numel();
auto* output_ptr = output.data_ptr<scalar_t>();

for (const auto i : c10::irange(N)) {
cumsum += input_ptr[i];
output_ptr[i] = cumsum;
}
});
return output;
}

Tensor asynchronous_complete_cumsum_cpu_out(Tensor& t_out, const Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TENSOR_ON_CPU(t_out);
const auto num_dims = t_in.dim();
TORCH_CHECK(num_dims == 1 || num_dims == 2);
const auto t_in_contig = t_in.expect_contiguous();
const auto t_out_contig = t_out.expect_contiguous();

FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_complete_cumsum_cpu_kernel",
[&] {
if (num_dims == 1) {
const auto N = t_in_contig->numel();
t_out.data_ptr<scalar_t>()[N] = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>(), t_out.data_ptr<scalar_t>());
} else {
const auto num_vecs = t_in_contig->size(0);
const auto N = t_in_contig->size(1);
at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
scalar_t* out_ptr = t_out.data_ptr<scalar_t>() + i * (N + 1);
out_ptr[N] = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>() + i * N, out_ptr);
}
});
}
});
return t_out;
}

Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) {
const auto num_dims = t_in.dim();
TORCH_CHECK(num_dims == 1 || num_dims == 2);
auto output = num_dims == 1
? at::empty({t_in.numel() + 1}, t_in.options())
: at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options());

return asynchronous_complete_cumsum_cpu_out(output, t_in);
}

} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_complete_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"asynchronous_exclusive_cumsum",
fbgemm_gpu::asynchronous_exclusive_cumsum_cpu);
DISPATCH_TO_CPU(
"asynchronous_inclusive_cumsum",
fbgemm_gpu::asynchronous_inclusive_cumsum_cpu);
DISPATCH_TO_CPU(
"asynchronous_complete_cumsum",
fbgemm_gpu::asynchronous_complete_cumsum_cpu);
}
138 changes: 1 addition & 137 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>

#include "common.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/ops_utils.h"
Expand Down Expand Up @@ -128,16 +129,6 @@ Tensor pack_segments_autograd(
return PackSegments::apply(t_in, lengths, max_length)[0];
}

Tensor native_empty_like(const Tensor& self) {
return at::native::empty_like(
self,
c10::optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt(),
c10::nullopt);
}

template <typename T>
void prefix_sum(const int length, const T* const array, T* const presum) {
presum[0] = 0;
Expand Down Expand Up @@ -1317,115 +1308,6 @@ bucketize_sparse_features_cpu(
return {new_lengths, new_indices, new_weights, new_pos};
}

// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3]
// Used as a helper to several functions below.
template <class T, class U>
U exclusive_scan_ptrs_cpu(
const int64_t N,
const T* const input,
U* const output) {
U cumsum = 0;
for (const auto i : c10::irange(N)) {
output[i] = cumsum;
cumsum += input[i];
}
return cumsum;
}

void asynchronous_exclusive_cumsum_cpu_out(
at::Tensor& t_out,
const Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TENSOR_ON_CPU(t_out);

const auto t_in_contig = t_in.expect_contiguous();
at::native::resize_(t_out, t_in_contig->sizes(), c10::nullopt);

FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_exclusive_cumsum_cpu_kernel",
[&] {
exclusive_scan_ptrs_cpu(
t_in_contig->numel(),
t_in_contig->data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>());
});
}

Tensor asynchronous_exclusive_cumsum_cpu(const Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
asynchronous_exclusive_cumsum_cpu_out(output, *t_in_contig);
return output;
}

Tensor asynchronous_inclusive_cumsum_cpu(const Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_inclusive_cumsum_cpu_kernel",
[&] {
scalar_t cumsum = 0;
const auto* input_ptr = t_in_contig->data_ptr<scalar_t>();
const auto N = t_in_contig->numel();
auto* output_ptr = output.data_ptr<scalar_t>();

for (const auto i : c10::irange(N)) {
cumsum += input_ptr[i];
output_ptr[i] = cumsum;
}
});
return output;
}

at::Tensor asynchronous_complete_cumsum_cpu_out(
at::Tensor& t_out,
const at::Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TENSOR_ON_CPU(t_out);
const auto num_dims = t_in.dim();
TORCH_CHECK(num_dims == 1 || num_dims == 2);
const auto t_in_contig = t_in.expect_contiguous();
const auto t_out_contig = t_out.expect_contiguous();

FBGEMM_DISPATCH_ALL_TYPES(
t_in_contig->scalar_type(),
"asynchronous_complete_cumsum_cpu_kernel",
[&] {
if (num_dims == 1) {
const auto N = t_in_contig->numel();
t_out.data_ptr<scalar_t>()[N] = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>(), t_out.data_ptr<scalar_t>());
} else {
const auto num_vecs = t_in_contig->size(0);
const auto N = t_in_contig->size(1);
at::parallel_for(0, num_vecs, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
scalar_t* out_ptr = t_out.data_ptr<scalar_t>() + i * (N + 1);
out_ptr[N] = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>() + i * N, out_ptr);
}
});
}
});
return t_out;
}

Tensor asynchronous_complete_cumsum_cpu(const Tensor& t_in) {
const auto num_dims = t_in.dim();
TORCH_CHECK(num_dims == 1 || num_dims == 2);
auto output = num_dims == 1
? at::empty({t_in.numel() + 1}, t_in.options())
: at::empty({t_in.size(0), t_in.size(1) + 1}, t_in.options());

return asynchronous_complete_cumsum_cpu_out(output, t_in);
}

template <typename index_t, typename scalar_t>
void reorder_batched_ad_lengths_(
const Tensor& cat_ad_lengths,
Expand Down Expand Up @@ -3100,15 +2982,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"block_bucketize_sparse_features_inference(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None, bool return_bucket_mapping=False, bool keep_orig_idx=False) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?)");
m.def(
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
m.def(
"asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"asynchronous_complete_cumsum(Tensor t_in) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"reorder_batched_sequence_embeddings(Tensor cat_sequence_embeddings_offsets, Tensor cat_sequence_embeddings, Tensor reordered_cat_sequence_embeddings_offsets, Tensor batch_offsets, SymInt num_items_in_batch) -> Tensor");
m.def(
Expand Down Expand Up @@ -3214,15 +3087,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
fbgemm_gpu::block_bucketize_sparse_features_inference_cpu);
DISPATCH_TO_CPU(
"bucketize_sparse_features", fbgemm_gpu::bucketize_sparse_features_cpu);
DISPATCH_TO_CPU(
"asynchronous_exclusive_cumsum",
fbgemm_gpu::asynchronous_exclusive_cumsum_cpu);
DISPATCH_TO_CPU(
"asynchronous_inclusive_cumsum",
fbgemm_gpu::asynchronous_inclusive_cumsum_cpu);
DISPATCH_TO_CPU(
"asynchronous_complete_cumsum",
fbgemm_gpu::asynchronous_complete_cumsum_cpu);
DISPATCH_TO_CPU(
"reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_cpu);
DISPATCH_TO_CPU(
Expand Down

0 comments on commit 4d16892

Please sign in to comment.