From b4055daa0b06e6a60d227cb6ea756c48de99bc74 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 1 Oct 2024 08:48:22 -0700 Subject: [PATCH] Don't call dequantize in `__repr__` (#965) Summary: To reduce the need for `dequantize` to work when people add new things, if people want to see the dequantized value, they can print `weight.dequantize()` instead of relying on `__repr__` Test Plan: test locally: ``` from torchao import quantize_, int8_weight_only import torch l = torch.nn.Linear(2, 2) quantize_(l, int8_weight_only()) print(l) print(l.dequantize()) ``` ``` AffineQuantizedTensor(layout_tensor=PlainAQTLayout(data=tensor([[ 127, -77], [-128, -40]], dtype=torch.int8)... , scale=tensor([0.0007, 0.0032])... , zero_point=tensor([0, 0])... , layout_type=PlainLayoutType()), block_size=(1, 2), shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, requires_grad=False) tensor([[ 0.0856, -0.0519], [-0.4070, -0.1272]]) ``` Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/affine_quantized_tensor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 16f2dd6a3..b082a7f30 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -188,8 +188,8 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(data={str(self.dequantize())}..., shape={self.shape}, block_size={self.block_size}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad}, layout_tensor={self.layout_tensor})" + f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): @@ -694,7 +694,7 @@ class BlockSparseAQTLayout(PlainAQTLayout): scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] @staticmethod def __new__( # noqa: PYI034 @@ -703,7 +703,7 @@ def __new__( # noqa: PYI034 bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], + scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], layout_type: LayoutType, requires_grad: bool = False, @@ -727,7 +727,7 @@ def __init__( # noqa: PYI034 bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], + scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], layout_type: LayoutType, requires_grad: bool = False, @@ -739,7 +739,7 @@ def __init__( # noqa: PYI034 self.zero_point = zero_point self.layout_type = layout_type - def __tensor_flatten__(self): + def __tensor_flatten__(self): inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) @@ -774,7 +774,7 @@ def from_plain(cls, int_data, scale, zero_point, layout_type): bsr_crow_indices=bsr_tensor.crow_indices(), bsr_col_indices=bsr_tensor.col_indices(), bsr_values=bsr_tensor.values(), - scale=scale, + scale=scale, zero_point=zero_point, layout_type = layout_type, requires_grad=False, @@ -820,7 +820,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return args[0].bsr_values.detach() if func is aten._nnz.default: - return args[0].bsr_values.shape[0] + return args[0].bsr_values.shape[0] raise NotImplementedError( f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported"