Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to precompute the epoch #569

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def _modules_to_rst() -> List[types.ModuleType]:
document_modules: List[types.Module] = [
streaming,
streaming.base.compression,
streaming.base.coord,
streaming.base.format,
streaming.base.hashing,
streaming.base.partition,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'azure-storage-blob>=12.0.0,<13',
'azure-storage-file-datalake>=12.11.0,<13',
'azure-identity>=1.13.0',
'psutil==5.9.4',
]

extra_deps = {}
Expand Down
4 changes: 4 additions & 0 deletions streaming/base/coord/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Functionality having to do with coordination between replicas."""
9 changes: 9 additions & 0 deletions streaming/base/coord/filesystem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Coordinating using files."""

from streaming.base.coord.filesystem.waiting import (create_file, wait_for_creation,
wait_for_deletion)

__all__ = ['create_file', 'wait_for_creation', 'wait_for_deletion']
71 changes: 71 additions & 0 deletions streaming/base/coord/filesystem/waiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Waiting on files."""

import os
from typing import Any, Optional

from streaming.base.coord.waiting import wait

__all__ = ['wait_for_creation', 'wait_for_deletion', 'create_file']


def wait_for_creation(
path: str,
timeout: Optional[float] = 30,
tick: float = 0.007,
lock: Optional[Any] = None,
) -> None:
"""Wait for the creation of a path on the local filesystem.

Args:
path (str): Local path to wait on the creation of.
timeout (float, optional): How long to wait before raising an exception, in seconds.
Defaults to ``30``.
tick (float): Check interval, in seconds. Defaults to ``0.007``.
lock (Any, optional): Context manager (this is intended for locks) to be held when
checking the predicate. Defaults to ``None``.
"""

def stop():
return os.path.exists(path)

wait(stop, timeout, tick, lock)


def wait_for_deletion(
path: str,
timeout: Optional[float] = 30,
tick: float = 0.007,
lock: Optional[Any] = None,
) -> None:
"""Wait for the deletion of a path on the local filesystem.

Args:
path (str): Local path to wait on the deletion of.
timeout (float, optional): How long to wait before raising an exception, in seconds.
Defaults to ``30``.
tick (float): Check interval, in seconds. Defaults to ``0.007``.
lock (Any, optional): Context manager (this is intended for locks) to be held when
checking the predicate. Defaults to ``None``.
"""

def stop():
return not os.path.exists(path)

wait(stop, timeout, tick, lock)


def create_file(filename: str) -> None:
"""Create a file at the given path on the local filesystem.

Raises an exception if the path already exists.

Args:
filename (str): Filename to create.
"""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
with open(filename, 'x'):
pass
9 changes: 9 additions & 0 deletions streaming/base/coord/job/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Handling for jobs, which are collections of StreamingDataset replicas with the same config."""

from streaming.base.coord.job.dir import JobDir
from streaming.base.coord.job.registry import JobRegistry

__all__ = ['JobDir', 'JobRegistry']
63 changes: 63 additions & 0 deletions streaming/base/coord/job/dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""A directory containing all dataset-wide filesystem state for a Streaming job."""

import os
from typing import Sequence

from streaming.base.coord.job.registry import JobRegistry
from streaming.base.stream import Stream
from streaming.base.world import World

__all__ = ['JobDir']


class JobDir:
"""Represents a Streaming job lease. On ``__del__``, cleans up after itself.

When it goes out of scope naturally, this Job will delete its config dir and its hold on all
the local dirs it is streaming to.

If this process dies badly and the destructor is not reached, the same cleanup will be done by
some future process incidentally as it registers or unregisters a Streaming job. It can tell it
died by a combination of pid and process create time.

Args:
registry (JobRegistry): Stremaing job registry.
"""

def __init__(self, registry: JobRegistry, streams: Sequence[Stream], world: World) -> None:
self.registry = registry
self.streams = streams
self.world = world
self.job_hash = registry.register(streams, world)

def get_filename(self, path: str) -> str:
"""Get a filename by relative path under its job dir.

Args:
path (str): Path relative to job dir.

Returns:
str: Filename.
"""
return os.path.join(self.registry.config_root, self.job_hash, path)

def manual_unregister(self) -> None:
"""Explicitly un-register the job ahead of its deletion.

This is useful when you want to ensure that this job is un-registered synchronously instead
of whenever the garbage collector eventually gets around to it.

This job must be registered when this is called.
"""
self.registry.unregister(self.job_hash, self.world, True)

def __del__(self) -> None:
"""Destructor.

You may unregister the job explicitly ahead of time (to ensure it happens synchronously
instead of eventually).
"""
self.registry.unregister(self.job_hash, self.world, False)
65 changes: 65 additions & 0 deletions streaming/base/coord/job/entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""An entry in a Streaming job registry file."""

from typing import Any, Dict, List, Optional

from typing_extensions import Self

__all__ = ['JobEntry']


class JobEntry:
"""Info about a Streaming job for local dir reuse detection purposes.

Args:
index (int, optional): The job's index in the total list.
job_hash (str): Job hash.
stream_hashes (List[str]): Stream hashes.
stream_locals (List[str], optional): Stream locals, if available.
process_id (int): PID of local rank zero of the Streaming job.
register_time (int): Process registration time.
"""

