Skip to content

Commit

Permalink
added manual calibration configuration support
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 14, 2024
1 parent 688fd8c commit e9ff880
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 32 deletions.
11 changes: 4 additions & 7 deletions examples/onnxruntime_static_quant_vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@ defaults:

experiment_name: onnxruntime_static_quant_vit
model: google/vit-base-patch16-224
device: cuda
device: cpu

backend:
auto_quantization: avx2
auto_quantization_config:
quantization: true
quantization_config:
is_static: true
per_channel: false

auto_calibration: minmax

launcher:
device_isolation: true
calibration: true

hydra:
run:
Expand Down
51 changes: 35 additions & 16 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AutoCalibrationConfig,
AutoOptimizationConfig,
AutoQuantizationConfig,
CalibrationConfig,
OptimizationConfig,
QuantizationConfig,
)
Expand All @@ -31,7 +32,12 @@
from ..peft_utils import get_peft_config_class
from ..pytorch.utils import randomize_weights
from .config import ORTConfig
from .utils import TASKS_TO_ORTMODELS, TASKS_TO_ORTSD, format_quantization_config
from .utils import (
TASKS_TO_ORTMODELS,
TASKS_TO_ORTSD,
format_calibration_config,
format_quantization_config,
)

# disable transformers logging
set_verbosity_error()
Expand Down Expand Up @@ -201,7 +207,7 @@ def is_quantized(self) -> bool:

@property
def is_calibrated(self) -> bool:
return self.config.auto_calibration is not None
return (self.config.auto_calibration is not None) or self.config.calibration

@property
def automodel_kwargs(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -290,22 +296,38 @@ def quantize_onnx_files(self) -> None:

LOGGER.info("\t+ Processing quantization config")
if self.config.auto_quantization is not None:
self.config.auto_quantization_config = format_quantization_config(self.config.auto_quantization_config)
auto_quantization_config_class = getattr(AutoQuantizationConfig, self.config.auto_quantization)
quantization_config = auto_quantization_config_class(**self.config.auto_quantization_config)
auto_quantization_config = format_quantization_config(self.config.auto_quantization_config)
auto_quantization_class = getattr(AutoQuantizationConfig, self.config.auto_quantization)
quantization_config = auto_quantization_class(**auto_quantization_config)
elif self.config.quantization:
self.config.quantization_config = format_quantization_config(self.config.quantization_config)
quantization_config = QuantizationConfig(**self.config.quantization_config)
quantization_config = format_quantization_config(self.config.quantization_config)
quantization_config = QuantizationConfig(**quantization_config)

if self.is_calibrated:
LOGGER.info("\t+ Generating calibration dataset")
dataset_shapes = {"dataset_size": 1, "sequence_length": 1, **self.model_shapes}
calibration_dataset = DatasetGenerator(task=self.task, dataset_shapes=dataset_shapes).generate()
columns_to_be_removed = list(set(calibration_dataset.column_names) - set(self.inputs_names))
calibration_dataset = calibration_dataset.remove_columns(columns_to_be_removed)

LOGGER.info("\t+ Processing calibration config")
calibration_config_method = getattr(AutoCalibrationConfig, self.config.auto_calibration)
calibration_config = calibration_config_method(calibration_dataset, **self.config.auto_calibration_config)
if self.config.auto_calibration is not None:
LOGGER.info("\t+ Processing calibration config")
auto_calibration_method = getattr(AutoCalibrationConfig, self.config.auto_calibration)
calibration_config = auto_calibration_method(
calibration_dataset,
**self.config.auto_calibration_config,
)
elif self.config.calibration:
LOGGER.info("\t+ Processing calibration config")
calibration_config = format_calibration_config(self.config.calibration_config)
calibration_config = CalibrationConfig(
dataset_name="calibration_dataset",
dataset_split=calibration_dataset.split,
dataset_num_samples=calibration_dataset.num_rows,
dataset_config_name=calibration_dataset.config_name,
**self.config.calibration_config,
)

for onnx_file_name_to_quantize in self.onnx_files_names_to_quantize:
LOGGER.info(f"\t+ Creating quantizer for {onnx_file_name_to_quantize}")
Expand Down Expand Up @@ -358,15 +380,12 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.library == "diffusers":
return {"prompt": inputs["prompt"]}

for key in list(inputs.keys()):
# sometimes optimum onnx exported models don't have inputs
# that their pytorch counterparts have, for instance token_type_ids
if key not in self.inputs_names:
inputs.pop(key)

LOGGER.info(f"\t+ Moving inputs tensors to device {self.device}")
for key, value in inputs.items():
inputs[key] = value.to(self.device)
if key not in self.inputs_names:
inputs.pop(key)
else:
inputs[key] = value.to(self.device)

return inputs

Expand Down
27 changes: 19 additions & 8 deletions optimum_benchmark/backends/onnxruntime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@

QUANTIZATION_CONFIG = {
"is_static": False,
"format": "QOperator", # QOperator, QDQ
"format": "QOperator",
# is_static and format are mandatory
}

CALIBRATION_CONFIG = {
"method": "MinMax"
# method is mandatory
}

AUTO_QUANTIZATION_CONFIG = {
"is_static": False,
# full auto quantization config depends on the hardware but
# is_static is mandatory
}

Expand Down Expand Up @@ -78,6 +82,10 @@ class ORTConfig(BackendConfig):
quantization: bool = False
quantization_config: Dict[str, Any] = field(default_factory=dict)

# manual calibration options
calibration: bool = False
calibration_config: Dict[str, Any] = field(default_factory=dict)

# ort-training is basically a different package so we might need to separate these two backends in the future
use_inference_session: bool = "${is_inference:${benchmark.name}}"

Expand All @@ -100,22 +108,25 @@ def __post_init__(self):
OmegaConf.merge(QUANTIZATION_CONFIG, self.quantization_config)
)
# raise ValueError if the quantization is static but calibration is not enabled
if self.quantization_config["is_static"] and self.auto_calibration is None:
if self.quantization_config["is_static"] and self.auto_calibration is None and not self.calibration:
raise ValueError(
"Quantization is static but auto calibration is not enabled. "
"Please enable auto calibration or disable static quantization."
"Quantization is static but calibration is not enabled. "
"Please enable calibration or disable static quantization."
)

