Skip to content

Commit

Permalink
Add Float8 support for AQT tensor parallel (#1003)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Oct 4, 2024
1 parent d2982d5 commit 9e2a253
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
15 changes: 10 additions & 5 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import torch
from torchao.testing.utils import copy_tests, TorchAOTensorParallelTestCase
from torch.testing._internal.common_utils import run_tests
from torchao.quantization import int8_weight_only
from torchao.quantization import int8_weight_only, float8_weight_only

class TestAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
pass
class TestInt8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(int8_weight_only)
copy_tests(TorchAOTensorParallelTestCase, TestInt8woAffineQuantizedTensorParallel, "int8wo_tp")


copy_tests(TorchAOTensorParallelTestCase, TestAffineQuantizedTensorParallel, "aqt_tp")
# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
class TestFloat8woAffineQuantizedTensorParallel(TorchAOTensorParallelTestCase):
QUANT_METHOD_FN = staticmethod(float8_weight_only)
copy_tests(TorchAOTensorParallelTestCase, TestFloat8woAffineQuantizedTensorParallel, "fp8wo_tp")

if __name__ == "__main__":
run_tests()
46 changes: 40 additions & 6 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
elif func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
if func is aten.t.default:
elif func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])

raise NotImplementedError(
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
)
elif func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
elif dim == 1:
assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}"
return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type)
else:
raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported")
else:
raise NotImplementedError(
f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

Expand Down Expand Up @@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
use_fast_accum=scaled_mm_config.use_fast_accum,
).reshape(out_shape)

def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
return (
# input is native float tensor
not is_traceable_wrapper_subclass(input_tensor) and
input_tensor.is_floating_point() and
# weight is float8 quantized affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor) and
isinstance(weight_tensor.layout_type, Float8LayoutType)
and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor))
)

def _linear_fp_act_fp8_weight_impl(
input_tensor: torch.Tensor,
weight_tensor: AffineQuantizedTensor,
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)

def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias):
return (
Expand Down Expand Up @@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
(_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
(_linear_f16_act_floatx_weight_check, _linear_f16_act_floatx_weight_impl),
Expand Down

0 comments on commit 9e2a253

Please sign in to comment.