From 54b62c10f32efce7830052ea32aee9398bd609d7 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Tue, 1 Oct 2024 08:18:11 -0700 Subject: [PATCH] Update torchao to 0.5.0 and fix GPU quantization tutorial (#3069) * Update torchao to 0.5.0 and fix GPU quantization tutorial --------- Co-authored-by: HDCharles --- .ci/docker/requirements.txt | 2 +- .../gpu_quantization_torchao_tutorial.py | 31 +++++++++++++------ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 2384fb1b00..afa5588919 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -68,5 +68,5 @@ iopath pygame==2.6.0 pycocotools semilearn==0.3.2 -torchao==0.0.3 +torchao==0.5.0 segment_anything==1.0 diff --git a/prototype_source/gpu_quantization_torchao_tutorial.py b/prototype_source/gpu_quantization_torchao_tutorial.py index 4050a88e56..f901f8abd3 100644 --- a/prototype_source/gpu_quantization_torchao_tutorial.py +++ b/prototype_source/gpu_quantization_torchao_tutorial.py @@ -44,7 +44,8 @@ # import torch -from torchao.quantization import change_linear_weights_to_int8_dqtensors +from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight +from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5 from segment_anything import sam_model_registry from torch.utils.benchmark import Timer @@ -156,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1): # in memory bound situations where the benefit comes from loading less # weight data, rather than doing less computation. The torchao APIs: # -# ``change_linear_weights_to_int8_dqtensors``, -# ``change_linear_weights_to_int8_woqtensors`` or -# ``change_linear_weights_to_int4_woqtensors`` +# ``int8_dynamic_activation_int8_weight()``, +# ``int8_weight_only()`` or +# ``int4_weight_only()`` # # can be used to easily apply the desired quantization technique and then # once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is @@ -170,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1): # ``apply_weight_only_int8_quant`` instead as drop in replacement for the two # above (no replacement for int4). # -# The difference between the two APIs is that ``change_linear_weights`` API +# The difference between the two APIs is that ``int8_dynamic_activation`` API # alters the weight tensor of the linear module so instead of doing a # normal linear, it does a quantized operation. This is helpful when you # have non-standard linear ops that do more than one thing. The ``apply`` @@ -185,7 +186,10 @@ def get_sam_model(only_one_block=False, batchsize=1): model, image = get_sam_model(only_one_block, batchsize) model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -220,7 +224,10 @@ def get_sam_model(only_one_block=False, batchsize=1): model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) torch._inductor.config.force_fuse_int_mm_with_mul = True -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -251,7 +258,10 @@ def get_sam_model(only_one_block=False, batchsize=1): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_check_all_directions = True torch._inductor.config.force_fuse_int_mm_with_mul = True -change_linear_weights_to_int8_dqtensors(model) +quantize_(model, int8_dynamic_activation_int8_weight()) +if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB") @@ -280,7 +290,10 @@ def get_sam_model(only_one_block=False, batchsize=1): model, image = get_sam_model(False, batchsize) model = model.to(torch.bfloat16) image = image.to(torch.bfloat16) - change_linear_weights_to_int8_dqtensors(model) + quantize_(model, int8_dynamic_activation_int8_weight()) + if not TORCH_VERSION_AT_LEAST_2_5: + # needed for subclass + compile to work on older versions of pytorch + unwrap_tensor_subclass(model) model_c = torch.compile(model, mode='max-autotune') quant_res = benchmark(model_c, image) print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")