Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W4A8 based on CUTLASS #880

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexsamardzic
Copy link
Contributor

@alexsamardzic alexsamardzic commented Sep 12, 2024

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/880

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @alexsamardzic!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 12, 2024

The kernel implements W4A8 GEMM, with float16 scaling factors. The zero point support is to be eventually added later, for now several hacks (to be removed) are put in the code, that will force int8_dynamic_activation_int4_weight to do symmetric quantization for both activation and weight.

There are several points to discuss:

CUTLASS would have to be made a dependency. IMO, the best approach to satisfy the dependency would be to install nvidia-cutlass package, the only problem is that it doesn't always contain latest changes in CUTLASS. An alternative would be to have CUTLASS repo as submodule of this repo, like in PyTorch.

The group quantization may be a problem. Let's say X is input matrix of size MxK, with Xs vector of input scales of size M, and Wis weight matrix of size NxK. If group size parameter is equal to K, then weight scales Ws will be a vector of size N, and an element of output matrix Y of a linear operator would be calculated as follows (let's ignore bias for now, as it's not relevant):

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j}=xs_{i}\cdot ws_{j}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}$$

The sum in the last expression could be efficiently calculated as mixed integer data types GEMM on tensor cores, and the result could be then updated by mulitplying the scale factors in. However, if group size parameter is less than K, say 32 for example (32 < K, K % 32 == 0), then weight scales will be matrix of size Nx(K/32). In this case, an element of output matrix Y of a linear operator would be calculated as follows:

$$y_{i,j}=\sum_{k}xs_{i}\cdot x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}=xs_{i}\cdot \sum_{k}x_{i,k}\cdot w_{j,k}\cdot ws_{j,k/32}$$

Now, the only approach possible in CUTLASS to do this calculation in integer mixed data types on tensor cores would be to split it into K/32 GEMMs, and try to run them at the same time as so-called grouped GEMM. The code would be much more complicated, and also the update with the scaling factors will be still different for each of these individual GEMMs, so I don't think this approach would be performant. So my question here is: Does it make sense to create a quantization different than int8_dynamic_activation_int4_weight, that would match this kernel better, in particular that would not use group quantization for weight at all? (BTW, creating a new quantization, or at least adding a variant of int8_dynamic_activation_int4_weight is needed anyway, as this one is not packing two 4-bit weight values into a byte, that is required by CUTLASS for int8/int4 GEMM.)

Another related issue is zero point handling. Let's say Xz is vector of size M of input zero point values, and Wz is vector of size N of weight zero point values. Then the linear operator calculation, in PyTorch notation would be as follows: Y=((X-Xz)*Xs)@((W-Wz)*Ws).T (again, let's ignore bias), that translates into following calculation for an individual element of output matrix Y:

$$ \begin{array}{lcl} y_{i,j} & = & \sum_{k}xs_{i}\cdot (x_{i,k}-xz_{i})\cdot (w_{j,k}-wz_{j})\cdot ws_{j} \\ & = & xs_{i}\cdot ws_{j}\cdot (\sum_{k}x_{i,k}\cdot w_{j,k}-wz_{j}\sum_{k}a_{i,k}-xz_{i}\sum_{k}w_{k,j}+K\cdot xz_{i}\cdot wz_{j}) \\ \end{array} $$

Only the first expression within parentheses could be calculated on tensor cores as mixed integer data types GEMM, while the sums in the next two expression are best to be pre-calculated in case of weight values, or calculated on the fly during the input quantization. So it seems to me these are also calling for specialized type of quantization. (Note also that if group quantization used, above mentioned complications for Ws are extended to Wz too.)

All comments/suggestions welcome; in particular I'm pretty much new to quantization specifics so please let me know if I'm missing something obvious.

@msaroufim
Copy link
Member

I'm on PTO today and tomorrow so will review asap, apologies for the delay

@cpuhrsch
Copy link
Contributor

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those? I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

@alexsamardzic
Copy link
Contributor Author

I'm on PTO today and tomorrow so will review asap, apologies for the delay

Thanks Mark - it's really just a draft, so not yet ready for review, but it would be useful to discuss points that I mentioned in my comment above.

@alexsamardzic
Copy link
Contributor Author

