Skip to content

Commit

Permalink
augmentation/torchaudio: add Phone effect (mulaw, lpc10 codecs) (lhot…
Browse files Browse the repository at this point in the history
…se-speech#1348)

* augmentation/torchaudio: add Phone effect (mulaw, lpc10 codecs)

* restore_orig_sr option

---------

Co-authored-by: Piotr Żelasko <[email protected]>
  • Loading branch information
rouseabout and pzelasko authored Jul 18, 2024
1 parent c286f28 commit 18436e9
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 0 deletions.
31 changes: 31 additions & 0 deletions lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AudioTransform,
DereverbWPE,
LoudnessNormalization,
Narrowband,
Resample,
ReverbWithImpulseResponse,
Speed,
Expand Down Expand Up @@ -732,6 +733,36 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording":
transforms=transforms,
)

def narrowband(
self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True
) -> "Recording":
"""
Return a new ``Recording`` that will lazily apply narrowband effect while loading audio.
by affixing it with "_nb_{codec}".
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(
Narrowband(
codec=codec,
source_sampling_rate=self.sampling_rate,
restore_orig_sr=restore_orig_sr,
).to_dict()
)
new_num_samples = compute_num_samples(
self.duration,
self.sampling_rate if restore_orig_sr else 8000,
rounding=ROUND_HALF_UP,
)
return fastcopy(
self,
id=f"{self.id}_nb_{codec}" if affix_id else self.id,
num_samples=new_num_samples,
sampling_rate=self.sampling_rate if restore_orig_sr else 8000,
transforms=transforms,
)

def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recording":
"""
Return a new ``Recording`` that will lazily apply WPE dereverberation.
Expand Down
168 changes: 168 additions & 0 deletions lhotse/augmentation/torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,174 @@ def reverse_timestamps(
)


class Codec:
def __call__(self, samples: np.ndarray) -> np.ndarray:
"""
Apply encoder then decoder.
To be implemented in derived classes.
"""
raise NotImplementedError


class MuLawCodec(Codec):
def __init__(self):
import torchaudio

self.encoder = torchaudio.transforms.MuLawEncoding()
self.decoder = torchaudio.transforms.MuLawDecoding()

def __call__(self, samples):
return self.decoder(self.encoder(samples))


from ctypes import CDLL, POINTER, c_int, c_short, c_uint8, c_void_p

LPC10_FRAME_SAMPLES = 180
LPC10_FRAME_BYTES = 7


def libspandsp_api():
try:
api = CDLL("libspandsp.so")
except OSError as e:
raise RuntimeError(
"We cannot apply the narrowband transformation using the LPC10 codec as the SpanDSP library cannot be found. "
"To install use `apt-get install libspandsp-dev` or visit <https://github.com/freeswitch/spandsp>."
)

api.lpc10_encode_init.restype = c_void_p
api.lpc10_encode_init.argtypes = [c_void_p, c_int]

api.lpc10_encode.restype = c_int
api.lpc10_encode.argtypes = [c_void_p, POINTER(c_uint8), POINTER(c_short), c_int]

api.lpc10_encode_free.argtypes = [c_void_p]

api.lpc10_decode_init.restype = c_void_p
api.lpc10_decode_init.argtypes = [c_void_p, c_int]

api.lpc10_decode.restype = c_int
api.lpc10_decode.argtypes = [c_void_p, POINTER(c_short), POINTER(c_uint8), c_int]

api.lpc10_decode_free.argtypes = [c_void_p]

return api


class LPC10Codec(Codec):
def __init__(self):
self.api = libspandsp_api()
self.c_data = (c_uint8 * LPC10_FRAME_BYTES)()
self.c_samples = (c_short * LPC10_FRAME_SAMPLES)()

def __call__(self, samples):
encoder = self.api.lpc10_encode_init(None, 0)
decoder = self.api.lpc10_decode_init(None, 0)

frames = samples[0].split(LPC10_FRAME_SAMPLES)

idx = 0
out = torch.zeros([1, len(frames) * LPC10_FRAME_SAMPLES])

for frame in frames:

samples_int = (frame * 32768).to(torch.int16)

for i in range(0, samples_int.shape[0]):
self.c_samples[i] = samples_int[i]

for i in range(samples_int.shape[0], LPC10_FRAME_SAMPLES):
self.c_samples[i] = 0

assert (
self.api.lpc10_encode(
encoder, self.c_data, self.c_samples, len(self.c_samples)
)
== LPC10_FRAME_BYTES
)
assert (
self.api.lpc10_decode(
decoder, self.c_samples, self.c_data, LPC10_FRAME_BYTES
)
== LPC10_FRAME_SAMPLES
)

for i in range(0, LPC10_FRAME_SAMPLES):
out[0][idx] = self.c_samples[i]
idx = idx + 1

self.api.lpc10_encode_free(encoder)
self.api.lpc10_decode_free(decoder)

return out / 32768


CODECS = {
"lpc10": LPC10Codec,
"mulaw": MuLawCodec,
}


@dataclass
class Narrowband(AudioTransform):
"""
Narrowband effect.
Resample input audio to 8000 Hz, apply codec (encode then immediately decode), then (optionally) resample back to the original sampling rate.
"""

codec: str
source_sampling_rate: int
restore_orig_sr: bool

def __post_init__(self):
check_torchaudio_version()
import torchaudio

if self.codec in CODECS:
self.codec_instance = CODECS[self.codec]()
else:
raise ValueError(f"unsupported codec: {self.codec}")

