diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index bb7212f7..dd11ddfd 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -308,7 +308,7 @@ def process_quantization_config(self) -> None: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( - dict(**getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) + dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) ) @property @@ -404,9 +404,9 @@ def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: - assert ( - kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1 - ), "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" + assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, ( + "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" + ) return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode()