From cc8bf8595dfbc6e5e2ca3f18bbd6e9384e794c04 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Tue, 8 Oct 2024 13:49:27 -0700 Subject: [PATCH] Rename Layout -> TensorImpl (#1028) --- .github/workflows/regression_test.yml | 2 +- benchmarks/benchmark_fp6.py | 2 +- test/dtypes/test_affine_quantized_float.py | 14 +- test/dtypes/test_floatx.py | 14 +- test/hqq/test_hqq_affine.py | 4 +- test/integration/test_integration.py | 2 +- torchao/dtypes/__init__.py | 4 +- torchao/dtypes/affine_quantized_tensor.py | 210 +++++++++--------- torchao/dtypes/floatx/__init__.py | 2 +- torchao/dtypes/floatx/floatx.py | 18 +- torchao/dtypes/uintx/uintx.py | 6 +- torchao/dtypes/utils.py | 8 +- torchao/prototype/hqq/example.py | 4 +- torchao/quantization/autoquant.py | 8 +- torchao/quantization/quant_api.py | 2 +- torchao/sparsity/marlin/utils.py | 2 +- torchao/utils.py | 46 ++-- .../my_dtype_tensor_subclass.py | 64 +++--- .../my_trainable_tensor_subclass.py | 18 +- .../developer_api_guide/tensor_parallel.py | 16 +- 20 files changed, 223 insertions(+), 223 deletions(-) diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index 3aee8dbfb..13cd4e2e7 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -35,7 +35,7 @@ jobs: gpu-arch-version: "12.1" - name: CUDA 2.4 runs-on: linux.g5.12xlarge.nvidia.gpu - torch-spec: 'torch==2.4.0' + torch-spec: 'torch==2.4.1' gpu-arch-type: "cuda" gpu-arch-version: "12.1" - name: CUDA Nightly (Oct 1) diff --git a/benchmarks/benchmark_fp6.py b/benchmarks/benchmark_fp6.py index 425507bd9..509ea6e86 100644 --- a/benchmarks/benchmark_fp6.py +++ b/benchmarks/benchmark_fp6.py @@ -2,7 +2,7 @@ import pandas as pd import torch.nn.functional as F from torchao.dtypes import to_affine_quantized_fpx -from torchao.dtypes.floatx import FloatxTensorCoreAQTLayout, FloatxTensorCoreLayoutType +from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 621e3596e..761b233fc 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -210,18 +210,18 @@ def test_serialization(self, mode: str): # Compare weights if mode == "weight-only": - original_weight = original_layer.weight.layout_tensor.float8_data.to( - torch.float32 - ) - new_weight = new_layer.weight.layout_tensor.float8_data.to( + original_weight = original_layer.weight.tensor_impl.float8_data.to( torch.float32 ) + new_weight = new_layer.weight.tensor_impl.float8_data.to(torch.float32) else: - original_weight = original_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( + original_weight = original_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( torch.float32 ) - new_weight = new_layer.weight.original_weight_tensor.layout_tensor.float8_data.to( - torch.float32 + new_weight = ( + new_layer.weight.original_weight_tensor.tensor_impl.float8_data.to( + torch.float32 + ) ) assert torch.allclose( diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index b4776f95e..f228c4c0c 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -9,7 +9,7 @@ run_tests, ) from torchao.dtypes.floatx import ( - FloatxTensorCoreAQTLayout, + FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType, to_scaled_tc_floatx, from_scaled_tc_floatx, @@ -28,7 +28,7 @@ _Floatx_DTYPES = [(3, 2), (2, 2)] -class TestFloatxTensorCoreAQTLayout(TestCase): +class TestFloatxTensorCoreAQTTensorImpl(TestCase): @parametrize("device", _DEVICES) def test_pack_tc_fp6_correctness(self, device): x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) @@ -82,10 +82,10 @@ def test_to_copy_device(self, ebits, mbits): scale = choose_qparams_affine_floatx(x, ebits, mbits) x = quantize_affine_floatx(x, scale, ebits, mbits) layout_type = FloatxTensorCoreLayoutType(ebits, mbits) - floatx_layout_tensor = FloatxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda() - assert floatx_layout_tensor.device.type == "cuda" - floatx_layout_tensor = floatx_layout_tensor.cpu() - assert floatx_layout_tensor.device.type == "cpu" + floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda() + assert floatx_tensor_impl.device.type == "cuda" + floatx_tensor_impl = floatx_tensor_impl.cpu() + assert floatx_tensor_impl.device.type == "cpu" @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+") @@ -106,7 +106,7 @@ def test_fpx_weight_only(self, ebits, mbits, bias): torch.testing.assert_close(actual, expected) -instantiate_parametrized_tests(TestFloatxTensorCoreAQTLayout) +instantiate_parametrized_tests(TestFloatxTensorCoreAQTTensorImpl) if __name__ == "__main__": diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index f3fa41c64..c1177d2d4 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTLayout, + PlainAQTTensorImpl, PlainLayoutType, - TensorCoreTiledAQTLayout, + TensorCoreTiledAQTTensorImpl, TensorCoreTiledLayoutType, MappingType, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index be8f2f954..46799b491 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1051,7 +1051,7 @@ def forward(self, x): self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) - @unittest.skipIf(is_fbcode(), "'PlainAQTLayout' object has no attribute 'int_data'") + @unittest.skipIf(is_fbcode(), "'PlainAQTTensorImpl' object has no attribute 'int_data'") @torch.no_grad() def test_save_load_dqtensors(self, device, dtype): if device == "cpu": diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index e27bf6497..8d4be52dc 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,7 +14,7 @@ SemiSparseLayoutType, TensorCoreTiledLayoutType, Float8LayoutType, - Float8AQTLayout, + Float8AQTTensorImpl, MarlinSparseLayoutType, ) @@ -33,6 +33,6 @@ "SemiSparseLayoutType", "TensorCoreTiledLayoutType", "Float8LayoutType", - "Float8AQTLayout", + "Float8AQTTensorImpl", "MarlinSparseLayoutType", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c2c8e3c0b..0fa864f8b 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -51,17 +51,17 @@ aten = torch.ops.aten ############################### -# Base Layout Tensor Subclass # +# Base Tensor Impl Subclass # ############################### -class AQTLayout(TorchAOBaseTensor): +class AQTTensorImpl(TorchAOBaseTensor): """ - Base class for the layout tensor for `AffineQuantizedTensor` + Base class for the tensor impl for `AffineQuantizedTensor` """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the layout Tensor + """Get the plain (unpacked) Tensor for the tensor impl Returns data, scale and zero_point - Can be overwritten if other types of AQTLayout Tensor has different numbers of plain tensors + Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors """ pass @@ -76,7 +76,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - """ Construct a Layout from data, scale, zero_point and the layout_type""" + """ Construct a TensorImpl from data, scale, zero_point and the layout_type""" pass def __repr__(self): @@ -131,7 +131,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): regardless of the internal representation's type or orientation. fields: - layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, + tensor_impl (AQTTensorImpl): tensor that serves as a general tensor impl storage for the quantized data, e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device and operator/kernel block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam @@ -151,7 +151,7 @@ class AffineQuantizedTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - layout_tensor: AQTLayout, + tensor_impl: AQTTensorImpl, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -161,9 +161,9 @@ def __new__( strides=None, ): kwargs = {} - kwargs["device"] = layout_tensor.device + kwargs["device"] = tensor_impl.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout + kwargs.get("layout") if kwargs.get("layout", False) else tensor_impl.layout ) kwargs["dtype"] = dtype if strides is not None: @@ -173,7 +173,7 @@ def __new__( def __init__( self, - layout_tensor: AQTLayout, + tensor_impl: AQTTensorImpl, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[Union[int, float]] = None, @@ -182,7 +182,7 @@ def __init__( dtype=None, strides=None, ): - self.layout_tensor = layout_tensor + self.tensor_impl = tensor_impl self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max @@ -190,12 +190,12 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, " + f"{self.__class__.__name__}(tensor_impl={self.tensor_impl}, block_size={self.block_size}, " f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): - return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, layout_tensor_dtype={self.layout_tensor.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, layout_type={self.layout_type}, tensor_impl_dtype={self.tensor_impl.dtype}, quant_min={self.quant_min}, quant_max={self.quant_max}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -203,10 +203,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor from torchao.dtypes.floatx import FloatxTensorCoreLayoutType if isinstance(self.layout_type, FloatxTensorCoreLayoutType): - int_data, scale = self.layout_tensor.get_plain() + int_data, scale = self.tensor_impl.get_plain() return dequantize_affine_floatx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) else: - data, scale, zero_point = self.layout_tensor.get_plain() + data, scale, zero_point = self.tensor_impl.get_plain() dq = dequantize_affine( data, self.block_size, @@ -232,16 +232,16 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") def __tensor_flatten__(self): - return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - layout_tensor = tensor_data_dict["layout_tensor"] + tensor_impl = tensor_data_dict["tensor_impl"] block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( - layout_tensor, + tensor_impl, block_size, shape if outer_size is None else outer_size, quant_min, @@ -289,10 +289,10 @@ def from_hp_to_intx( # Note: output will be uint8 tensor for sub byte tensors for now data = layout_type.post_process(data) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, quant_min, @@ -324,10 +324,10 @@ def from_hp_to_intx_static( int_data = layout_type.post_process(int_data) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, quant_min, @@ -410,10 +410,10 @@ def from_hp_to_fpx( floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits) floatx_packed = layout_type.post_process(floatx_unpacked) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(floatx_packed, scale, None, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, layout_type) return cls( - layout_tensor, + tensor_impl, block_size, original_shape, dtype=input_float.dtype @@ -421,13 +421,13 @@ def from_hp_to_fpx( @property def layout_type(self) -> LayoutType: - return self.layout_tensor.layout_type + return self.tensor_impl.layout_type def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs.pop("device") return self.__class__( - self.layout_tensor.to(device), + self.tensor_impl.to(device), self.block_size, self.shape, self.quant_min, @@ -438,7 +438,7 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.layout_tensor), + fn(self.tensor_impl), self.block_size, self.shape, self.quant_min, @@ -464,10 +464,10 @@ def _apply_fn_to_data(self, fn): ###################################################### -# LayoutType and Layout Tensor Subclass Registration # +# LayoutType and TensorImpl Subclass Registration # ###################################################### -register_layout_cls = AffineQuantizedTensor.register_layout_cls -get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor +register_layout = AffineQuantizedTensor.register_layout +get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor @dataclass(frozen=True) class SemiSparseLayoutType(LayoutType): @@ -548,10 +548,10 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return w_24.t() -@register_layout_cls(PlainLayoutType) -class PlainAQTLayout(AQTLayout): +@register_layout(PlainLayoutType) +class PlainAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + TensorImpl storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point tensors directly as plain tensors. fields: @@ -645,12 +645,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) + return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self.layout_type) else: - raise NotImplementedError(f"PlainAQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") raise NotImplementedError( - f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -672,10 +672,10 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) -@register_layout_cls(SemiSparseLayoutType) -class SemiSparseAQTLayout(PlainAQTLayout): +@register_layout(SemiSparseLayoutType) +class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ - Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor + TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor """ @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): @@ -687,7 +687,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"SparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) def get_plain(self): @@ -712,8 +712,8 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, layout_type) -@register_layout_cls(BlockSparseLayoutType) -class BlockSparseAQTLayout(PlainAQTLayout): +@register_layout(BlockSparseLayoutType) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] @@ -849,13 +849,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return args[0].bsr_values.shape[0] raise NotImplementedError( - f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) -@register_layout_cls(MarlinSparseLayoutType) -class MarlinSparseAQTLayout(AQTLayout): +@register_layout(MarlinSparseLayoutType) +class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + TensorImpl storage class for sparse_marlin_24 layout for affine quantized tensor. Can be used with 4 bits and 8 bits quantization. @@ -922,7 +922,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"MarlinSparseAQTLayout dispatch: attempting to run {func}, this is not supported" + f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) def __tensor_flatten__(self): @@ -1022,10 +1022,10 @@ def _apply_fn_to_data(self, fn): return self -@register_layout_cls(Float8LayoutType) -class Float8AQTLayout(AQTLayout): +@register_layout(Float8LayoutType) +class Float8AQTTensorImpl(AQTTensorImpl): """ - Layout storage class for float8 layout for affine quantized tensor + TensorImpl storage class for float8 tensor impl for affine quantized tensor """ float8_data: torch.Tensor scale: torch.Tensor @@ -1112,12 +1112,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif dim == 1: assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return Float8AQTLayout(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) + return Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self.layout_type) else: - raise NotImplementedError(f"Float8AQTLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") else: raise NotImplementedError( - f"Float8AQTLayout dispatch: attempting to run {func}, this is not supported" + f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1136,9 +1136,9 @@ def from_plain( zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): - """ Main entrypoint for constructing Float8Layout Tensor""" - assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" + """ Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(layout_type, Float8LayoutType), f"Float8 TensorImpl must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): @@ -1151,10 +1151,10 @@ def __repr__(self): f"layout_type={layout_type})") -@register_layout_cls(TensorCoreTiledLayoutType) -class TensorCoreTiledAQTLayout(AQTLayout): +@register_layout(TensorCoreTiledLayoutType) +class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): """ - Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + TensorImpl storage class for tensor_core_tiled tensor impl for affine quantized tensor, this is for int4 only, it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] @@ -1230,7 +1230,7 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] if not is_device("cuda", device): - raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") + raise ValueError(f"TensorCoreTiledAQTTensorImpl is only available for cuda device, can't convert to {device}") return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1265,7 +1265,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, args[0]) raise NotImplementedError( - f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @@ -1311,14 +1311,14 @@ def get_layout_type(self) -> LayoutType: def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.layout_tensor.dtype == torch.int8 and + aqt.tensor_impl.dtype == torch.int8 and (aqt.quant_min is None or aqt.quant_min == -128) and (aqt.quant_max is None or aqt.quant_max == 127) ) def _aqt_is_int8_reduced_range(aqt): return ( - aqt.layout_tensor.dtype == torch.int8 and + aqt.tensor_impl.dtype == torch.int8 and aqt.quant_min == -127 and (aqt.quant_max is None or aqt.quant_max == 127) ) @@ -1327,7 +1327,7 @@ def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.layout_tensor.dtype == torch.int32 and + aqt.tensor_impl.dtype == torch.int32 and aqt.quant_min == 0 and aqt.quant_max == 15 ) @@ -1364,10 +1364,10 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # value of a float 16, (which results in a value of inf even if multiplying # by the other scale would bring it within the expected range) - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals_int8_t = weight_tensor.layout_tensor.int_data.contiguous().t() - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -1395,10 +1395,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weig ) def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals_int8 = weight_tensor.layout_tensor.int_data - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8 = weight_tensor.tensor_impl.int_data + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( @@ -1427,10 +1427,10 @@ def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.layout_tensor.int_data - x_scales = input_tensor.layout_tensor.scale - w_vals = weight_tensor.layout_tensor - w_scales = weight_tensor.layout_tensor.scale + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() @@ -1456,7 +1456,7 @@ def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): # input is native bfloat16 tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.dtype == torch.bfloat16 and - # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and _aqt_is_tensor_core_tile_uint4(weight_tensor) and weight_tensor.dtype == torch.bfloat16 and @@ -1478,8 +1478,8 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): act_mat = input_tensor # weight is packed from padded (out_features, in_features) weight tensor # (same dimension requirement as F.linear weight) - packed_weight = weight_tensor.layout_tensor.packed_weight - scale_and_zero = weight_tensor.layout_tensor.scale_and_zero + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero orig_act_size = act_mat.size() orig_dtype = act_mat.dtype @@ -1522,11 +1522,11 @@ def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.layout_tensor.scale) + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() - scale = weight_tensor.layout_tensor.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), @@ -1580,8 +1580,8 @@ def _linear_f16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): weight.layout_type.ebits, weight.layout_type.mbits, act_reshaped, - weight.layout_tensor.packed_floatx_data, - weight.layout_tensor.scale, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, splitK=splitK, ) @@ -1599,7 +1599,7 @@ def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt.layout_type, Float8LayoutType) - and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) return check_aqt(input_tensor) and check_aqt(weight_tensor) @@ -1624,14 +1624,14 @@ def _linear_fp8_act_fp8_weight_impl( out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing - w_layout = weight_tensor.layout_tensor - assert not w_layout.transposed, "Weight tensor must be contiguous" - w_data = w_layout.float8_data - w_scale = w_layout.scale + w_tensor_impl = weight_tensor.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale # Input tensor preprocessing - inpt_data = input_tensor.layout_tensor.float8_data - input_scale = input_tensor.layout_tensor.scale + inpt_data = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) @@ -1667,7 +1667,7 @@ def _linear_fp_act_fp8_weight_check( # weight is float8 quantized affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor.layout_type, Float8LayoutType) - and weight_tensor.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) ) @@ -1694,11 +1694,11 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b assert isinstance(weight_tensor, AffineQuantizedTensor) - sparse_w_int4 = weight_tensor.layout_tensor.int_data - scale = weight_tensor.layout_tensor.scale - meta = weight_tensor.layout_tensor.meta - original_shape = weight_tensor.layout_tensor.original_shape - num_bits = weight_tensor.layout_tensor.num_bits + sparse_w_int4 = weight_tensor.tensor_impl.int_data + scale = weight_tensor.tensor_impl.scale + meta = weight_tensor.tensor_impl.meta + original_shape = weight_tensor.tensor_impl.original_shape + num_bits = weight_tensor.tensor_impl.num_bits # Folds batch dimension into the first dimension input_2d = input_tensor.view(-1, input_tensor.shape[-1]) @@ -1845,7 +1845,7 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -1863,7 +1863,7 @@ def _(func, types, args, kwargs): # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) return return_and_correct_aliasing(func, args, kwargs, new) # this is needed for DTensor.from_local() and for flattening tensor @@ -1872,12 +1872,12 @@ def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.layout_tensor, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.layout_tensor, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 0eb1e7052..39461d886 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1 @@ -from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx import FloatxTensorCoreLayoutType, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index dcbfd5f69..5a9aab035 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -10,7 +10,7 @@ ) from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass -from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls +from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout aten = torch.ops.aten @@ -354,14 +354,14 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> @dataclass(frozen=True) class FloatxTensorCoreLayoutType(LayoutType): - """Layout type for FloatxTensorCoreAQTLayout + """Layout type for FloatxTensorCoreAQTTensorImpl """ ebits: int mbits: int -@register_layout_cls(FloatxTensorCoreLayoutType) -class FloatxTensorCoreAQTLayout(AQTLayout): - """FloatxTensorCoreAQTLayout represents a Tensor with dtype floatx(ebits=a, mbits=b), +@register_layout(FloatxTensorCoreLayoutType) +class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): + """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), it has a internal tensor field of "packed_floatx_data", which is packed from the uint8 unpacked data (the output of `quantize_affine_floatx` operator) @@ -377,10 +377,10 @@ class FloatxTensorCoreAQTLayout(AQTLayout): If original Tensor shape is (M, N), and the data is in nbit, the shape of the packed data will be (M, N // 8 * nbit) - FloatxTensorCoreAQTLayout.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of + FloatxTensorCoreAQTTensorImpl.from_plain takes an unpacked uint8 floatx Tensor of shape (M, N), with format of (zero padding bits + sign bit + exponent bits + mantissa bits), e.g. 00SEEEMM for fp6_e3_m2 - it will then pack the weight and instantiate the FloatxTensorCoreAQTLayout tensor - FloatxTensorCoreAQTLayout.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) + it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor + FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ def __new__( cls, @@ -483,7 +483,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"FloatxTensorCoreAQTLayout dispatch: attempting to run {func}, this is not supported" + f"FloatxTensorCoreAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index a0cd687f5..eb63fc619 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -8,7 +8,7 @@ LayoutType, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls +from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten @@ -194,8 +194,8 @@ class UintxLayoutType(LayoutType): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) -@register_layout_cls(UintxLayoutType) -class UintxAQTLayout(PlainAQTLayout): +@register_layout(UintxLayoutType) +class UintxAQTTensorImpl(PlainAQTTensorImpl): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 52a9c5719..4a6b3a0bb 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -4,11 +4,11 @@ """ Base class for different LayoutType, should not be instantiated directly -used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles -for int4 tensor core tiled layout +used to allow users to pass around configurations for the tensor impl, e.g. inner_k_tiles +for int4 tensor core tiled tensor impl -Note: layout is an abstraction not only for custom data representation, it is also used for how the -layout interacts with different operators, e.g. the same data representation can have different +Note: TensorImpl is an abstraction not only for custom data representation, it is also used for how the +tensorImpl interacts with different operators, e.g. the same data representation can have different behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index cd5a93b56..f410a11cd 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -3,9 +3,9 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ZeroPointDomain, - PlainAQTLayout, + PlainAQTTensorImpl, PlainLayoutType, - TensorCoreTiledAQTLayout, + TensorCoreTiledAQTTensorImpl, TensorCoreTiledLayoutType, MappingType, ) diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a5568c4e1..7439c982b 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -348,7 +348,7 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): ) q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") with torch.no_grad(): - w_vals_int8 = w_qtensor.original_weight_tensor.layout_tensor.int_data.contiguous().t() + w_vals_int8 = w_qtensor.original_weight_tensor.tensor_impl.int_data.contiguous().t() res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales.reshape(-1,1), w_vals_int8) print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") @@ -399,8 +399,8 @@ def _quantized_linear_op(act_mat, w_qtensor, bias): orig_dtype = act_mat.dtype orig_shape = act_mat.shape act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) - y = (act_mat*w_qtensor.layout_tensor.int_data.t().unsqueeze(0)).sum(dim=-2) - y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.layout_tensor.scale + y = (act_mat*w_qtensor.tensor_impl.int_data.t().unsqueeze(0)).sum(dim=-2) + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.tensor_impl.scale if bias is not None: y += bias return y.to(orig_dtype) @@ -420,7 +420,7 @@ class AQInt8WeightOnlyQuantizedLinearWeight3(AQInt8WeightOnlyQuantizedLinearWeig @staticmethod def _quantized_linear_op(act_mat, w_qtensor, bias): orig_shape = act_mat.shape - y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.layout_tensor.int_data.t()*w_qtensor.layout_tensor.scale) + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.tensor_impl.int_data.t()*w_qtensor.tensor_impl.scale) y=y.reshape(*orig_shape[:-1], y.shape[-1]) if bias is not None: y += bias diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6c4142506..aef873e2f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -886,7 +886,7 @@ def fpx_weight_only(ebits: int, mbits: int): e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 github repo: https://github.com/usyd-fsalab/fp6_llm, now renamed to quant-llm - For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTLayout` + For more details for packing please see: :class:`~torchao.dtypes.fpx.FpxTensorCoreAQTTensorImpl` This is experimental, will be merged with `to_affine_quantized_floatx` in the future diff --git a/torchao/sparsity/marlin/utils.py b/torchao/sparsity/marlin/utils.py index 4ebdf432e..4c5572553 100644 --- a/torchao/sparsity/marlin/utils.py +++ b/torchao/sparsity/marlin/utils.py @@ -9,7 +9,7 @@ class Marlin24Constants: MIN_THREAD_N: int = 128 MAX_PARALLEL: int = 64 - # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTLayout + # NOTE: Cuda kernel supports fp8, but not implemented yet in SparseMarlinAQTTensorImpl SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4, 8]) SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 32, 64, 128]) const = Marlin24Constants() diff --git a/torchao/utils.py b/torchao/utils.py index a0302cabe..36bc1be36 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -392,34 +392,34 @@ class MyTensor(torch.Tensor): kwarg_types = {k: type(arg) for k, arg in kwargs} raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}") -def _register_layout_cls(cls: Callable, layout_type_class: Callable): +def _register_layout(cls: Callable, layout_type_class: Callable): """Helper function for layout registrations, this is used to implement - register_layout_cls decorator for each tensor subclass, see aqt.py for example usage + register_layout decorator for each tensor subclass, see aqt.py for example usage Args: cls: Tensor subclass type layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - a decorator that registers the layout tensor constructor in the table + a decorator that registers the tensor impl constructor in the table """ # cls._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_type_class like TensorCoreTiledLayout - # to layout class constructor like TensorCoreTiledAQTLayout.from_plain that can construct a layout_tensor + # to tensor_impl class constructor like TensorCoreTiledAQTTensorImpl.from_plain that can construct a tensor_impl # from plain data like (quantized, unpacked) `data`, `scale`, `zero_point` if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): cls._LAYOUT_CONSTRUCTOR_TABLE = {} - def decorator(layout_class): - cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = layout_class.from_plain + def decorator(tensor_impl_class): + cls._LAYOUT_CONSTRUCTOR_TABLE[layout_type_class] = tensor_impl_class.from_plain if TORCH_VERSION_AT_LEAST_2_5: - # Allow serialization to work for models uses this layout tensor subclass - torch.serialization.add_safe_globals([layout_type_class, layout_class]) - return layout_class + # Allow serialization to work for models uses this tensor impl subclass + torch.serialization.add_safe_globals([layout_type_class, tensor_impl_class]) + return tensor_impl_class return decorator -def _get_layout_tensor_constructor(cls: Callable, layout_type_class: Callable) -> Callable: - """Get Layout class constructor (LayoutClass.from_plain) for `cls` based on `layout_type_class` +def _get_tensor_impl_constructor(cls: Callable, layout_type_class: Callable) -> Callable: + """Get TensorImpl class constructor (TensorImplClass.from_plain) for `cls` based on `layout_type_class` `layout_type_class` means the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Args: @@ -427,10 +427,10 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: Callable) - layout_type_class: the class type of subclass of `LayoutType`, e.g. `PlainLayoutType` Returns: - layout tensor subclass constructor for the layout_type_class + tensor impl subclass constructor for the layout_type_class """ if not hasattr(cls, "_LAYOUT_CONSTRUCTOR_TABLE"): - raise ValueError(f"no registered layout class constructor for: {cls}") + raise ValueError(f"no registered tensor_impl class constructor for: {cls}") if layout_type_class not in cls._LAYOUT_CONSTRUCTOR_TABLE: raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}") @@ -457,25 +457,25 @@ def to(self, *args, **kwargs): def _(func, types, args, kwargs): ... - `register_layout_cls`: - register_layout_cls = MyTensor.register_layout_cls + `register_layout`: + register_layout = MyTensor.register_layout - @register_layout_cls(PlainLayoutType) - class PlainAQTLayout(...): + @register_layout(PlainLayoutType) + class PlainAQTTensorImpl(...): ... - `get_layout_tensor_constructor`: - get_layout_tensor_constructor = MyTensor.get_layout_tensor_constructor + `get_tensor_impl_constructor`: + get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(data, scale, zero_point, layout_type) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(data, scale, zero_point, layout_type) """ implements = classmethod(_implements) __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) __torch_function__ = classmethod(_dispatch__torch_function__) - register_layout_cls = classmethod(_register_layout_cls) - get_layout_tensor_constructor = classmethod(_get_layout_tensor_constructor) + register_layout = classmethod(_register_layout) + get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor) def _get_to_kwargs(self, *args, **kwargs): # `torch._C._nn._parse_to` can't handle `layout` argument diff --git a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py index bc85d26f5..c714df2a7 100644 --- a/tutorials/developer_api_guide/my_dtype_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_dtype_tensor_subclass.py @@ -33,11 +33,11 @@ aten = torch.ops.aten ############################### -# Base Layout Tensor Subclass # +# Base Tensor Impl Subclass # ############################### -class MyDTypeLayout(torch.Tensor): +class MyDTypeTensorImpl(torch.Tensor): """ - Base class for the layout tensor for `MyDTypeTensor` + Base class for the tensor impl for `MyDTypeTensor` """ # get the original unpacked Tensors def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: @@ -53,7 +53,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a layout tensor from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ pass @@ -82,17 +82,17 @@ class MyDTypeTensor(TorchAOBaseTensor): @staticmethod def __new__( cls, - layout_tensor: MyDTypeLayout, + tensor_impl: MyDTypeTensorImpl, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): kwargs = {} - kwargs["device"] = layout_tensor.device + kwargs["device"] = tensor_impl.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) - else layout_tensor.layout + else tensor_impl.layout ) kwargs["dtype"] = dtype kwargs["requires_grad"] = requires_grad @@ -100,12 +100,12 @@ def __new__( def __init__( self, - layout_tensor: MyDTypeLayout, + tensor_impl: MyDTypeTensorImpl, shape: torch.Size, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, ): - self.layout_tensor = layout_tensor + self.tensor_impl = tensor_impl """__tensor_flatten__ and __tensor_unflatten__ are used to desugar the tensor into native Tensors/attributes and reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define @@ -118,7 +118,7 @@ def __tensor_flatten__(self): The first one contains any tensor fields such as int_data and scale as keys to a dictionary The second one contains all other non tensor type fields as values of a list """ - return ["layout_tensor"], [self.shape, self.dtype, self.requires_grad] + return ["tensor_impl"], [self.shape, self.dtype, self.requires_grad] @classmethod def __tensor_unflatten__( @@ -129,10 +129,10 @@ def __tensor_unflatten__( tensor_data_dict contains the tensor fields of the class as a dictionary tensor_attributes contains all other non tensor type fields """ - layout_tensor = tensor_data_dict["layout_tensor"] + tensor_impl = tensor_data_dict["tensor_impl"] shape, dtype, requires_grad = tensor_attributes return cls( - layout_tensor, + tensor_impl, shape if outer_size is None else outer_size, dtype=dtype, requires_grad=requires_grad, @@ -152,25 +152,25 @@ def from_float( dtype = torch.int16 scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = quantize_affine(input_float, block_size, scale, zero_point, dtype) - layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) - layout_tensor = layout_tensor_ctr(int_data, scale, layout_type) - return cls(layout_tensor, input_float.shape) + tensor_impl_ctr = get_tensor_impl_constructor(type(layout_type)) + tensor_impl = tensor_impl_ctr(int_data, scale, layout_type) + return cls(tensor_impl, input_float.shape) """[Optional] We can overwrite layout property of the Tensor to represent different packing formats """ @property def layout_type(self) -> LayoutType: - return self.layout_tensor.layout_type + return self.tensor_impl.layout_type def dequantize(self, output_dtype=None): """We can define a dequantize method to convert the quantized tensor to a floating point tensor""" if output_dtype is None: output_dtype = torch.get_default_dtype() - int_data, scale = self.layout_tensor.get_plain() + int_data, scale = self.tensor_impl.get_plain() transposed = False block_size = (1, int_data.shape[-1]) - if hasattr(self.layout_tensor, "transposed") and self.layout_tensor.transposed: + if hasattr(self.tensor_impl, "transposed") and self.tensor_impl.transposed: transposed = True res = dequantize_affine(int_data, block_size, scale, None, int_data.dtype, output_dtype=output_dtype) if transposed: @@ -186,10 +186,10 @@ def __repr__(self): def _apply_fn_to_data(self, fn): """ Used for implementing aten ops by applying them only to the relevant tensor atributes - In this case we only want to call things like to() or view() on the layout tensor + In this case we only want to call things like to() or view() on the tensor impl """ return self.__class__( - fn(self.layout_tensor), + fn(self.tensor_impl), self.shape, self.dtype, ) @@ -206,14 +206,14 @@ def _apply_fn_to_data(self, fn): """ ###################################################### -# LayoutType and Layout Tensor Subclass Registration # +# LayoutType and TensorImpl Subclass Registration # ###################################################### -register_layout_cls = MyDTypeTensor.register_layout_cls -get_layout_tensor_constructor = MyDTypeTensor.get_layout_tensor_constructor +register_layout = MyDTypeTensor.register_layout +get_tensor_impl_constructor = MyDTypeTensor.get_tensor_impl_constructor -@register_layout_cls(PlainLayoutType) -class PlainMyDTypeLayout(MyDTypeLayout): +@register_layout(PlainLayoutType) +class PlainMyDTypeTensorImpl(MyDTypeTensorImpl): def __new__( cls, int_data: torch.Tensor, @@ -261,7 +261,7 @@ def from_plain( scale: torch.Tensor, layout_type: LayoutType, ): - """Construct a layout tensor from plain tensors and a layout_type, which main contain + """Construct a tensor impl from plain tensors and a layout_type, which main contain extra metadata for packing etc. """ assert isinstance(layout_type, PlainLayoutType) @@ -292,11 +292,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.split.Tensor: int_data_list = func(args[0].int_data, *args[1:], **kwargs) scale_list = func(args[0].scale, *args[1:], **kwargs) - out = [PlainMyDTypeLayout(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] + out = [PlainMyDTypeTensorImpl(int_data, scale, args[0].transposed, args[0].layout_type) for int_data, scale in zip(int_data_list, scale_list)] return out elif func is aten.empty_like.default: int_data_empty_like = func(args[0].int_data, *args[1:], **kwargs) - return PlainMyDTypeLayout(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) + return PlainMyDTypeTensorImpl(int_data_empty_like, args[0].scale, args[0].transposed, args[0].layout_type) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: @@ -304,16 +304,16 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) ) elif dim == 1: - return PlainMyDTypeLayout(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) + return PlainMyDTypeTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.transposed, self.layout_type) else: - raise NotImplementedError(f"PlainMyDTypeLayout dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError(f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") elif func is aten.t.default: - return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeLayout(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) + return return_and_correct_aliasing(func, args, kwargs, PlainMyDTypeTensorImpl(args[0].int_data, args[0].scale, not args[0].transposed, args[0].layout_type)) # Tensor parallel support END raise NotImplementedError( - f"PlainMyDTypeLayout dispatch: attempting to run {func}, this is not supported" + f"PlainMyDTypeTensorImpl dispatch: attempting to run {func}, this is not supported" ) ##################################################### diff --git a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py index b702ac4f9..59e72efb6 100644 --- a/tutorials/developer_api_guide/my_trainable_tensor_subclass.py +++ b/tutorials/developer_api_guide/my_trainable_tensor_subclass.py @@ -43,8 +43,8 @@ def _quantize( dtype = torch.int16 scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) int_data = (input_float / scale).to(torch.int8) - layout_tensor_ctr = cls.get_layout_tensor_constructor(type(layout_type)) - return layout_tensor_ctr(int_data, scale, layout_type) + tensor_impl_ctr = cls.get_tensor_impl_constructor(type(layout_type)) + return tensor_impl_ctr(int_data, scale, layout_type) @classmethod def from_float( @@ -71,9 +71,9 @@ def forward( input_float: torch.Tensor, layout_type: LayoutType, ) -> "MyTrainableDTypeTensor": - layout_tensor = MyTrainableDTypeTensor._quantize(input_float, layout_type) + tensor_impl = MyTrainableDTypeTensor._quantize(input_float, layout_type) return MyTrainableDTypeTensor( - layout_tensor, + tensor_impl, input_float.shape, requires_grad=True, ) @@ -137,15 +137,15 @@ def _(func, types, args, kwargs): """ assert len(args) == 2 assert isinstance(args[0], MyTrainableDTypeTensor) - assert args[0].layout_tensor.int_data.dtype == torch.int8 + assert args[0].tensor_impl.int_data.dtype == torch.int8 float0 = args[0].dequantize() float1 = args[1].dequantize() if isinstance(args[1], MyTrainableDTypeTensor) else args[1] new_value = torch.add(float0, float1, **kwargs) - new_layout_tensor = MyTrainableDTypeTensor._quantize( + new_tensor_impl = MyTrainableDTypeTensor._quantize( new_value, - args[0].layout_tensor.get_layout_type(), + args[0].tensor_impl.get_layout_type(), ) - args[0].layout_tensor = new_layout_tensor + args[0].tensor_impl = new_tensor_impl return return_and_correct_aliasing(func, args, kwargs, args[0]) @implements(aten.add.Tensor) @@ -190,7 +190,7 @@ def main(): loss = loss_fn(output, target) loss.backward() if VERBOSE: - weight = m.linear.weight.layout_tensor.int_data.flatten()[:3] + weight = m.linear.weight.tensor_impl.int_data.flatten()[:3] weight_grad = m.linear.weight.grad.flatten()[:3] print(" * step %s: weight grad = %s, weight value = %s" % (i, weight_grad, weight)) optimizer.step() diff --git a/tutorials/developer_api_guide/tensor_parallel.py b/tutorials/developer_api_guide/tensor_parallel.py index 0ed3bc9a2..84de815a3 100644 --- a/tutorials/developer_api_guide/tensor_parallel.py +++ b/tutorials/developer_api_guide/tensor_parallel.py @@ -24,14 +24,14 @@ def _(func, types, args, kwargs): @implements([aten.split.Tensor]) def _(func, types, args, kwargs): - layout_tensor_list = func(args[0].layout_tensor, *args[1:], **kwargs) - out = [MyDTypeTensorTP(layout_tensor, layout_tensor.shape) for layout_tensor in layout_tensor_list] + tensor_impl_list = func(args[0].tensor_impl, *args[1:], **kwargs) + out = [MyDTypeTensorTP(tensor_impl, tensor_impl.shape) for tensor_impl in tensor_impl_list] return out @implements([aten.empty_like.default]) def _(func, types, args, kwargs): - empty_like_layout_tensor = func(args[0].layout_tensor, *args[1:], **kwargs) - return MyDTypeTensorTP(empty_like_layout_tensor, empty_like_layout_tensor.shape) + empty_like_tensor_impl = func(args[0].tensor_impl, *args[1:], **kwargs) + return MyDTypeTensorTP(empty_like_tensor_impl, empty_like_tensor_impl.shape) @implements(aten.slice.Tensor) def _(func, types, args, kwargs): @@ -41,7 +41,7 @@ def _(func, types, args, kwargs): end = self.shape[dim] shape = list(self.shape) shape[dim] = end - start - return self.__class__(aten.slice.Tensor(self.layout_tensor, dim, start, end, step), shape, self.dtype) + return self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), shape, self.dtype) # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) @@ -49,10 +49,10 @@ def _(func, types, args, kwargs): x, shape = args if tuple(x.shape) == tuple(shape): - return x.__class__(x.layout_tensor, x.shape, x.dtype) + return x.__class__(x.tensor_impl, x.shape, x.dtype) if len(shape) == 1 and shape[0] == -1: - return x.__class__(x.layout_tensor, (x.numel(),), x.dtype) + return x.__class__(x.tensor_impl, (x.numel(),), x.dtype) raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") @@ -60,7 +60,7 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] - new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype) + new = tensor.__class__(tensor.tensor_impl.t(), shape, tensor.dtype) return return_and_correct_aliasing(func, args, kwargs, new) @implements(aten.addmm.default)