From 4b0d505ce4d60167acc566a9e5353b81a4772073 Mon Sep 17 00:00:00 2001 From: Author Name Date: Fri, 16 Apr 2021 17:21:26 +0200 Subject: [PATCH 1/6] add max_gm_samples param and subsample continuous columns before fitting GMs --- ctgan/data_transformer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index 9f373ac9..254feb74 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -19,7 +19,7 @@ class DataTransformer(object): Discrete columns are encoded using a scikit-learn OneHotEncoder. """ - def __init__(self, max_clusters=10, weight_threshold=0.005): + def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None): """Create a data transformer. Args: @@ -27,12 +27,19 @@ def __init__(self, max_clusters=10, weight_threshold=0.005): Maximum number of Gaussian distributions in Bayesian GMM. weight_threshold (float): Weight threshold for a Gaussian distribution to be kept. + _max_gm_samples (int): + Maximum number of sample to use during GMM fit """ self._max_clusters = max_clusters self._weight_threshold = weight_threshold + self._max_gm_samples = np.inf if max_gm_samples is None else max_gm_samples def _fit_continuous(self, column_name, raw_column_data): """Train Bayesian GMM for continuous column.""" + if self._max_gm_samples <= raw_column_data.shape[0]: + raw_column_data = np.random.choice(raw_column_data, + size=self._max_gm_samples, + replace=False) gm = BayesianGaussianMixture( self._max_clusters, weight_concentration_prior_type='dirichlet_process', From 4560e78f484c6e75c64d68f75b38f1625c3c515d Mon Sep 17 00:00:00 2001 From: Author Name Date: Fri, 16 Apr 2021 17:22:42 +0200 Subject: [PATCH 2/6] add data_transformer_params to have control over data_transformer --- ctgan/synthesizers/ctgan.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index d280e72f..02b0600b 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -267,7 +267,8 @@ def _validate_discrete_columns(self, train_data, discrete_columns): if invalid_columns: raise ValueError('Invalid columns found: {}'.format(invalid_columns)) - def fit(self, train_data, discrete_columns=tuple(), epochs=None): + def fit(self, train_data, discrete_columns=tuple(), epochs=None, + data_transformer_params={}): """Fit the CTGAN Synthesizer models to the training data. Args: @@ -278,6 +279,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): Vector. If ``train_data`` is a Numpy array, this list should contain the integer indices of the columns. Otherwise, if it is a ``pandas.DataFrame``, this list should contain the column names. + data_transformer_params (dict): + Dictionary of parameters for ``DataTransformer`` initialization. """ self._validate_discrete_columns(train_data, discrete_columns) @@ -290,7 +293,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): DeprecationWarning ) - self._transformer = DataTransformer() + self._transformer = DataTransformer(**data_transformer_params) self._transformer.fit(train_data, discrete_columns) train_data = self._transformer.transform(train_data) From 96a632131188e3019b6cf8240d75c2eac067551a Mon Sep 17 00:00:00 2001 From: Author Name Date: Fri, 16 Apr 2021 17:23:42 +0200 Subject: [PATCH 3/6] add test to check max_gm_samples --- tests/integration/test_ctgan.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/integration/test_ctgan.py b/tests/integration/test_ctgan.py index d84ffdcf..220b6021 100644 --- a/tests/integration/test_ctgan.py +++ b/tests/integration/test_ctgan.py @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions(): with pytest.raises(ValueError): ctgan.sample(1, 'discrete', "d") + + +def test_ctgan_data_transformer_params(): + data = pd.DataFrame({ + 'continuous': np.random.random(1000) + }) + + ctgan = CTGANSynthesizer(epochs=1) + ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100}) + + assert ctgan._transformer._max_gm_samples == 100 From 86695993bf72bc9bb2ac390c8c913a0f25c60974 Mon Sep 17 00:00:00 2001 From: Florent Rambaud Date: Thu, 22 Apr 2021 10:11:20 +0200 Subject: [PATCH 4/6] fix docs --- ctgan/data_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index 254feb74..e0fc2789 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -27,8 +27,8 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None) Maximum number of Gaussian distributions in Bayesian GMM. weight_threshold (float): Weight threshold for a Gaussian distribution to be kept. - _max_gm_samples (int): - Maximum number of sample to use during GMM fit + max_gm_samples (int): + Maximum number of samples to use during GMM fit. """ self._max_clusters = max_clusters self._weight_threshold = weight_threshold From 2a3222fc01d72e0a23a148244fe3f8f6ae249a3f Mon Sep 17 00:00:00 2001 From: Florent Rambaud Date: Mon, 26 Apr 2021 09:38:04 +0200 Subject: [PATCH 5/6] fix indentation --- ctgan/data_transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index e0fc2789..e4606ccc 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -37,9 +37,11 @@ def __init__(self, max_clusters=10, weight_threshold=0.005, max_gm_samples=None) def _fit_continuous(self, column_name, raw_column_data): """Train Bayesian GMM for continuous column.""" if self._max_gm_samples <= raw_column_data.shape[0]: - raw_column_data = np.random.choice(raw_column_data, - size=self._max_gm_samples, - replace=False) + raw_column_data = np.random.choice( + raw_column_data, + size=self._max_gm_samples, + replace=False + ) gm = BayesianGaussianMixture( self._max_clusters, weight_concentration_prior_type='dirichlet_process', From 1b401596ffaf5a2259f9f693b5e05be7b814b930 Mon Sep 17 00:00:00 2001 From: Florent Rambaud Date: Mon, 26 Apr 2021 09:40:11 +0200 Subject: [PATCH 6/6] move data_transformers args to init --- ctgan/synthesizers/ctgan.py | 14 ++++++++------ tests/integration/test_ctgan.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index 02b0600b..646bc79a 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -1,3 +1,4 @@ +import copy import warnings import numpy as np @@ -129,12 +130,15 @@ class CTGANSynthesizer(BaseSynthesizer): Whether to attempt to use cuda for GPU computation. If this is False or CUDA is not available, CPU will be used. Defaults to ``True``. + data_transformer_params (dict): + Dictionary of parameters for ``DataTransformer`` initialization. """ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, - log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True): + log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True, + data_transformer_params={}): assert batch_size % 2 == 0 @@ -163,6 +167,7 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._device = torch.device(device) + self._data_transformer_params = copy.deepcopy(data_transformer_params) self._transformer = None self._data_sampler = None self._generator = None @@ -267,8 +272,7 @@ def _validate_discrete_columns(self, train_data, discrete_columns): if invalid_columns: raise ValueError('Invalid columns found: {}'.format(invalid_columns)) - def fit(self, train_data, discrete_columns=tuple(), epochs=None, - data_transformer_params={}): + def fit(self, train_data, discrete_columns=tuple(), epochs=None): """Fit the CTGAN Synthesizer models to the training data. Args: @@ -279,8 +283,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None, Vector. If ``train_data`` is a Numpy array, this list should contain the integer indices of the columns. Otherwise, if it is a ``pandas.DataFrame``, this list should contain the column names. - data_transformer_params (dict): - Dictionary of parameters for ``DataTransformer`` initialization. """ self._validate_discrete_columns(train_data, discrete_columns) @@ -293,7 +295,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None, DeprecationWarning ) - self._transformer = DataTransformer(**data_transformer_params) + self._transformer = DataTransformer(**self._data_transformer_params) self._transformer.fit(train_data, discrete_columns) train_data = self._transformer.transform(train_data) diff --git a/tests/integration/test_ctgan.py b/tests/integration/test_ctgan.py index 220b6021..8180f762 100644 --- a/tests/integration/test_ctgan.py +++ b/tests/integration/test_ctgan.py @@ -191,7 +191,7 @@ def test_ctgan_data_transformer_params(): 'continuous': np.random.random(1000) }) - ctgan = CTGANSynthesizer(epochs=1) - ctgan.fit(data, [], data_transformer_params={'max_gm_samples': 100}) + ctgan = CTGANSynthesizer(epochs=1, data_transformer_params={'max_gm_samples': 100}) + ctgan.fit(data, []) assert ctgan._transformer._max_gm_samples == 100