Skip to content

Commit

Permalink
Merge pull request #150 from scikit-learn-contrib/chp_add_rand_state_…
Browse files Browse the repository at this point in the history
…ddpm

Chp add rand state ddpm
  • Loading branch information
JulienRoussel77 authored Jun 13, 2024
2 parents 47565ff + bb53bae commit e3ae559
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 20 deletions.
26 changes: 26 additions & 0 deletions examples/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,32 @@ except ModuleNotFoundError:
For the example, we use a simple MLP model with 3 layers of neurons.
Then we train the model without taking a group on the stations

```python
import numpy as np
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
from qolmat.imputations.diffusions.ddpms import TabDDPM

X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)

imputer.fit_transform(X)
```

```python
import numpy as np
from qolmat.imputations.imputers_pytorch import ImputerDiffusion
from qolmat.imputations.diffusions.ddpms import TabDDPM

X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)

imputer.fit_transform(X)
```

```python
1.33573675, 1.40472937
```

```python
fig = plt.figure(figsize=(10 * n_stations, 3 * n_cols))
for i_station, (station, df) in enumerate(df_data.groupby("station")):
Expand Down
Empty file added qolmat/analysis/__init__.py
Empty file.
5 changes: 3 additions & 2 deletions qolmat/analysis/holes_characterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class LittleTest(McarTest):
imputer : Optional[ImputerEM]
Imputer based on the EM algorithm. The 'model' attribute must be equal to 'multinormal'.
If None, the default ImputerEM is taken.
random_state : Union[None, int, np.random.RandomState], optional
Controls the randomness of the fit_transform, by default None
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""

def __init__(
Expand Down
35 changes: 21 additions & 14 deletions qolmat/benchmark/missing_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class _HoleGenerator:
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float]
Ratio of values ​​to mask, by default 0.05.
random_state : Optional[int]
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -150,8 +151,9 @@ class UniformHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
sample_proportional: bool, optional
If True, generates holes in target columns with same equal frequency.
If False, reproduces the empirical proportions between the variables.
Expand Down Expand Up @@ -215,8 +217,9 @@ class _SamplerHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -321,8 +324,9 @@ class GeometricHoleGenerator(_SamplerHoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Union[None, int, np.random.RandomState], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -390,8 +394,9 @@ class EmpiricalHoleGenerator(_SamplerHoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -485,8 +490,9 @@ class MultiMarkovHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values to add, by default 0.05
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -634,8 +640,9 @@ class GroupedHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked to add, by default 0.05
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups : Tuple[str, ...]
Names of the columns forming the groups, by default []
"""
Expand Down
20 changes: 16 additions & 4 deletions qolmat/imputations/diffusions/ddpms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Dict, List, Callable, Tuple
from typing import Dict, List, Callable, Tuple, Union
from typing_extensions import Self
import math
import sys
import numpy as np
import pandas as pd
import time
from datetime import timedelta
from tqdm import tqdm
import gc

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn import preprocessing
from sklearn import utils as sku


from qolmat.imputations.diffusions.base import AutoEncoder, ResidualBlock, ResidualBlockTS
from qolmat.imputations.diffusions.utils import get_num_params
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(
p_dropout: float = 0.0,
num_sampling: int = 1,
is_clip: bool = True,
random_state: Union[None, int, np.random.RandomState] = None,
):
"""Diffusion model for tabular data based on
Denoising Diffusion Probabilistic Models (DDPM) of
Expand Down Expand Up @@ -68,6 +70,9 @@ def __init__(
Dropout probability, by default 0.0
num_sampling : int, optional
Number of samples generated for each cell, by default 1
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand Down Expand Up @@ -108,6 +113,9 @@ def __init__(
self.is_clip = is_clip

self.normalizer_x = preprocessing.StandardScaler()
self.random_state = sku.check_random_state(random_state)
seed_torch = self.random_state.randint(2**31 - 1)
torch.manual_seed(seed_torch)

def _q_sample(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Section 3.2, algorithm 1 formula implementation. Forward process, defined by `q`.
Expand Down Expand Up @@ -345,7 +353,6 @@ def fit(
round: int = 10,
cols_imputed: Tuple[str, ...] = (),
) -> Self:

"""Fit data
Parameters
Expand Down Expand Up @@ -537,6 +544,7 @@ def __init__(
p_dropout: float = 0.0,
num_sampling: int = 1,
is_rolling: bool = False,
random_state: Union[None, int, np.random.RandomState] = None,
):
"""Diffusion model for time-series data based on the works of
Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
Expand Down Expand Up @@ -575,6 +583,9 @@ def __init__(
Number of samples generated for each cell, by default 1
is_rolling : bool, optional
Use pandas.DataFrame.rolling for preprocessing data, by default False
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""
super().__init__(
num_noise_steps,
Expand All @@ -586,6 +597,7 @@ def __init__(
num_blocks,
p_dropout,
num_sampling,
random_state=random_state,
)

self.dim_feedforward = dim_feedforward
Expand Down
11 changes: 11 additions & 0 deletions qolmat/imputations/imputers_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,17 @@ def __init__(
freq_str : str
Frequency string of DateOffset of Pandas.
It is for processing time-series data, used in diffusion models e.g., TsDDPM.
Examples
--------
>>> import numpy as np
>>> from qolmat.imputations.imputers_pytorch import ImputerDiffusion
>>> from qolmat.imputations.diffusions.ddpms import TabDDPM
>>>
>>> X = np.array([[1, 1, 1, 1], [np.nan, np.nan, 3, 2], [1, 2, 2, 1], [2, 2, 2, 2]])
>>> imputer = ImputerDiffusion(model=TabDDPM(random_state=11), epochs=50, batch_size=1)
>>>
>>> df_imputed = imputer.fit_transform(X)
"""
super().__init__(groups=groups, columnwise=False)
self.model = model
Expand Down

0 comments on commit e3ae559

Please sign in to comment.