Skip to content

Commit

Permalink
Fix MixedCut transforms serialization (lhotse-speech#1370)
Browse files Browse the repository at this point in the history
* Fix MixedCut transforms serialization

* fix
  • Loading branch information
pzelasko authored and Your Name committed Jan 7, 2025
1 parent 83aef17 commit 769daeb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
35 changes: 31 additions & 4 deletions lhotse/cut/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def from_dict(data: dict):
cut_dict["type"] = data.pop("type")
return MixTrack(deserialize_cut(cut_dict), **data)

def to_dict(self) -> Dict:
ans = {
"cut": self.cut.to_dict(),
"type": self.type,
"offset": self.offset,
}
if self.snr is not None:
ans["snr"] = self.snr
return ans


@dataclass
class MixedCut(Cut):
Expand Down Expand Up @@ -125,7 +135,7 @@ class MixedCut(Cut):

id: str
tracks: List[MixTrack]
transforms: Optional[List[Dict]] = None
transforms: Optional[List[AudioTransform]] = None

@property
def supervisions(self) -> List[SupervisionSegment]:
Expand Down Expand Up @@ -207,6 +217,16 @@ def num_channels(self) -> Optional[int]:
def features_type(self) -> Optional[str]:
return self._first_non_padding_cut.features.type if self.has_features else None

def to_dict(self) -> dict:
ans = {
"id": self.id,
"tracks": [t.to_dict() for t in self.tracks],
"type": type(self).__name__,
}
if self.transforms:
ans["transforms"] = [t.to_dict() for t in self.transforms]
return ans

def iter_data(
self,
) -> Generator[
Expand Down Expand Up @@ -793,7 +813,7 @@ def normalize_loudness(

if mix_first:
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(LoudnessNormalization(target=target).to_dict())
transforms.append(LoudnessNormalization(target=target))
return fastcopy(
self,
id=f"{self.id}_ln{target}" if affix_id else self.id,
Expand Down Expand Up @@ -908,7 +928,7 @@ def reverb_rir(
early_only=early_only,
rir_channels=rir_channels if rir_channels is not None else [0],
rir_generator=rir_generator,
).to_dict()
)
)
return fastcopy(
self,
Expand Down Expand Up @@ -1133,7 +1153,10 @@ def load_audio(

# We'll apply the transforms now (if any).
transforms = [
AudioTransform.from_dict(params) for params in self.transforms or []
tnfm
if isinstance(tnfm, AudioTransform)
else AudioTransform.from_dict(tnfm)
for tnfm in self.transforms or []
]
for tfn in transforms:
audio = tfn(audio, self.sampling_rate)
Expand Down Expand Up @@ -1551,9 +1574,13 @@ def filter_supervisions(
def from_dict(data: dict) -> "MixedCut":
if "type" in data:
data.pop("type")
transforms = None
if "transforms" in data:
transforms = [AudioTransform.from_dict(t) for t in data["transforms"]]
return MixedCut(
id=data["id"],
tracks=[MixTrack.from_dict(track) for track in data["tracks"]],
transforms=transforms,
)

def with_features_path_prefix(self, path: Pathlike) -> "MixedCut":
Expand Down
41 changes: 40 additions & 1 deletion test/cut/test_cut_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from lhotse import AudioSource, CutSet, MonoCut, Recording, SupervisionSegment
from lhotse.audio import RecordingSet
from lhotse.cut import PaddingCut
from lhotse.cut import Cut, MixedCut, PaddingCut
from lhotse.testing.dummies import dummy_cut, dummy_multi_cut
from lhotse.utils import fastcopy, is_module_available, nullcontext

Expand Down Expand Up @@ -424,6 +424,27 @@ def test_mixed_cut_start01_reverb_rir_mix_first(cut_with_supervision_start01, ri
)


def test_mixed_cut_start01_reverb_rir_mix_first_deserialized(
cut_with_supervision_start01, rir
):
mixed_rvb_orig = cut_with_supervision_start01.pad(duration=0.5).reverb_rir(
rir_recording=rir, mix_first=True
)
mixed_rvb = MixedCut.from_dict(mixed_rvb_orig.to_dict())
assert mixed_rvb.start == 0 # MixedCut always starts at 0
assert mixed_rvb.duration == 0.5
assert mixed_rvb.end == 0.5
assert mixed_rvb.num_samples == 4000

# Check that the padding part should not be all zeros afte
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_almost_equal,
mixed_rvb.load_audio()[:, 3200:],
np.zeros((1, 800)),
)


def test_mixed_cut_start01_reverb_rir_with_fast_random(
cut_with_supervision_start01, rir
):
Expand Down Expand Up @@ -498,6 +519,24 @@ def test_mixed_cut_normalize_loudness(cut_with_supervision_start01, target, mix_
assert loudness == pytest.approx(target, abs=0.5)


@pytest.mark.skipif(
not is_module_available("pyloudnorm"),
reason="This test requires pyloudnorm to be installed.",
)
def test_mixed_cut_normalize_loudness_deserialized(cut_with_supervision_start01):
target = -15.0
mixed_cut = cut_with_supervision_start01.append(cut_with_supervision_start01)
mixed_cut_ln_orig = mixed_cut.normalize_loudness(target, mix_first=True)
mixed_cut_ln = MixedCut.from_dict(mixed_cut_ln_orig.to_dict())

import pyloudnorm as pyln

# check if loudness is correct
meter = pyln.Meter(mixed_cut_ln.sampling_rate) # create BS.1770 meter
loudness = meter.integrated_loudness(mixed_cut_ln.load_audio().T)
assert loudness == pytest.approx(target, abs=0.5)


@pytest.mark.skipif(
not is_module_available("nara_wpe"),
reason="This test requires nara_wpe to be installed.",
Expand Down

0 comments on commit 769daeb

Please sign in to comment.