@alexsamardzic - Can we use the CUTLASS that ships with PyTorch? As in, should we change PyTorch to ship the headers used to build its CUTLASS kernels / does the PyTorch nightly already ship those?

This CUTLASS version is also lagging behind. My CUTLASS PR with mixed int4/int8 GEMM is merged after the latest (3.5.1) CUTLASS release, hopefully there will be a new release soon. But in any case, this is a kind of problem that we'll have if we use more CUTLASS from torchao - for lots of time, the torchao build will have to be pointed to a bleeding edge CUTLASS checkout.

I see the test is using group size 128. I think it's ok if we don't necessarily support all group sizes or shapes right away.

It uses group size 128 in order to force weight scale to be a vector, and not a matrix. I tried to explain the issue in my comment above, if group quantization is obligatory here, then it's going to be rather complicated to make this work.

We have some int4 support via the pattern matched in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/post_grad.py#L345-L403 which dispatches to https://github.com/pytorch/pytorch/blob/dab7d646d55a2b6696d51dee4816a6743ec1ae5a/torch/_inductor/kernel/unpack_mixed_mm.py#L76 - would an extension for int4x2 X int8 of this be interesting here?

I'm just looking into the quantization code, to see is it possible to do it there - it's not hard to make this change, but CUTLASS in general doesn't support doing things before GEMM (while fusing operations after GEMM calculated is reasonably well supported), so it would be the best if the quantization code actually put the weight values in int4x2 format.

@alexsamardzic
Copy link
Contributor Author

Updated so that there is a new int8_dynamic_activation_int4_weight_cutlass quantization method available that, for now, would quantize both input and weight symmetrically, and won't use group quantization for weight (so weight scales are always a vector). It should be now possible to try kernel on arbitrary models, if quantized by above quantization method.

@@ -506,6 +508,41 @@ def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)


def apply_int8_dynamic_activation_int4_weight_quant_cutlass(weight):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be represented as a different Layout for int8 dynamic activation/int4 weight quantization? docs for Packing/Layout can be found in #391 "Layout and Packing" and simplified example in https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointer! Yes, this will need refinement on this and several other places, as I learn about doing things the "torchao way"; but my main goal initially is to connect the dots, so that some benchmarks could be run, and that we could verify that CUTLASS provides some value here.

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 7 times, most recently from f6383ca to 02f8805 Compare September 17, 2024 08:26
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 17, 2024

Made some minor updates, including added support for bfloat16.

Micro-benchmarking script
import copy

import torch

from torchao.utils import (
    TORCH_VERSION_AT_LEAST_2_5,
    unwrap_tensor_subclass,
)
from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int4_weight_cutlass,
)

# FIXME: change this!
_CUTLASS_DIR = ".../cutlass"