def __init__(
self,
*,
index: Optional[int] = None,
job_hash: str,
stream_hashes: List[str],
stream_locals: Optional[List[str]] = None,
process_id: int,
register_time: int,
) -> None:
self.index = index
self.job_hash = job_hash
self.stream_hashes = stream_hashes
self.stream_locals = stream_locals
self.process_id = process_id
self.register_time = register_time

@classmethod
def from_json(cls, obj: Dict[str, Any]) -> Self:
"""Load from JSON.

Args:
obj (Dict[str, Any]): Source JSON object.

Returns:
Self: Loaded JobEntry.
"""
return cls(job_hash=obj['job_hash'],
stream_hashes=obj['stream_hashes'],
stream_locals=obj.get('stream_locals'),
process_id=obj['process_id'],
register_time=obj['register_time'])

def to_json(self) -> Dict[str, Any]:
return {
'job_hash': self.job_hash,
'stream_hashes': self.stream_hashes,
# stream_locals is not saved, only their hashes.
'process_id': self.process_id,
'register_time': self.register_time,
}
145 changes: 145 additions & 0 deletions streaming/base/coord/job/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""A Streaming job registry file."""

import json
import os
from typing import Dict, List

from typing_extensions import Self

from streaming.base.coord.job.entry import JobEntry

__all__ = ['RegistryFile']


class RegistryFile:
"""StreamingDataset job registry, which is backed by a JSON file.

Args:
jobs (List[JobEntry]): List of StreamingDataset jobs.
"""

def __init__(self, jobs: List[JobEntry]) -> None:
self.jobs = []
self.job_hash2job = {}
self.stream_hash2job = {}
self.num_jobs = 0
for job in jobs:
self.add(job)

@classmethod
def read(cls, filename: str) -> Self:
if os.path.exists(filename):
try:
obj = json.load(open(filename))
except:
os.remove(filename)
obj = {}
else:
obj = {}
jobs = obj.get('jobs') or []
jobs = [JobEntry.from_json(job) for job in jobs]
return cls(jobs)

def write(self, filename: str) -> None:
jobs = [job.to_json() for job in filter(bool, self.jobs)]
obj = {'jobs': jobs}
with open(filename, 'w') as out:
json.dump(obj, out)

def __len__(self) -> int:
"""Get the number of jobs registered.

Returns:
int: Number of registered jobs.
"""
return self.num_jobs

def add(self, job: JobEntry) -> None:
"""Register a Streaming job.

Args:
job (Job): The job.
"""
# Check that stream locals line up.
if job.stream_locals:
if len(job.stream_hashes) != len(job.stream_locals):
raise ValueError(f'If locals are provided, must have one local per stream hash, ' +
f'but got: {len(job.stream_hashes)} hashes vs ' +
f'{len(job.stream_locals)} locals.')
norm_stream_locals = job.stream_locals
else:
norm_stream_locals = [None] * len(job.stream_hashes)

# Check dataset hash for reuse.
if job.job_hash in self.job_hash2job:
if job.stream_locals:
raise ValueError(f'Reused dataset local path(s): {job.stream_locals}.')
else:
raise ValueError(f'Reused dataset local path(s): stream hashes = ' +
f'{job.stream_hashes}, dataset hash = {job.job_hash}.')

# Check each stream hash for reuse.
for stream_hash, norm_stream_local in zip(job.stream_hashes, norm_stream_locals):
if stream_hash in self.stream_hash2job:
if norm_stream_local:
raise ValueError('Reused stream local path: {norm_stream_local}.')
else:
raise ValueError('Reused stream local path: stream hash = {stream_hash}.')

# Do the insertion.
job.index = len(self.jobs)
self.jobs.append(job)
self.job_hash2job[job.job_hash] = job
for stream_hash in job.stream_hashes:
self.stream_hash2job[stream_hash] = job
self.num_jobs += 1

def contains(self, job_hash: str) -> bool:
"""Tell whether the given job_hash is registered.

Args:
job_hash (str): Potentially registered job hash.

Returns:
bool: Whether the job hash is registered.
"""
return job_hash in self.job_hash2job

def remove(self, job_hash: str) -> None:
"""Deregister a Streaming job.

Args:
job_hash (str): Job hash.
"""
job = self.job_hash2job.get(job_hash)
if not job:
raise ValueError(f'Job hash not found: {job_hash}.')

if job.index is None:
raise ValueError('Internal error in job registration: job index is missing.')

self.jobs[job.index] = None
del self.job_hash2job[job.job_hash]
for stream_hash in job.stream_hashes:
del self.stream_hash2job[stream_hash]
self.num_jobs -= 1

def filter(self, pid2create_time: Dict[int, int]) -> List[str]:
"""Filter our collection of Streaming jobs.

Args:
pid2create_time (Dict[int, int]): Mapping of pid to creation time.

Returns:
List[str]: List of hashes of removed datasets.
"""
del_job_hashes = []
for job in filter(bool, self.jobs):
create_time = pid2create_time.get(job.process_id)
if not create_time or job.register_time < create_time:
self.remove(job.job_hash)
del_job_hashes.append(job.job_hash)
return del_job_hashes
Loading
Loading