diff --git a/returnn/compile.py b/returnn/compile.py index 4fe236a1..67d1763c 100644 --- a/returnn/compile.py +++ b/returnn/compile.py @@ -6,7 +6,8 @@ import logging import shutil import subprocess as sp -from typing import Any, Dict, Optional, Sequence, Union +import sys +from typing import Any, Dict, List, Optional, Sequence, Union import i6_core.util as util @@ -242,17 +243,6 @@ def __init__( self.returnn_config = returnn_config self.checkpoint = checkpoint - # Get the list here, because ReturnnConfig serialization might potentially reorder via `sort_config=True`. - input_names = ( - list(returnn_config.config["extern_data"].keys()) - if ("extern_data" in returnn_config.config and input_names is None) - else input_names - ) - output_names = ( - list(returnn_config.config["model_outputs"].keys()) - if ("model_outputs" in returnn_config.config and output_names is None) - else output_names - ) self.input_names = input_names self.output_names = output_names @@ -281,10 +271,43 @@ def run(self): "--verbosity", "5", ] - if self.input_names: - cmd += ["--input_names", ",".join(self.input_names)] - if self.output_names: - cmd += ["--output_names", ",".join(self.output_names)] + + # Pass the tensor names here because ReturnnConfig serialization might + # potentially reorder via `sort_config=True`, and the order matters for + # output tensor <-> tensor name association. + input_names = self.input_names + if input_names is None and "extern_data" in self.returnn_config.config: + input_names = self.collect_tensor_names( + str(self.returnn_root), + self.returnn_config.config["extern_data"], + ) + output_names = self.output_names + if output_names is None and "model_outputs" in self.returnn_config.config: + output_names = self.collect_tensor_names( + str(self.returnn_root), + self.returnn_config.config["model_outputs"], + ) + if input_names: + cmd += ["--input_names", ",".join(input_names)] + if output_names: + cmd += ["--output_names", ",".join(output_names)] util.create_executable("compile.sh", cmd) # convenience file for manual execution sp.run(cmd, check=True) + + @staticmethod + def collect_tensor_names(returnn_root: str, data_dict: Dict[str, Any]) -> List[str]: + if returnn_root not in sys.path: + sys.path.append(returnn_root) + + from returnn.tensor import Tensor + + names = [] + for name, opts in data_dict.items(): + names.append(name) + tensor = Tensor(name=name, **opts) + for i, dim in enumerate(tensor.dims): + # We need seq lengths if there is a dyn size which is not a scalar. + if dim.dyn_size_ext and dim.dyn_size_ext.dims: + names.append(f"{name}:size{i}") + return names