Skip to content

Commit

Permalink
Improve TorchAO error message (#10627)
Browse files Browse the repository at this point in the history
improve error message
  • Loading branch information
a-r-r-o-w authored Jan 22, 2025
1 parent beacaa5 commit ca60ad8
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,15 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]

TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)

raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)

Expand Down Expand Up @@ -652,13 +659,13 @@ def get_apply_tensor_subclass(self):

def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint_a16w4",
"quant_type": "uint4wo",
"quant_type_kwargs": {
"group_size": 32
}
Expand Down

0 comments on commit ca60ad8

Please sign in to comment.