diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 92eff015f..4429e580d 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -64,55 +64,70 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); - int32_t T = hash_table_offsets.size(0) - 1; - int32_t B = (offsets.size(0) - 1) / T; + const int32_t T = hash_table_offsets.size(0) - 1; + const int32_t B = (offsets.size(0) - 1) / T; TORCH_CHECK(B > 0); - const auto* indices_acc = indices.data_ptr(); - const auto* dense_indices_acc = dense_indices.data_ptr(); - - const auto* offsets_acc = offsets.data_ptr(); - auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); -for (const auto t : c10::irange(T)) { - int64_t table_start = hash_table_offsets_acc[t]; - int64_t table_end = hash_table_offsets_acc[t + 1]; - if (table_start == table_end) { - continue; - } - int64_t capacity = table_end - table_start; -for (const auto b : c10::irange(B)) { - int32_t indices_start = offsets_acc[t * B + b]; - int32_t indices_end = offsets_acc[t * B + b + 1]; - int32_t L = indices_end - indices_start; -for (const auto l : c10::irange(L)) { - int32_t idx = indices_acc[indices_start + l]; - int32_t dense_idx = dense_indices_acc[indices_start + l]; - if (dense_idx == -1) { - // -1 means this row has been pruned, do not insert it. - continue; - } - uint32_t slot = pruned_hash_function(static_cast(idx)) % capacity; - while (true) { - int32_t slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - // empty slot - if (slot_sparse_idx == -1) { - hash_table_acc[table_start + static_cast(slot)][0] = idx; - hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; - break; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] { + using uidx_t = + std::conditional_t, uint64_t, uint32_t>; + + const auto* indices_acc = indices.data_ptr(); + const auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); + + auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); + + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + if (table_start == table_end) { + continue; + } + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + const auto dense_idx = dense_indices_acc[indices_start + l]; + if (dense_idx == -1) { + // -1 means this row has been pruned, do not insert it. + continue; } - // already exists (shouldn't happen in practice) - if (slot_sparse_idx == idx) { - hash_table_acc[table_start + static_cast(slot)][1] = dense_idx; - break; + + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + while (true) { + const auto ht_idx = table_start + static_cast(slot); + const auto slot_sparse_idx = hash_table_acc[ht_idx][0]; + + // Empty slot + if (slot_sparse_idx == -1) { + hash_table_acc[ht_idx][0] = idx; + hash_table_acc[ht_idx][1] = dense_idx; + break; + } + + // Already exists (shouldn't happen in practice) + if (slot_sparse_idx == idx) { + hash_table_acc[ht_idx][1] = dense_idx; + break; + } + + // Linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } - } + }); + return; } @@ -414,7 +429,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(hash_table); TENSOR_ON_CPU(hash_table_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); int32_t T = hash_table_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; @@ -422,57 +437,63 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu( auto dense_indices = empty_like(indices); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu", [&] { - using hash_t = - std::conditional_t, uint64_t, uint32_t>; + AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu_0", [&] { + using hash_t = index_t; - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_{{ wdesc }}_cpu_1", [&] { + using utdx_t = + std::conditional_t, uint64_t, uint32_t>; - const auto* offsets_acc = offsets.data_ptr(); - const auto hash_table_acc = hash_table.accessor(); - const auto hash_table_offsets_acc = hash_table_offsets.accessor(); - - for (const auto t : c10::irange(T)) { - const auto table_start = hash_table_offsets_acc[t]; - const auto table_end = hash_table_offsets_acc[t + 1]; - const auto capacity = table_end - table_start; + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); - for (const auto b : c10::irange(B)) { - const auto indices_start = offsets_acc[t * B + b]; - const auto indices_end = offsets_acc[t * B + b + 1]; - const auto L = indices_end - indices_start; + const auto hash_table_acc = hash_table.accessor(); + const auto hash_table_offsets_acc = hash_table_offsets.accessor(); - if (table_start == table_end) { - for (const auto l : c10::irange(L)) { - dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; - } - - } else { - for (const auto l : c10::irange(L)) { - const auto idx = indices_acc[indices_start + l]; - auto slot = pruned_hash_function(static_cast(idx)) % capacity; - - while (true) { - const auto slot_sparse_idx = hash_table_acc[table_start + static_cast(slot)][0]; - - // empty slot - if (slot_sparse_idx == -1) { - dense_indices_acc[indices_start + l] = -1; - break; - } - // already exists - if (slot_sparse_idx == idx) { - dense_indices_acc[indices_start + l] = hash_table_acc[table_start + static_cast(slot)][1]; - break; + for (const auto t : c10::irange(T)) { + const auto table_start = hash_table_offsets_acc[t]; + const auto table_end = hash_table_offsets_acc[t + 1]; + const auto capacity = table_end - table_start; + + for (const auto b : c10::irange(B)) { + const auto indices_start = offsets_acc[t * B + b]; + const auto indices_end = offsets_acc[t * B + b + 1]; + const auto L = indices_end - indices_start; + + if (table_start == table_end) { + for (const auto l : c10::irange(L)) { + dense_indices_acc[indices_start + l] = indices_acc[indices_start + l]; + } + + } else { + for (const auto l : c10::irange(L)) { + const auto idx = indices_acc[indices_start + l]; + auto slot = pruned_hash_function(static_cast(idx)) % capacity; + + while (true) { + const auto ht_idx = table_start + static_cast(slot); + const auto slot_sparse_idx = hash_table_acc[ht_idx][0]; + + // Empty slot + if (slot_sparse_idx == -1) { + dense_indices_acc[indices_start + l] = -1; + break; + } + // Already exists + if (slot_sparse_idx == idx) { + dense_indices_acc[indices_start + l] = static_cast(hash_table_acc[ht_idx][1]); + break; + } + + // Linear probe + slot = (slot + 1) % capacity; } - // linear probe - slot = (slot + 1) % capacity; } } } } - } + }); }); return dense_indices; @@ -489,7 +510,7 @@ Tensor pruned_array_lookup_cpu( TENSOR_ON_CPU(offsets); TENSOR_ON_CPU(index_remappings); TENSOR_ON_CPU(index_remappings_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); int32_t T = index_remappings_offsets.size(0) - 1; int32_t B = (offsets.size(0) - 1) / T; @@ -497,35 +518,39 @@ Tensor pruned_array_lookup_cpu( auto dense_indices = empty_like(indices); - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu", [&] { - const auto* indices_acc = indices.data_ptr(); - auto* dense_indices_acc = dense_indices.data_ptr(); - const auto* offsets_acc = offsets.data_ptr(); + AT_DISPATCH_INDEX_TYPES(index_remappings.scalar_type(), "pruned_array_lookup_cpu_0", [&] { + using hash_t = index_t; + + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cpu_1", [&] { + const auto* indices_acc = indices.data_ptr(); + auto* dense_indices_acc = dense_indices.data_ptr(); + const auto* offsets_acc = offsets.data_ptr(); - const auto index_remappings_acc = index_remappings.data_ptr(); - const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); + const auto index_remappings_acc = index_remappings.data_ptr(); + const auto index_remappings_offsets_acc = index_remappings_offsets.data_ptr(); - at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { - for (const auto t : c10::irange(begin, end)) { - const auto index_remappings_start = index_remappings_offsets_acc[t]; - const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; - const auto capacity = index_remappings_end - index_remappings_start; + at::parallel_for(0, T, 1, [&](int64_t begin, int64_t end) { + for (const auto t : c10::irange(begin, end)) { + const auto index_remappings_start = index_remappings_offsets_acc[t]; + const auto index_remappings_end = index_remappings_offsets_acc[t + 1]; + const auto capacity = index_remappings_end - index_remappings_start; - const auto indices_start = offsets_acc[t * B]; - const auto indices_end = offsets_acc[(t + 1) * B]; + const auto indices_start = offsets_acc[t * B]; + const auto indices_end = offsets_acc[(t + 1) * B]; - if (capacity > 0) { - for (const auto i : c10::irange(indices_start, indices_end)) { - auto idx = indices_acc[i]; - dense_indices_acc[i] = index_remappings_acc[index_remappings_start + idx]; + if (capacity > 0) { + for (const auto i : c10::irange(indices_start, indices_end)) { + auto idx = indices_acc[i]; + dense_indices_acc[i] = static_cast(index_remappings_acc[index_remappings_start + idx]); + } + } else { + std::memcpy( + dense_indices_acc + indices_start, + indices_acc + indices_start, + (indices_end - indices_start) * sizeof(index_t)); } - } else { - std::memcpy( - dense_indices_acc + indices_start, - indices_acc + indices_start, - (indices_end - indices_start) * sizeof(index_t)); } - } + }); }); }); diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu index 846cd4763..922c8ebd2 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu @@ -14,14 +14,14 @@ using Tensor = at::Tensor; namespace nbit { -template +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel( const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor64 + const pta::PackedTensorAccessor64 hash_table, const pta::PackedTensorAccessor32 hash_table_offsets, @@ -52,9 +52,13 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru return; } - using hash_t = + using uidx_t = std::conditional_t, uint64_t, uint32_t>; + // Use nv type of size (hash_t x 2) + using nv_hash_t = + std::conditional_t, longlong2, int2>; + const uint32_t subwarp_id = threadIdx.x / 4; const uint32_t subwarp_tid = threadIdx.x % 4; #ifdef USE_ROCM @@ -66,15 +70,15 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru for (int32_t l_start = 0; l_start + subwarp_id < L; l_start += kWarpSize / 4) { const index_t idx = indices[indices_start + l_start + subwarp_id]; - hash_t slot_start = - pruned_hash_function(static_cast(idx)) % capacity; + auto slot_start = pruned_hash_function(static_cast(idx)) % capacity; while (true) { - const hash_t slot = (slot_start + subwarp_tid) % capacity; - const int2 val = *reinterpret_cast( + const auto slot = (slot_start + subwarp_tid) % capacity; + + const nv_hash_t val = *reinterpret_cast( &hash_table[table_start + static_cast(slot)][0]); - const int32_t slot_sparse_idx = val.x; - const int32_t slot_dense_idx = val.y; + const auto slot_sparse_idx = val.x; + const auto slot_dense_idx = val.y; bool found = false; bool empty = false; @@ -96,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru } } -template +template __global__ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel( const pta::PackedTensorAccessor32 indices, const pta::PackedTensorAccessor32 offsets, - const pta::PackedTensorAccessor32 + const pta::PackedTensorAccessor32 index_remappings, const pta::PackedTensorAccessor32 index_remappings_offsets, @@ -129,7 +133,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru for (index_t l = threadIdx.x; l < L; l += blockDim.x) { index_t idx = indices[indices_start + l]; dense_indices[indices_start + l] = - index_remappings[index_remappings_start + idx]; + static_cast(index_remappings[index_remappings_start + idx]); } } else { for (index_t l = threadIdx.x; l < L; l += blockDim.x) { @@ -149,7 +153,7 @@ Tensor pruned_hashmap_lookup_cuda( Tensor hash_table_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, hash_table, hash_table_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, hash_table); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); CUDA_DEVICE_GUARD(indices); @@ -160,25 +164,32 @@ Tensor pruned_hashmap_lookup_cuda( TORCH_CHECK(hash_table.size(0) < std::numeric_limits::max()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_lookup_cuda", [&] { + AT_DISPATCH_INDEX_TYPES( + hash_table.scalar_type(), "pruned_hashmap_lookup_cuda_0", [&] { + using hash_t = index_t; + + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "pruned_hashmap_lookup_cuda_1", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel"; #endif - int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< - nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, hash_table, index_t, 2, 64), - MAKE_PTA_WITH_NAME(func_name, hash_table_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); - }); + int_nbit_split_embedding_codegen_forward_pruned_hashmap_lookup_kernel<<< + nbit::div_round_up(B * T + 1, kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, hash_table, hash_t, 2, 64), + MAKE_PTA_WITH_NAME( + func_name, hash_table_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; @@ -191,7 +202,7 @@ Tensor pruned_array_lookup_cuda( Tensor index_remappings_offsets) { TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( indices, offsets, index_remappings, index_remappings_offsets); - TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets, index_remappings); + TENSORS_HAVE_SAME_SCALAR_TYPE(indices, offsets); CUDA_DEVICE_GUARD(indices); @@ -218,25 +229,34 @@ Tensor pruned_array_lookup_cuda( TORCH_CHECK(dense_indices.dim() == 1, "Tensor dim: ", dense_indices.dim()); constexpr size_t kForwardMaxThreads = 256; - AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_array_lookup_cuda", [&] { + AT_DISPATCH_INDEX_TYPES( + index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] { + using hash_t = index_t; + + AT_DISPATCH_INDEX_TYPES( + indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] { #ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = - "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; + const auto func_name = + "int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel"; #endif - int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< - nbit::div_round_up(offsets.size(0), kForwardMaxThreads / kWarpSize), - dim3(kWarpSize, kForwardMaxThreads / kWarpSize), - 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings, index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name, index_remappings_offsets, int64_t, 1, 32), - B, - T, - MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); - }); + int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel<<< + nbit::div_round_up( + offsets.size(0), kForwardMaxThreads / kWarpSize), + dim3(kWarpSize, kForwardMaxThreads / kWarpSize), + 0, + at::cuda::getCurrentCUDAStream()>>>( + MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, index_remappings, hash_t, 1, 32), + MAKE_PTA_WITH_NAME( + func_name, index_remappings_offsets, int64_t, 1, 32), + B, + T, + MAKE_PTA_WITH_NAME(func_name, dense_indices, index_t, 1, 32)); + }); + }); C10_CUDA_KERNEL_LAUNCH_CHECK(); return dense_indices; diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index d90f12e0a..306a93728 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -2460,43 +2460,106 @@ def need_split_k(SIZE_M, SIZE_N, SIZE_K): return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 +# Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py def prune_configs(configs, named_args, **kwargs): + pruned_configs = [] + M = named_args["M"] + N = named_args["N"] + K = named_args["K"] + elemBytes_a = named_args["A"].element_size() + elemBytes_b = named_args["B"].element_size() + + if M < 32 or N < 32: + mfma = 16 + else: + mfma = 32 - SIZE_M = named_args["A"].shape[0] - SIZE_N = named_args["B"].shape[1] - SIZE_K = named_args["C"].shape[1] + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True - pruned_configs = [] for config in configs: - kw = config.kwargs - BLOCK_SIZE_M, BLOCK_SIZE_N, _ = ( - kw["BLOCK_M"], - kw["BLOCK_N"], - kw["BLOCK_K"], - ) - SPLIT_K = kw["SPLIT_K"] - if SIZE_M <= 32 and BLOCK_SIZE_M != 32: + BLOCK_SIZE_M = config.kwargs.get("BLOCK_M") + BLOCK_SIZE_N = config.kwargs.get("BLOCK_N") + BLOCK_SIZE_K = config.kwargs.get("BLOCK_K") + num_warps = config.num_warps + matrix_instr_nonkdim = config.kwargs.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elemens per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.kwargs.get("SPLIT_K") + GROUP_M = config.kwargs.get("GROUP_M") + if BLOCK_SIZE_M < matrix_instr_nonkdim or BLOCK_SIZE_N < matrix_instr_nonkdim: + continue + if M <= matrix_instr_nonkdim and BLOCK_SIZE_M != matrix_instr_nonkdim: + continue + if N <= matrix_instr_nonkdim and BLOCK_SIZE_N != matrix_instr_nonkdim: continue - if SIZE_N <= 32 and BLOCK_SIZE_N != 32: + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if BLOCK_SIZE_M > M * 2 and BLOCK_SIZE_M != 16: + continue + if BLOCK_SIZE_N > N * 2 and BLOCK_SIZE_N != 16: continue # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(SIZE_M, SIZE_N, SIZE_K): + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) + if LDS > 65536: continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + pruned_configs.append(config) - logging.info(f"pruned_configs: config len{len(pruned_configs)}") + + print(f"{len(configs)=} {len(pruned_configs)=}") + if len(pruned_configs) == 0: + print(f"No configs left after pruning! {M=} {N=} {K=}") + pruned_configs = configs[:10] return pruned_configs -def get_full_non_persistent_tuning_space(use_split_k): - if torch.version.hip is None: - logger.warning("Using HIP configs on CUDA device, this may be slow.") +def get_full_non_persistent_tuning_space(): configs = [] - block_mn_range = [32, 64, 128, 256] - block_k_range = [32, 64, 128] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] split_k_range = [1] - num_warps_range = [1, 2, 4, 8, 16] - group_m_range = [1, 4, 8] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 2, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] for block_m in block_mn_range: for block_n in block_mn_range: @@ -2505,28 +2568,36 @@ def get_full_non_persistent_tuning_space(use_split_k): for group_m in group_m_range: for split_k in split_k_range: for num_stages in num_stage_range: - configs.append( - triton.Config( - { - "BLOCK_M": block_m, - "BLOCK_N": block_n, - "BLOCK_K": block_k, - "GROUP_M": group_m, - "SPLIT_K": split_k, - }, - num_stages=num_stages, - num_warps=num_warps, - ) - ) - + for waves_per_eu in waves_per_eu_range: + for ( + matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append( + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "SPLIT_K": split_k, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "kpack": kpack, + }, + num_warps=num_warps, + num_stages=num_stages, + ) + ) + logger.info(f"all configs #: {len(configs)}") return configs -MATMUL_CONFIGS: List[Config] = get_full_non_persistent_tuning_space(True) +MATMUL_CONFIGS_NON_PERSISTENT: List[Config] = get_full_non_persistent_tuning_space() @triton.autotune( - configs=MATMUL_CONFIGS, + configs=MATMUL_CONFIGS_NON_PERSISTENT, key=["M", "N", "K"], prune_configs_by={ "early_config_prune": prune_configs, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index f1671d29d..541297b90 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -397,13 +397,14 @@ def max_ty_D(ty: SparseType) -> int: self.assign_embedding_weights(weight_lists) # Handle index remapping for embedding pruning. + # All buffers are int64 in order to support both int32 and int64 indices. self.register_buffer( "index_remappings_array_offsets", torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "index_remappings_array", - torch.empty(0, device=self.current_device, dtype=torch.int32), + torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "index_remapping_hash_table_offsets", @@ -411,7 +412,7 @@ def max_ty_D(ty: SparseType) -> int: ) self.register_buffer( "index_remapping_hash_table", - torch.empty(0, device=self.current_device, dtype=torch.int32), + torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "original_rows_per_table", @@ -946,8 +947,9 @@ def reset_embedding_spec_location( @torch.jit.export def recompute_module_buffers(self) -> None: """ - Compute module buffers that're on meta device and are not materialized in reset_weights_placements_and_offsets(). - Currently those buffers are `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`. + Compute module buffers that're on meta device and are not materialized + in reset_weights_placements_and_offsets(). Currently those buffers are + `weights_tys`, `rows_per_table`, `D_offsets` and `bounds_check_warning`. Pruning related or uvm related buffers are not computed right now. """ if ( @@ -1527,11 +1529,11 @@ def set_index_remappings_array( index_remappings_filter_nones.append(mapping) if len(index_remappings_filter_nones) == 0: self.index_remappings_array = torch.empty( - 0, dtype=torch.int32, device=self.current_device + 0, dtype=torch.int64, device=self.current_device ) else: self.index_remappings_array = torch.cat(index_remappings_filter_nones).to( - self.current_device + dtype=torch.int64, device=self.current_device ) def set_index_remappings( @@ -1554,7 +1556,7 @@ def set_index_remappings( ] hash_table = torch.empty( (sum(capacities), 2), - dtype=torch.int32, + dtype=torch.int64, ) hash_table[:, :] = -1 hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long() diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu index 11948bb22..8b78bb854 100644 --- a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -24,7 +24,7 @@ using Tensor = at::Tensor; namespace fbgemm_gpu { -// Kernerl for permute pooled embedding op. +// Kernel for permute pooled embedding op. // This kernel is moving D elements per warp. template __global__ void permute_multi_embs_kernel( @@ -40,7 +40,7 @@ __global__ void permute_multi_embs_kernel( const int32_t permute_size) { // workers in a warp handle exact one permute (of a feature/key) const int32_t worker_id = threadIdx.x; - const int32_t permute_id = threadIdx.y + blockIdx.x * blockDim.x; + const int32_t permute_id = threadIdx.y + blockIdx.x * blockDim.y; const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; if (batch_id >= batch_size) { return; diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py index 1aee221a2..920a86cbd 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py @@ -87,7 +87,7 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]: # @optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) -class NBitFowardTest(unittest.TestCase): +class NBitFowardAutovecTest(unittest.TestCase): def execute_nbit_forward_( # noqa C901 self, T: int, diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index f2872bb4e..8f4c32eea 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -122,6 +122,11 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]: "Operator outputs int4 tensors which do not support opcheck tests" ), }, + "test_pt2_compliant_tag_fbgemm_int_nbit_split_embedding_codegen_lookup_function": [ + unittest.skip( + "Operator outputs int4 tensors which do not support opcheck tests" + ), + ], } diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 439797688..a1433e61d 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -11,6 +11,7 @@ import random import unittest +from typing import Callable, Dict, List import hypothesis.strategies as st import numpy as np @@ -38,8 +39,22 @@ VERBOSITY: Verbosity = Verbosity.verbose +# pyre-ignore +additional_decorators: Dict[str, List[Callable]] = { + "test_faketensor__test_nbit_forward_cpu_seq_int4": { + unittest.skip( + "Operator outputs int4 tensors which do not support opcheck tests" + ), + }, + "test_pt2_compliant_tag_fbgemm_int_nbit_split_embedding_codegen_lookup_function": { + unittest.skip( + "Operator outputs int4 tensors which do not support opcheck tests" + ), + }, +} + -@optests.generate_opcheck_tests(fast=True) +@optests.generate_opcheck_tests(fast=True, additional_decorators=additional_decorators) class NBitSplitEmbeddingsTest(unittest.TestCase): @unittest.skipIf(*gpu_unavailable) @given( diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index 1d475909c..5d9b3eabe 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -469,7 +469,7 @@ def test_pruning( # Initialize and insert Hashmap index remapping based data structure hash_table = torch.empty( (sum(capacities), 2), - dtype=torch.int32, + dtype=torch.int64, ) hash_table[:, :] = -1 hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long() @@ -486,7 +486,7 @@ def test_pruning( # Initialize and insert Array index remapping based data structure index_remappings_array = torch.tensor( [-1] * original_E * T, - dtype=torch.int32, + dtype=torch.int64, device=current_device, ) index_remappings_array_offsets = torch.empty( @@ -498,7 +498,7 @@ def test_pruning( for t in range(T): indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device) dense_indice_t = ( - (dense_indices.view(T, B, L))[t].view(-1).to(current_device) + (dense_indices.view(T, B, L))[t].long().view(-1).to(current_device) ) selected_indices = torch.add(indice_t, t * original_E)[:E] index_remappings_array[selected_indices] = dense_indice_t