Skip to content

Commit

Permalink
add quantizer dropout, for soundstream training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 3, 2022
1 parent ab44bb9 commit b6b605d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.10.4',
version = '0.10.5',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
39 changes: 38 additions & 1 deletion vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from functools import partial
from random import randrange

import torch
from torch import nn
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize

from einops import rearrange, repeat

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

class ResidualVQ(nn.Module):
""" Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
def __init__(
Expand All @@ -14,13 +21,22 @@ def __init__(
num_quantizers,
shared_codebook = False,
heads = 1,
quantize_dropout = False,
quantize_dropout_cutoff_index = 0,
**kwargs
):
super().__init__()
assert heads == 1, 'residual vq is not compatible with multi-headed codes'

self.num_quantizers = num_quantizers

self.layers = nn.ModuleList([VectorQuantize(**kwargs) for _ in range(num_quantizers)])

self.quantize_dropout = quantize_dropout

assert quantize_dropout_cutoff_index >= 0
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index

if not shared_codebook:
return

Expand All @@ -42,21 +58,42 @@ def get_codes_from_indices(self, indices):
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])

# take care of quantizer dropout
mask = gather_indices == -1.
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later

all_codes = codebooks.gather(2, gather_indices) # gather all codes

# mask out any codes that were dropout-ed
all_codes = all_codes.masked_fill(mask, 0.)
return all_codes

def forward(
self,
x,
return_all_codes = False
):
b, n, *_, num_quant, device = *x.shape, self.num_quantizers, x.device

quantized_out = 0.
residual = x

all_losses = []
all_indices = []

for layer in self.layers:
if self.training and self.quantize_dropout:
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)

for quantizer_index, layer in enumerate(self.layers):

if self.training and quantizer_index > rand_quantize_dropout_index:
null_indices = torch.full((b, n), -1., device = device, dtype = torch.long)
null_loss = torch.full((b,), 0., device = device, dtype = x.dtype)

all_indices.append(null_indices)
all_losses.append(null_loss)
continue

quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
Expand Down

0 comments on commit b6b605d

Please sign in to comment.