Skip to content

Commit

Permalink
Optionally print ranges?
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
andrewor14 committed Jun 14, 2024
1 parent b50609c commit 71298a9
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Any, Callable, Optional, Tuple

import torch
Expand All @@ -14,6 +15,8 @@
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.unified import TwoStepQuantizer

CURRENT_STEP = -1
PRINT_VALUES = {}

if TORCH_VERSION_AFTER_2_4:
from torchao.quantization.GPTQ import (
Expand Down Expand Up @@ -199,6 +202,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
else:
w_fq = self.weight

print_ranges = os.getenv("PRINT_RANGES", "") == "true" and os.getenv("IS_RANK_ZERO", "") == "true"
global_step = int(os.getenv("GLOBAL_STEP", "0"))
if print_ranges:
global CURRENT_STEP
global PRINT_VALUES

# If it's a new step, print and reinitialize
all_attributes = ["act_min", "act_max", "weight_min", "weight_max", "weight_fq_min", "weight_fq_max"]
if global_step != CURRENT_STEP:
if len(PRINT_VALUES) > 0:
for attr in all_attributes:
print("ANDREW %s %s %s" % (attr, CURRENT_STEP, ",".join([str(x) for x in PRINT_VALUES[attr]])))
CURRENT_STEP = global_step
PRINT_VALUES = {}
for attr in all_attributes:
PRINT_VALUES[attr] = []

# Record current values
PRINT_VALUES["act_min"].append(x_fq.min().item())
PRINT_VALUES["act_max"].append(x_fq.max().item())
PRINT_VALUES["weight_min"].append(self.weight.min().item())
PRINT_VALUES["weight_max"].append(self.weight.max().item())
PRINT_VALUES["weight_fq_min"].append(w_fq.min().item())
PRINT_VALUES["weight_fq_max"].append(w_fq.max().item())

return torch.nn.functional.linear(x_fq, w_fq)

# TODO: move this to common util
Expand Down

0 comments on commit 71298a9

Please sign in to comment.