diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 51070d8b..07e45980 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -308,7 +308,7 @@ def process_quantization_config(self) -> None: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( - getattr(self.pretrained_config, "quantization_config", {}).update(self.config.quantization_config) + (getattr(self.pretrained_config, "quantization_config") or {}).update(self.config.quantization_config) ) @property @@ -404,9 +404,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, ( - "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" - ) + assert ( + kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1 + ), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode()