Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Aug 2, 2023
1 parent 775b117 commit 9a4d7e6
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions numalogic/models/vae/variants/conv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from collections.abc import Sequence, Callable
from typing import Final

Check warning on line 2 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L1-L2

Added lines #L1 - L2 were not covered by tests

import torch
from torch import nn, Tensor, optim
from torch.distributions import MultivariateNormal, kl_divergence
import torch.nn.functional as F
import pytorch_lightning as pl

Check warning on line 8 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L4-L8

Added lines #L4 - L8 were not covered by tests

from numalogic.tools.exceptions import ModelInitializationError

Check warning on line 10 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L10

Added line #L10 was not covered by tests


_DEFAULT_KERNEL_SIZE: Final[int] = 3
_DEFAULT_STRIDE: Final[int] = 2

Check warning on line 14 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L13-L14

Added lines #L13 - L14 were not covered by tests


class CausalConv1d(nn.Conv1d):

Check warning on line 17 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L17

Added line #L17 was not covered by tests
"""Temporal convolutional layer with causal padding."""
Expand Down Expand Up @@ -96,8 +103,8 @@ def __init__(
conv_layer = CausalConvBlock(

Check warning on line 103 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L103

Added line #L103 was not covered by tests
in_channels=n_features,
out_channels=conv_channels[0],
kernel_size=3,
stride=2,
kernel_size=_DEFAULT_KERNEL_SIZE,
stride=_DEFAULT_STRIDE,
dilation=1,
)
layers = self._construct_conv_layers(conv_channels)

Check warning on line 110 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L110

Added line #L110 was not covered by tests
Expand All @@ -114,18 +121,18 @@ def __init__(
@staticmethod
def _construct_conv_layers(conv_channels) -> nn.ModuleList:
layers = nn.ModuleList()
idx = 0
while idx < len(conv_channels) - 1:
layer_idx = 1

Check warning on line 124 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L121-L124

Added lines #L121 - L124 were not covered by tests
while layer_idx < len(conv_channels):
layers.append(

Check warning on line 126 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L126

Added line #L126 was not covered by tests
CausalConvBlock(
conv_channels[idx],
conv_channels[idx + 1],
kernel_size=3,
stride=2,
dilation=2,
conv_channels[layer_idx - 1],
conv_channels[layer_idx],
kernel_size=_DEFAULT_KERNEL_SIZE,
stride=_DEFAULT_STRIDE,
dilation=2**layer_idx,
)
)
idx += 1
layer_idx += 1
return layers

Check warning on line 136 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L135-L136

Added lines #L135 - L136 were not covered by tests

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -156,10 +163,10 @@ def __init__(self, seq_len: int, n_features: int, num_conv_filters: int, latent_
self.fc = nn.Linear(latent_dim, num_conv_filters * 6)
self.unflatten = nn.Unflatten(dim=1, unflattened_size=(num_conv_filters, 6))
self.conv_tr = nn.ConvTranspose1d(

Check warning on line 165 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L159-L165

Added lines #L159 - L165 were not covered by tests
in_channels=16,
in_channels=num_conv_filters,
out_channels=n_features,
kernel_size=3,
stride=2,
kernel_size=_DEFAULT_KERNEL_SIZE,
stride=_DEFAULT_STRIDE,
padding=1,
output_padding=1,
)
Expand Down Expand Up @@ -249,6 +256,15 @@ def __init__(
latent_dim=latent_dim,
)

# Do a dry run to initialize lazy modules
try:
self.forward(torch.rand(1, seq_len, n_features))
except (ValueError, RuntimeError) as err:
raise ModelInitializationError(

Check warning on line 263 in numalogic/models/vae/variants/conv.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/vae/variants/conv.py#L260-L263

Added lines #L260 - L263 were not covered by tests
"Model forward pass failed. "
"Please validate input arguments and the expected input shape "
) from err

def forward(self, x: Tensor) -> tuple[MultivariateNormal, Tensor]:
x = self.configure_shape(x)
z_mu, z_logvar = self.encoder(x)
Expand Down

0 comments on commit 9a4d7e6

Please sign in to comment.