Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC for reading cuts in background thread in dynamic bucketing #680

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
20 changes: 17 additions & 3 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import concurrent.futures
import random
import warnings
from bisect import bisect_right
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any, Deque, Dict, Generator, Iterable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -334,6 +336,9 @@ def __init__(
deque() for _ in range(len(duration_bins) + 1)
]

self._cut_reading_thread = ThreadPoolExecutor(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to not use a process pool? Due to the global interpreter lock, there can be only one running thread at any given time in Python, I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with some setups that use IterableDatasetWrapper you are placing the sampler in a dataloader worker process, and AFAIK you can't spawn a nested process pool there because that process is daemonic.

Anyway thread should be sufficient here as I expect the CPU to be mostly idle when running forward and backward passes on GPUs... The reason it didn't work for you is likely the thread could not populate the buckets fast enough and sampler thought they are depleted (race condition). This can be solved with a proper synchronization mechanism but unfortunately I don't have the time to add it right now. I'll return to it sometime.

self._cut_reading_future: Optional[concurrent.futures.Future] = None

def __iter__(self) -> Generator[CutSet, None, None]:
# Init: sample `buffer_size` cuts and assign them to the right buckets.
self.cuts_iter = iter(self.cuts)
Expand All @@ -356,6 +361,7 @@ def is_ready(bucket: Deque[Cut]):
# On each step we're sampling a new batch.
try:
while True:
self._wait_for_cut_collection()
ready_buckets = [b for b in self.buckets if is_ready(b)]
if not ready_buckets:
# No bucket has enough data to yield for the last full batch.
Expand Down Expand Up @@ -394,13 +400,21 @@ def is_ready(bucket: Deque[Cut]):
self.cuts_iter = None

def _collect_cuts_in_buckets(self, n_cuts: int):
try:
def collect():
for _ in range(n_cuts):
cuts = next(self.cuts_iter)
duration = (
cuts[0].duration if isinstance(cuts, tuple) else cuts.duration
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].append(cuts)
except StopIteration:
pass

assert self._cut_reading_future is None
self._cut_reading_future = self._cut_reading_thread.submit(collect)

def _wait_for_cut_collection(self):
assert self._cut_reading_future is not None
err = self._cut_reading_future.exception()
if err is not None and not isinstance(err, StopIteration):
raise err
self._cut_reading_future = None