diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 4429e580d..3b6ecace0 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu( const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] { - using uidx_t = - std::conditional_t, uint64_t, uint32_t>; - - const auto* indices_acc = indices.data_ptr(); - const auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); - - auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); - - for (const auto t : c10::irange(T)) { - const auto table_start = hash_table_offsets_acc[t]; - const auto table_end = hash_table_offsets_acc[t + 1]; - if (table_start == table_end) { - continue; - } - const auto capacity = table_end - table_start; - - for (const auto b : c10::irange(B)) { - const auto indices_start = offsets_acc[t * B + b]; - const auto indices_end = offsets_acc[t * B + b + 1]; - const auto L = indices_end - indices_start; - - for (const auto l : c10::irange(L)) { - const auto idx = indices_acc[indices_start + l]; - const auto dense_idx = dense_indices_acc[indices_start + l]; - if (dense_idx == -1) { - // -1 means this row has been pruned, do not insert it. - continue; - } + AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] { + using hash_t = index_t; - auto slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - const auto ht_idx = table_start + static_cast(slot); - const auto slot_sparse_idx = hash_table_acc[ht_idx][0]; - - // Empty slot - if (slot_sparse_idx == -1) { - hash_table_acc[ht_idx][0] = idx; - hash_table_acc[ht_idx][1] = dense_idx; - break; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] { + using uidx_t = + std::conditional_t, uint64_t, uint32_t>; + + const auto* indices_acc = indices.data_ptr(); + const auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + if (table_start == table_end) { + continue; + } + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + const auto dense_idx = dense_indices_acc[indices_start + l]; + if (dense_idx == -1) { + // -1 means this row has been pruned, do not insert it. + continue; } - - // Already exists (shouldn't happen in practice) - if (slot_sparse_idx == idx) { - hash_table_acc[ht_idx][1] = dense_idx; - break; + + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + while (true) { + const auto ht_idx = table_start + static_cast(slot); + const auto slot_sparse_idx = hash_table_acc[ht_idx][0]; + + // Empty slot + if (slot_sparse_idx == -1) { + hash_table_acc[ht_idx][0] = idx; + hash_table_acc[ht_idx][1] = dense_idx; + break; + } + + // Already exists (shouldn't happen in practice) + if (slot_sparse_idx == idx) { + hash_table_acc[ht_idx][1] = dense_idx; + break; + } + + // Linear probe + slot = (slot + 1) % capacity; } - - // Linear probe - slot = (slot + 1) % capacity; } } } - } + }); }); return; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 922c8ebd2..23692982c 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } -template +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, @@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda( AT_DISPATCH_INDEX_TYPES( index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] { - using hash_t = index_t; + using remap_t = index_t; AT_DISPATCH_INDEX_TYPES( indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] { @@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda( MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), MAKE_PTA_WITH_NAME( - func_name, index_remappings, hash_t, 1, 32), + func_name, index_remappings, remap_t, 1, 32), MAKE_PTA_WITH_NAME( func_name, index_remappings_offsets, int64_t, 1, 32), B, diff --git a/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py b/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py index 8a64162d2..e68e10ae8 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_cache_test.py @@ -118,9 +118,12 @@ def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None: self.assertEqual(total_access_count, expected_total_access) @unittest.skipIf(*gpu_unavailable) - @given(N=st.integers(min_value=1, max_value=8)) + @given( + N=st.integers(min_value=1, max_value=8), + indices_dtype=st.sampled_from([torch.int, torch.long]), + ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) - def test_nbit_cache_miss_counter(self, N: int) -> None: + def test_nbit_cache_miss_counter(self, N: int, indices_dtype: torch.dtype) -> None: # Create an abstract split table D = 8 T = 2 @@ -156,7 +159,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: ): (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False) for _ in range(N): - cc(indices.int(), offsets.int()) + cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype)) ( cache_miss_forward_count, unique_cache_miss_count, @@ -173,9 +176,12 @@ def test_nbit_cache_miss_counter(self, N: int) -> None: @given( N=st.integers(min_value=1, max_value=8), dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]), + indices_dtype=st.sampled_from([torch.int, torch.long]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) - def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: + def test_nbit_uvm_cache_stats( + self, N: int, dtype: SparseType, indices_dtype: torch.dtype + ) -> None: # Create an abstract split table D = 8 T = 2 @@ -215,7 +221,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: for _ in range(N): num_calls_expected = num_calls_expected + 1 num_indices_expcted = num_indices_expcted + len(indices) - cc(indices.int(), offsets.int()) + cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype)) ( num_calls, num_indices, @@ -271,7 +277,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: for x, e in zip((indices1, indices2, indices3), expected): (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False) for _ in range(N): - cc1(indices.int(), offsets.int()) + cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype)) ( _, _, @@ -288,10 +294,11 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None: @given( N=st.integers(min_value=1, max_value=8), dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]), + indices_dtype=st.sampled_from([torch.int, torch.long]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_nbit_direct_mapped_uvm_cache_stats( - self, N: int, dtype: SparseType + self, N: int, dtype: SparseType, indices_dtype: torch.dtype ) -> None: # Create an abstract split table D = 8 @@ -333,7 +340,7 @@ def test_nbit_direct_mapped_uvm_cache_stats( for _ in range(N): num_calls_expected = num_calls_expected + 1 num_indices_expcted = num_indices_expcted + len(indices) - cc(indices.int(), offsets.int()) + cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype)) ( num_calls, num_indices, @@ -393,7 +400,7 @@ def test_nbit_direct_mapped_uvm_cache_stats( for x, e in zip((indices1, indices2, indices3), expected): (indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False) for _ in range(N): - cc1(indices.int(), offsets.int()) + cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype)) ( _, _, diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py index 920a86cbd..368c118bb 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py @@ -105,6 +105,7 @@ def execute_nbit_forward_( # noqa C901 use_array_for_index_remapping: bool, do_pruning: bool, mixed_weights_ty: bool, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: # NOTE: weighted operation can be done only for SUM. @@ -311,19 +312,22 @@ def execute_nbit_forward_( # noqa C901 fp8_config=fp8_config if has_fp8_weight else None, ) + indices = indices.to(dtype=indices_dtype) + offsets = offsets.to(dtype=indices_dtype) + if not use_cpu: fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1)) + else cc(indices, offsets, xw.contiguous().view(-1)) ) else: cc = cc.cpu() indices, offsets = indices.cpu(), offsets.cpu() fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu()) + else cc(indices, offsets, xw.contiguous().view(-1).cpu()) ) if do_pooling and B == 0: @@ -373,6 +377,7 @@ def execute_nbit_forward_( # noqa C901 pooling_mode=st.sampled_from( [PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE] ), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), output_dtype=st.sampled_from( [SparseType.FP32, SparseType.FP16, SparseType.BF16] ), @@ -386,6 +391,7 @@ def test_nbit_forward_cpu_autovec( self, nbit_weights_ty: Optional[SparseType], pooling_mode: PoolingMode, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: use_cpu = True @@ -432,6 +438,7 @@ def test_nbit_forward_cpu_autovec( False, False, mixed_weights_ty, + indices_dtype, output_dtype, ) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 8f4c32eea..682d384f3 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -332,6 +332,7 @@ def execute_nbit_forward_( # noqa C901 use_array_for_index_remapping: bool, do_pruning: bool, mixed_weights_ty: bool, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: # NOTE: weighted operation can be done only for SUM. @@ -538,19 +539,22 @@ def execute_nbit_forward_( # noqa C901 fp8_config=fp8_config if has_fp8_weight else None, ) + indices = indices.to(dtype=indices_dtype) + offsets = offsets.to(dtype=indices_dtype) + if not use_cpu: fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1)) + else cc(indices, offsets, xw.contiguous().view(-1)) ) else: cc = cc.cpu() indices, offsets = indices.cpu(), offsets.cpu() fc2 = ( - cc(indices.int(), offsets.int()) + cc(indices, offsets) if not weighted - else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu()) + else cc(indices, offsets, xw.contiguous().view(-1).cpu()) ) if do_pooling and B == 0: @@ -594,6 +598,7 @@ def execute_nbit_forward_( # noqa C901 ) else: fc2_float = fc2.float() + torch.testing.assert_close( fc2_float.cpu(), f.float().cpu(), @@ -608,6 +613,7 @@ def execute_nbit_forward_( # noqa C901 pooling_mode=st.sampled_from( [PoolingMode.SUM, PoolingMode.NONE, PoolingMode.MEAN] ), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), output_dtype=st.sampled_from( [SparseType.FP32, SparseType.FP16, SparseType.BF16] ), @@ -623,6 +629,7 @@ def test_nbit_forward_cpu( use_array_for_index_remapping: bool, do_pruning: bool, pooling_mode: PoolingMode, + indices_dtype: torch.dtype, output_dtype: SparseType, ) -> None: use_cpu = True @@ -666,11 +673,18 @@ def test_nbit_forward_cpu( use_array_for_index_remapping, do_pruning, mixed_weights_ty, + indices_dtype, output_dtype, ) + @given( + indices_dtype=st.sampled_from([torch.int32, torch.int64]), + ) + @settings(deadline=None) @unittest.skipIf(*gpu_unavailable) - def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: + def test_nbit_forward_gpu_no_cache_fp8_2048( + self, indices_dtype: torch.dtype + ) -> None: # Test the case of FB8 table with 128B*8 < D <= 128B*16 self.execute_nbit_forward_( T=1, @@ -688,14 +702,30 @@ def test_nbit_forward_gpu_no_cache_fp8_2048(self) -> None: use_array_for_index_remapping=True, do_pruning=False, mixed_weights_ty=False, + indices_dtype=indices_dtype, output_dtype=SparseType.FP16, ) @unittest.skipIf(*gpu_unavailable) @given( - nbit_weights_ty=get_nbit_weights_ty(), - use_array_for_index_remapping=st.booleans(), + nbit_weights_ty=st.sampled_from( + [ + SparseType.FP32, + SparseType.FP16, + SparseType.FP8, + SparseType.INT8, + SparseType.INT4, + # None, + # SparseType.INT2, + ] + ), + # nbit_weights_ty=get_nbit_weights_ty(), + use_array_for_index_remapping=st.just(True), do_pruning=st.booleans(), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), + output_dtype=st.sampled_from( + [SparseType.FP32, SparseType.FP16, SparseType.BF16] + ), ) @settings( verbosity=VERBOSITY, @@ -706,7 +736,9 @@ def test_nbit_forward_gpu_no_cache( self, nbit_weights_ty: Optional[SparseType], use_array_for_index_remapping: bool, + indices_dtype: torch.dtype, do_pruning: bool, + output_dtype: SparseType, ) -> None: use_cpu = False T = random.randint(1, 50) @@ -742,9 +774,7 @@ def test_nbit_forward_gpu_no_cache( else: weights_ty: SparseType = nbit_weights_ty mixed_weights_ty = False - output_dtype = random.choice( - [SparseType.FP32, SparseType.FP16, SparseType.BF16] - ) + self.execute_nbit_forward_( T, D, @@ -761,6 +791,7 @@ def test_nbit_forward_gpu_no_cache( use_array_for_index_remapping, do_pruning, mixed_weights_ty, + indices_dtype, output_dtype, ) @@ -983,6 +1014,7 @@ def test_nbit_forward_cpu_seq_int8( T=st.integers(min_value=10, max_value=20), L=st.integers(min_value=10, max_value=100), MAXH=st.integers(min_value=50, max_value=100), + indices_dtype=st.sampled_from([torch.int32, torch.int64]), ) @settings( verbosity=VERBOSITY, @@ -996,6 +1028,7 @@ def test_nbit_forward_cpu_seq_int4( T: int, L: int, MAXH: int, + indices_dtype: torch.dtype, ) -> None: """ we init a quant table split embedding bag with int4 weights and scale of 1 and 0 bias @@ -1017,6 +1050,7 @@ def test_nbit_forward_cpu_seq_int4( use_array_for_index_remapping=True, do_pruning=False, mixed_weights_ty=False, + indices_dtype=indices_dtype, output_dtype=SparseType.INT4, ) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index a1433e61d..9b9bf27a3 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -191,12 +191,14 @@ def test_nbit_split_embedding_weights_with_scale_and_bias( ] ), emulate_pruning=st.booleans(), + indices_dtype=st.sampled_from([torch.int, torch.int64]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( self, weights_ty: SparseType, emulate_pruning: bool, + indices_dtype: torch.dtype, ) -> None: # TODO: support direct-mapped in int_nbit_split_embedding_uvm_caching_codegen_lookup_function # This test is for int_nbit_split_embedding_uvm_caching_codegen_lookup_function. @@ -260,8 +262,8 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( ) for req in requests: indices, offsets = req.unpack_2() - indices = indices.int() - offsets = offsets.int() + indices = indices.to(dtype=indices_dtype) + offsets = offsets.to(dtype=indices_dtype) output_ref = cc_ref(indices, offsets) # int_nbit_split_embedding_uvm_caching_codegen_lookup_function for UVM_CACHING.