From 898f6479aee10e5b0604a6a5a00b8a3fa6359521 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 23 Nov 2023 09:51:09 -0800 Subject: [PATCH] use an assert to guide researchers --- audiolm_pytorch/soundstream.py | 11 +++++++---- audiolm_pytorch/version.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/audiolm_pytorch/soundstream.py b/audiolm_pytorch/soundstream.py index 4fe2d96..7b52b36 100644 --- a/audiolm_pytorch/soundstream.py +++ b/audiolm_pytorch/soundstream.py @@ -2,7 +2,7 @@ from pathlib import Path from functools import partial, wraps from itertools import cycle, zip_longest -from typing import Optional +from typing import Optional, List import torch from torch import nn, einsum @@ -455,7 +455,8 @@ def __init__( strides = (2, 4, 5, 8), channel_mults = (2, 4, 8, 16), codebook_dim = 512, - codebook_size = 4096, + codebook_size: Optional[int] = None, + finite_scalar_quantizer_levels: Optional[List[int]] = None, rq_num_quantizers = 8, rq_commitment_weight = 1., rq_ema_decay = 0.95, @@ -465,7 +466,6 @@ def __init__( rq_kwargs: dict = {}, use_lookup_free_quantizer = False, # proposed in https://arxiv.org/abs/2310.05737, adapted for residual quantization use_finite_scalar_quantizer = False, # proposed in https://arxiv.org/abs/2309.15505, adapted for residual quantization - finite_scalar_quantizer_levels = [8, 5, 5, 5], input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), stft_normalized = False, @@ -557,6 +557,7 @@ def __init__( self.use_finite_scalar_quantizer = use_finite_scalar_quantizer if use_lookup_free_quantizer: + assert exists(codebook_size) and not exists(finite_scalar_quantizer_levels), 'if use_finite_scalar_quantizer is set to False, `codebook_size` must be set (and not `finite_scalar_quantizer_levels`)' self.rq = GroupedResidualLFQ( dim = codebook_dim, @@ -571,6 +572,7 @@ def __init__( self.codebook_size = codebook_size elif use_finite_scalar_quantizer: + assert not exists(codebook_size) and exists(finite_scalar_quantizer_levels), 'if use_finite_scalar_quantizer is set to True, `finite_scalar_quantizer_levels` must be set (and not `codebook_size`). the effective codebook size is the cumulative product of all the FSQ levels' self.rq = GroupedResidualFSQ( dim = codebook_dim, @@ -583,8 +585,9 @@ def __init__( ) self.codebook_size = self.rq.codebook_size - else: + else: + assert exists(codebook_size) and not exists(finite_scalar_quantizer_levels), 'if use_finite_scalar_quantizer is set to False, `codebook_size` must be set (and not `finite_scalar_quantizer_levels`)' self.rq = GroupedResidualVQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index aa1a8c4..cfe6447 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.2' +__version__ = '1.8.3'