diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index 7e4983a9..049b2c7b 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -45,41 +45,29 @@ def load(self) -> None: self.tmpdir.cleanup() - def _load_automodel_from_pretrained(self) -> None: - self.pretrained_model = self.automodel_loader.from_pretrained(self.config.model, **self.config.model_kwargs) - - def _load_automodel_with_no_weights(self) -> None: - original_model, self.config.model = self.config.model, self.no_weights_model - - with fast_weights_init(): - self._load_automodel_from_pretrained() - - self.logger.info("\t+ Tying model weights") - self.pretrained_model.tie_weights() - - self.config.model = original_model - def _load_ipexmodel_from_pretrained(self) -> None: self.pretrained_model = self.ipexmodel_class.from_pretrained( self.config.model, - export=self.config.export, **self.config.model_kwargs, - **self.automodel_kwargs, + **self.ipexmodel_kwargs, ) def _load_ipexmodel_with_no_weights(self) -> None: with fast_weights_init(): + self.logger.info("\t+ Loading no weights IPEXModel") original_model, self.config.model = self.config.model, self.no_weights_model original_export, self.config.export = self.config.export, True - self.logger.info("\t+ Loading no weights IPEXModel") self._load_ipexmodel_from_pretrained() self.config.export = original_export self.config.model = original_model @property - def automodel_kwargs(self) -> Dict[str, Any]: + def ipexmodel_kwargs(self) -> Dict[str, Any]: kwargs = {} + if self.config.export: + kwargs["export"] = self.config.export + if self.config.torch_dtype is not None: kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) @@ -89,7 +77,7 @@ def automodel_kwargs(self) -> Dict[str, Any]: def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def prepare_inputs_before_load(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/ipex/config.py b/optimum_benchmark/backends/ipex/config.py index 5ee4aad1..4fb553da 100644 --- a/optimum_benchmark/backends/ipex/config.py +++ b/optimum_benchmark/backends/ipex/config.py @@ -13,17 +13,17 @@ class IPEXConfig(BackendConfig): version: Optional[str] = ipex_version() _target_: str = "optimum_benchmark.backends.ipex.backend.IPEXBackend" - # load options no_weights: bool = False - torch_dtype: Optional[str] = None - # export options - export: bool = True + # ipexmodel kwargs + export: Optional[bool] = None + torch_dtype: Optional[str] = None def __post_init__(self): super().__post_init__() self.device = self.device.lower() + if self.device not in ["cpu", "gpu"]: raise ValueError(f"IPEXBackend only supports CPU devices, got {self.device}")