diff --git a/src/ydata/sdk/common/client/client.py b/src/ydata/sdk/common/client/client.py index 3040c0c3..c00c6773 100644 --- a/src/ydata/sdk/common/client/client.py +++ b/src/ydata/sdk/common/client/client.py @@ -47,6 +47,8 @@ class Client(metaclass=SingletonClient): codes = codes + DEFAULT_PROJECT: Optional[Project] = environ.get("DEFAULT_PROJECT", None) + def __init__(self, credentials: Optional[Union[str, Dict]] = None, project: Optional[Project] = None, set_as_global: bool = False): self._base_url = environ.get("YDATA_BASE_URL", DEFAULT_URL) self._scheme = 'https' @@ -56,10 +58,18 @@ def __init__(self, credentials: Optional[Union[str, Dict]] = None, project: Opti self._handshake() - self._default_project = project or self._get_default_project(credentials) + self._default_project = project or Client.DEFAULT_PROJECT or self._get_default_project(credentials) if set_as_global: self.__set_global() + @property + def project(self) -> Project: + return Client.DEFAULT_PROJECT or self._default_project + + @project.setter + def project(self, value: Project): + self._default_project = value + def post( self, endpoint: str, data: Optional[Dict] = None, json: Optional[Dict] = None, project: Optional[Project] = None, files: Optional[Dict] = None, raise_for_status: bool = True diff --git a/src/ydata/sdk/connectors/connector.py b/src/ydata/sdk/connectors/connector.py index 85dc720e..30a3e22f 100644 --- a/src/ydata/sdk/connectors/connector.py +++ b/src/ydata/sdk/connectors/connector.py @@ -55,6 +55,10 @@ def uid(self) -> UID: def type(self) -> ConnectorType: return self._model.type + @property + def project(self) -> Project: + return self._project or self._client.project + @staticmethod @init_client def get(uid: UID, project: Optional[Project] = None, client: Optional[Client] = None) -> "Connector": diff --git a/src/ydata/sdk/datasources/datasource.py b/src/ydata/sdk/datasources/datasource.py index 80743c15..cb1f98f9 100644 --- a/src/ydata/sdk/datasources/datasource.py +++ b/src/ydata/sdk/datasources/datasource.py @@ -67,6 +67,10 @@ def uid(self) -> UID: def datatype(self) -> DataSourceType: return self._model.datatype + @property + def project(self) -> Project: + return self._project or self._client.project + @property def status(self) -> Status: try: diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index faff0a75..e2c35a7e 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -28,7 +28,8 @@ class MultiTableSynthesizer(BaseSynthesizer): The synthesizer instance is created in the backend only when the `fit` method is called. Arguments: - write_connector (UID): Connector of type RDBMS to be used to write the samples + write_connector (UID | Connector): Connector of type RDBMS to be used to write the samples + uid (UID): (optional) UID to identify this synthesizer name (str): (optional) Name to be used when creating the synthesizer. Calculated internally if not provided client (Client): (optional) Client to connect to the backend """ @@ -126,3 +127,4 @@ def _check_or_fetch_connector(self, write_connector: Union[Connector, UID]) -> C f"Invalid type `{write_connector.type}` for the provided connector") return write_connector + diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 604c3211..9dfcadfb 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -61,6 +61,10 @@ def _init_common(self, client: Optional[Client] = None): self._client = client self._logger = create_logger(__name__, level=LOG_LEVEL) + @property + def project(self) -> Project: + return self._project or self._client.project + def fit(self, X: Union[DataSource, pdDataFrame], privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY, datatype: Optional[Union[DataSourceType, str]] = None,