Skip to content

Commit

Permalink
break out a function needed for audiolm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 2, 2022
1 parent e7ced71 commit 970d416
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vector_quantize_pytorch',
packages = find_packages(),
version = '0.10.2',
version = '0.10.3',
license='MIT',
description = 'Vector Quantization - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
14 changes: 9 additions & 5 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def codebooks(self):
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
return codebooks

def get_codes_from_indices(self, indices):
batch = indices.shape[0]
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])

all_codes = codebooks.gather(2, gather_indices) # gather all codes
return all_codes

def forward(
self,
x,
Expand All @@ -62,11 +70,7 @@ def forward(

if return_all_codes:
# whether to return all codes from all codebooks across layers

codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = x.shape[0])
gather_indices = repeat(all_indices, 'b n q -> q b n d', d = codebooks.shape[-1])

all_codes = codebooks.gather(2, gather_indices) # gather all codes
all_codes = self.get_codes_from_indices(all_indices)

# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
ret = (*ret, all_codes)
Expand Down

0 comments on commit 970d416

Please sign in to comment.