Skip to content

Commit

Permalink
remove allow-untyped-defs from ao/quantization/experimental/fake_quan…
Browse files Browse the repository at this point in the history
…tize.py (pytorch#144091)

Pull Request resolved: pytorch#144091
Approved by: https://github.com/aorenste
  • Loading branch information
bobrenjc93 authored and pytorchmergebot committed Jan 3, 2025
1 parent 377e297 commit 891a86d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch/ao/quantization/experimental/fake_quantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Callable

import torch
from torch import Tensor
from torch.ao.quantization.experimental.fake_quantize_function import (
Expand All @@ -14,15 +15,15 @@ class APoTFakeQuantize(FakeQuantizeBase):
quantization_levels: Tensor
level_indices: Tensor

def __init__(self, observer=APoTObserver, **observer_kwargs):
def __init__(self, observer: Callable = APoTObserver, **observer_kwargs: Any):
super().__init__()
self.activation_post_process = observer(**observer_kwargs)
self.dtype = self.activation_post_process.dtype

def calculate_qparams(self, signed=False): # type: ignore[override]
def calculate_qparams(self, signed: bool = False) -> tuple[Tensor, Tensor, Tensor, Tensor]: # type: ignore[override]
return self.activation_post_process.calculate_qparams(signed=signed)

def forward(self, X: torch.Tensor): # type: ignore[override]
def forward(self, X: torch.Tensor) -> Tensor: # type: ignore[override]
if self.observer_enabled[0] == 1:
self.activation_post_process.forward(X)
result = self.activation_post_process.calculate_qparams(signed=False)
Expand Down

0 comments on commit 891a86d

Please sign in to comment.