From 3602496078ae48f6360310ef3c9490a48b56796d Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Sat, 5 Oct 2024 11:35:18 +0000 Subject: [PATCH] 2024-10-05 nightly release (7a4472a8d14912ab7a3b7ca12bca030448f8fec8) --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 1 + .../bf16i4bf16_rowwise_batched.cu | 298 ++++++++++++++++++ .../f8f8bf16_rowwise_batched.cu | 13 + .../cutlass_extensions/include/kernel_mode.h | 16 +- .../gen_ai/src/quantize/quantize.cpp | 28 +- .../gen_ai/test/quantize/quantize_test.py | 60 ++-- .../sparse_pack_segments_backward.cu | 16 +- fbgemm_gpu/test/release/utils.py | 2 + fbgemm_gpu/test/sparse/failures_dict.json | 10 + fbgemm_gpu/test/sparse/pack_segments_test.py | 174 +++++++++- 10 files changed, 581 insertions(+), 37 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 5accb9c53..dd6f165ed 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -49,6 +49,7 @@ else() src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu src/quantize/cutlass_extensions/i8i8bf16_dynamic.cu src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu + src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu src/quantize/quantize.cu src/quantize/quantize.cpp) endif() diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu new file mode 100644 index 000000000..871543a2f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu @@ -0,0 +1,298 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +#include "cutlass_extensions/include/kernel_mode.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool PONG, + typename WEIGHT_SCALE_DTYPE> +at::Tensor bf16i4bf16_rowwise_batched_impl( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // XQ: B x M x K + // WQ: B x N x K + // output: B x M x N + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + int K = X.size(2); + + int num_groups = w_scale.size(0) / B; + + TORCH_CHECK(X.is_cuda() && X.is_contiguous()); + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(w_scale.is_cuda() && w_scale.is_contiguous()); + TORCH_CHECK(w_zp.is_cuda() && w_zp.is_contiguous()); + TORCH_CHECK(K >= num_groups && K % num_groups == 0); + + int group_size = K / num_groups; + + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + + using ElementInputA = cutlass::bfloat16_t; + using LayoutInputA = cutlass::layout::ColumnMajor; + constexpr int AlignmentInputA = + 128 / + cutlass::sizeof_bits< + ElementInputA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + using ElementInputB = cutlass::int4b_t; + using LayoutInputB = cutlass::layout::RowMajor; + constexpr int AlignmentInputB = + 128 / + cutlass::sizeof_bits< + ElementInputB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + using ElementScale = WEIGHT_SCALE_DTYPE; + using ElementZeroPoint = WEIGHT_SCALE_DTYPE; + using ElementComputeEpilogue = float; + using ElementAccumulator = float; + + using ElementOutput = cutlass::bfloat16_t; + using LayoutOutput = cutlass::layout::ColumnMajor; + constexpr int AlignmentOutput = + 128 / + cutlass::sizeof_bits< + ElementOutput>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Threadblock-level + // tile size + using ClusterShape = cute::Shape< + cute::Int, + cute::Int, + cute::Int>; // Shape of the + // threadblocks in a + // cluster + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using PongSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using MainLoopSchedule = + cute::conditional_t; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementAccumulator, + ElementOutput, + LayoutOutput, + AlignmentOutput, + ElementOutput, + LayoutOutput, + AlignmentOutput, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + cute::tuple, + LayoutInputB, + AlignmentInputB, + ElementInputA, + LayoutInputA, + AlignmentInputA, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideInputA = typename Gemm::GemmKernel::StrideA; + using StrideInputB = typename Gemm::GemmKernel::StrideB; + using StrideOutput = typename Gemm::GemmKernel::StrideC; + using StrideS = typename CollectiveMainloop::StrideScale; + + StrideInputA stride_a = cutlass::make_cute_packed_stride( + StrideInputA{}, cute::make_shape(M, K, B)); + StrideInputB stride_b = cutlass::make_cute_packed_stride( + StrideInputB{}, cute::make_shape(N, K, B)); + StrideOutput stride_output = cutlass::make_cute_packed_stride( + StrideOutput{}, cute::make_shape(N, M, B)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(N, num_groups, B)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {N, M, K, B}, + {reinterpret_cast(WQ.data_ptr()), + stride_b, + reinterpret_cast(X.data_ptr()), + stride_a, + reinterpret_cast(w_scale.data_ptr()), + stride_S, + group_size, + reinterpret_cast(w_zp.data_ptr())}, + {{1.0, 0.0}, + (ElementOutput*)Y.data_ptr(), + stride_output, + (ElementOutput*)Y.data_ptr(), + stride_output}}; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +template +at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + KernelMode kernel = get_batched_kernel_mode(X, WQ); + if (kernel == KernelMode::Small) { + return bf16i4bf16_rowwise_batched_impl< + 64, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else if (kernel == KernelMode::Large) { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + true, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } else { + return bf16i4bf16_rowwise_batched_impl< + 128, + 128, + 64, + 2, + 1, + 1, + false, + WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); + } +} + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + // Check datatypes. + TORCH_CHECK( + (w_scale.dtype() == at::kFloat && w_zp.dtype() == at::kFloat) || + (w_scale.dtype() == at::kHalf && w_zp.dtype() == at::kHalf) || + (w_scale.dtype() == at::kBFloat16 && w_zp.dtype() == at::kBFloat16), + "Weight scale and zero point tensors must be float32, bfloat16, or float16, and dtype of weight scale and zero point tensors must be the same ."); + + if (w_scale.dtype() == at::kFloat) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kHalf) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else if (w_scale.dtype() == at::kBFloat16) { + return dispatch_bf16i4bf16_rowwise_batched_kernel( + X, WQ, w_scale, w_zp); + } else { + throw std::runtime_error( + "Weight scale and zero point data type not supported in bf16i4bf16_rowwise_batched"); + } +} + +#else + +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor w_scale, + at::Tensor w_zp) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu index 313c81298..a34c694e0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu @@ -335,6 +335,19 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel( UseBias, InputDType, BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); + } else if (kernel == KernelMode::Medium) { + return f8f8bf16_rowwise_batched_impl< + 64, + 128, + 128, + 1, + 2, + 1, + true, + FastAccum, + UseBias, + InputDType, + BiasDType>(XQ, WQ, x_scale, w_scale, bias, output); } else if (kernel == KernelMode::Large) { return f8f8bf16_rowwise_batched_impl< 128, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h index 93b96fb04..9a267193a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h @@ -12,7 +12,7 @@ namespace fbgemm_gpu { -enum class KernelMode { Small, Large, Default }; +enum class KernelMode { Small, Medium, Large, Default }; inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto M = XQ.size(0); @@ -37,14 +37,14 @@ inline KernelMode get_batched_kernel_mode(at::Tensor XQ, at::Tensor WQ) { auto K = XQ.size(2); auto N = WQ.size(1); auto BM = B * M; - auto BN = B * N; - auto BK = B * K; - // Use a large kernel if at least two shapes are large.... - bool use_large_kernel = - ((BM >= 2048 && BK >= 2048) || (BM >= 2048 && BK >= 2048) || - (BK >= 2048 && BN >= 2048)); - if (BM <= 128 || BN <= 128) { + // Heuristic to determine kernel mode + bool use_medium_kernel = + ((BM <= 512 && ((N <= 8192 && K < 8192) || (N < 8192 && K <= 8192)))); + bool use_large_kernel = ((BM > 512 && (N >= 1024 || K >= 1024))); + if (BM <= 128 || N <= 128) { return KernelMode::Small; + } else if (use_medium_kernel) { + return KernelMode::Medium; } else if (use_large_kernel) { return KernelMode::Large; } else { diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 101a5cba1..1abf8fb40 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -96,6 +96,11 @@ at::Tensor bf16i4bf16_rowwise( at::Tensor WQ, at::Tensor w_scale, at::Tensor w_zp); +at::Tensor bf16i4bf16_rowwise_batched( + at::Tensor X, + at::Tensor WQ, + at::Tensor w_scale, + at::Tensor w_zp); at::Tensor per_tensor_quantize_i8(at::Tensor X, double scale); std::tuple per_tensor_dynamic_quantize_i8(at::Tensor X); @@ -152,6 +157,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); + m.def( + "bf16i4bf16_rowwise_batched(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); + m.def( "i8i8bf16_dynamic(Tensor XQ, Tensor WQ, Tensor scale, int split_k=1) -> Tensor"); m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); @@ -326,14 +335,28 @@ at::Tensor f8i4bf16_rowwise_meta( at::Tensor bf16i4bf16_rowwise_meta( at::Tensor X, // BF16 at::Tensor WQ, // INT4 - at::Tensor w_scale, - at::Tensor w_zp) { + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { int M = X.size(0); int N = WQ.size(0); auto Y = at::empty({M, N}, X.options().dtype(at::kBFloat16)); return Y; } +at::Tensor bf16i4bf16_rowwise_batched_meta( + at::Tensor X, // BF16 + at::Tensor WQ, // INT4 + at::Tensor /* w_scale */, + at::Tensor /* w_zp */ +) { + int B = X.size(0); + int M = X.size(1); + int N = WQ.size(1); + auto Y = at::empty({B, M, N}, X.options().dtype(at::kBFloat16)); + return Y; +} + std::vector quantize_fp8_per_row_meta( at::Tensor input, std::optional bs, @@ -370,6 +393,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); + m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index c21c1713a..38a09f360 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -673,6 +673,7 @@ def fp8_loopover_bmm( M=st.sampled_from([2048, 4096]), N=st.sampled_from([256, 512]), K=st.sampled_from([256, 512]), + use_loopover=st.sampled_from([True, False]), ) def test_int4_batched_gemm( self, @@ -680,6 +681,7 @@ def test_int4_batched_gemm( M: int, N: int, K: int, + use_loopover: bool, ) -> None: if not MARLIN_ENABLED: return @@ -689,28 +691,48 @@ def test_int4_batched_gemm( wq = [] w_scale = [] group_size = 128 - for i in range(B): - _, wq_, w_scale_ = marlin_quantize(w[i].cuda().t().contiguous(), group_size) - wq.append(wq_) - w_scale.append(w_scale_) - wq = torch.stack(wq) - w_scale = torch.stack(w_scale) - - def int4_loopover_bmm( - x: torch.Tensor, - wq: torch.Tensor, - w_scale: torch.Tensor, - ) -> torch.Tensor: - B = x.shape[0] - M = x.shape[1] - N = w_scale.shape[2] - y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + + if use_loopover: for i in range(B): - y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) - return y + _, wq_, w_scale_ = marlin_quantize( + w[i].cuda().t().contiguous(), group_size + ) + wq.append(wq_) + w_scale.append(w_scale_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale) + + def int4_loopover_bmm( + x: torch.Tensor, + wq: torch.Tensor, + w_scale: torch.Tensor, + ) -> torch.Tensor: + B = x.shape[0] + M = x.shape[1] + N = w_scale.shape[2] + y = torch.empty((B, M, N), dtype=torch.bfloat16, device=x[0].device) + for i in range(B): + y[i] = torch.ops.marlin.marlin_gemm(x[i], wq[i], w_scale[i]) + return y + + y_int4 = int4_loopover_bmm(x, wq, w_scale) + else: + w_zp = [] + for i in range(B): + wq_, w_scale_, w_zp_ = int4_row_quantize(w[i], group_size) + + wq_ = pack_int4(wq_).contiguous().to(device="cuda") + w_scale_ = w_scale_.contiguous().to(device="cuda") + w_zp_ = w_zp_.contiguous().to(device="cuda") + wq.append(wq_) + w_scale.append(w_scale_) + w_zp.append(w_zp_) + wq = torch.stack(wq) + w_scale = torch.stack(w_scale).view(-1, N) + w_zp = torch.stack(w_zp).view(-1, N) + y_int4 = torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, wq, w_scale, w_zp) y_ref = torch.bmm(x, w.transpose(1, 2)) - y_int4 = int4_loopover_bmm(x, wq, w_scale) torch.testing.assert_close(y_ref, y_int4, atol=8.0e-2, rtol=8.0e-2) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu index 9037b7c09..c899bbf9b 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_pack_segments_backward.cu @@ -62,18 +62,21 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( CUDA_DEVICE_GUARD(data); + const auto data_contig = data.expect_contiguous(); + Tensor unpacked_tensor; // The output tensor AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "unpack_segments_cuda", [&] { const auto* const lengths_data = lengths.data_ptr(); // Create output tensor of appropriate dimensions - auto shape = data.sizes().vec(); + auto shape = data_contig->sizes().vec(); shape.erase(shape.begin()); shape[0] = total_length; - unpacked_tensor = at::empty(shape, data.options()); + unpacked_tensor = at::empty(shape, data_contig->options()); - if (!(data.size(0) && data.size(1))) { // TODO: What does this mean? + if (!(data_contig->size(0) && + data_contig->size(1))) { // TODO: What does this mean? return; } @@ -82,10 +85,11 @@ DLL_PUBLIC Tensor pack_segments_backward_cuda( auto lps_data = lengths_prefix_sum.data_ptr(); FBGEMM_DISPATCH_ALL_TYPES( - data.scalar_type(), "unpack_segments_cuda-unpacking", [&] { + data_contig->scalar_type(), "unpack_segments_cuda-unpacking", [&] { const auto num_seq = lengths.size(0); - const auto cell_size = data.numel() / (data.size(0) * data.size(1)); - const auto* const data_ptr = data.data_ptr(); + const auto cell_size = data_contig->numel() / + (data_contig->size(0) * data_contig->size(1)); + const auto* const data_ptr = data_contig->data_ptr(); auto* const out_data = unpacked_tensor.data_ptr(); unpack_segments_cuda_kernel diff --git a/fbgemm_gpu/test/release/utils.py b/fbgemm_gpu/test/release/utils.py index 005cf38b5..4218ebc95 100644 --- a/fbgemm_gpu/test/release/utils.py +++ b/fbgemm_gpu/test/release/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import inspect import typing from typing import Iterable, List, Optional, Sequence, Union # noqa: F401 diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index 40cfacc06..fb2fcf85e 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -2,6 +2,16 @@ "_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit", "_version": 1, "data": { + "fb::pack_segments": { + "PackedSegmentsTest.test_aot_dispatch_dynamic__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + }, + "PackedSegmentsTest.test_faketensor__test_pack_segments_noncontig": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, diff --git a/fbgemm_gpu/test/sparse/pack_segments_test.py b/fbgemm_gpu/test/sparse/pack_segments_test.py index dd5319277..095ea4377 100644 --- a/fbgemm_gpu/test/sparse/pack_segments_test.py +++ b/fbgemm_gpu/test/sparse/pack_segments_test.py @@ -23,9 +23,9 @@ if open_source: # pyre-ignore[21] - from test_utils import gpu_available + from test_utils import gpu_available, gpu_unavailable else: - from fbgemm_gpu.test.test_utils import gpu_available + from fbgemm_gpu.test.test_utils import gpu_available, gpu_unavailable def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: @@ -47,6 +47,15 @@ def get_n_rand_num_summing_to_k(n: int, k: int) -> npt.NDArray: # pyre-fixme[2] # pyre-fixme[24] def torch_compiled(model: Callable, **kwargs) -> Callable: + """A helper function to apply torch.compile if python < 3.12. + + Args: + model: The model to be compiled. + kwargs: The arguments to be passed to torch.compile. + + Returns: + The model. + """ if sys.version_info < (3, 12, 0): return torch.compile(model, **kwargs) else: @@ -60,6 +69,17 @@ def _pack_segments_ref( tensor: torch.Tensor, max_length: Optional[int] = None, ) -> npt.NDArray: + """ + This function is a reference implementation of pack_segments. + + Args: + lengths (Tensor): The lengths of tensor. + tensor (Tensor): The tensor to be packed. + max_length (Optional[int]): The maximum length of the packed tensor. + + Returns: + The packed tensor. + """ lengths = lengths.numpy() sections = np.split(tensor, np.cumsum(lengths)) max_length = np.max(lengths, initial=0) if max_length is None else max_length @@ -106,6 +126,22 @@ def test_pack_segments( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops compared to the reference implementation. + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True) lengths = torch.tensor( @@ -209,6 +245,23 @@ def test_pack_segments_smaller_max_len( dtype: torch.dtype, torch_compile: bool, ) -> None: + """ + This function tests pack_segments ops with set max_length + Both CPU and GPU (if available) are tested. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + max_length - The maximum length of the packed tensor + dtype - The data type + torch_compile - Whether to use torch.compile + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor(input_raw, dtype=dtype) lengths = torch.tensor( @@ -264,6 +317,20 @@ def test_pack_segments_meta_backend( divisions: int, dtype: torch.dtype, ) -> None: + """ + This function tests pack_segments ops with meta backend. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + + Returns: + None + """ + input_raw = np.random.rand(batch_size, n, k) input_data = torch.tensor( input_raw, dtype=torch.float32, requires_grad=True @@ -281,6 +348,109 @@ def test_pack_segments_meta_backend( # verify forward assert packed_tensor.size() == torch.Tensor(packed_ref).size() + @unittest.skipIf(*gpu_unavailable) + @given( + n=st.integers(2, 10), + k=st.integers(2, 10), + batch_size=st.integers(1, 30), + divisions=st.integers(1, 10), + dtype=st.sampled_from( + [ + torch.float, + torch.half, + ] + ), + torch_compile=st.booleans(), + use_cpu=st.booleans(), + ) + @settings(deadline=None) + def test_pack_segments_noncontig( + self, + n: int, + k: int, + batch_size: int, + divisions: int, + dtype: torch.dtype, + torch_compile: bool, + use_cpu: bool, + ) -> None: + """ + This function tests pack_segments ops when input gradients to backward are non-contiguous. + + Args: + n - The number of rows in the input tensor + k - The number of columns in the input tensor + batch_size - The number of batches of the input tensor + divisions - The number of segments to be packed + dtype - The data type + torch_compile - Whether to use torch.compile + use_cpu - Whether to use CPU or GPU + + Returns: + None + """ + + input_raw = np.random.rand(batch_size, n, k) + # create input + input_data_ref = torch.tensor(input_raw, dtype=dtype, requires_grad=True) + input_data = torch.tensor(input_raw, dtype=dtype, requires_grad=True).cuda() + # retain grad to compare gradients of the inputs later + input_data.retain_grad() + input_data_ref.retain_grad() + + # set lengths + lengths = torch.tensor( + get_n_rand_num_summing_to_k(divisions, batch_size), + dtype=torch.int, + ) + max_length = lengths.max().item() + + packed_ref = torch.ops.fbgemm.pack_segments( + t_in=input_data_ref, lengths=lengths, max_length=max_length + ) + packed_ref.retain_grad() + + # pack segments using fbgemm and fb + packed_tensor = torch.ops.fbgemm.pack_segments( + t_in=input_data, lengths=lengths.cuda(), max_length=max_length + ) + packed_tensor.retain_grad() + + # verify forward + self.assertTrue(torch.equal(packed_tensor.cpu(), packed_ref)) + + # create non-contiguous grad + shape = tuple(x * 2 for x in packed_ref.shape) + grads = torch.tensor( + np.random.uniform(low=0.01, high=0.5, size=shape).astype(np.float32) + ).to(dtype) + grad_noncontig_cpu = grads.as_strided(packed_ref.shape, grads.stride()) + grad_noncontig_cuda = grads.cuda().as_strided(packed_ref.shape, grads.stride()) + + self.assertTrue( + not ( + grad_noncontig_cpu.is_contiguous() + and grad_noncontig_cuda.is_contiguous() + ), + msg="Expected grads to be non-contiguous but they are contiguous", + ) + + # verify backward + packed_ref.backward(grad_noncontig_cpu) + packed_tensor.backward(grad_noncontig_cuda) + self.assertTrue( + torch.equal(packed_tensor.cpu(), packed_ref), + msg="Expected packed tensors to be equal but they are not", + ) + + # verify backward input gradients + self.assertTrue( + # pyre-fixme[16]: Optional type has no attribute `cpu`. + # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Optional[Tensor]`. + torch.equal(input_data.grad.cpu(), input_data_ref.grad.cpu()), + msg="Expected input gradients to be equal but they are not", + ) + extend_test_class(PackedSegmentsTest)