From 487465f67b21cc34a66fcd1ef516120cfb2b98e5 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 20 Jan 2024 03:29:31 -0800 Subject: [PATCH 01/15] Borrow some rizz from the future. --- streaming/base/coord/filesystem/__init__.py | 9 + streaming/base/coord/filesystem/waiting.py | 71 ++++++ streaming/base/coord/job/__init__.py | 9 + streaming/base/coord/job/dir.py | 49 ++++ streaming/base/coord/job/entry.py | 65 +++++ streaming/base/coord/job/file.py | 130 ++++++++++ streaming/base/coord/job/registry.py | 257 ++++++++++++++++++++ streaming/base/coord/waiting.py | 72 ++++++ 8 files changed, 662 insertions(+) create mode 100644 streaming/base/coord/filesystem/__init__.py create mode 100644 streaming/base/coord/filesystem/waiting.py create mode 100644 streaming/base/coord/job/__init__.py create mode 100644 streaming/base/coord/job/dir.py create mode 100644 streaming/base/coord/job/entry.py create mode 100644 streaming/base/coord/job/file.py create mode 100644 streaming/base/coord/job/registry.py create mode 100644 streaming/base/coord/waiting.py diff --git a/streaming/base/coord/filesystem/__init__.py b/streaming/base/coord/filesystem/__init__.py new file mode 100644 index 000000000..5febe7967 --- /dev/null +++ b/streaming/base/coord/filesystem/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Coordinating using files.""" + +from streaming.base.coord.filesystem.waiting import (create_file, wait_for_creation, + wait_for_deletion) + +__all__ = ['create_file', 'wait_for_creation', 'wait_for_deletion'] diff --git a/streaming/base/coord/filesystem/waiting.py b/streaming/base/coord/filesystem/waiting.py new file mode 100644 index 000000000..daf4a8996 --- /dev/null +++ b/streaming/base/coord/filesystem/waiting.py @@ -0,0 +1,71 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Waiting on files.""" + +import os +from typing import Any, Optional + +from streaming.base.coord.waiting import wait + +__all__ = ['wait_for_creation', 'wait_for_deletion', 'create_file'] + + +def wait_for_creation( + path: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the creation of a path on the local filesystem. + + Args: + path (str): Local path to wait on the creation of. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + + def stop(): + return os.path.exists(path) + + wait(stop, timeout, tick, lock) + + +def wait_for_deletion( + path: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the deletion of a path on the local filesystem. + + Args: + path (str): Local path to wait on the deletion of. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + + def stop(): + return not os.path.exists(path) + + wait(stop, timeout, tick, lock) + + +def create_file(filename: str) -> None: + """Create a file at the given path on the local filesystem. + + Raises an exception if the path already exists. + + Args: + filename (str): Filename to create. + """ + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + with open(filename, 'x'): + pass diff --git a/streaming/base/coord/job/__init__.py b/streaming/base/coord/job/__init__.py new file mode 100644 index 000000000..69207a1db --- /dev/null +++ b/streaming/base/coord/job/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Handling for jobs, which are collections of StreamingDataset replicas with the same config.""" + +from streaming.base.coord.job.dir import JobDir +from streaming.base.coord.job.registry import JobRegistry + +__all__ = ['JobDir', 'JobRegistry'] diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py new file mode 100644 index 000000000..8e8a1dc2b --- /dev/null +++ b/streaming/base/coord/job/dir.py @@ -0,0 +1,49 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A directory containing all dataset-wide filesystem state for a Streaming job.""" + +import os +from typing import Sequence + +from streaming.base.coord.job.registry import JobRegistry +from streaming.base.coord.world import World +from streaming.base.stream import Stream + +__all__ = ['JobDir'] + + +class JobDir: + """Represents a Streaming job lease. On ``__del__``, cleans up after itself. + + When it goes out of scope naturally, this Job will delete its config dir and its hold on all + the local dirs it is streaming to. + + If this process dies badly and the destructor is not reached, the same cleanup will be done by + some future process incidentally as it registers or unregisters a Streaming job. It can tell it + died by a combination of pid and process create time. + + Args: + registry (JobRegistry): Stremaing job registry. + """ + + def __init__(self, registry: JobRegistry, streams: Sequence[Stream], world: World) -> None: + self.registry = registry + self.streams = streams + self.world = world + self.job_hash = registry.register(streams, world) + + def get_filename(self, path: str) -> str: + """Get a filename by relative path under its job dir. + + Args: + path (str): Path relative to job dir. + + Returns: + str: Filename. + """ + return os.path.join(self.registry.config_root, self.job_hash, path) + + def __del__(self) -> None: + """Destructor.""" + self.registry.unregister(self.job_hash, self.world) diff --git a/streaming/base/coord/job/entry.py b/streaming/base/coord/job/entry.py new file mode 100644 index 000000000..6ebf88a6f --- /dev/null +++ b/streaming/base/coord/job/entry.py @@ -0,0 +1,65 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""An entry in a Streaming job registry file.""" + +from typing import Any, Dict, List, Optional + +from typing_extensions import Self + +__all__ = ['JobEntry'] + + +class JobEntry: + """Info about a Streaming job for local dir reuse detection purposes. + + Args: + index (int, optional): The job's index in the total list. + job_hash (str): Job hash. + stream_hashes (List[str]): Stream hashes. + stream_locals (List[str], optional): Stream locals, if available. + process_id (int): PID of local rank zero of the Streaming job. + register_time (int): Process registration time. + """ + + def __init__( + self, + *, + index: Optional[int] = None, + job_hash: str, + stream_hashes: List[str], + stream_locals: Optional[List[str]] = None, + process_id: int, + register_time: int, + ) -> None: + self.index = index + self.job_hash = job_hash + self.stream_hashes = stream_hashes + self.stream_locals = stream_locals + self.process_id = process_id + self.register_time = register_time + + @classmethod + def from_json(cls, obj: Dict[str, Any]) -> Self: + """Load from JSON. + + Args: + obj (Dict[str, Any]): Source JSON object. + + Returns: + Self: Loaded JobEntry. + """ + return cls(job_hash=obj['job_hash'], + stream_hashes=obj['stream_hashes'], + stream_locals=obj.get('stream_locals'), + process_id=obj['process_id'], + register_time=obj['register_time']) + + def to_json(self) -> Dict[str, Any]: + return { + 'job_hash': self.job_hash, + 'stream_hashes': self.stream_hashes, + # stream_locals is not saved, only their hashes. + 'process_id': self.process_id, + 'register_time': self.register_time, + } diff --git a/streaming/base/coord/job/file.py b/streaming/base/coord/job/file.py new file mode 100644 index 000000000..213394eac --- /dev/null +++ b/streaming/base/coord/job/file.py @@ -0,0 +1,130 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A Streaming job registry file.""" + +import json +import os +from typing import Dict, List + +from typing_extensions import Self + +from streaming.base.coord.job.entry import JobEntry + +__all__ = ['RegistryFile'] + + +class RegistryFile: + """StreamingDataset job registry, which is backed by a JSON file. + + Args: + jobs (List[JobEntry]): List of StreamingDataset jobs. + """ + + def __init__(self, jobs: List[JobEntry]) -> None: + self.jobs = [] + self.job_hash2job = {} + self.stream_hash2job = {} + self.num_jobs = 0 + for job in jobs: + self.add(job) + + @classmethod + def read(cls, filename: str) -> Self: + if os.path.exists(filename): + obj = json.load(open(filename)) + else: + obj = {} + jobs = obj.get('jobs') or [] + jobs = [JobEntry.from_json(job) for job in jobs] + return cls(jobs) + + def write(self, filename: str) -> None: + jobs = [job.to_json() for job in filter(bool, self.jobs)] + obj = {'jobs': jobs} + with open(filename, 'w') as out: + json.dump(obj, out) + + def __len__(self) -> int: + """Get the number of jobs registered. + + Returns: + int: Number of registered jobs. + """ + return self.num_jobs + + def add(self, job: JobEntry) -> None: + """Register a Stremaing job. + + Args: + job (Job): The job. + """ + # Check that stream locals line up. + if job.stream_locals: + if len(job.stream_hashes) != len(job.stream_locals): + raise ValueError(f'If locals are provided, must have one local per stream hash, ' + + f'but got: {len(job.stream_hashes)} hashes vs ' + + f'{len(job.stream_locals)} locals.') + norm_stream_locals = job.stream_locals + else: + norm_stream_locals = [None] * len(job.stream_hashes) + + # Check dataset hash for reuse. + if job.job_hash in self.job_hash2job: + if job.stream_locals: + raise ValueError(f'Reused dataset local path(s): {job.stream_locals}.') + else: + raise ValueError(f'Reused dataset local path(s): stream hashes = ' + + f'{job.stream_hashes}, dataset hash = {job.job_hash}.') + + # Check each stream hash for reuse. + for stream_hash, norm_stream_local in zip(job.stream_hashes, norm_stream_locals): + if stream_hash in self.stream_hash2job: + if norm_stream_local: + raise ValueError('Reused stream local path: {norm_stream_local}.') + else: + raise ValueError('Reused stream local path: stream hash = {stream_hash}.') + + # Do the insertion. + job.index = len(self.jobs) + self.jobs.append(job) + self.job_hash2job[job.job_hash] = job + for stream_hash in job.stream_hashes: + self.stream_hash2job[stream_hash] = job + self.num_jobs += 1 + + def remove(self, job_hash: str) -> None: + """Deregister a Streaming job. + + Args: + job_hash (str): Job hash. + """ + job = self.job_hash2job.get(job_hash) + if not job: + raise ValueError(f'Job hash not found: {job_hash}.') + + if job.index is None: + raise ValueError('Internal error in job registration: job index is missing.') + + self.jobs[job.index] = None + del self.job_hash2job[job.job_hash] + for stream_hash in job.stream_hashes: + del self.stream_hash2job[stream_hash] + self.num_jobs -= 1 + + def filter(self, pid2create_time: Dict[int, int]) -> List[str]: + """Filter our collection of Streaming jobs. + + Args: + pid2create_time (Dict[int, int]): Mapping of pid to creation time. + + Returns: + List[str]: List of hashes of removed datasets. + """ + del_job_hashes = [] + for job in filter(bool, self.jobs): + create_time = pid2create_time.get(job.process_id) + if not create_time or job.register_time < create_time: + self.remove(job.job_hash) + del_job_hashes.append(job.job_hash) + return del_job_hashes diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py new file mode 100644 index 000000000..446f1fac5 --- /dev/null +++ b/streaming/base/coord/job/registry.py @@ -0,0 +1,257 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""A directory containing all Streaming-wide filesystem state. + +Useful for detecting collisions between different jobs' local dirs. +""" + +import os +from hashlib import sha3_224 +from shutil import rmtree +from time import time_ns +from typing import Dict, List, Optional, Sequence, Tuple + +from filelock import FileLock +from psutil import process_iter + +from streaming.base.coord.filesystem.waiting import wait_for_creation, wait_for_deletion +from streaming.base.coord.job.entry import JobEntry +from streaming.base.coord.job.file import RegistryFile +from streaming.base.coord.world import World +from streaming.base.stream import Stream + +__all__ = ['JobRegistry'] + + +class JobRegistry: + """StreamingDataset job registry, for the purpose of detecting local dir reuse. + + This class is safe for concurrent access via a filelock. + + Args: + config_root (str): Streaming configuration root directory, used for collision detection, + filelock paths, etc. Defaults to ``/tmp/streaming``, using the equivalent temp root on + your system. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + """ + + def __init__( + self, + config_root: str, + timeout: Optional[float] = 30, + tick: float = 0.007, + ) -> None: + self.config_root = config_root + self.timeout = timeout + self.tick = tick + + self.lock_filename = os.path.join(config_root, 'registry.lock') + self.lock = FileLock(self.lock_filename) + + self.registry_filename = os.path.join(config_root, 'registry.json') + + def _get_live_procs(self) -> Dict[int, int]: + """List the pids and creation times of every live process in the system. + + The creation times protect us from PID reuse. + + Returns: + Dict[int, int]: Mapping of pid to integer creation time. + """ + ret = {} + for obj in process_iter(['pid', 'create_time']): + ret[obj.pid] = int(obj.create_time() * 1e9) + return ret + + def _hash(self, data: bytes) -> str: + """Get a short, deterministic, fixed-length code for the given data. + + Args: + data (bytes): The data to hash. + + Returns: + str: Truncated hex digest. + """ + return sha3_224(data).hexdigest()[:8] + + def _hash_streams(self, streams: Sequence[Stream]) -> Tuple[List[str], List[str], str]: + """Get a short, opaque str key for a StreamingDataset and each of its Streams. + + This is useful for collision detection. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + + Returns: + Tuple[str, List[str], List[str]]: Triple of (normalized stream locals, stream hashes, + and dataset hash). + """ + # Get a list of the normalized locals of each Stream. + stream_locals = [] + for stream in streams: + local = os.path.join(stream.local, stream.split) + local = os.path.normpath(local) + local = os.path.abspath(local) + stream_locals.append(local) + + # Collect the locals into a deduped set. + stream_locals_set = set() + for local in stream_locals: + if local in stream_locals_set: + raise ValueError(f'Reused local path: {local}.') + stream_locals_set.add(local) + + # Verify that no local is contained within another local. + for local in stream_locals: + parts = local.split(os.path.sep)[1:] + for num_parts in range(1, len(parts) - 1): # Leftmost is '' because they start with /. + parent = os.path.sep.join(parts[:num_parts]) + if parent in stream_locals_set: + raise ValueError(f'One local path contains another local path: {parent} vs ' + + f'{local}.') + + # Hash each local. + stream_hashes = [] + for local in sorted(stream_locals): + data = local.encode('utf-8') + stream_hash = self._hash(data) + stream_hashes.append(stream_hash) + + # Hash the dataset. + text = ','.join(stream_hashes) + data = text.encode('utf-8') + job_hash = self._hash(data) + + return stream_locals, stream_hashes, job_hash + + def _make_dir(self, job_hash: str) -> None: + """Create a Streaming job config dir. + + Args: + job_hash: Streaming config subdir for this job. + """ + dirname = os.path.join(self.config_root, job_hash) + os.makedirs(dirname) + + def _remove_dir(self, job_hash: str) -> None: + """Delete a Streaming job config dir. + + Args: + job_hash: Streaming config subdir for this job. + """ + dirname = os.path.join(self.config_root, job_hash) + rmtree(dirname) + + def _register(self, streams: Sequence[Stream]) -> str: + """Register this collection of StreamingDataset replicas. + + Called by the local leader. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + + Returns: + str: Streaming config subdir for this job. + """ + register_time = time_ns() + pid2create_time = self._get_live_procs() + pid = os.getpid() + create_time = pid2create_time.get(pid) + if create_time is None: + raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid = {pid}.') + + stream_locals, stream_hashes, job_hash = self._hash_streams(streams) + + entry = JobEntry(job_hash=job_hash, + stream_hashes=stream_hashes, + stream_locals=stream_locals, + process_id=pid, + register_time=register_time) + + with self.lock: + conf = RegistryFile.read(self.registry_filename) + conf.add(entry) + del_job_hashes = conf.filter(pid2create_time) + conf.write(self.registry_filename) + map(self._remove_dir, del_job_hashes) + self._make_dir(job_hash) + + return job_hash + + def _lookup(self, streams: Sequence[Stream]) -> str: + """Look up this collection of StreamingDataset replicas. + + Called by the local leader. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + + Returns: + str: Streaming config subdir for this job. + """ + _, _, job_hash = self._hash_streams(streams) + return job_hash + + def register(self, streams: Sequence[Stream], world: World) -> str: + """Register or look up this collection of StreamingDataset replicas. + + Called by all ranks. + + Args: + streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in + combination with process IDs and creation times lets us uniquely identify a + Streaming job. + world (World): Rank-wise world state. + + Returns: + str: Subdir for this collection of StreamingDataset replicas. + """ + if world.is_local_leader: + job_hash = self._register(streams) + else: + job_hash = self._lookup(streams) + dirname = os.path.join(self.config_root, job_hash) + wait_for_creation(dirname, self.timeout, self.tick, self.lock) + return job_hash + + def _unregister(self, job_hash: str) -> None: + """Unregister this collection of StreamingDataset replicas. + + Called by the local leader. + + Args: + job_hash (str): Subdir identifying this Streaming job. + """ + pid2create_time = self._get_live_procs() + + with self.lock: + conf = RegistryFile.read(self.registry_filename) + conf.remove(job_hash) + del_job_hashes = conf.filter(pid2create_time) + conf.write(self.registry_filename) + map(self._remove_dir, del_job_hashes) + self._remove_dir(job_hash) + + def unregister(self, job_hash: str, world: World) -> None: + """Unregister this collection of StreamingDataset replicas. + + Called by all ranks. + + Args: + job_hash (str): Subdir identifying this Streaming job. + world (World): Rank-wise world state. + """ + if world.is_local_leader: + self._unregister(job_hash) + else: + dirname = os.path.join(self.config_root, job_hash) + wait_for_deletion(dirname, self.timeout, self.tick, self.lock) diff --git a/streaming/base/coord/waiting.py b/streaming/base/coord/waiting.py new file mode 100644 index 000000000..92a640630 --- /dev/null +++ b/streaming/base/coord/waiting.py @@ -0,0 +1,72 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Waiting on predicates.""" + +from contextlib import nullcontext +from time import sleep, time +from typing import Any, Callable, Optional + +__all__ = ['wait'] + + +def _say_duration(duration: float) -> str: + """Pretty-print a duration. + + Args: + duration (float): The duration as a float. + + Returns: + str: The duration as a str. + """ + return f'{duration:.3f}'.rstrip('0').rstrip('.') + + +def wait( + stop: Callable[[], bool], + timeout: Optional[float] = 30, + tick: float = 0.007, + lock: Optional[Any] = None, +) -> None: + """Wait for the predicate to succeed. + + Args: + stop (Callable[[], bool]): When this check returns True, you break out of the retry loop. + timeout (float, optional): How long to wait before raising an exception, in seconds. + Defaults to ``30``. + tick (float): Check interval, in seconds. Defaults to ``0.007``. + lock (Any, optional): Context manager (this is intended for locks) to be held when + checking the predicate. Defaults to ``None``. + """ + start = time() + + if timeout is not None and timeout <= 0: + raise ValueError(f'Timeout must be positive if provided, but got: ' + + f'{_say_duration(timeout)} sec.') + + if tick <= 0: + raise ValueError(f'Tick must be positive if provided, but got: {_say_duration(tick)} sec.') + + if lock is not None: + if not hasattr(lock, '__enter__'): + raise ValueError(f'Lock must support `__enter__`, but got: {lock}.') + + if not hasattr(lock, '__exit__'): + raise ValueError(f'Lock must support `__exit__`, but got: {lock}.') + + norm_lock = lock + else: + norm_lock = nullcontext() + + while True: + with norm_lock: + if stop(): + break + + if timeout is not None: + now = time() + if timeout <= now - start: + raise RuntimeError(f'Wait timed out: timeout {_say_duration(timeout)} sec vs ' + + f'elapsed {_say_duration(now - start)} sec.') + + sleep(tick) From 1f4272dbb7fb9982ff179f911ccd3d9e6f5ce97e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 20 Jan 2024 07:12:31 -0800 Subject: [PATCH 02/15] Farizzle (_pregen_epoch, _gen_epoch). --- streaming/base/dataset.py | 289 +++++++++++++++++++++++++++++++++++--- 1 file changed, 273 insertions(+), 16 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index e9f5b96d6..fbc45195d 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -12,8 +12,10 @@ from concurrent.futures._base import Future from enum import IntEnum from math import ceil +from multiprocessing import Process +from tempfile import gettempdir from threading import Event, Lock -from time import sleep, time_ns +from time import sleep, time, time_ns from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union import numpy as np @@ -27,6 +29,8 @@ from streaming.base.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES, TICK) +from streaming.base.coord.job.dir import JobDir +from streaming.base.coord.job.registry import JobRegistry from streaming.base.distributed import maybe_init_dist from streaming.base.format import get_index_basename from streaming.base.sampling import get_sampling @@ -187,8 +191,13 @@ class StreamingDataset(Array, IterableDataset): * What to iterate: + * Dataset/job registry: + + * ``config_root`` + * One or more streams (you must provide either ``streams`` or ``remote``/``local``): + * ``epoch_size`` * ``streams`` * ``remote`` * ``local`` @@ -202,11 +211,16 @@ class StreamingDataset(Array, IterableDataset): * ``validate_hash`` * ``keep_zip`` - * Absolute dataset size, if streams were weighted relatively: + * How to iterate: - * ``epoch_size`` + * Epoch pre-generation: - * How to iterate: + * ``init_pregen_epoch`` + * ``inti_pregen_sample`` + * ``pregen_next_epoch`` + * ``pregen_epoch_timeout`` + * ``pregen_epoch_tick`` + * ``num_workers`` * Shard lifecycle: @@ -237,6 +251,14 @@ class StreamingDataset(Array, IterableDataset): Args: + config_root (str, optional): Streaming configuration root directory, used for collision + detection, filelock paths, etc. If ``None``, uses a ``/streaming/`` subdir under your + system's temp root. Defaults to ``None``. + epoch_size (int | str, optional): Number of samples to draw per epoch balanced + across all streams. If ``None``, takes its value from the total number of underlying + samples. Provide this field if you are weighting streams relatively to target a larger + or smaller epoch size. Defaults to ``None``. Can also take in human-readable number + abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``. streams (Sequence[Stream], optional): One or more streams to stream/cache samples from, which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. @@ -256,17 +278,28 @@ class StreamingDataset(Array, IterableDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to ``False``. - epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced - across all streams. If ``None``, takes its value from the total number of underlying - samples. Provide this field if you are weighting streams relatively to target a larger - or smaller epoch size. Defaults to ``None``. Can also take in human-readable number - abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``. + init_pregen_epoch (int, optional): What epoch to pre-generate in the background at init + time, if any. This is useful if you do a lot of work between instantiating your + StreamingDataset and iterating it. Defaults to ``None``. + init_pregen_sample (int, optional): What sample offset into the epoch to pre-generate with + in the background at init time. If ``init_pregen_epoch`` is not set, must not be set + either. Defaults to ``None``. + pregen_next_epoch (bool): Whether to pre-generate the next epoch in the background at the + start of iter after generating or loading the current about-to-be-iterated epoch. + Defaults to ``True``. + pregen_epoch_timeout (float, optional): Timeout when waiting on this epoch to be + pre-generated. Defaults to ``float(np.arange(1, 7).prod())``, i.e. 12 minutes. + pregen_epoch_tick (float): Polling interval when waiting on this epoch to be pre-generated. + Defaults to ``0xCAFE / 1337 / 42``, i.e. about 925ms. + num_workers (int, optional): Number of workers per rank, same as PyTorch DataLoader + ``num_workers``. Required iff you are pre-generating an epoch at init time, otherwise + this information is determined automatically elsewhere. Defaults to ``None``. predownload (int, optional): Target number of samples to download per worker in advance of current sample. Workers will attempt to download ahead by this many samples during, but not before, training. Recommendation is to provide a value greater than per device batch size to ensure at-least per device batch size number of samples cached locally. If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``. - cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's + cache_limit (int | str, optional): Maximum size in bytes of this StreamingDataset's shard cache. Before downloading a shard, the least recently used resident shard(s) may be evicted (deleted from the local cache) in order to stay under the limit. Set to ``None`` to disable shard eviction. Supports integer bytes as well as string @@ -310,6 +343,8 @@ class StreamingDataset(Array, IterableDataset): def __init__(self, *, + config_root: Optional[str] = None, + epoch_size: Optional[Union[int, str]] = None, streams: Optional[Sequence[Stream]] = None, remote: Optional[str] = None, local: Optional[str] = None, @@ -318,7 +353,12 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - epoch_size: Optional[Union[int, str]] = None, + init_pregen_epoch: Optional[int] = None, + init_pregen_sample: Optional[int] = None, + pregen_next_epoch: bool = True, + pregen_epoch_timeout: Optional[float] = float(np.arange(1, 7).prod()), + pregen_epoch_tick: float = 0xCAFE / 1337 / 42, + num_workers: Optional[int] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, sampling_method: str = 'balanced', @@ -505,7 +545,53 @@ def __init__(self, # Length (__len__) is the resampled epoch size divided over the number of devices. self.length = ceil(self.epoch_size / world.num_ranks) - # Register/lookup our shared memory prefix and filelock root directory. + # Args about pre-generating epochs. + if init_pregen_epoch is not None: + if init_pregen_epoch < 0: + raise ValueError(f'Init pregen epoch must be non-negative, but got: ' + + f'{init_pregen_epoch}.') + self.init_pregen_epoch = init_pregen_epoch + + if init_pregen_sample is not None: + if not (0 <= init_pregen_sample <= self.num_samples): + raise ValueError(f'Init pregen sample must be from 0 to {self.num_samples}, but ' + + f'got: {init_pregen_sample}.') + if init_pregen_epoch is not None: + self.init_pregen_sample = init_pregen_sample or 0 + else: + if init_pregen_sample is None: + raise ValueError(f'Init pregen epoch is not set, but init pregen sample is: ' + + f'epoch {init_pregen_epoch}, sample {init_pregen_sample}.') + self.init_pregen_sample = init_pregen_sample + + self.pregen_next_epoch = pregen_next_epoch + + if pregen_epoch_timeout is not None and pregen_epoch_timeout < 0: + raise ValueError(f'Pregen epoch timeout must be non-negative if set, but got: ' + + f'{pregen_epoch_timeout}.') + self.pregen_epoch_timeout = pregen_epoch_timeout + + if pregen_epoch_tick <= 0: + raise ValueError(f'Pregen epoch tick must be positive seconds, but got: ' + + f'{pregen_epoch_tick}.') + self.pregen_epoch_tick = pregen_epoch_tick + + self.num_workers = num_workers + + # Init registry, then register/lookup this Streaming job (new style). + self.config_root = self._get_config_root(config_root) + self._test_config_root(self.config_root) + self.registry = JobRegistry(self.config_root, 42, 0.007) + self.job = JobDir(self.registry, streams, world) + + # Maybe pre-generate some epoch. + if init_pregen_epoch is not None: + process = Process(target=self._pregen_epoch, + args=(self.init_pregen_epoch, self.init_pregen_sample), + daemon=True) + process.start() + + # Register/lookup our shared memory prefix and filelock root directory (old style). streams_local = [os.path.abspath(os.path.join(x.local, x.split)) for x in streams] streams_remote = [ os.path.join(x.remote, x.split) if x.remote is not None else None for x in streams @@ -596,6 +682,41 @@ def __del__(self) -> None: except: pass + @classmethod + def _test_config_root(cls, config_root: str) -> None: + """Validate that the provided config root is usable. + + If you are unable to get root or 777 perms, you may encounter problems in registering your + Streaming jobs for collision detection, getting unique interprocess filelock paths, etc. + You can sort of get around this by changing config root to a directory you control, but + this may negatively impact collision detection. + + Args: + config_root (str): Streaming configuration root directory. + """ + os.makedirs(config_root, exist_ok=True) + filename = os.path.join(config_root, 'test.txt') + try: + with open(filename, 'wb') as out: + out.write(b'') + except: + raise ValueError('Please provide a `config_root` dir that is writeable and readable.') + os.remove(filename) + + @classmethod + def _get_config_root(cls, config_root: Optional[str] = None) -> str: + """Get the Streaming configuration root directory. + + Args: + config_root (str, optional): Config root, if explicitly provided. Defaults to ``None``. + + Returns: + str: Streaming configuration root directory. + """ + if config_root is None: + config_root = os.path.join(gettempdir(), 'streaming') + return config_root + @property def size(self) -> int: """Get the size of the dataset in samples. @@ -941,7 +1062,144 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]: return sample_ids, shape_shm, data_shm - def _get_work(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: + def _locate_epoch_work(self, epoch: int, sample: int) -> str: + """Get the filename for generated epoch work given its epoch and sample offset. + + Args: + epoch (int): Which epoch. + sample (int): What sample offset. + + Returns: + str: Filename of serialized epoch work. + """ + return self.job.get_filename(f'epoch.{epoch:09}.{sample:012}.npy') + + def _serialize_epoch_work(self, work: NDArray[np.int64]) -> bytes: + """Serialize a 5-dimensional sample ID arrangement tensor to bytes. + + Args: + work (NDArray[np.int64]): Sample IDs tensor. + + Returns: + bytes: The serialized data. + """ + # Serialize to bytes prefixed with shape (we use int64 for alignment reasons). + return b''.join([ + np.int64(work.ndim).tobytes(), + np.array(work.shape, np.int64).tobytes(), + work.tobytes(), + ]) + + def _deserialize_epoch_work(self, data: bytes) -> NDArray[np.int64]: + """Deserialize a 5-dimensional sample ID arrangement tensor from bytes. + + Args: + data (bytes): The serialized data. + + Returns: + NDArray[np.int64]: Sample IDs tensor. + """ + arr = np.ndarray(shape=-1, dtype=np.int64, buffer=data) + ndim = arr[0] + shape = tuple(arr[1:1 + ndim].tolist()) + offset = (1 + ndim) * np.int64().nbytes + return np.ndarray(shape, np.int64, arr, offset) + + def _pregen_epoch(self, epoch: int, sample: int) -> None: + """Pre-generate the sample ID arrangement for some epoch. + + This is typically run in the background in a daemon process. + + Args: + epoch (int): Which epoch. + sample (int): What sample offset. + """ + if self.num_workers is None: + raise ValueError(f'You must provide DataLoader num_workers to StreamingDataset in ' + + f'order for it to be able to pre-generate the epoch at init time.') + + # Locate epoch data, e.g. "epoch.000000007.000000001000.npy". + filename = self._locate_epoch_work(epoch, sample) + + # If there is already a file there, either someone has pre-generated it already (non-empty) + # or they are in the process of pre-generating it (empty) and we are done. If no file, + # create one to claim it ourself. + try: + with open(filename, 'xb'): + pass + except: + return + + # Create the world a worker will see. + world = World() + if 1 < self.num_workers: + world.workers_per_rank = self.num_workers + world.num_workers = world.num_ranks * world.workers_per_rank + world.workers_per_node = world.ranks_per_node * world.workers_per_rank + + # Do the epoch generation (heavy). + work = generate_work(self.batching_method, self, world, epoch, sample) + + # Serialize to bytes. + data = self._serialize_epoch_work(work) + + # Write those bytes, to be picked up by the main process/thread. + tmp_filename = filename + '.tmp' + with open(tmp_filename, 'wb') as out: + out.write(data) + os.rename(tmp_filename, filename) + + def _gen_epoch(self, world: World, epoch: int, sample: int) -> NDArray[np.int64]: + """Generate (or load pre-generated) the sample ID arrangement for some epoch. + + Args: + world (World): The world dimensions to generate it for. + epoch (int): Which epoch. + sample (int): What sample offset. + + Returns: + NDArray[np.int64]: 5-dim sample IDs tensor. + """ + # Get where our pre-generated epoch data would be found, if it exists. + filename = self._locate_epoch_work(epoch, sample) + + # If the file is taken, it either is populated or will be soon. If not, we have to generate + # the epoch ourself. + if os.path.exists(filename): + # Wait for the file to become populated. + then = time() + while True: + # If it's populated, break out. + stat = os.stat(filename) + if stat.st_size: + break + + # If it's not yet populated, you then check how much time we've taken. + now = time() + elapsed = now - then + if self.pregen_epoch_timeout is not None and self.pregen_epoch_timeout < elapsed: + raise ValueError(f'Timed out while waiting on epoch pre-generation: epoch ' + + f'{epoch}, sample {sample}, timeout ' + + f'{self.pregen_epoch_timeout}, elapsed {elapsed}.') + + # If we're still waiting, sleep a bit. + sleep(self.pregen_epoch_tick) + + # Deserialize the populated file. + data = open(filename, 'rb').read() + work = self._deserialize_epoch_work(data) + else: + # Generate the epoch ourself. + work = generate_work(self.batching_method, self, world, epoch, sample) + + # Maybe pre-generate the next epoch in the background. + if self.pregen_next_epoch: + process = Process(target=self._pregen_epoch, args=(epoch + 1, 0), daemon=True) + process.start() + + return work + + def _get_epoch(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]: """Get this worker's partition of this epoch's sample space. Args: @@ -959,8 +1217,7 @@ def _get_work(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[n # Do expensive work that may use a lot of cores/memory just once, in the local leader. if world.is_local_leader: - epoch_sample_ids = generate_work(self.batching_method, self, world, epoch, - sample_in_epoch) + epoch_sample_ids = self._gen_epoch(world, epoch, sample_in_epoch) shape_shm, data_shm = self._share_work(epoch_sample_ids) self._shared_barrier(world.workers_per_node) else: @@ -1417,7 +1674,7 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: epoch, sample_in_epoch = self._resume_incr_epoch(world) # Get this worker's partition of samples to process. - sample_ids = self._get_work(world, epoch, sample_in_epoch) + sample_ids = self._get_epoch(world, epoch, sample_in_epoch) if not len(sample_ids): # Resumed at end of epoch, out of samples. return From 7e7c2b72f055f2fbe8ed2baee255b4cc4d43aad5 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 20 Jan 2024 07:17:24 -0800 Subject: [PATCH 03/15] Womp womp. --- streaming/base/coord/__init__,py | 0 streaming/base/coord/job/dir.py | 2 +- streaming/base/coord/job/registry.py | 2 +- streaming/base/dataset.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 streaming/base/coord/__init__,py diff --git a/streaming/base/coord/__init__,py b/streaming/base/coord/__init__,py new file mode 100644 index 000000000..e69de29bb diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py index 8e8a1dc2b..77c2a620e 100644 --- a/streaming/base/coord/job/dir.py +++ b/streaming/base/coord/job/dir.py @@ -7,8 +7,8 @@ from typing import Sequence from streaming.base.coord.job.registry import JobRegistry -from streaming.base.coord.world import World from streaming.base.stream import Stream +from streaming.base.world import World __all__ = ['JobDir'] diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py index 446f1fac5..90372a9de 100644 --- a/streaming/base/coord/job/registry.py +++ b/streaming/base/coord/job/registry.py @@ -18,8 +18,8 @@ from streaming.base.coord.filesystem.waiting import wait_for_creation, wait_for_deletion from streaming.base.coord.job.entry import JobEntry from streaming.base.coord.job.file import RegistryFile -from streaming.base.coord.world import World from streaming.base.stream import Stream +from streaming.base.world import World __all__ = ['JobRegistry'] diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index fbc45195d..9dcf24604 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -559,7 +559,7 @@ def __init__(self, if init_pregen_epoch is not None: self.init_pregen_sample = init_pregen_sample or 0 else: - if init_pregen_sample is None: + if init_pregen_sample is not None: raise ValueError(f'Init pregen epoch is not set, but init pregen sample is: ' + f'epoch {init_pregen_epoch}, sample {init_pregen_sample}.') self.init_pregen_sample = init_pregen_sample From 708f236b238a1fda35299a09f1368fbdf6f0093c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 20 Jan 2024 07:39:51 -0800 Subject: [PATCH 04/15] Add psutil. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index c70a84143..de2116c9e 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'psutil==5.9.4', ] extra_deps = {} From c4b35abd291832375df7c9d100cac110a015ee1b Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 21 Jan 2024 01:40:22 -0800 Subject: [PATCH 05/15] Temp disable. --- streaming/base/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 9dcf24604..bc5c64fb2 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1193,9 +1193,10 @@ def _gen_epoch(self, world: World, epoch: int, sample: int) -> NDArray[np.int64] work = generate_work(self.batching_method, self, world, epoch, sample) # Maybe pre-generate the next epoch in the background. - if self.pregen_next_epoch: - process = Process(target=self._pregen_epoch, args=(epoch + 1, 0), daemon=True) - process.start() + # TODO: re-enable: + # if self.pregen_next_epoch: + # process = Process(target=self._pregen_epoch, args=(epoch + 1, 0), daemon=True) + # process.start() return work From f6e327cac792d40f36b916df2d25e0069933244a Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 21 Jan 2024 02:03:28 -0800 Subject: [PATCH 06/15] Update tests. --- tests/test_streaming.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7ef98dfec..6928a3c83 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -782,6 +782,7 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s sample_order.extend(batch['id'][:]) del dataloader + del dataset.job # TODO: Why do we need this hack? del dataset clean_stale_shared_memory() @@ -861,6 +862,9 @@ def test_multiple_dataset_instantiation(local_remote_dir: Any, shuffle_seed: tup assert len(set(train_sample_order)) == len(set(val_sample_order)), 'Duplicate samples' +@pytest.mark.skip('We could be resuming with shard files not all be in their final phases, so ' + + 'the directory could still change on the fly even if there is no remote, so ' + + 'we cannot reuse local even in this case.') def test_same_local_no_remote(local_remote_dir: Tuple[str, str]): local_0, _ = local_remote_dir convert_to_mds(out_root=local_0, @@ -893,5 +897,5 @@ def test_same_local_diff_remote(local_remote_dir: Tuple[str, str]): # Build StreamingDataset _ = StreamingDataset(local=local_0, remote=remote_0, batch_size=4, num_canonical_nodes=1) # Build StreamingDataset - with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'): + with pytest.raises(ValueError): _ = StreamingDataset(local=local_0, remote=remote_1, batch_size=2, num_canonical_nodes=1) From 5607931fe916222c67f63e0e6aeb326d37791676 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 21 Jan 2024 21:21:31 -0800 Subject: [PATCH 07/15] Updates (need docstrings). def _push_back_pregen_epoch_todo(self, todo_filename: str, epoch: int, sample: int) -> None: def _pop_front_pregen_epoch_todo(self, todo_filename: str) -> Tuple[int, int, int]: def _request_pregen_epoch(self, epoch: int, sample: int) -> None: def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]: def _pregen_epoch_loop(self) -> None: --- streaming/base/dataset.py | 82 +++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 12 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index bc5c64fb2..8fd0b7b29 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -341,6 +341,9 @@ class StreamingDataset(Array, IterableDataset): if ``False``. Defaults to ``False``. """ + pregen_todos_lock_path = 'pregen_todos.lock' + pregen_todos_path = 'pregen_todos.npy' + def __init__(self, *, config_root: Optional[str] = None, @@ -584,12 +587,14 @@ def __init__(self, self.registry = JobRegistry(self.config_root, 42, 0.007) self.job = JobDir(self.registry, streams, world) - # Maybe pre-generate some epoch. - if init_pregen_epoch is not None: - process = Process(target=self._pregen_epoch, - args=(self.init_pregen_epoch, self.init_pregen_sample), - daemon=True) - process.start() + # Maybe note some epoch to pre-generate (like epoch 0, sample offset 0)? + if self.init_pregen_epoch is not None: + self._request_pregen_epoch(self.init_pregen_epoch, self.init_pregen_sample or 0) + + # Start the epoch pre-generation loop as a daemon process. + if init_pregen_epoch is not None or pregen_next_epoch: + self.process = Process(target=self._pregen_epoch_loop, daemon=True) + self.process.run() # Register/lookup our shared memory prefix and filelock root directory (old style). streams_local = [os.path.abspath(os.path.join(x.local, x.split)) for x in streams] @@ -675,7 +680,12 @@ def __init__(self, del self._shared_barrier.lock # Remote the lock that makes it unpickleable. def __del__(self) -> None: - """Destructor, which releases its local working directories.""" + """Destructor,kill which releases its local working directories.""" + try: + self.process.kill() + except: + pass + if hasattr(self, '_locals_shm'): try: self._locals_shm.buf[:4] = np.int32(0).tobytes() @@ -1072,7 +1082,7 @@ def _locate_epoch_work(self, epoch: int, sample: int) -> str: Returns: str: Filename of serialized epoch work. """ - return self.job.get_filename(f'epoch.{epoch:09}.{sample:012}.npy') + return self.job.get_filename(f'epoch.{epoch:06}.{sample:012}.npy') def _serialize_epoch_work(self, work: NDArray[np.int64]) -> bytes: """Serialize a 5-dimensional sample ID arrangement tensor to bytes. @@ -1149,6 +1159,49 @@ def _pregen_epoch(self, epoch: int, sample: int) -> None: out.write(data) os.rename(tmp_filename, filename) + def _push_back_pregen_epoch_todo(self, todo_filename: str, epoch: int, sample: int) -> None: + now = time_ns() + push_back = np.array([epoch, sample, now], np.int64) + if os.path.exists(todo_filename): + old = np.fromfile(todo_filename, np.int64) + old = old.reshape(-1, 3) + new = np.concatenate([old, push_back], 0) + else: + new = push_back + new.tofile(todo_filename) + + def _pop_front_pregen_epoch_todo(self, todo_filename: str) -> Tuple[int, int, int]: + old = np.fromfile(todo_filename, np.int64) + old = old.reshape(-1, 3) + pop_front = old[0] + new = old[1:] + if len(new): + new.tofile(todo_filename) + else: + os.remove(todo_filename) + return tuple(pop_front.tolist()) + + def _request_pregen_epoch(self, epoch: int, sample: int) -> None: + lock_filename = self.job.get_filename(self.pregen_todos_lock_path) + todo_filename = self.job.get_filename(self.pregen_todos_path) + with FileLock(lock_filename): + self._push_back_pregen_epoch_todo(todo_filename, epoch, sample) + + def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]: + lock_filename = self.job.get_filename(self.pregen_todos_lock_path) + todo_filename = self.job.get_filename(self.pregen_todos_path) + lock = FileLock(lock_filename) + while True: + with lock: + if os.path.exists(todo_filename): + epoch, sample, _ = self._pop_front_pregen_epoch_todo(todo_filename) + yield epoch, sample + sleep(0.777) + + def _pregen_epoch_loop(self) -> None: + for epoch, sample in self._each_pregen_epoch_todo(): + self._pregen_epoch(epoch, sample) + def _gen_epoch(self, world: World, epoch: int, sample: int) -> NDArray[np.int64]: """Generate (or load pre-generated) the sample ID arrangement for some epoch. @@ -1189,14 +1242,19 @@ def _gen_epoch(self, world: World, epoch: int, sample: int) -> NDArray[np.int64] data = open(filename, 'rb').read() work = self._deserialize_epoch_work(data) else: + # Claim the epoch generation work, preventing the epoch pregen process from doing it. + try: + with open(filename, 'xb'): + pass + except: + pass + # Generate the epoch ourself. work = generate_work(self.batching_method, self, world, epoch, sample) # Maybe pre-generate the next epoch in the background. - # TODO: re-enable: - # if self.pregen_next_epoch: - # process = Process(target=self._pregen_epoch, args=(epoch + 1, 0), daemon=True) - # process.start() + if self.pregen_next_epoch: + self._request_pregen_epoch(epoch + 1, 0) return work From ce5329830f8794d0d0ad964e307190d93b0add09 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 21 Jan 2024 23:44:10 -0800 Subject: [PATCH 08/15] Blame the dummy. --- streaming/base/dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 8fd0b7b29..d70e9647a 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -679,13 +679,11 @@ def __init__(self, del self._shared_barrier.lock # Remote the lock that makes it unpickleable. + self._dummy = None + def __del__(self) -> None: """Destructor,kill which releases its local working directories.""" - try: - self.process.kill() - except: - pass - + del self._dummy if hasattr(self, '_locals_shm'): try: self._locals_shm.buf[:4] = np.int32(0).tobytes() @@ -1196,7 +1194,9 @@ def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]: if os.path.exists(todo_filename): epoch, sample, _ = self._pop_front_pregen_epoch_todo(todo_filename) yield epoch, sample - sleep(0.777) + if not hasattr(self, '_dummy'): + break + sleep(0.1337) def _pregen_epoch_loop(self) -> None: for epoch, sample in self._each_pregen_epoch_todo(): From f905849667d4959567b41229614dbc25eca2d2bb Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 01:32:27 -0800 Subject: [PATCH 09/15] Fix. --- streaming/base/dataset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index d70e9647a..33e6917a0 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1159,25 +1159,26 @@ def _pregen_epoch(self, epoch: int, sample: int) -> None: def _push_back_pregen_epoch_todo(self, todo_filename: str, epoch: int, sample: int) -> None: now = time_ns() - push_back = np.array([epoch, sample, now], np.int64) + todo = np.array([epoch, sample, now], np.int64) + todo = np.expand_dims(todo, 0) if os.path.exists(todo_filename): old = np.fromfile(todo_filename, np.int64) old = old.reshape(-1, 3) - new = np.concatenate([old, push_back], 0) + new = np.concatenate([old, todo], 0) else: - new = push_back + new = todo new.tofile(todo_filename) def _pop_front_pregen_epoch_todo(self, todo_filename: str) -> Tuple[int, int, int]: old = np.fromfile(todo_filename, np.int64) old = old.reshape(-1, 3) - pop_front = old[0] + todo = old[0] new = old[1:] if len(new): new.tofile(todo_filename) else: os.remove(todo_filename) - return tuple(pop_front.tolist()) + return tuple(todo.tolist()) def _request_pregen_epoch(self, epoch: int, sample: int) -> None: lock_filename = self.job.get_filename(self.pregen_todos_lock_path) From 5d11262cc2636aeca9d4cf49119fabe820114a87 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 01:43:58 -0800 Subject: [PATCH 10/15] Fix init file. --- docs/source/conf.py | 1 + streaming/base/coord/__init__,py | 0 streaming/base/coord/__init__.py | 4 ++++ 3 files changed, 5 insertions(+) delete mode 100644 streaming/base/coord/__init__,py create mode 100644 streaming/base/coord/__init__.py diff --git a/docs/source/conf.py b/docs/source/conf.py index dbd0f1b83..41b253732 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -365,6 +365,7 @@ def _modules_to_rst() -> List[types.ModuleType]: document_modules: List[types.Module] = [ streaming, streaming.base.compression, + streaming.base.coord, streaming.base.format, streaming.base.hashing, streaming.base.partition, diff --git a/streaming/base/coord/__init__,py b/streaming/base/coord/__init__,py deleted file mode 100644 index e69de29bb..000000000 diff --git a/streaming/base/coord/__init__.py b/streaming/base/coord/__init__.py new file mode 100644 index 000000000..1bd1d49d9 --- /dev/null +++ b/streaming/base/coord/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2022-2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Functionality having to do with coordination between replicas.""" From 046b2d750bb23418e6601bf5b1d4b2e86c080e5f Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 02:14:30 -0800 Subject: [PATCH 11/15] Fix (re: garbage collection and destructors). --- streaming/base/coord/job/dir.py | 18 +++++++++-- streaming/base/coord/job/file.py | 13 +++++++- streaming/base/coord/job/registry.py | 46 +++++++++++++++++++++++++--- streaming/base/dataset.py | 3 +- tests/test_streaming.py | 1 - 5 files changed, 72 insertions(+), 9 deletions(-) diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py index 77c2a620e..dab97bb7a 100644 --- a/streaming/base/coord/job/dir.py +++ b/streaming/base/coord/job/dir.py @@ -44,6 +44,20 @@ def get_filename(self, path: str) -> str: """ return os.path.join(self.registry.config_root, self.job_hash, path) - def __del__(self) -> None: - """Destructor.""" + def manual_unregister(self) -> None: + """Explicitly un-register the job ahead of its deletion. + + This is useful when you want to ensure that this job is un-registered synchronously instead + of whenever the garbage collector eventually gets around to it. + + This job must be registered when this is called. + """ self.registry.unregister(self.job_hash, self.world) + + def __del__(self) -> None: + """Destructor. + + You may unregister the job explicitly ahead of time (to ensure it happens synchronously + instead of eventually). + """ + self.registry.ensure_unregistered(self.job_hash, self.world) diff --git a/streaming/base/coord/job/file.py b/streaming/base/coord/job/file.py index 213394eac..41457f198 100644 --- a/streaming/base/coord/job/file.py +++ b/streaming/base/coord/job/file.py @@ -54,7 +54,7 @@ def __len__(self) -> int: return self.num_jobs def add(self, job: JobEntry) -> None: - """Register a Stremaing job. + """Register a Streaming job. Args: job (Job): The job. @@ -93,6 +93,17 @@ def add(self, job: JobEntry) -> None: self.stream_hash2job[stream_hash] = job self.num_jobs += 1 + def contains(self, job_hash: str) -> bool: + """Tell whether the given job_hash is registered. + + Args: + job_hash (str): Potentially registered job hash. + + Returns: + bool: Whether the job hash is registered. + """ + return job_hash in self.job_hash2job + def remove(self, job_hash: str) -> None: """Deregister a Streaming job. diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py index 90372a9de..3912511dd 100644 --- a/streaming/base/coord/job/registry.py +++ b/streaming/base/coord/job/registry.py @@ -147,7 +147,7 @@ def _remove_dir(self, job_hash: str) -> None: dirname = os.path.join(self.config_root, job_hash) rmtree(dirname) - def _register(self, streams: Sequence[Stream]) -> str: + def _do_register(self, streams: Sequence[Stream]) -> str: """Register this collection of StreamingDataset replicas. Called by the local leader. @@ -216,14 +216,29 @@ def register(self, streams: Sequence[Stream], world: World) -> str: str: Subdir for this collection of StreamingDataset replicas. """ if world.is_local_leader: - job_hash = self._register(streams) + job_hash = self._do_register(streams) else: job_hash = self._lookup(streams) dirname = os.path.join(self.config_root, job_hash) wait_for_creation(dirname, self.timeout, self.tick, self.lock) return job_hash - def _unregister(self, job_hash: str) -> None: + def is_registered(self, job_hash: str) -> bool: + """Tell whether the given job_hash is registered. + + Called by all ranks. + + Args: + job_hash (str): Potentially registered job hash. + + Returns: + bool: Whether the job hash is registered. + """ + with self.lock: + conf = RegistryFile.read(self.registry_filename) + return conf.contains(job_hash) + + def _do_unregister(self, job_hash: str) -> None: """Unregister this collection of StreamingDataset replicas. Called by the local leader. @@ -251,7 +266,30 @@ def unregister(self, job_hash: str, world: World) -> None: world (World): Rank-wise world state. """ if world.is_local_leader: - self._unregister(job_hash) + self._do_unregister(job_hash) else: dirname = os.path.join(self.config_root, job_hash) wait_for_deletion(dirname, self.timeout, self.tick, self.lock) + + def ensure_unregistered(self, job_hash: str, world: World) -> None: + """Ensure that this collection of StreamingDataset replicas is unregistered. + + Called by all ranks. + + Args: + job_hash (str): Subdir identifying this Streaming job. + world (World): Rank-wise world state. + """ + pid2create_time = self._get_live_procs() + + with self.lock: + conf = RegistryFile.read(self.registry_filename) + is_registered = conf.contains(job_hash) + if not is_registered: + return + + conf.remove(job_hash) + del_job_hashes = conf.filter(pid2create_time) + conf.write(self.registry_filename) + map(self._remove_dir, del_job_hashes) + self._remove_dir(job_hash) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index cf5c56531..c3545afbe 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -683,12 +683,13 @@ def __init__(self, def __del__(self) -> None: """Destructor,kill which releases its local working directories.""" - del self._dummy if hasattr(self, '_locals_shm'): try: self._locals_shm.buf[:4] = np.int32(0).tobytes() except: pass + del self._dummy + self.job.manual_unregister() @classmethod def _test_config_root(cls, config_root: str) -> None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6928a3c83..fd8aebc56 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -782,7 +782,6 @@ def test_streamingdataloader_mid_epoch_resumption(local_remote_dir: Any, batch_s sample_order.extend(batch['id'][:]) del dataloader - del dataset.job # TODO: Why do we need this hack? del dataset clean_stale_shared_memory() From b3fc2a2343b802f6bbfabeaaf0646d3ef6a6df85 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 02:50:10 -0800 Subject: [PATCH 12/15] Try gc. --- streaming/base/coord/job/registry.py | 20 ++++++++++++++++++++ streaming/base/dataset.py | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py index 3912511dd..fbf6d61fc 100644 --- a/streaming/base/coord/job/registry.py +++ b/streaming/base/coord/job/registry.py @@ -6,6 +6,7 @@ Useful for detecting collisions between different jobs' local dirs. """ +import gc import os from hashlib import sha3_224 from shutil import rmtree @@ -152,6 +153,24 @@ def _do_register(self, streams: Sequence[Stream]) -> str: Called by the local leader. + Note: we explicitly garbage collect under the lock right before registration. This is to + save us and you from the following scenario: + + ```py + dataset = StreamingDataset(...) + + del dataset + + # The dataset is marked deleted, but the python garbage collector does not execute + # dataset.__del__ in time, much less dataset.job.__del__ in time, which would have + # automatically ensured the job was un-registered. + + dataset = StreamingDataset(...) # Same locals as before. + + # *Boom*, due to "reused" locals, matching the locals still registered from the first + # time the dataset was created. + ``` + Args: streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in combination with process IDs and creation times lets us uniquely identify a @@ -176,6 +195,7 @@ def _do_register(self, streams: Sequence[Stream]) -> str: register_time=register_time) with self.lock: + gc.collect() conf = RegistryFile.read(self.registry_filename) conf.add(entry) del_job_hashes = conf.filter(pid2create_time) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index c3545afbe..33c92180b 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -688,7 +688,6 @@ def __del__(self) -> None: self._locals_shm.buf[:4] = np.int32(0).tobytes() except: pass - del self._dummy self.job.manual_unregister() @classmethod From 9487b1f2f20042c62328176b2ea63990d7cde84f Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 02:59:36 -0800 Subject: [PATCH 13/15] Try gc #2. --- streaming/base/coord/job/registry.py | 39 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py index fbf6d61fc..9222e1fc9 100644 --- a/streaming/base/coord/job/registry.py +++ b/streaming/base/coord/job/registry.py @@ -153,24 +153,6 @@ def _do_register(self, streams: Sequence[Stream]) -> str: Called by the local leader. - Note: we explicitly garbage collect under the lock right before registration. This is to - save us and you from the following scenario: - - ```py - dataset = StreamingDataset(...) - - del dataset - - # The dataset is marked deleted, but the python garbage collector does not execute - # dataset.__del__ in time, much less dataset.job.__del__ in time, which would have - # automatically ensured the job was un-registered. - - dataset = StreamingDataset(...) # Same locals as before. - - # *Boom*, due to "reused" locals, matching the locals still registered from the first - # time the dataset was created. - ``` - Args: streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in combination with process IDs and creation times lets us uniquely identify a @@ -195,7 +177,6 @@ def _do_register(self, streams: Sequence[Stream]) -> str: register_time=register_time) with self.lock: - gc.collect() conf = RegistryFile.read(self.registry_filename) conf.add(entry) del_job_hashes = conf.filter(pid2create_time) @@ -226,6 +207,25 @@ def register(self, streams: Sequence[Stream], world: World) -> str: Called by all ranks. + + Note: we explicitly garbage collect right before registration. This is to save us from the + following scenario: + + ```py + dataset = StreamingDataset(...) + + del dataset + + # The dataset is marked deleted, but the python garbage collector does not execute + # dataset.__del__ in time, much less dataset.job.__del__ in time, which would have + # automatically ensured the job was un-registered. + + dataset = StreamingDataset(...) # Same locals as before. + + # *Boom*, due to "reused" locals, matching the locals still registered from the first + # time the dataset was created. + ``` + Args: streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in combination with process IDs and creation times lets us uniquely identify a @@ -235,6 +235,7 @@ def register(self, streams: Sequence[Stream], world: World) -> str: Returns: str: Subdir for this collection of StreamingDataset replicas. """ + gc.collect() if world.is_local_leader: job_hash = self._do_register(streams) else: From db3041f75e4a47f53cbccf2219ee911036a4d659 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 03:10:11 -0800 Subject: [PATCH 14/15] Try gc 3. --- streaming/base/coord/job/dir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py index dab97bb7a..6f5a96641 100644 --- a/streaming/base/coord/job/dir.py +++ b/streaming/base/coord/job/dir.py @@ -52,7 +52,7 @@ def manual_unregister(self) -> None: This job must be registered when this is called. """ - self.registry.unregister(self.job_hash, self.world) + self.registry.ensure_unregistered(self.job_hash, self.world) def __del__(self) -> None: """Destructor. From 7470d27cdabeee5fbbf0f1588584875e02e1c906 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 22 Jan 2024 07:16:28 -0800 Subject: [PATCH 15/15] Refactor registry. --- streaming/base/coord/job/dir.py | 4 +- streaming/base/coord/job/file.py | 6 +- streaming/base/coord/job/registry.py | 190 ++++++++++----------------- streaming/base/dataset.py | 2 + 4 files changed, 77 insertions(+), 125 deletions(-) diff --git a/streaming/base/coord/job/dir.py b/streaming/base/coord/job/dir.py index 6f5a96641..e553183bd 100644 --- a/streaming/base/coord/job/dir.py +++ b/streaming/base/coord/job/dir.py @@ -52,7 +52,7 @@ def manual_unregister(self) -> None: This job must be registered when this is called. """ - self.registry.ensure_unregistered(self.job_hash, self.world) + self.registry.unregister(self.job_hash, self.world, True) def __del__(self) -> None: """Destructor. @@ -60,4 +60,4 @@ def __del__(self) -> None: You may unregister the job explicitly ahead of time (to ensure it happens synchronously instead of eventually). """ - self.registry.ensure_unregistered(self.job_hash, self.world) + self.registry.unregister(self.job_hash, self.world, False) diff --git a/streaming/base/coord/job/file.py b/streaming/base/coord/job/file.py index 41457f198..cbb631b53 100644 --- a/streaming/base/coord/job/file.py +++ b/streaming/base/coord/job/file.py @@ -32,7 +32,11 @@ def __init__(self, jobs: List[JobEntry]) -> None: @classmethod def read(cls, filename: str) -> Self: if os.path.exists(filename): - obj = json.load(open(filename)) + try: + obj = json.load(open(filename)) + except: + os.remove(filename) + obj = {} else: obj = {} jobs = obj.get('jobs') or [] diff --git a/streaming/base/coord/job/registry.py b/streaming/base/coord/job/registry.py index 9222e1fc9..0ce748175 100644 --- a/streaming/base/coord/job/registry.py +++ b/streaming/base/coord/job/registry.py @@ -6,7 +6,6 @@ Useful for detecting collisions between different jobs' local dirs. """ -import gc import os from hashlib import sha3_224 from shutil import rmtree @@ -130,7 +129,7 @@ def _hash_streams(self, streams: Sequence[Stream]) -> Tuple[List[str], List[str] return stream_locals, stream_hashes, job_hash - def _make_dir(self, job_hash: str) -> None: + def _make_job_dir(self, job_hash: str) -> None: """Create a Streaming job config dir. Args: @@ -139,7 +138,7 @@ def _make_dir(self, job_hash: str) -> None: dirname = os.path.join(self.config_root, job_hash) os.makedirs(dirname) - def _remove_dir(self, job_hash: str) -> None: + def _remove_job_dir(self, job_hash: str) -> None: """Delete a Streaming job config dir. Args: @@ -148,101 +147,60 @@ def _remove_dir(self, job_hash: str) -> None: dirname = os.path.join(self.config_root, job_hash) rmtree(dirname) - def _do_register(self, streams: Sequence[Stream]) -> str: - """Register this collection of StreamingDataset replicas. + def register(self, streams: Sequence[Stream], world: World) -> str: + """Register or look up this collection of StreamingDataset replicas. - Called by the local leader. + Called by all ranks. Args: streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in combination with process IDs and creation times lets us uniquely identify a Streaming job. + world (World): Rank-wise world state. Returns: - str: Streaming config subdir for this job. + str: Subdir for this collection of StreamingDataset replicas. """ - register_time = time_ns() - pid2create_time = self._get_live_procs() - pid = os.getpid() - create_time = pid2create_time.get(pid) - if create_time is None: - raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid = {pid}.') + if not world.is_local_leader: + _, _, job_hash = self._hash_streams(streams) + dirname = os.path.join(self.config_root, job_hash) + wait_for_creation(dirname, self.timeout, self.tick, self.lock) + return job_hash + # Collect our stream locals and hash them, resulting in a job hash. stream_locals, stream_hashes, job_hash = self._hash_streams(streams) - entry = JobEntry(job_hash=job_hash, - stream_hashes=stream_hashes, - stream_locals=stream_locals, - process_id=pid, - register_time=register_time) - with self.lock: - conf = RegistryFile.read(self.registry_filename) - conf.add(entry) - del_job_hashes = conf.filter(pid2create_time) - conf.write(self.registry_filename) - map(self._remove_dir, del_job_hashes) - self._make_dir(job_hash) - - return job_hash - - def _lookup(self, streams: Sequence[Stream]) -> str: - """Look up this collection of StreamingDataset replicas. - - Called by the local leader. - - Args: - streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in - combination with process IDs and creation times lets us uniquely identify a - Streaming job. - - Returns: - str: Streaming config subdir for this job. - """ - _, _, job_hash = self._hash_streams(streams) - return job_hash - - def register(self, streams: Sequence[Stream], world: World) -> str: - """Register or look up this collection of StreamingDataset replicas. - - Called by all ranks. - + # Get registration time. + register_time = time_ns() - Note: we explicitly garbage collect right before registration. This is to save us from the - following scenario: + # Load the job database. + db = RegistryFile.read(self.registry_filename) - ```py - dataset = StreamingDataset(...) + # Perform liveness checks on the jobs we have registered. + pid2create_time = self._get_live_procs() + del_job_hashes = db.filter(pid2create_time) - del dataset + # Add an entry for this job. + pid = os.getpid() + create_time = pid2create_time.get(pid) + if create_time is None: + raise RuntimeError('`psutil` thinks we are dead, and yet here we are: pid {pid}.') + entry = JobEntry(job_hash=job_hash, + stream_hashes=stream_hashes, + stream_locals=stream_locals, + process_id=pid, + register_time=register_time) + db.add(entry) - # The dataset is marked deleted, but the python garbage collector does not execute - # dataset.__del__ in time, much less dataset.job.__del__ in time, which would have - # automatically ensured the job was un-registered. + # Save the new db to disk. + db.write(self.registry_filename) - dataset = StreamingDataset(...) # Same locals as before. - - # *Boom*, due to "reused" locals, matching the locals still registered from the first - # time the dataset was created. - ``` - - Args: - streams (Sequence[Stream]): List of this StreamingDataset's Streams, which in - combination with process IDs and creation times lets us uniquely identify a - Streaming job. - world (World): Rank-wise world state. + # Add and remove job directories accordingly. + self._make_job_dir(job_hash) + map(self._remove_job_dir, del_job_hashes) - Returns: - str: Subdir for this collection of StreamingDataset replicas. - """ - gc.collect() - if world.is_local_leader: - job_hash = self._do_register(streams) - else: - job_hash = self._lookup(streams) - dirname = os.path.join(self.config_root, job_hash) - wait_for_creation(dirname, self.timeout, self.tick, self.lock) - return job_hash + return job_hash def is_registered(self, job_hash: str) -> bool: """Tell whether the given job_hash is registered. @@ -255,29 +213,11 @@ def is_registered(self, job_hash: str) -> bool: Returns: bool: Whether the job hash is registered. """ + dirname = os.path.join(self.config_root, job_hash) with self.lock: - conf = RegistryFile.read(self.registry_filename) - return conf.contains(job_hash) - - def _do_unregister(self, job_hash: str) -> None: - """Unregister this collection of StreamingDataset replicas. - - Called by the local leader. - - Args: - job_hash (str): Subdir identifying this Streaming job. - """ - pid2create_time = self._get_live_procs() + return os.path.isdir(dirname) - with self.lock: - conf = RegistryFile.read(self.registry_filename) - conf.remove(job_hash) - del_job_hashes = conf.filter(pid2create_time) - conf.write(self.registry_filename) - map(self._remove_dir, del_job_hashes) - self._remove_dir(job_hash) - - def unregister(self, job_hash: str, world: World) -> None: + def unregister(self, job_hash: str, world: World, strict: bool = True) -> None: """Unregister this collection of StreamingDataset replicas. Called by all ranks. @@ -285,32 +225,38 @@ def unregister(self, job_hash: str, world: World) -> None: Args: job_hash (str): Subdir identifying this Streaming job. world (World): Rank-wise world state. + strict (bool): If strict, require the job to be currently registered at start. """ - if world.is_local_leader: - self._do_unregister(job_hash) - else: + if not world.is_local_leader: dirname = os.path.join(self.config_root, job_hash) wait_for_deletion(dirname, self.timeout, self.tick, self.lock) + return - def ensure_unregistered(self, job_hash: str, world: World) -> None: - """Ensure that this collection of StreamingDataset replicas is unregistered. + with self.lock: + # Load the job database. + db = RegistryFile.read(self.registry_filename) - Called by all ranks. + # Check if the job hash is registered. + was_registered = db.contains(job_hash) - Args: - job_hash (str): Subdir identifying this Streaming job. - world (World): Rank-wise world state. - """ - pid2create_time = self._get_live_procs() + # If strict, require the job to be registered. + if strict and not was_registered: + raise ValueError(f'Attempted to unregister job {job_hash}, but it was not ' + + f'registered.') - with self.lock: - conf = RegistryFile.read(self.registry_filename) - is_registered = conf.contains(job_hash) - if not is_registered: - return - - conf.remove(job_hash) - del_job_hashes = conf.filter(pid2create_time) - conf.write(self.registry_filename) - map(self._remove_dir, del_job_hashes) - self._remove_dir(job_hash) + # Unregister the job, if it is registered. + if was_registered: + db.remove(job_hash) + self._remove_job_dir(job_hash) + + # Perform liveness checks on the jobs we have registered. + pid2create_time = self._get_live_procs() + del_job_hashes = db.filter(pid2create_time) + + # If we unregistered the job and/or we garbage collected job(s), save the new jobs + # database back to disk. + if was_registered or del_job_hashes: + db.write(self.registry_filename) + + # Remove each directory corresponding to a job that was garbage collected. + map(self._remove_job_dir, del_job_hashes) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 33c92180b..283606429 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1189,6 +1189,8 @@ def _request_pregen_epoch(self, epoch: int, sample: int) -> None: def _each_pregen_epoch_todo(self) -> Iterator[Tuple[int, int]]: lock_filename = self.job.get_filename(self.pregen_todos_lock_path) todo_filename = self.job.get_filename(self.pregen_todos_path) + dirname = os.path.dirname(lock_filename) + os.makedirs(dirname, exist_ok=True) lock = FileLock(lock_filename) while True: with lock: