From 96d2c10a590aaad20fa9ef564189176363f4ea20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Mon, 5 Feb 2024 17:39:40 +0000 Subject: [PATCH] fix(synthesizer): add already_fitted for synthesizer --- src/ydata/sdk/synthesizers/synthesizer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 9dfcadfb..1cd951b2 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -100,7 +100,7 @@ def fit(self, X: Union[DataSource, pdDataFrame], anonymize (Optional[str]): (optional) fields to anonymize and the anonymization strategy condition_on: (Optional[List[str]]): (optional) list of features to condition upon """ - if self._is_initialized(): + if self._already_fitted(): raise AlreadyFittedError() _datatype = DataSourceType(datatype) if isinstance( @@ -383,6 +383,16 @@ def _is_initialized(self) -> bool: """ return self._model is not None + def _already_fitted(self) -> bool: + """Determine if a synthesizer is already fitted. + + Returns: + True if the synthesizer is instanciated + """ + + return self._is_initialized() and (self._model.status and self._model.status.training and + self._model.status.training.state is not [TrainingState.PREPARING]) + @staticmethod def _resolve_api_status(api_status: Dict) -> Status: """Determine the status of the Synthesizer.