diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index a37e65684f..b26a204f4a 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -184,18 +184,16 @@ def apply_torchdynamo_args( if args.quantization: import torchao from torchao.quantization import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_dqtensors, - change_linear_weights_to_int8_woqtensors, + quantize, int8_weight_only, int4_weight_only, int8_dynamic_activation_int8_weight ) - torch._dynamo.epilogue_fusion = False + from torchao.utils import unwrap_tensor_subclass + torch._dynamo.config.automatic_dynamic_shapes = False - torch._dynamo.config.force_parameter_static_shapes = False + torch._inductor.config.force_fuse_int_mm_with_mul = True + torch._inductor.config.use_mixed_mm = True torch._dynamo.config.cache_size_limit = 10000 assert "cuda" in model.device module, example_inputs = model.get_module() - torch._inductor.config.force_fuse_int_mm_with_mul = True - torch._inductor.config.use_mixed_mm = True if isinstance(example_inputs, tuple([tuple, list])): example_inputs = tuple([ x.to(torch.bfloat16) @@ -209,22 +207,22 @@ def apply_torchdynamo_args( module(**example_inputs) else: module(*example_inputs) - if args.quantization == "int8dynamic": - change_linear_weights_to_int8_dqtensors(module) + quantize(module, int8_dynamic_activation_int8_weight(), set_inductor_config=False) elif args.quantization == "int8weightonly": - change_linear_weights_to_int8_woqtensors(module) + quantize(module, int8_weight_only(), set_inductor_config=False) elif args.quantization == "int4weightonly": - change_linear_weights_to_int4_woqtensors(module) - elif args.quantization == "autoquant": - torchao.autoquant(module, error_on_unseen=False) + quantize(module, int4_weight_only(), set_inductor_config=False) + if args.quantization == "autoquant": + torchao.autoquant(module, error_on_unseen=False, mode=["interpolate", .85], set_inductor_config=False) if isinstance(example_inputs, dict): module(**example_inputs) else: module(*example_inputs) from torchao.quantization.autoquant import AUTOQUANT_CACHE assert len(AUTOQUANT_CACHE)>0, f"Err: found no autoquantizable layers in model {type(module)}, stopping autoquantization" - + else: + unwrap_tensor_subclass(module) if args.freeze_prepack_weights: torch._inductor.config.freezing = True