diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 9dfcadfb..1cd951b2 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -100,7 +100,7 @@ def fit(self, X: Union[DataSource, pdDataFrame], anonymize (Optional[str]): (optional) fields to anonymize and the anonymization strategy condition_on: (Optional[List[str]]): (optional) list of features to condition upon """ - if self._is_initialized(): + if self._already_fitted(): raise AlreadyFittedError() _datatype = DataSourceType(datatype) if isinstance( @@ -383,6 +383,16 @@ def _is_initialized(self) -> bool: """ return self._model is not None + def _already_fitted(self) -> bool: + """Determine if a synthesizer is already fitted. + + Returns: + True if the synthesizer is instanciated + """ + + return self._is_initialized() and (self._model.status and self._model.status.training and + self._model.status.training.state is not [TrainingState.PREPARING]) + @staticmethod def _resolve_api_status(api_status: Dict) -> Status: """Determine the status of the Synthesizer.