if self.auto_quantization is not None:
self.auto_quantization_config = OmegaConf.to_object(
OmegaConf.merge(AUTO_QUANTIZATION_CONFIG, self.auto_quantization_config)
)
if self.auto_quantization_config["is_static"] and self.auto_calibration is None:
if self.auto_quantization_config["is_static"] and self.auto_calibration is None and not self.calibration:
raise ValueError(
"Quantization is static but auto calibration is not enabled. "
"Please enable auto calibration or disable static quantization."
"Quantization is static but calibration is not enabled. "
"Please enable calibration or disable static quantization."
)

if self.calibration:
self.calibration_config = OmegaConf.to_object(OmegaConf.merge(CALIBRATION_CONFIG, self.calibration_config))

if self.peft_strategy is not None:
if self.peft_strategy not in PEFT_CONFIGS:
raise ValueError(
Expand Down
14 changes: 13 additions & 1 deletion optimum_benchmark/backends/onnxruntime/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any, Dict

from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType
from onnxruntime.quantization import (
CalibrationMethod,
QuantFormat,
QuantizationMode,
QuantType,
)
from optimum.pipelines import ORT_SUPPORTED_TASKS

TASKS_TO_ORTSD = {
Expand All @@ -11,6 +16,13 @@
TASKS_TO_ORTMODELS = {task: task_dict["class"][0] for task, task_dict in ORT_SUPPORTED_TASKS.items()}


def format_calibration_config(calibration_config: Dict[str, Any]) -> None:
if calibration_config.get("method", None) is not None:
calibration_config["method"] = CalibrationMethod[calibration_config["method"]]

return calibration_config


def format_quantization_config(quantization_config: Dict[str, Any]) -> None:
"""Format the quantization dictionary for onnxruntime."""
# the conditionals are here because some quantization strategies don't have all the options
Expand Down

0 comments on commit e9ff880

Please sign in to comment.