diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index b0ad85c25c572..587fc3901eb7c 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -7,23 +7,32 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights) + MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize, + marlin_quantize, marlin_weights) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] -K_CHUNKS = [128, 256] -N_CHUNKS = [64, 128, 256] +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 128, 256] + +MARLIN_24_K_CHUNKS = [128] +MARLIN_24_N_CHUNKS = [256] MNK_FACTORS = [ (1, 1, 1), (1, 4, 8), (1, 7, 5), - (1, 7 * 4, 5 * 1), (13, 17, 67), (26, 37, 13), (67, 13, 11), @@ -31,14 +40,13 @@ def rand_data(shape): - data = torch.rand(shape).to(torch.half).cuda() - return data + return torch.randn(shape, dtype=torch.half, device="cuda") @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @@ -82,7 +90,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Pack to Marlin format - marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits) + marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) # Run Marlin repack GPU kernel marlin_q_w_2 = ops.gptq_marlin_repack( @@ -99,8 +108,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order, @pytest.mark.skipif(not is_marlin_supported(), reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", K_CHUNKS) -@pytest.mark.parametrize("n_chunk", N_CHUNKS) +@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS) @pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @@ -136,7 +145,8 @@ def test_marlin_gemm( w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( b_weight, num_bits, group_size, act_order) - workspace = MarlinWorkspace(size_n) + workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) output = ops.gptq_marlin_gemm( a_input, @@ -155,4 +165,55 @@ def test_marlin_gemm( torch.cuda.synchronize() - assert torch.allclose(output, output_ref, rtol=1e-2) + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 + + +@pytest.mark.skipif(not is_marlin_supported(), + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) +@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS) +@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) +@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors): + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + print(f"MNK = {size_m} {size_n} {size_k}") + print(f"groupsize = {group_size}") + + a_input = rand_data((size_m, size_k)) + b_weight = rand_data((size_k, size_n)) + + (w_24_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b_weight, num_bits, group_size) + + workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) + + output_ref = torch.matmul(a_input, w_24_ref) + + output = ops.gptq_marlin_24_gemm( + a_input, + marlin_24_q_w_comp, + marlin_24_meta, + marlin_24_s, + workspace_24.scratch, + num_bits, + a_input.shape[0], + b_weight.shape[1], + a_input.shape[1], + ) + + torch.cuda.synchronize() + + max_diff = compute_max_diff(output, output_ref) + print("max_diff = {}".format(max_diff)) + + assert max_diff < 0.04 diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index 1bd6127104654..f5345c0443029 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -12,6 +12,15 @@ logger = init_logger(__name__) +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] +GPTQ_MARLIN_24_SUPPORTED_SYM = [True] + class GPTQMarlin24Config(QuantizationConfig): """Config class for Marlin24. @@ -25,15 +34,17 @@ def __init__( self.weight_bits = weight_bits self.group_size = group_size - if self.weight_bits != 4 and self.weight_bits != 8: - raise ValueError("weight_bits must be 4 or 8. Got = {}".format( - self.weight_bits)) - - if self.group_size != 128 and self.group_size != -1: + # Verify + if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin_24 does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) " - "is supported for Marlin24, but got group_size of " - f"{self.group_size}") + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // self.weight_bits diff --git a/vllm/model_executor/layers/quantization/utils/format_24.py b/vllm/model_executor/layers/quantization/utils/format_24.py new file mode 100644 index 0000000000000..01c8cf789204b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/format_24.py @@ -0,0 +1,308 @@ +# +# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es). +# + +import torch + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask diff --git a/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py new file mode 100644 index 0000000000000..12e77cb710687 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_24_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501 +# +# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms_24(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return perm, scale_perm, scale_perm_single + + +marlin_24_perm = {} +marlin_24_scale_perm = {} +marlin_24_scale_perm_single = {} +for num_bits in [4, 8]: + perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) + marlin_24_perm[num_bits] = perm_24 + marlin_24_scale_perm[num_bits] = scale_perm_24 + marlin_24_scale_perm_single[num_bits] = scale_perm_single_24 diff --git a/vllm/model_executor/layers/quantization/utils/marlin_perms.py b/vllm/model_executor/layers/quantization/utils/marlin_perms.py new file mode 100644 index 0000000000000..76bd2ff7c724e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_perms.py @@ -0,0 +1,58 @@ +"""This file is used for /tests and /benchmarks""" +import numpy +import torch + + +# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 +# +# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 +# (without the need to use ldmatrix instructions) # noqa: E501 +def get_perms(num_bits): + perm_list = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +marlin_perm = {} +marlin_scale_perm = {} +marlin_scale_perm_single = {} +for num_bits in [4, 8]: + perm, scale_perm, scale_perm_single = get_perms(num_bits) + marlin_perm[num_bits] = perm + marlin_scale_perm[num_bits] = scale_perm + marlin_scale_perm_single[num_bits] = scale_perm_single diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 33b3169983475..0d027d0620ab3 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -1,79 +1,28 @@ """This file is used for /tests and /benchmarks""" +import random + import numpy import torch -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE) +from vllm.model_executor.layers.quantization.utils.format_24 import ( + mask_creator, sparse_semi_structured_from_dense_cutlass) +from vllm.model_executor.layers.quantization.utils.marlin_24_perms import ( + marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) +from vllm.model_executor.layers.quantization.utils.marlin_perms import ( + marlin_perm, marlin_scale_perm, marlin_scale_perm_single) from vllm.model_executor.layers.quantization.utils.quant_utils import ( get_pack_factor, quantize_weights, sort_weights) __cuda_arch = torch.cuda.get_device_capability() +MARLIN_TILE = 16 + def is_marlin_supported(): return __cuda_arch[0] >= 8 -# Precompute permutations for Marlin weight and scale shuffling # noqa: E501 -# -# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501 -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 -# (without the need to use ldmatrix instructions) # noqa: E501 -def _get_perms(num_bits): - perm_list = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm = {} -_scale_perm = {} -_scale_perm_single = {} -for num_bits in [4, 8]: - perm, scale_perm, scale_perm_single = _get_perms(num_bits) - _perm[num_bits] = perm - _scale_perm[num_bits] = scale_perm - _scale_perm_single[num_bits] = scale_perm_single - - -def marlin_permute_weights(q_w, - size_k, - size_n, - num_bits, - tile=GPTQ_MARLIN_TILE): +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" @@ -83,15 +32,14 @@ def marlin_permute_weights(q_w, q_w = q_w.permute((0, 2, 1, 3)) q_w = q_w.reshape((size_k // tile, size_n * tile)) - q_w = q_w.reshape( - (-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape) + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) return q_w -def marlin_weights(q_w, size_k, size_n, num_bits): +def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits) + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) # Pack pack_factor = get_pack_factor(num_bits) @@ -101,7 +49,6 @@ def marlin_weights(q_w, size_k, size_n, num_bits): q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=numpy.uint32) - for i in range(pack_factor): q_packed |= q_w[:, i::pack_factor] << num_bits * i @@ -110,15 +57,12 @@ def marlin_weights(q_w, size_k, size_n, num_bits): return q_packed -def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): +def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, + scale_perm_single): if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm[num_bits])))[:, - _scale_perm[num_bits]] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape( - (-1, - len(_scale_perm_single[num_bits])))[:, - _scale_perm_single[num_bits]] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -148,8 +92,11 @@ def marlin_quantize( q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) # Reformat to marlin - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, + marlin_perm[num_bits]) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_scale_perm[num_bits], + marlin_scale_perm_single[num_bits]) # Create result res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] @@ -159,15 +106,118 @@ def marlin_quantize( return res_list +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, num_bits): + assert q_24.shape == (size_k, size_n) + + # Remove zp to normalize over 0 + max_q_val = (1 << num_bits) - 1 + zp = (max_q_val + 1) // 2 + q_24_no_zp = q_24 - zp + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore zp + q_24_comp = q_24_no_zp_comp + zp + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def marlin_24_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24, + num_bits, + group_size, + act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + num_bits) + size_k_comp = size_k // 2 + + # Reformat to marlin + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + num_bits, marlin_24_perm[num_bits]) + marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size, + marlin_24_scale_perm[num_bits], + marlin_24_scale_perm_single[num_bits]) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + class MarlinWorkspace: - def __init__(self, out_features): - assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), ( - "out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}" - .format(out_features, GPTQ_MARLIN_MIN_THREAD_N)) + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) - max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) * - GPTQ_MARLIN_MAX_PARALLEL) + max_workspace_size = ((out_features // min_thread_n) * max_parallel) self.scratch = torch.zeros(max_workspace_size, dtype=torch.int,