def __call__(self, samples: np.ndarray, sampling_rate: int) -> np.ndarray:
import torchaudio

orig_size = samples.size

samples = torch.from_numpy(samples)

if self.source_sampling_rate != 8000:
resampler_down = get_or_create_resampler(self.source_sampling_rate, 8000)
samples = resampler_down(samples)

samples = self.codec_instance(samples)

if self.restore_orig_sr and self.source_sampling_rate != 8000:
resampler_up = get_or_create_resampler(8000, self.source_sampling_rate)
samples = resampler_up(samples)

samples = samples.numpy()

if self.restore_orig_sr and orig_size != samples.size:
samples = np.resize(samples, (1, orig_size))

return samples

def reverse_timestamps(
self,
offset: Seconds,
duration: Optional[Seconds],
sampling_rate: Optional[int],
) -> Tuple[Seconds, Optional[Seconds]]:
"""
This method just returnes the original offset and duration as the narrowband effect
doesn't change any these audio properies.
"""

return offset, duration


@dataclass
class Volume(AudioTransform):
"""
Expand Down
1 change: 1 addition & 0 deletions lhotse/cut/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class Cut:
perturb_speed: Callable
perturb_tempo: Callable
perturb_volume: Callable
phone: Callable
reverb_rir: Callable
map_supervisions: Callable
merge_supervisions: Callable
Expand Down
39 changes: 39 additions & 0 deletions lhotse/cut/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,45 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "DataCut":
supervisions=supervisions_vp,
)

def narrowband(
self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True
) -> "DataCut":
"""
Return a new ``DataCut`` that will lazily apply narrowband effect.
:param codec: Codec name.
:param restore_orig_sr: Restore original sampling rate.
:param affix_id: When true, we will modify the ``DataCut.id`` field
by affixing it with "_nb_{codec}".
:return: a modified copy of the current ``DataCut``.
"""
# Pre-conditions
assert (
self.has_recording
), "Cannot apply narrowband effect on a DataCut without Recording."
if self.has_features:
logging.warning(
"Attempting to apply narrowband effect on a DataCut that references pre-computed features. "
"The feature manifest will be detached, as we do not support feature-domain "
"volume perturbation."
)
self.features = None
# Actual audio perturbation.
recording_nb = self.recording.narrowband(
codec=codec, restore_orig_sr=restore_orig_sr, affix_id=affix_id
)
# Match the supervision's id (and it's underlying recording id).
supervisions_nb = [
s.narrowband(codec=codec, affix_id=affix_id) for s in self.supervisions
]

return fastcopy(
self,
id=f"{self.id}_nb_{codec}" if affix_id else self.id,
recording=recording_nb,
supervisions=supervisions_nb,
)

def normalize_loudness(
self, target: float, affix_id: bool = False, **kwargs
) -> "DataCut":
Expand Down
21 changes: 21 additions & 0 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,27 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "CutSet":
"""
return self.map(partial(_perturb_volume, factor=factor, affix_id=affix_id))

def narrowband(
self, codec: str, restore_orig_sr: bool = True, affix_id: bool = True
) -> "CutSet":
"""
Return a new :class:`~lhotse.cut.CutSet` that contains narrowband effect cuts.
It requires the recording manifests to be present.
If the feature manifests are attached, they are dropped.
The supervision manifests are remaining the same.
:param codec: Codec name.
:param restore_orig_sr: Restore original sampling rate.
:param affix_id: Should we modify the ID (useful if both versions of the same
cut are going to be present in a single manifest).
:return: a modified copy of the ``CutSet``.
"""
return self.map(
lambda cut: cut.narrowband(
codec=codec, restore_orig_sr=restore_orig_sr, affix_id=affix_id
)
)

def normalize_loudness(
self, target: float, mix_first: bool = True, affix_id: bool = True
) -> "CutSet":
Expand Down
18 changes: 18 additions & 0 deletions lhotse/supervision.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,24 @@ def perturb_volume(
else self.recording_id,
)

def narrowband(self, codec: str, affix_id: bool = True) -> "SupervisionSegment":
"""
Return a ``SupervisionSegment`` with modified ids.
:param codec: Codec name.
:param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields
by affixing it with "_nb_{codec}".
:return: a modified copy of the current ``SupervisionSegment``.
"""

return fastcopy(
self,
id=f"{self.id}_nb_{codec}" if affix_id else self.id,
recording_id=f"{self.recording_id}_nb_{codec}"
if affix_id
else self.recording_id,
)

def reverb_rir(
self, affix_id: bool = True, channel: Optional[Union[int, List[int]]] = None
) -> "SupervisionSegment":
Expand Down
9 changes: 9 additions & 0 deletions test/augmentation/test_torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lhotse import MonoCut, Recording, Seconds
from lhotse.augmentation import (
AudioTransform,
Narrowband,
Resample,
ReverbWithImpulseResponse,
Speed,
Expand Down Expand Up @@ -266,3 +267,11 @@ def test_augmentation_chain_randomized(
recording=recording_aug,
)
assert cut_aug.load_audio().shape[1] == cut_aug.num_samples


def test_narrowband(mono_audio):
narrowband = Narrowband(
codec="mulaw", source_sampling_rate=SAMPLING_RATE, restore_orig_sr=True
)
nb = narrowband(mono_audio, SAMPLING_RATE)
assert nb.shape == mono_audio.shape

0 comments on commit 18436e9

Please sign in to comment.