diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 16f2dd6a3..b082a7f30 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -188,8 +188,8 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(data={str(self.dequantize())}..., shape={self.shape}, block_size={self.block_size}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad}, layout_tensor={self.layout_tensor})" + f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, " + f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def _quantization_type(self): @@ -694,7 +694,7 @@ class BlockSparseAQTLayout(PlainAQTLayout): scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] @staticmethod def __new__( # noqa: PYI034 @@ -703,7 +703,7 @@ def __new__( # noqa: PYI034 bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], + scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], layout_type: LayoutType, requires_grad: bool = False, @@ -727,7 +727,7 @@ def __init__( # noqa: PYI034 bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], + scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], layout_type: LayoutType, requires_grad: bool = False, @@ -739,7 +739,7 @@ def __init__( # noqa: PYI034 self.zero_point = zero_point self.layout_type = layout_type - def __tensor_flatten__(self): + def __tensor_flatten__(self): inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) @@ -774,7 +774,7 @@ def from_plain(cls, int_data, scale, zero_point, layout_type): bsr_crow_indices=bsr_tensor.crow_indices(), bsr_col_indices=bsr_tensor.col_indices(), bsr_values=bsr_tensor.values(), - scale=scale, + scale=scale, zero_point=zero_point, layout_type = layout_type, requires_grad=False, @@ -820,7 +820,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return args[0].bsr_values.detach() if func is aten._nnz.default: - return args[0].bsr_values.shape[0] + return args[0].bsr_values.shape[0] raise NotImplementedError( f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported"