diff --git a/mimikit/networks/sample_rnn_v2.py b/mimikit/networks/sample_rnn_v2.py index b03f636..1d2ce77 100644 --- a/mimikit/networks/sample_rnn_v2.py +++ b/mimikit/networks/sample_rnn_v2.py @@ -202,7 +202,7 @@ def from_config(cls, config: "SampleRNN.Config") -> "SampleRNN": # only one input module supported spec_input_module = config.io_spec.inputs[0].module for i, fs in enumerate(config.frame_sizes[:-1]): - if isinstance(spec_input_module, FramedIO) and i == 0: # only the top-tier has no proj of the input + if i == 0: # the top-tier never has proj of the input input_module = FramedIO() \ .set(class_size=spec_input_module.class_size, frame_size=fs, hop_length=fs).module() in_dim = fs @@ -230,9 +230,16 @@ def from_config(cls, config: "SampleRNN.Config") -> "SampleRNN": else 1) )] - modules = [spec_input_module.copy() - .set(frame_size=config.frame_sizes[-1], - hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module()] + if isinstance(config.frame_sizes[-1], tuple): + # TODO: would be nice! needs support in batch_items, generate... + modules = [spec_input_module.copy() + .set(frame_size=fs, + hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module() + for fs in config.frame_sizes[-1]] + else: + modules = [spec_input_module.copy() + .set(frame_size=config.frame_sizes[-1], + hop_length=1, out_dim=h_dim, h_dim=config.embedding_dim).module()] input_module = ZipReduceVariables(mode=config.inputs_mode, modules=modules) tiers += [ SampleRNNTier(