Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int64_t indices and offsets in TBE inference [8/N] #3233

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
const int32_t B = (offsets.size(0) - 1) / T;
TORCH_CHECK(B > 0);

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<int64_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using hash_t = index_t;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
using uidx_t =
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;

const auto* indices_acc = indices.data_ptr<index_t>();
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
const auto* offsets_acc = offsets.data_ptr<index_t>();

auto hash_table_acc = hash_table.accessor<hash_t, 2>();
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();

for (const auto t : c10::irange(T)) {
const auto table_start = hash_table_offsets_acc[t];
const auto table_end = hash_table_offsets_acc[t + 1];
if (table_start == table_end) {
continue;
}
const auto capacity = table_end - table_start;

for (const auto b : c10::irange(B)) {
const auto indices_start = offsets_acc[t * B + b];
const auto indices_end = offsets_acc[t * B + b + 1];
const auto L = indices_end - indices_start;

for (const auto l : c10::irange(L)) {
const auto idx = indices_acc[indices_start + l];
const auto dense_idx = dense_indices_acc[indices_start + l];
if (dense_idx == -1) {
// -1 means this row has been pruned, do not insert it.
continue;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;

auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
while (true) {
const auto ht_idx = table_start + static_cast<int64_t>(slot);
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];

// Empty slot
if (slot_sparse_idx == -1) {
hash_table_acc[ht_idx][0] = idx;
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Already exists (shouldn't happen in practice)
if (slot_sparse_idx == idx) {
hash_table_acc[ht_idx][1] = dense_idx;
break;
}

// Linear probe
slot = (slot + 1) % capacity;
}

// Linear probe
slot = (slot + 1) % capacity;
}
}
}
}
});
});

return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
}
}

template <typename index_t, typename hash_t>
template <typename index_t, typename remap_t>
__global__
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
offsets,
const pta::PackedTensorAccessor32<hash_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<remap_t, 1, at::RestrictPtrTraits>
index_remappings,
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
index_remappings_offsets,
Expand Down Expand Up @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(

AT_DISPATCH_INDEX_TYPES(
index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] {
using hash_t = index_t;
using remap_t = index_t;

AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
Expand All @@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda(
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings, hash_t, 1, 32),
func_name, index_remappings, remap_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, index_remappings_offsets, int64_t, 1, 32),
B,
Expand Down
25 changes: 16 additions & 9 deletions fbgemm_gpu/test/tbe/inference/nbit_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,12 @@ def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None:
self.assertEqual(total_access_count, expected_total_access)

@unittest.skipIf(*gpu_unavailable)
@given(N=st.integers(min_value=1, max_value=8))
@given(
N=st.integers(min_value=1, max_value=8),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_cache_miss_counter(self, N: int) -> None:
def test_nbit_cache_miss_counter(self, N: int, indices_dtype: torch.dtype) -> None:
# Create an abstract split table
D = 8
T = 2
Expand Down Expand Up @@ -156,7 +159,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
cache_miss_forward_count,
unique_cache_miss_count,
Expand All @@ -173,9 +176,12 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
@given(
N=st.integers(min_value=1, max_value=8),
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
def test_nbit_uvm_cache_stats(
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
) -> None:
# Create an abstract split table
D = 8
T = 2
Expand Down Expand Up @@ -215,7 +221,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
for _ in range(N):
num_calls_expected = num_calls_expected + 1
num_indices_expcted = num_indices_expcted + len(indices)
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
num_calls,
num_indices,
Expand Down Expand Up @@ -271,7 +277,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
for x, e in zip((indices1, indices2, indices3), expected):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc1(indices.int(), offsets.int())
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
_,
_,
Expand All @@ -288,10 +294,11 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
@given(
N=st.integers(min_value=1, max_value=8),
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
indices_dtype=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_nbit_direct_mapped_uvm_cache_stats(
self, N: int, dtype: SparseType
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
) -> None:
# Create an abstract split table
D = 8
Expand Down Expand Up @@ -333,7 +340,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
for _ in range(N):
num_calls_expected = num_calls_expected + 1
num_indices_expcted = num_indices_expcted + len(indices)
cc(indices.int(), offsets.int())
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
num_calls,
num_indices,
Expand Down Expand Up @@ -393,7 +400,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
for x, e in zip((indices1, indices2, indices3), expected):
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
for _ in range(N):
cc1(indices.int(), offsets.int())
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
(
_,
_,
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def execute_nbit_forward_( # noqa C901
use_array_for_index_remapping: bool,
do_pruning: bool,
mixed_weights_ty: bool,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
# NOTE: weighted operation can be done only for SUM.
Expand Down Expand Up @@ -311,19 +312,22 @@ def execute_nbit_forward_( # noqa C901
fp8_config=fp8_config if has_fp8_weight else None,
)

indices = indices.to(dtype=indices_dtype)
offsets = offsets.to(dtype=indices_dtype)

if not use_cpu:
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1))
else cc(indices, offsets, xw.contiguous().view(-1))
)
else:
cc = cc.cpu()
indices, offsets = indices.cpu(), offsets.cpu()
fc2 = (
cc(indices.int(), offsets.int())
cc(indices, offsets)
if not weighted
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu())
else cc(indices, offsets, xw.contiguous().view(-1).cpu())
)

if do_pooling and B == 0:
Expand Down Expand Up @@ -373,6 +377,7 @@ def execute_nbit_forward_( # noqa C901
pooling_mode=st.sampled_from(
[PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE]
),
indices_dtype=st.sampled_from([torch.int32, torch.int64]),
output_dtype=st.sampled_from(
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
),
Expand All @@ -386,6 +391,7 @@ def test_nbit_forward_cpu_autovec(
self,
nbit_weights_ty: Optional[SparseType],
pooling_mode: PoolingMode,
indices_dtype: torch.dtype,
output_dtype: SparseType,
) -> None:
use_cpu = True
Expand Down Expand Up @@ -432,6 +438,7 @@ def test_nbit_forward_cpu_autovec(
False,
False,
mixed_weights_ty,
indices_dtype,
output_dtype,
)

Expand Down
Loading
Loading