class ToyModel(torch.nn.Module):
    def __init__(self, nin, nout1, nout2):
        super().__init__()
        self.linear1 = torch.nn.Linear(nin, nout1)
        self.linear2 = torch.nn.Linear(nout1, nout2, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


methodq = int8_dynamic_activation_int4_weight_cutlass()
compile = False
dtype = torch.float16  # dtype = torch.bfloat16
device = "cuda"
bs, nin, nout1, nout2 = 256, 1024, 2048, 128

inputs = (torch.randn((1, bs, nin), dtype=dtype, device=device),)
model = ToyModel(nin, nout1, nout2).eval().to(dtype).to(device)
modelq = copy.deepcopy(model)

if compile:
    model = torch.compile(model, mode="max-autotune")

quantize_(modelq, methodq)
if not TORCH_VERSION_AT_LEAST_2_5:
    unwrap_tensor_subclass(modelq)

if compile:
    modelq = torch.compile(
        modelq,
        options={
            "max_autotune": True,
            "autotune_in_subproc": False,
            "max_autotune_gemm_backends": "Triton,CUTLASS",
            "cuda.cutlass_dir": _CUTLASS_DIR,
            "use_mixed_mm": True,
        },
    )


if __name__ == "__main__":
    from torchao.utils import benchmark_model

    nruns = 100
    torch._dynamo.reset()
    time = benchmark_model(model, nruns, inputs)
    timeq = benchmark_model(modelq, nruns, inputs)
    print(f"original model mean time  : {time:8.3f}")
    print(f"quantized model mean time : {timeq:8.3f}")
    print(f"speedup by quantization   : {time / timeq:8.3f}")

For particular shapes given in the script above, on A100 the micro-benchmark shows around 2x speedup over the case when float16 MM used, and around 1.8x speedup over the case when bfloat16 MM used. (Note that this is for eager mode execution, as compilation to corresponding CUTLASS kernel is not yet supported by PyTorch.)

Patch to run torchao/_models/llama/generate.py
diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py
index 5fb905d..e5b891b 100644
--- a/torchao/_models/llama/generate.py
+++ b/torchao/_models/llama/generate.py
@@ -206,6 +206,7 @@ def main(
             quantize_,
             int8_weight_only,
             int8_dynamic_activation_int8_weight,
+            int8_dynamic_activation_int4_weight_cutlass,
             int4_weight_only,
             fpx_weight_only,
             uintx_weight_only,
@@ -216,6 +217,8 @@ def main(
             quantize_(model, int8_weight_only())
         if "int8dq" in quantization:
             quantize_(model, int8_dynamic_activation_int8_weight())
+        if "w4a8-cutlass" in quantization:
+            quantize_(model, int8_dynamic_activation_int4_weight_cutlass())
         if "int4wo" in quantization:
             if "hqq" in quantization:
                 use_hqq=True
@@ -414,7 +417,7 @@ if __name__ == '__main__':
     parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
     parser.add_argument('-q', '--quantization', type=str, 
         help=(
-            'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+            'Which quantization techniques to apply: int8dq, w4a8-cutlass, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
             +'autoquant-int4, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
         )
     )
diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 1df3549..1252bb8 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1158,6 +1158,7 @@ implements = AffineQuantizedTensor.implements
 # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
 
 def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
+    return False
     return (
         isinstance(input_tensor, AffineQuantizedTensor) and
         _aqt_is_int8_reduced_range(input_tensor) and
diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py
index 3005cb1..451d0e6 100644
--- a/torchao/kernel/intmm.py
+++ b/torchao/kernel/intmm.py
@@ -54,6 +54,8 @@ if TORCH_VERSION_AT_LEAST_2_2:
             and k_is_nonzero_multiple_of_8
         )
 
+        bad_dimensions_for_cublas = False
+
         if device_cpu or bad_dimensions_for_cublas:
             # fallback path
             return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(

With the patch above, I was able to run Llama generator.py script. The command to run is as follows:

python generate.py -q w4a8-cutlass

and the output is as follows (again, this is run on A100):

==========
Average tokens/sec: 10.21
Average Bandwidth: 33.78 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

while the reference output, for the case when no arguments supplied to generate.py, is as follows:

==========
Average tokens/sec: 32.87
Average Bandwidth: 434.31 GB/s
Peak Memory Usage: 13.62 GB
Model Size: 13.21 GB

So the tokens/sec is more than 3x slower, but this is not even that bad, considering that batch size is 1 here, and that the CUTLASS code has it hard-coded for a block of threads to handle input tile size that is 128 for the same dimension, so most of the work is wasted.

So there is a room for improvement regarding the speed. The text generated is garbage, however. Even for the micro-benchmark above, output values visibly deviate from the values produced when native precision used (but at least they resemble each other).

@alexsamardzic alexsamardzic force-pushed the w4a8-cutlass branch 2 times, most recently from 575e074 to 956fc80 Compare September 18, 2024 13:14
@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Sep 18, 2024

Made an update - turns out that actually CUTLASS needs a fix (posted below for now), and then generate.py script for Llama model would generate meaningful content.

CUTLASS fix
diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
index 1692cc30..5a1b164c 100644
--- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
+++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
@@ -263,6 +263,44 @@ struct DefaultIteratorsTensorOp<
   static int const kFragmentsPerIteration = 2;
 };
 
+/// Partial specialization for bfloat16 <= int32_t x 8 epilogues avoids shared memory bank conflicts.
+template <
+  typename ThreadblockShape,
+  typename WarpShape,
+  typename InstructionShape,
+  typename ThreadMap
+>
+struct DefaultIteratorsTensorOp<
+  bfloat16_t, 
+  int32_t, 
+  8, 
+  ThreadblockShape, 
+  WarpShape, 
+  InstructionShape, 
+  ThreadMap> {
+  
+  using WarpTileIterator = cutlass::epilogue::warp::TileIteratorTensorOpMixed<
+    WarpShape,
+    InstructionShape,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorMixed<
+    ThreadMap,
+    int32_t,
+    32,
+    16,
+    8,
+    8
+  >;
+
+  static int const kFragmentsPerIteration = 2;
+};
+
 /// Partial specialization for int8/int4b_t <= int32 x 16/8 epilogues avoids shared memory bank conflicts.
 /// Threadblock::kN = 256 still has bank conflicts.
 template <

On the other side, I tried with adapting tile sizes processed by block/warp of threads of corresponding CUTLASS kernel, in order to adapt to the fact that batch size is 1 here. Here is an example of such change:

+++ b/torchao/csrc/cuda/s8s4_linear_cutlass/s8s4_linear_cutlass.cu
@@ -418,8 +418,8 @@ s8s4_linear_cutlass(const at::Tensor& input, const at::Tensor& input_scale,
   using ElementA = int8_t;
   using ElementB = cutlass::int4b_t;
   using ElementAccumulator = int32_t;
-  using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
-  using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
+  using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 128>;
+  using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>;
   using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
   AT_DISPATCH_SWITCH(
     input_scale.scalar_type(),

However, tokens/sec is not much improved this way. Thus, the performance of this kernel for Llama model will require more work.

Edit: CUTLASS fix posted upstream here.

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will make a second pass for the kernel code

@@ -65,6 +65,12 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
import cutlass_library
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting: not too familiar with cutlass packaging but what is cutlass_library exactly? only reference I found is this https://github.com/NVIDIA/cutlass/tree/main/python/cutlass_library

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a recent addition to CUTLASS: a Python library that is able to generate C++ code for CUTLASS GEMM templates instantiation (which is nice to have, as these templates have dozen or more arguments, and it's oftentimes hard to get them right). It's used in CUTLASS codegen for TorchInductor, like here. However, recently CUTLASS itself also added a functionality to generate and compile C++ code for GEMM kernels, from a high-level specification in Python - this is part of cutlass Python package, see here. Both cutlass and cutlass_library are available through nvidia-cutlass pip package. It's important to note that this package also contains all of the CUTLASS C++ header files, in order to make it possible to compile the C++ generated kernels.

cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")
# FIXME: remove this once CUTLASS package updated to include int4/int8 MM
cutlass_include_dir = "/data/quansight/scratch/cutlass/include"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: what is this exactly? Do you need any help packaging CUTLASS?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed this a bit in my first comment on this PR. In order ao to compile after this PR eventually merged, CUTLASS C++ header files are to be made available. There are at least two ways to do it:

  1. To make CUTLASS repo a submodule of ao repo, just like PyTorch did it.
  2. To make above mentioned nvidia-cutlass package a dependency of ao.

I'm leaning towards the later, and this is what above code, before "FIXME" is expecting. However, in both of above cases, we'll certainly face an issue of having to depend on stuff that is not yet merged into CUTLASS, but we need it. For example, at this very moment:

  1. My CUTLASS PR with int4/int8 GEMM support for CUTLASS is merged, but CUTLASS team has not made a release in the meantime, so this functionality is only available in CUTLASS main branch, and also above mentioned nvidia-cutlass package doesn't contain it yet.
  2. As mentioned in one of my comments above, while working in this PR, I found an omission in CUTLASS. I created a CUTLASS PR with a fix, but this one is not yet merged, so neither CUTLASS main branch nor nvidia-cutlass package contain the fix at the moment, it's only available in my branch. So the only way to proceed with the development of my PR was to create a local copy of this branch - I created it in /data/quansight/scratch/cutlass directory on my machine; in order to try this PR, the local copy of this branch is to be created, and this last line in the snippet above is to be changed to the local directory.

From my experience with this stuff from PyTorch development based on CUTLASS, this is going to be permanent issue - if we decide to use CUTLASS in ao, the for the most of the time we'll need bleeding edge features. So this is to be discussed further, IMO the best approach would be to build our own nvidia-cutlass package, from whatever CUTLASS branch we find the most appropriate.

@@ -85,6 +85,7 @@
"_get_subclass_inserter",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int4_weight_cutlass",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have some baseline numbers vs int8_dynamic_activation_int4_weight

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now when I have the dots connected, in the sense that I can run a micro-benchmark, and also Lllama model, using this kernel, I'm working on a more detailed profiling, part of this is also comparing the performance of this kernel with int8_dynamic_activation_int4_weight kernel. I'll report all my findings here when I'm done with the profiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a quick update here: Using the micro-benchmarking script above, it seems this PR is just 3-5% faster than int8_dynamic_activation_int4_weight. However, on the Llama generator, it seems about 2x faster, when tokens/sec numbers compared. (Remember that all the caveats from my first comment above still apply, so let's not jump into any conclusions for now.)

@@ -0,0 +1,51 @@
# FIXME: move this test to the appropriate test file!!!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah maybe make yourself a cutlass folder to park all your work

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Again, as mentioned in one of my comments above: At the moment, most of the "FIXME"-s in the PR are as I'm aware that I took shortcuts to make things work. If/when we're happy with the main stuff, I'll revisit all of these, and redo them in the proper "ao-way".

output_ref = model(input)

modelq = copy.deepcopy(model)
quantize_(modelq, int8_dynamic_activation_int4_weight_cutlass())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe another reference would be the non cutlass variant

# then corresponding changes made in
# _linear_int8_act_int4_weight_cutlass_check and for the check in
# the CUTLASS kernel!!!
weight.original_weight_tensor.layout_tensor.int_data = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment like

# Combine pairs of 4-bit values into single bytes
weight.original_weight_tensor.layout_tensor.int_data = (
    # Take odd-indexed columns, keep lower 4 bits, shift left by 4 bits
    (weight.original_weight_tensor.layout_tensor.int_data[:, 1::2] & 0xF) << 4
) | (
    # Take even-indexed columns, keep lower 4 bits
    weight.original_weight_tensor.layout_tensor.int_data[:, 0::2] & 0xF
)

"""
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant_cutlass)


def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated comment, what is this use_hqq? @jerryzh168 do you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this means use hqq algorithm to choose qparams and quantize the weight, since it is reusing the tinygemm kernel, we just added this as a separate option here

const int n = tensor_b.size(0);
const int k = tensor_a.size(1);

constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind adding a comment for why 128

Also how do you think about padding vs erroring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 128 bits here is because of how tensor cores work (so it's not CUTLASS-specific), at least for SM 8.x. It's related to the layout of tiles of matrix operands that single warp of thread is multiplying cooperatively. The best explanation that I found so far is in GTC 2020 talk, by CUTLASS team, around slide 15.

We can consider padding (maybe at the later stage?), I believe it would the best to incorporate padding together with the quantization.

using SmArch = cutlass::arch::Sm80;
using ThreadblockSwizzle =
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK;
constexpr auto NumStages = 4;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cutlass n00b but how do you pick these hyperparams?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These, and others, are the CUTLASS GEMM C++ template arguments. As mentioned above, there is dozen of these to set, but on the other side only small number of combinations of these arguments actually works. The above mentioned cutlass_library package enumerates some of these working combinations. The CUTLASS itself doesn't include any sort of heuristic for selection of these parameters, for example based on GEMM operand shapes. So I had to hard-code some values, at least for now. The values selected here are based on my previous experimentation with different combinations, and different operand shapes - in the sense that these values should provide acceptable performance for number of cases. But certainly there are cases where these values are not good fit, Lllama inference, having batch size 1, is one such example. So we may want to consider adding some heuristic here, but on the longer term we'd probably prefer to do support some auto-tuning, just like what is possible with Triton kernels.

@alexsamardzic
Copy link
Contributor Author

(Pushed an update, where the branch is just rebased on the latest main.)

I did lots of profiling in the meantime, focusing primarily on running Llama generator (torchao/_models/llama/generate.py), using tokens/sec as performance measure, and comparing between this PR and W8A8DQ case (i.e. when model quantized using int8_dynamic_activation_int8_weight). All of the results presented below were for A100 runs, the W8A8DQ run was as follows:

python generate.py -q int8dq

and the run for this PR was as follows (with the patch mentioned above applied beforehand):

python generate.py -q w4a8-cutlass

TLDR (note that each of these items could be verified by profiling W8A8DQ alone, without using this PR at all):

  1. The CUTLASS MM kernel in case of this PR, and also the Triton kernel MM for the W8A8DQ are not, at least at this moment, the most critical for performance. Instead, the other parts of the code, that are run each time along with the linear operator, are taking more execution time - see the remaining two items in the list.
  2. The dispatch checks registered here are re-run over and over again. These take considerable time, and also they make the performance depending on the position of registering the kernel, and corresponding check, in this list: if corresponding item moved between top and bottom of the list, the tokens/sec differ up to 10%. (@jerryzh168)
  3. The dynamic quantization takes considerable time too, more than the MM kernel itself. This could be improved, by working on fusing PyTorch operators used to perform quantization, or by implementing dedicated kernel(s) for dynamic quantization; also for Llama generator in particular by adjusting configs of these kernels to the fact that the number of inputs is 1. (Still, IMO it's questionable is there any performance benefit in using dynamic quantization vs. weight quantization only.)

As an example for item 1 above, here are the performance results, as printed by generate.py script, in case when item registering given kernel and check moved to the first place in the list:

python generate -q int8dq
# ... lots of output here
==========
Average tokens/sec: 4.81
Average Bandwidth: 31.83 GB/s
Peak Memory Usage: 14.86 GB
Model Size: 6.62 GB

python generate -q w4a8-cutlass
# ... lots of output here
==========
Average tokens/sec: 10.31
Average Bandwidth: 34.11 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

and when moved to the last place in the list:

python generate -q int8dq
# ... lots of output here
==========
Average tokens/sec: 4.35
Average Bandwidth: 28.82 GB/s
Peak Memory Usage: 14.86 GB
Model Size: 6.62 GB

==========
# ... lots of output here
Average tokens/sec: 9.92
Average Bandwidth: 32.82 GB/s
Peak Memory Usage: 14.22 GB
Model Size: 3.31 GB

The generator runs are profiled using pyinstrument, and verified using cProfile and nsys profilers. With the profiling run launched as follows:

python -m pyinstrument generate.py -q w4a8-cutlass

here is the relevant part of the pyinstrument output:

pyinstrument

So, for the attention segment of the model, one could see that everything related to running the linear operator takes about 34s in total. Out of this time, 24s are spend in the dynamic quantization, while about 9.4s only are spent on the linear operator itself, and then out of these 9.4s, only 2.4s are spent on the CUTLASS MM kernel execution, while the rest of time get spent on checking to which kernel to dispatch (note that for this run, the check for applicability of the CUTLASS kernel is added last to the list) - these checks are not visible in this snippet, as pyinstrument by default suppresses calls that take shorter time, but attached below is full pyinstrument output to verify it. The distribution of time spent is alike for the feed-forward part of the network - this could be also seen from the full output below.

Here is the pyinstrument --show-all ... output for the run above: pyinstrument.txt.

As mentioned above, profiling results are verified using cProfile and nsys. For example, for nsys run as follows:

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu -python-sampling=true $(which python) generate.py -q w4a8-cutlass

here is a screenshot of the timeline as shown by nsys-ui:

nsys-ui-1

Here, one could see that loading of model takes about 30s, then there is a short sequence of copying model to GPU and doing weights quantization, and then the rest of the timeline is the inference. The CUTLASS MM kernel, designated as Kernel2 here, takes less 30% of time of all of the CUDA kernels executed. If timeline zoomed into a segment of time during the inference, one could see that CUDA kernels are not actually executed tightly (because the checks and dynamic quantization are actually a sequence of calls to PyTorch kernels that are not fused):

nsys-ui-2

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Oct 9, 2024

@alexsamardzic - Was the model torch.compile'd with mode 'max-autotune'? Also you can use torch.profiler to generate kernel traces potentially a bit more quickly than with nsys (at least for rapid iteration). You can then open these with https://ui.perfetto.dev/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants