Skip to content

Commit

Permalink
Add support for int64_t indices and offsets in TBE inference [8/N]
Browse files Browse the repository at this point in the history
Summary: - Update tests to use int64_t indices and offsets

Differential Revision: D63807049
  • Loading branch information
q10 authored and facebook-github-bot committed Oct 9, 2024
1 parent b5d2d3e commit 95dc4c2
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
const int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int64_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using hash_t = index_t;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<hash_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Linear probe
slot = (slot + 1) % capacity;
}

// Linear probe
slot = (slot + 1) % capacity;
}
}
}
}
});
});

return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t, typename hash_t>
template <typename index_t, typename remap_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<hash_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<remap_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
Expand Down Expand Up @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(
AT_DISPATCH_INDEX_TYPES(
index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] {
using hash_t = index_t;
using remap_t = index_t;
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
Expand All @@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings, hash_t, 1, 32),
func_name, index_remappings, remap_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings_offsets, int64_t, 1, 32),
B,
Expand Down
25 changes: 16 additions & 9 deletions fbgemm_gpu/test/tbe/inference/nbit_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None:
self.assertEqual(total_access_count, expected_total_access)

@unittest.skipIf(*gpu_unavailable)
@given(N=st.integers(min_value=1, max_value=8))
@given(
N=st.integers(min_value=1, max_value=8),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_cache_miss_counter(self, N: int) -> None:
def test_nbit_cache_miss_counter(self, N: int, indices_dtype: torch.dtype) -> None:
# Create an abstract split table
D = 8
T = 2
Expand Down Expand Up @@ -156,7 +159,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
cache_miss_forward_count,
unique_cache_miss_count,
Expand All @@ -173,9 +176,12 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
@given(
N=st.integers(min_value=1, max_value=8),
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
def test_nbit_uvm_cache_stats(
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
) -> None:
# Create an abstract split table
D = 8
T = 2
Expand Down Expand Up @@ -215,7 +221,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
for _ in range(N):
num_calls_expected = num_calls_expected + 1
num_indices_expcted = num_indices_expcted + len(indices)
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
num_calls,
num_indices,
Expand Down Expand Up @@ -271,7 +277,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
for x, e in zip((indices1, indices2, indices3), expected):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc1(indices.int(), offsets.int())
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
_,
_,
Expand All @@ -288,10 +294,11 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
@given(
N=st.integers(min_value=1, max_value=8),
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_direct_mapped_uvm_cache_stats(
self, N: int, dtype: SparseType
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
) -> None:
# Create an abstract split table
D = 8
Expand Down Expand Up @@ -333,7 +340,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
for _ in range(N):
num_calls_expected = num_calls_expected + 1
num_indices_expcted = num_indices_expcted + len(indices)
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
num_calls,
num_indices,
Expand Down Expand Up @@ -393,7 +400,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
for x, e in zip((indices1, indices2, indices3), expected):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc1(indices.int(), offsets.int())
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
_,
_,
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def execute_nbit_forward_( # noqa C901
use_array_for_index_remapping: bool,
do_pruning: bool,
mixed_weights_ty: bool,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
# NOTE: weighted operation can be done only for SUM.
Expand Down Expand Up @@ -311,19 +312,22 @@ def execute_nbit_forward_( # noqa C901
fp8_config=fp8_config if has_fp8_weight else None,
)

indices = indices.to(dtype=indices_dtype)
offsets = offsets.to(dtype=indices_dtype)

if not use_cpu:
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1))
else cc(indices, offsets, xw.contiguous().view(-1))
)
else:
cc = cc.cpu()
indices, offsets = indices.cpu(), offsets.cpu()
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu())
else cc(indices, offsets, xw.contiguous().view(-1).cpu())
)

if do_pooling and B == 0:
Expand Down Expand Up @@ -373,6 +377,7 @@ def execute_nbit_forward_( # noqa C901
pooling_mode=st.sampled_from(
[PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE]
),
indices_dtype=st.sampled_from([torch.int32, torch.int64]),
output_dtype=st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
),
Expand All @@ -386,6 +391,7 @@ def test_nbit_forward_cpu_autovec(
self,
nbit_weights_ty: Optional[SparseType],
pooling_mode: PoolingMode,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
use_cpu = True
Expand Down Expand Up @@ -432,6 +438,7 @@ def test_nbit_forward_cpu_autovec(
False,
False,
mixed_weights_ty,
indices_dtype,
output_dtype,
)

Expand Down
Loading

0 comments on commit 95dc4c2

Please sign in to comment.