diff --git a/numalogic/models/vae/__init__.py b/numalogic/models/vae/__init__.py index e69de29b..c6f64d4e 100644 --- a/numalogic/models/vae/__init__.py +++ b/numalogic/models/vae/__init__.py @@ -0,0 +1,3 @@ +from numalogic.models.vae.trainer import VAETrainer + +__all__ = ["VAETrainer"] diff --git a/numalogic/models/vae/variants/__init__.py b/numalogic/models/vae/variants/__init__.py index e69de29b..419e5838 100644 --- a/numalogic/models/vae/variants/__init__.py +++ b/numalogic/models/vae/variants/__init__.py @@ -0,0 +1,3 @@ +from numalogic.models.vae.variants.conv import Conv1dVAE + +__all__ = ["Conv1dVAE"] diff --git a/tests/models/vae/test_conv.py b/tests/models/vae/test_conv.py index d9c48064..f78648dc 100644 --- a/tests/models/vae/test_conv.py +++ b/tests/models/vae/test_conv.py @@ -9,8 +9,8 @@ from torch.utils.data import DataLoader from numalogic._constants import TESTS_DIR -from numalogic.models.vae.trainer import VAETrainer -from numalogic.models.vae.variants.conv import Conv1dVAE +from numalogic.models.vae import VAETrainer +from numalogic.models.vae.variants import Conv1dVAE from numalogic.tools.data import TimeseriesDataModule, StreamingDataset from numalogic.tools.exceptions import ModelInitializationError