diff --git a/numalogic/models/vae/variants/conv.py b/numalogic/models/vae/variants/conv.py index 6aa99cfd..aebb1d70 100644 --- a/numalogic/models/vae/variants/conv.py +++ b/numalogic/models/vae/variants/conv.py @@ -1,4 +1,5 @@ from collections.abc import Sequence, Callable +from typing import Final import torch from torch import nn, Tensor, optim @@ -6,6 +7,12 @@ import torch.nn.functional as F import pytorch_lightning as pl +from numalogic.tools.exceptions import ModelInitializationError + + +_DEFAULT_KERNEL_SIZE: Final[int] = 3 +_DEFAULT_STRIDE: Final[int] = 2 + class CausalConv1d(nn.Conv1d): """Temporal convolutional layer with causal padding.""" @@ -96,8 +103,8 @@ def __init__( conv_layer = CausalConvBlock( in_channels=n_features, out_channels=conv_channels[0], - kernel_size=3, - stride=2, + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, dilation=1, ) layers = self._construct_conv_layers(conv_channels) @@ -114,18 +121,18 @@ def __init__( @staticmethod def _construct_conv_layers(conv_channels) -> nn.ModuleList: layers = nn.ModuleList() - idx = 0 - while idx < len(conv_channels) - 1: + layer_idx = 1 + while layer_idx < len(conv_channels): layers.append( CausalConvBlock( - conv_channels[idx], - conv_channels[idx + 1], - kernel_size=3, - stride=2, - dilation=2, + conv_channels[layer_idx - 1], + conv_channels[layer_idx], + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, + dilation=2**layer_idx, ) ) - idx += 1 + layer_idx += 1 return layers def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: @@ -156,10 +163,10 @@ def __init__(self, seq_len: int, n_features: int, num_conv_filters: int, latent_ self.fc = nn.Linear(latent_dim, num_conv_filters * 6) self.unflatten = nn.Unflatten(dim=1, unflattened_size=(num_conv_filters, 6)) self.conv_tr = nn.ConvTranspose1d( - in_channels=16, + in_channels=num_conv_filters, out_channels=n_features, - kernel_size=3, - stride=2, + kernel_size=_DEFAULT_KERNEL_SIZE, + stride=_DEFAULT_STRIDE, padding=1, output_padding=1, ) @@ -249,6 +256,15 @@ def __init__( latent_dim=latent_dim, ) + # Do a dry run to initialize lazy modules + try: + self.forward(torch.rand(1, seq_len, n_features)) + except (ValueError, RuntimeError) as err: + raise ModelInitializationError( + "Model forward pass failed. " + "Please validate input arguments and the expected input shape " + ) from err + def forward(self, x: Tensor) -> tuple[MultivariateNormal, Tensor]: x = self.configure_shape(x) z_mu, z_logvar = self.encoder(x)