Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [7C/N] (#…
Browse files Browse the repository at this point in the history
…3215)

Summary:
Pull Request resolved: #3215

X-link: facebookresearch/FBGEMM#312

- Default `index_remapping` in `pruned_array_lookup` to be `int64_t`, since they are set up before `indices` and `offsets` are passed in

Reviewed By: spcyppt

Differential Revision: D63778645

fbshipit-source-id: 270834722e8b7b7b316e5f1b3d29763601b2ae67
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 5, 2024
1 parent 42dca08 commit 9408072
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 166 deletions.
241 changes: 133 additions & 108 deletions fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,55 +64,70 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(hash_table);
TENSOR_ON_CPU(hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = hash_table_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
const int32_t T = hash_table_offsets.size(0) - 1;
const int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);
const auto* indices_acc = indices.data_ptr<int32_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<int32_t>();

const auto* offsets_acc = offsets.data_ptr<int32_t>();
auto hash_table_acc = hash_table.accessor<int32_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
for (const auto t : c10::irange(T)) {
int64_t table_start = hash_table_offsets_acc[t];
int64_t table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
int64_t capacity = table_end - table_start;
for (const auto b : c10::irange(B)) {
int32_t indices_start = offsets_acc[t * B + b];
int32_t indices_end = offsets_acc[t * B + b + 1];
int32_t L = indices_end - indices_start;
for (const auto l : c10::irange(L)) {
int32_t idx = indices_acc[indices_start + l];
int32_t dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}

uint32_t slot = pruned_hash_function(static_cast<uint32_t>(idx)) % capacity;
while (true) {
int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];
// empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][0] = idx;
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int64_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}
// already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[table_start + static_cast<int64_t>(slot)][1] = dense_idx;
break;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Linear probe
slot = (slot + 1) % capacity;
}
// linear probe
slot = (slot + 1) % capacity;
}
}
}
}
});

return;
}

Expand Down Expand Up @@ -414,65 +429,71 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(hash_table);
TENSOR_ON_CPU(hash_table_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = hash_table_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

auto dense_indices = empty_like(indices);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] {
using hash_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu_0", [&] {
using hash_t = index_t;

const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu_1", [&] {
using utdx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* offsets_acc = offsets.data_ptr<index_t>();
const auto hash_table_acc = hash_table.accessor<index_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
const auto capacity = table_end - table_start;
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;
const auto hash_table_acc = hash_table.accessor<hash_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

if (table_start == table_end) {
for (const auto l : c10::irange(L)) {
dense_indices_acc[indices_start + l] = indices_acc[indices_start + l];
}

} else {
for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
auto slot = pruned_hash_function(static_cast<hash_t>(idx)) % capacity;

while (true) {
const auto slot_sparse_idx = hash_table_acc[table_start + static_cast<int64_t>(slot)][0];

// empty slot
if (slot_sparse_idx == -1) {
dense_indices_acc[indices_start + l] = -1;
break;
}
// already exists
if (slot_sparse_idx == idx) {
dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast<int64_t>(slot)][1];
break;
for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

if (table_start == table_end) {
for (const auto l : c10::irange(L)) {
dense_indices_acc[indices_start + l] = indices_acc[indices_start + l];
}

} else {
for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
auto slot = pruned_hash_function(static_cast<utdx_t>(idx)) % capacity;

while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
dense_indices_acc[indices_start + l] = -1;
break;
}
// Already exists
if (slot_sparse_idx == idx) {
dense_indices_acc[indices_start + l] = static_cast<index_t>(hash_table_acc[ht_idx][1]);
break;
}

// Linear probe
slot = (slot + 1) % capacity;
}
// linear probe
slot = (slot + 1) % capacity;
}
}
}
}
}
});
});

return dense_indices;
Expand All @@ -489,43 +510,47 @@ Tensor pruned_array_lookup_cpu(
TENSOR_ON_CPU(offsets);
TENSOR_ON_CPU(index_remappings);
TENSOR_ON_CPU(index_remappings_offsets);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings);
TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets);

int32_t T = index_remappings_offsets.size(0) - 1;
int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

auto dense_indices = empty_like(indices);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();
AT_DISPATCH_INDEX_TYPES(index_remappings.scalar_type(), "pruned_array_lookup_cpu_0", [&] {
using hash_t = index_t;

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu_1", [&] {
const auto* indices_acc = indices.data_ptr<index_t>();
auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

const auto index_remappings_acc = index_remappings.data_ptr<index_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();
const auto index_remappings_acc = index_remappings.data_ptr<hash_t>();
const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr<int64_t>();

at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
const auto index_remappings_start = index_remappings_offsets_acc[t];
const auto index_remappings_end = index_remappings_offsets_acc[t + 1];
const auto capacity = index_remappings_end - index_remappings_start;
at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) {
for (const auto t : c10::irange(begin, end)) {
const auto index_remappings_start = index_remappings_offsets_acc[t];
const auto index_remappings_end = index_remappings_offsets_acc[t + 1];
const auto capacity = index_remappings_end - index_remappings_start;

const auto indices_start = offsets_acc[t * B];
const auto indices_end = offsets_acc[(t + 1) * B];
const auto indices_start = offsets_acc[t * B];
const auto indices_end = offsets_acc[(t + 1) * B];

if (capacity > 0) {
for (const auto i : c10::irange(indices_start, indices_end)) {
auto idx = indices_acc[i];
dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx];
if (capacity > 0) {
for (const auto i : c10::irange(indices_start, indices_end)) {
auto idx = indices_acc[i];
dense_indices_acc[i] = static_cast<index_t>(index_remappings_acc[index_remappings_start + idx]);
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(index_t));
}
} else {
std::memcpy(
dense_indices_acc + indices_start,
indices_acc + indices_start,
(indices_end - indices_start) * sizeof(index_t));
}
}
});
});
});

Expand Down
Loading

0 comments on commit 9408072

Please sign in to comment.