Skip to content

Commit

Permalink
Refactor bucket selection for customization (#1377)
Browse files Browse the repository at this point in the history
* Refactor bucket selection to allow customization

* Extend the API further

* Prune imports
  • Loading branch information
pzelasko authored Jul 24, 2024
1 parent bd12d5d commit 21b102c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
21 changes: 20 additions & 1 deletion lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import warnings
from abc import ABCMeta, abstractmethod
from bisect import bisect_right
from copy import deepcopy
from dataclasses import asdict, dataclass
from math import isclose
Expand All @@ -15,7 +16,7 @@
from lhotse.cut.text import TextExample
from lhotse.lazy import Dillable
from lhotse.manipulation import combine
from lhotse.utils import Seconds, ifnone, is_none_or_gt
from lhotse.utils import Seconds, exactly_one_not_null, ifnone, is_none_or_gt


class CutSampler(Sampler, Dillable):
Expand Down Expand Up @@ -407,6 +408,24 @@ def measure_length(self, example: Any) -> float:
"""
pass

def select_bucket(
self, buckets: Any, example: Any = None, example_len: Any = None
) -> int:
"""
Given a list of buckets and an example, assign the example to the correct bucket.
This is leveraged by bucketing samplers.
Default implementation assumes that buckets are expressed in the same units as
the output of :meth:`SamplingConstraint.measure_length` and returns the index
of the first bucket that has a larger length than the example.
"""
assert exactly_one_not_null(
example, example_len
), f"select_bucket requires either example= or example_len= as the input (we received {example=} and {example_len=})."
if example_len is None:
example_len = self.measure_length(example)
return bisect_right(buckets, example_len)

def copy(self) -> "SamplingConstraint":
"""Return a shallow copy of this constraint."""
return copy.copy(self)
Expand Down
18 changes: 9 additions & 9 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import threading
import time
import warnings
from bisect import bisect_right
from collections import deque
from dataclasses import asdict, dataclass
from itertools import islice
from queue import Queue
Expand Down Expand Up @@ -350,7 +348,9 @@ def add(self, example: Cut) -> None:
selecting the right property from the input ``cut`` object.
"""
seqlen = self.measure_length(example)
bucket_idx = bisect_right(self.max_seq_len_buckets, seqlen)
bucket_idx = self.select_bucket(
buckets=self.max_seq_len_buckets, example_len=seqlen
)
assert bucket_idx < len(self.max_seq_len_buckets), (
f"Received example with sequence length {seqlen} that exceeds "
f"the highest allowed length {self.max_seq_len_buckets[-1]}."
Expand Down Expand Up @@ -742,10 +742,10 @@ def producer():
time.sleep(0.1)
continue
cuts = next(self.cuts_iter)
duration = self.constraint.measure_length(
cuts[0] if isinstance(cuts, tuple) else cuts
bucket_idx = self.constraint.select_bucket(
buckets=self.duration_bins,
example=cuts[0] if isinstance(cuts, tuple) else cuts,
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].put(cuts)
except StopIteration:
self._source_exhausted = True
Expand All @@ -766,10 +766,10 @@ def _collect_cuts_in_buckets(self, n_cuts: int) -> None:
try:
for _ in range(n_cuts):
cuts = next(self.cuts_iter)
duration = self.constraint.measure_length(
cuts[0] if isinstance(cuts, tuple) else cuts
bucket_idx = self.constraint.select_bucket(
buckets=self.duration_bins,
example=cuts[0] if isinstance(cuts, tuple) else cuts,
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].put(cuts)
except StopIteration:
pass
Expand Down

0 comments on commit 21b102c

Please sign in to comment.