Skip to content

Commit

Permalink
Don't call dequantize in __repr__ (#965)
Browse files Browse the repository at this point in the history
Summary:
To reduce the need for `dequantize` to work when people add new things,
if people want to see the dequantized value, they can print `weight.dequantize()`
instead of relying on `__repr__`

Test Plan:
test locally:
```
from torchao import quantize_, int8_weight_only
import torch
l = torch.nn.Linear(2, 2)
quantize_(l, int8_weight_only())
print(l)
print(l.dequantize())
```

```
AffineQuantizedTensor(layout_tensor=PlainAQTLayout(data=tensor([[ 127,  -77],
        [-128,  -40]], dtype=torch.int8)... , scale=tensor([0.0007, 0.0032])... , zero_point=tensor([0, 0])... , layout_type=PlainLayoutType()), block_size=(1, 2), shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, requires_grad=False)
tensor([[ 0.0856, -0.0519],
        [-0.4070, -0.1272]])
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Oct 1, 2024
1 parent a382752 commit b4055da
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def __init__(

def __repr__(self):
return (
f"{self.__class__.__name__}(data={str(self.dequantize())}..., shape={self.shape}, block_size={self.block_size}, "
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad}, layout_tensor={self.layout_tensor})"
f"{self.__class__.__name__}(layout_tensor={self.layout_tensor}, block_size={self.block_size}, "
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

def _quantization_type(self):
Expand Down Expand Up @@ -694,7 +694,7 @@ class BlockSparseAQTLayout(PlainAQTLayout):
scale: Optional[torch.Tensor]
zero_point: Optional[torch.Tensor]

__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"]
__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"]

@staticmethod
def __new__( # noqa: PYI034
Expand All @@ -703,7 +703,7 @@ def __new__( # noqa: PYI034
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
layout_type: LayoutType,
requires_grad: bool = False,
Expand All @@ -727,7 +727,7 @@ def __init__( # noqa: PYI034
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
layout_type: LayoutType,
requires_grad: bool = False,
Expand All @@ -739,7 +739,7 @@ def __init__( # noqa: PYI034
self.zero_point = zero_point
self.layout_type = layout_type

def __tensor_flatten__(self):
def __tensor_flatten__(self):
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
Expand Down Expand Up @@ -774,7 +774,7 @@ def from_plain(cls, int_data, scale, zero_point, layout_type):
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
scale=scale,
scale=scale,
zero_point=zero_point,
layout_type = layout_type,
requires_grad=False,
Expand Down Expand Up @@ -820,7 +820,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return args[0].bsr_values.detach()

if func is aten._nnz.default:
return args[0].bsr_values.shape[0]
return args[0].bsr_values.shape[0]

raise NotImplementedError(
f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported"
Expand Down

0 comments on commit b4055da

Please sign in to comment.