Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASR,ST and CS recipies #1307

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
45 changes: 33 additions & 12 deletions lhotse/dataset/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
input_strategy: BatchIO = PrecomputedFeatures(),
lid: bool = False,
):
"""
k2 ASR IterableDataset constructor.
Expand All @@ -78,13 +79,15 @@
Examples: normalization, SpecAugment, etc.
:param input_strategy: Converts cuts into a collated batch of audio/features.
By default, reads pre-computed features from disk.
:param lid: adding lid information to the batch.
"""
super().__init__()
# Initialize the fields
self.return_cuts = return_cuts
self.cut_transforms = ifnone(cut_transforms, [])
self.input_transforms = ifnone(input_transforms, [])
self.input_strategy = input_strategy
self.lid = lid

# This attribute is a workaround to constantly growing HDF5 memory
# throughout the epoch. It regularly closes open file handles to
Expand Down Expand Up @@ -132,19 +135,37 @@
segments = torch.stack(list(supervision_intervals.values()), dim=1)
for tnfm in self.input_transforms:
inputs = tnfm(inputs, supervision_segments=segments)

batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can just add ”language”: supervision.language, in the line below and always return it to get rid of the extra option and code duplication.

}
for sequence_idx, cut in enumerate(cuts)
if self.lid == True:
batch = {

Check warning on line 139 in lhotse/dataset/speech_recognition.py

View check run for this annotation

Codecov / codecov/patch

lhotse/dataset/speech_recognition.py#L139

Added line #L139 was not covered by tests
"inputs": inputs,
"lids": [
supervision.language
for _, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
],
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
else:
batch = {
"inputs": inputs,
"supervisions": default_collate(
[
{
"text": supervision.text,
}
for sequence_idx, cut in enumerate(cuts)
for supervision in cut.supervisions
]
),
}
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
batch["supervisions"].update(supervision_intervals)
if self.return_cuts:
Expand Down
Loading