Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to passthrough num_seqs in PostprocessingDataset #1677

Merged
merged 10 commits into from
Jan 19, 2025
38 changes: 34 additions & 4 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
map_seq: Optional[Callable] = None,
map_seq_stream: Optional[Callable] = None,
map_outputs: Optional[Dict[str, Any]] = None,
map_seq_stream_preserves_num_seqs: Optional[bool] = None,
**kwargs,
):
"""
Expand All @@ -111,6 +112,8 @@ def __init__(
To simplify the common case when no shapes change, this value can be left unspecified. The dataset then
assumes the same data layout as returned by the wrapped dataset.
Example: `map_outputs={"data": {"dim": 42}}`
:param map_seq_stream_preserves_num_seqs: whether the function in map_seq_stream preserves the number of
sequences, i.e. for every input sequence there is exactly one output sequence.
:param kwargs: see :class:`CachedDataset2`, :class:`Dataset`
"""
super().__init__(**kwargs)
Expand All @@ -121,19 +124,25 @@ def __init__(
raise ValueError(f"{self}: need to either set map_seq or map_seq_stream")
if map_seq and map_seq_stream:
raise ValueError(f"{self}: cannot set both map_seq and map_seq_stream")
if map_seq and map_seq_stream_preserves_num_seqs is not None:
raise ValueError(f"{self}: map_seq_stream_preserves_num_seqs is only allowed with map_seq_stream")

self._dataset_def = dataset
self._map_seq = map_seq
self._map_seq_stream = map_seq_stream
if map_seq_stream_preserves_num_seqs is None and map_seq_stream is not None:
map_seq_stream_preserves_num_seqs = getattr(map_seq_stream, "preserves_num_seqs", None)
self._map_seq_stream_preserves_num_seqs = map_seq_stream_preserves_num_seqs or False
self._map_outputs = map_outputs
self._rng = RandomState(self._get_random_seed_for_epoch(0))
self._seq_list_for_validation: Optional[List[str]] = None

self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
if self._map_seq_stream is None:
if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs:
# if the stream mapper is set, the num_seqs may change and the estimation is less accurate
self._estimated_num_seqs = self._dataset.estimated_num_seqs
self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None
self._data_iter_produced_num_seqs = 0

self._in_tensor_dict_template = TensorDict(
{name: self._make_tensor_template_from_input(name) for name in self._dataset.get_data_keys()}
Expand Down Expand Up @@ -187,10 +196,11 @@ def init_seq_order(
assert self._dataset is not None
self._dataset.init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)
self._data_iter = enumerate(self._build_mapping_iter())
self._data_iter_produced_num_seqs = 0
self._seq_list_for_validation = seq_list
if self._map_seq_stream is None:
# If we don't have an iterable mapper we know the number of segments exactly
# equals the number of segments in the wrapped dataset
if self._map_seq_stream is None or self._map_seq_stream_preserves_num_seqs:
# If we don't have an iterable mapper (or the user explicitly specifies this),
# we know the number of segments exactly equals the number of segments in the wrapped dataset
try:
self._num_seqs = self._dataset.num_seqs
except NotImplementedError:
Expand Down Expand Up @@ -221,7 +231,20 @@ def _collect_single_seq(self, seq_idx: int) -> Optional[DatasetSeq]:
while True:
try:
loaded_seq_idx, tensor_dict = next(self._data_iter)
self._data_iter_produced_num_seqs += 1
if self._num_seqs is not None:
assert self._data_iter_produced_num_seqs <= self._num_seqs, (
f"{self}: map_seq_stream yielded more seqs ({self._data_iter_produced_num_seqs}) "
f"than expected ({self._num_seqs}). map_seq_stream_preserves_num_seqs is set to "
f"{self._map_seq_stream_preserves_num_seqs}"
)
except StopIteration:
if self._num_seqs is not None:
assert self._data_iter_produced_num_seqs == self._num_seqs, (
f"{self}: map_seq_stream yielded {self._data_iter_produced_num_seqs} seqs, "
f"while {self._num_seqs} were expected. map_seq_stream_preserves_num_seqs is set to "
f"{self._map_seq_stream_preserves_num_seqs}"
)
return None
assert loaded_seq_idx <= seq_idx, "_collect_single_seq must be done monotonically"
if loaded_seq_idx != seq_idx:
Expand Down Expand Up @@ -324,6 +347,8 @@ class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
To be composed with any custom data postprocessing logic via :class:`Sequential`.
"""

preserves_num_seqs = True

def __init__(self, num_seqs_per_bin: int, length_key: str = "data"):
"""
:param num_seqs_per_bin: number of segments in a single laplace bin.
Expand Down Expand Up @@ -394,3 +419,8 @@ def __call__(self, arg: T, **kwargs) -> T:
for func in self.funcs:
arg = func(arg, **kwargs)
return arg

@property
def preserves_num_seqs(self):
""":return: whether the composed functions all preserve the number of sequences"""
return all(getattr(f, "preserves_num_seqs", False) for f in self.funcs)
Loading