diff --git a/setup.py b/setup.py index 9ac1bbd..103efed 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.2.0', + version = '0.3.0', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 67c7928..4538344 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -6,6 +6,7 @@ from pathlib import Path import torch +from torch.cuda.amp import autocast from torch import Tensor, nn, einsum import torch.nn.functional as F @@ -124,6 +125,7 @@ def __init__(self, dim, theta = 10000): def device(self): return next(self.buffers()).device + @autocast(enabled = False) def forward(self, seq_len): t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) @@ -134,6 +136,7 @@ def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) +@autocast(enabled = False) def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) @@ -216,9 +219,18 @@ def __init__(self, chan_in, chan_out, kernel_size, padding): self.padding = padding self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) - def forward(self, x): + def forward(self, x, mask = None): + if exists(mask): + mask = mask[..., None] + x = x.masked_fill(~mask, 0.) + x = F.pad(x, self.padding) - return self.conv(x) + out = self.conv(x) + + if exists(mask): + out = out.masked_fill(~mask, 0.) + + return out # attention, feedforward, and conv module @@ -333,12 +345,16 @@ def __init__( inner_dim = dim * expansion_factor padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) - self.net = nn.Sequential( + self.net1 = nn.Sequential( nn.LayerNorm(dim), Rearrange('b n c -> b c n'), nn.Conv1d(dim, inner_dim * 2, 1), - GLU(dim=1), - DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), + GLU(dim=1) + ) + + self.ds_conv = DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding) + + self.net2 = nn.Sequential( Swish(), ChanLayerNorm(inner_dim), nn.Conv1d(inner_dim, dim, 1), @@ -346,8 +362,10 @@ def __init__( nn.Dropout(dropout) ) - def forward(self, x): - return self.net(x) + def forward(self, x, mask = None): + x = self.net1(x) + x = self.ds_conv(x, mask = mask) + return self.net2(x) # Conformer Block @@ -388,7 +406,7 @@ def forward( ): x = self.ff1(x) + x x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x - x = self.conv(x) + x + x = self.conv(x, mask = mask) + x x = self.ff2(x) + x x = self.post_norm(x) return x