From a859b5853318b9228b525ea3cb7e933bca296aa9 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 27 Oct 2023 22:39:28 -0700 Subject: [PATCH 01/45] Stub out index_delta(), index_lance(), index_parquet(). --- streaming/base/format/__init__.py | 7 ++- streaming/base/format/delta/__init__.py | 8 +++ streaming/base/format/delta/indexing.py | 41 +++++++++++++ streaming/base/format/lance/__init__.py | 8 +++ streaming/base/format/lance/indexing.py | 35 +++++++++++ streaming/base/format/parquet/__init__.py | 8 +++ streaming/base/format/parquet/indexing.py | 72 +++++++++++++++++++++++ 7 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 streaming/base/format/delta/__init__.py create mode 100644 streaming/base/format/delta/indexing.py create mode 100644 streaming/base/format/lance/__init__.py create mode 100644 streaming/base/format/lance/indexing.py create mode 100644 streaming/base/format/parquet/__init__.py create mode 100644 streaming/base/format/parquet/indexing.py diff --git a/streaming/base/format/__init__.py b/streaming/base/format/__init__.py index 962828ae2..d819a6a9a 100644 --- a/streaming/base/format/__init__.py +++ b/streaming/base/format/__init__.py @@ -6,15 +6,18 @@ from typing import Any, Dict, Optional from streaming.base.format.base import FileInfo, Reader +from streaming.base.format.delta import index_delta from streaming.base.format.index import get_index_basename from streaming.base.format.json import JSONReader, JSONWriter +from streaming.base.format.lance import index_lance from streaming.base.format.mds import MDSReader, MDSWriter +from streaming.base.format.parquet import index_parquet from streaming.base.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter) __all__ = [ - 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Reader', - 'reader_from_json', 'TSVWriter', 'XSVWriter' + 'CSVWriter', 'FileInfo', 'JSONWriter', 'MDSWriter', 'Reader', 'TSVWriter', 'XSVWriter', + 'get_index_basename', 'index_delta', 'index_lance', 'index_parquet', 'reader_from_json' ] _readers = { diff --git a/streaming/base/format/delta/__init__.py b/streaming/base/format/delta/__init__.py new file mode 100644 index 000000000..248e928a0 --- /dev/null +++ b/streaming/base/format/delta/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Integration with Delta tables.""" + +from streaming.base.format.delta.indexing import index_delta + +__all__ = ['index_delta'] diff --git a/streaming/base/format/delta/indexing.py b/streaming/base/format/delta/indexing.py new file mode 100644 index 000000000..6071ae0a8 --- /dev/null +++ b/streaming/base/format/delta/indexing.py @@ -0,0 +1,41 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Index a Delta table for use by Streaming.""" + +from typing import Any, Dict, Optional, Union + + +def index_delta(*, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + version: Optional[int] = None, + num_threads: Optional[int] = 0, + download_timeout: Union[float, str] = '1m', + max_file_bytes: Optional[Union[int, str]] = '200mb', + columns: Optional[Dict[str, Optional[str]]] = None, + show_progress: bool = True) -> Dict[str, Any]: + """Initialize from a local and/or remote Delta table directory. + + Args: + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + version (int, optional): Which snapshot version of the dataset to use, or else take the + latest if ``None``. Defaults to ``None``. + num_threads (int, optional): Number of threads for downloading potentially many very small + files. ``None`` means single-threaded; ``0`` means threads; positive + int means that number of threads. Default: ``0``. + download_timeout (Union[float, str]): For each Delta metadata file. Defaults to ``1m``. + max_file_bytes (Union[int, str], optional): File size limit, above which we raise an error. + This is a performance guard rail, as choppiness increases linearly with shard size. The + sweet spot is typically around 32mb. Defaults to ``200mb``. + columns (Dict[str, str], optional): For field names and types specified here, override the + inferred schema to configure it manually. Defaults to ``None``. + show_progress (bool): Show progress bar for downloading Delta logs. Defaults to ``True``. + + Returns: + Dict[str, Any]: StreamingDataset index configuration to stream this Delta table. + """ + raise NotImplementedError # TODO diff --git a/streaming/base/format/lance/__init__.py b/streaming/base/format/lance/__init__.py new file mode 100644 index 000000000..3e3d3ac87 --- /dev/null +++ b/streaming/base/format/lance/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Integration with Lance datasets.""" + +from streaming.base.format.lance.indexing import index_lance + +__all__ = ['index_lance'] diff --git a/streaming/base/format/lance/indexing.py b/streaming/base/format/lance/indexing.py new file mode 100644 index 000000000..65e6da30b --- /dev/null +++ b/streaming/base/format/lance/indexing.py @@ -0,0 +1,35 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Indexing a Lance dataset for use by Streaming.""" + +from typing import Any, Dict, Optional, Union + + +def index_lance(*, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + version: Optional[int] = None, + download_timeout: Union[float, str] = '1m', + max_file_bytes: Optional[Union[int, str]] = '200mb', + columns: Optional[Dict[str, str]] = None) -> Dict[str, Any]: + """Initialize from a local and/or remote Lance dataset directory. + + Args: + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + version (int, optional): Which snapshot version of the dataset to use, or else take + the latest if ``None``. Defaults to ``None``. + download_timeout (Union[float, str]): For each Lance metadata file. Defaults to ``1m``. + max_file_bytes (Union[int, str], optional): File size limit, above which we raise an + error. This is a performance guard rail, as choppiness increases linearly with + shard size. The sweet spot is typically around 32mb. Defaults to ``200mb``. + columns (Dict[str, str], optional): For field names and types specified here, override + the inferred schema to configure it manually. Defaults to ``None``. + + Returns: + Dict[str, Any]: StreamingDataset index configuration to stream this Lance dataset. + """ + raise NotImplementedError # TODO diff --git a/streaming/base/format/parquet/__init__.py b/streaming/base/format/parquet/__init__.py new file mode 100644 index 000000000..c2847cee8 --- /dev/null +++ b/streaming/base/format/parquet/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Integration with Parquet datasets.""" + +from streaming.base.format.parquet.indexing import index_parquet + +__all__ = ['index_parquet'] diff --git a/streaming/base/format/parquet/indexing.py b/streaming/base/format/parquet/indexing.py new file mode 100644 index 000000000..a933a6f5a --- /dev/null +++ b/streaming/base/format/parquet/indexing.py @@ -0,0 +1,72 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Indexing a Parquet dataset for use by Streaming.""" + +from re import Pattern +from typing import Any, Callable, Dict, Iterable, Optional, Union + +Filter = Union[str, Pattern, Callable[[str], bool]] + + +def index_parquet(*, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + files: Optional[Iterable[str]] = None, + keep: Optional[Filter] = r'.*\.parquet$', + num_procs: Optional[int] = 0, + download_timeout: Union[float, str] = '2m', + max_file_bytes: Optional[Union[int, str]] = '200mb', + same_schema: bool = True, + columns: Optional[Dict[str, str]] = None, + show_progress: bool = True) -> Dict[str, Any]: + r"""Initialize from a local and/or remote Parquet dataset directory. + + "Parquet dataset" means the samples live in a collection of naked Parquet files. There is not + any kind of index or manifest we can count on existing, so we will have to create one. + + Assumptions: + * Samples live in a collection of naked Parquet files. + * There is not any kind of index or manifest that we can count on existing. + * Files are all found under a common root directory, which local/remote point to. + * This root directory may contain other files, which we ignore. + * Ideally, but not necessarily, the Parquets all have the same schema. + + Locality: + * If we are given an explicit list of Parquet files, we try local first, then remote. Both + are cross-checked for completeness. + * If we are default listing all files instead, and just have a local, it is assumed to be + complete. + * If we are listing files, and remote is provided, the remote must be authoritative. + + Args: + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + files (Iterable[str], optional): An Iterable of file paths relative to dataset root. These + paths filtered for the Parquets constituting this dataset by ``keep``. If not set, we + default to a sorted listing of all the files under dataset root. We list the remote if + provided, else we assume local is complete. Defaults to ``None``. + keep (Union[str, Pattern, Callable[[str], bool]], optional): Iterating ``files``, we keep + the ones this regex matches (if str) or predicate accepts (if Callable). Defaults to + ``.*\.parquet$``, i.e. include every file that ends with ".parquet". + num_procs (int, optional): Number of processes for download/processing of potentially many + large Parquet files. ``None`` means single-process; ``0`` means + processes; positive int means that number of processes. Defaults to ``0``. + download_timeout (Union[float, str]): For each Parquet file. Defaults to ``2m``. + max_file_bytes (Union[int, str], optional): File size limit, above which we raise an error. + This is a performance guard rail, as choppiness increases linearly with shard size. The + sweet spot is typically around 32mb. Defaults to ``200mb``. + same_schema (bool): Whether to require that all the dataset Parquets have exactly the same + Parquet schema. This is a correctness guard rail, preventing non-dataset Parquet shards + from sneaking into our dataset. Streaming for its part is fine with shards being + "incompatible"; assumes client will handle it. Defaults to ``True``. + columns (Dict[str, str], optional): For field names and types specified here, override the + inferred schema to configure it manually. Defaults to ``None``. + show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. + + Returns: + Dict[str, Any]: StreamingDataset index configuration to stream this Parquet dataset. + """ + raise NotImplementedError # TODO From bd0208afdd1b4f109b87f98019a6dd8822f8d7fc Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 27 Oct 2023 23:25:11 -0700 Subject: [PATCH 02/45] index_backend(). --- streaming/base/format/__init__.py | 79 ++++++++++++++++++++++- streaming/base/format/delta/indexing.py | 2 +- streaming/base/format/lance/indexing.py | 2 +- streaming/base/format/parquet/indexing.py | 2 +- 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/streaming/base/format/__init__.py b/streaming/base/format/__init__.py index d819a6a9a..7b21a8cb3 100644 --- a/streaming/base/format/__init__.py +++ b/streaming/base/format/__init__.py @@ -3,7 +3,7 @@ """Individual dataset writer for every format.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from streaming.base.format.base import FileInfo, Reader from streaming.base.format.delta import index_delta @@ -17,7 +17,8 @@ __all__ = [ 'CSVWriter', 'FileInfo', 'JSONWriter', 'MDSWriter', 'Reader', 'TSVWriter', 'XSVWriter', - 'get_index_basename', 'index_delta', 'index_lance', 'index_parquet', 'reader_from_json' + 'get_index_basename', 'index_backend', 'index_delta', 'index_lance', 'index_parquet', + 'reader_from_json' ] _readers = { @@ -43,3 +44,77 @@ def reader_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> assert obj['version'] == 2 cls = _readers[obj['format']] return cls.from_json(dirname, split, obj) + + +def index_backend(backend: str, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + version: Optional[int] = None, + num_procs: Optional[int] = 0, + download_timeout: Union[float, str] = '1m', + max_file_bytes: Union[int, str] = '200mb', + same_schema: bool = True, + columns: Optional[Dict[str, Any]] = None, + show_progress: bool = True) -> Dict[str, Any]: + """Index a local and/or remote third-party dataset directory for use by Streaming. + + Args: + backend (str): What dataset/database system serves this entire dataset, whose files we + convert, wrap, or both as Streaming shards. Must be one of ``delta`` (Delta table), + ``lance`` (Lance dataset), or ``parquet`` (Parquet dataset) (if ``streaming``, the + index is created at dataset write time). + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + version (int, optional): Dataset snapshot version (used by ``delta`` and ``lance`` + datasets). If not provided, takes the latest version. Defaults to ``None``. + num_procs (int, optional): Parallelism for downloading/processing of third-party dataset + files. ``None`` means single-process. ``0`` means processes. Positive + integer means use that number of processes. Defaults to ``0``. + download_timeout (Union[float, str]): For each Parquet file. Defaults to ``2m``. + max_file_bytes (Union[int, str], optional): File size limit, above which we raise an error. + This is a performance guard rail, as choppiness increases linearly with shard size. The + sweet spot is typically around 32mb. Defaults to ``200mb``. + same_schema (bool): Whether to require that all the dataset shards have exactly the same + MDS schema. Applicable to indexless Parquet datasets. This is a correctness guard rail, + preventingh non-dataset shards from sneaking into our dataset. Streaming for its part + is fine with shards being "incompatible"; assumes client will handle it. Defaults to + ``True``. + columns (Dict[str, str], optional): For field names and types specified here, override the + inferred schema to configure it manually. Defaults to ``None``. + show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. + + Returns: + Dict[str, Any]: StreamingDataset index configuration to stream this Parquet dataset. + """ + if backend == 'delta': + return index_delta(local=local, + remote=remote, + split=split, + version=version, + num_threads=num_procs, + download_timeout=download_timeout, + max_file_bytes=max_file_bytes, + columns=columns, + show_progress=show_progress) + elif backend == 'lance': + return index_lance(local=local, + remote=remote, + split=split, + version=version, + download_timeout=download_timeout, + max_file_bytes=max_file_bytes, + columns=columns) + elif backend == 'parquet': + return index_parquet(local=local, + remote=remote, + split=split, + num_procs=num_procs, + download_timeout=download_timeout, + max_file_bytes=max_file_bytes, + same_schema=same_schema, + columns=columns, + show_progress=show_progress) + else: + raise ValueError('Unsupported backend: {backend}.') diff --git a/streaming/base/format/delta/indexing.py b/streaming/base/format/delta/indexing.py index 6071ae0a8..75656d8a4 100644 --- a/streaming/base/format/delta/indexing.py +++ b/streaming/base/format/delta/indexing.py @@ -16,7 +16,7 @@ def index_delta(*, max_file_bytes: Optional[Union[int, str]] = '200mb', columns: Optional[Dict[str, Optional[str]]] = None, show_progress: bool = True) -> Dict[str, Any]: - """Initialize from a local and/or remote Delta table directory. + """Index a local and/or remote Delta table directory for use by Streaming. Args: local (str): Where the dataset is cached on the local filesystem. diff --git a/streaming/base/format/lance/indexing.py b/streaming/base/format/lance/indexing.py index 65e6da30b..6ad3d0cd5 100644 --- a/streaming/base/format/lance/indexing.py +++ b/streaming/base/format/lance/indexing.py @@ -14,7 +14,7 @@ def index_lance(*, download_timeout: Union[float, str] = '1m', max_file_bytes: Optional[Union[int, str]] = '200mb', columns: Optional[Dict[str, str]] = None) -> Dict[str, Any]: - """Initialize from a local and/or remote Lance dataset directory. + """Index a local and/or remote Lance dataset directory for use by Streaming. Args: local (str): Where the dataset is cached on the local filesystem. diff --git a/streaming/base/format/parquet/indexing.py b/streaming/base/format/parquet/indexing.py index a933a6f5a..a8cfddca4 100644 --- a/streaming/base/format/parquet/indexing.py +++ b/streaming/base/format/parquet/indexing.py @@ -21,7 +21,7 @@ def index_parquet(*, same_schema: bool = True, columns: Optional[Dict[str, str]] = None, show_progress: bool = True) -> Dict[str, Any]: - r"""Initialize from a local and/or remote Parquet dataset directory. + """Index a local and/or remote Parquet dataset directory for use by Streaming. "Parquet dataset" means the samples live in a collection of naked Parquet files. There is not any kind of index or manifest we can count on existing, so we will have to create one. From 69575bc334d199b73fdbff355551a58db2264965 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 27 Oct 2023 23:37:41 -0700 Subject: [PATCH 03/45] Fix. --- streaming/base/format/parquet/indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/parquet/indexing.py b/streaming/base/format/parquet/indexing.py index a8cfddca4..c68b1b3be 100644 --- a/streaming/base/format/parquet/indexing.py +++ b/streaming/base/format/parquet/indexing.py @@ -21,7 +21,7 @@ def index_parquet(*, same_schema: bool = True, columns: Optional[Dict[str, str]] = None, show_progress: bool = True) -> Dict[str, Any]: - """Index a local and/or remote Parquet dataset directory for use by Streaming. + r"""Index a local and/or remote Parquet dataset directory for use by Streaming. "Parquet dataset" means the samples live in a collection of naked Parquet files. There is not any kind of index or manifest we can count on existing, so we will have to create one. From 42b59f106a05123a0ad2f5fec9067ba225d1e6bc Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 27 Oct 2023 23:38:43 -0700 Subject: [PATCH 04/45] task.py for benchmarking. --- benchmarks/backends-and-formats/task.py | 107 ++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 benchmarks/backends-and-formats/task.py diff --git a/benchmarks/backends-and-formats/task.py b/benchmarks/backends-and-formats/task.py new file mode 100644 index 000000000..0bf550f55 --- /dev/null +++ b/benchmarks/backends-and-formats/task.py @@ -0,0 +1,107 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate infinite samples for a 'saying numbers as words' task.""" + +from typing import List, Tuple + +import numpy as np +from tqdm import tqdm + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def _say(i: int) -> List[str]: + """Get the word form of a number. + + Args: + i (int): The number. + + Returns: + List[str]: The number in word form. + """ + if i < 0: + return ['negative'] + _say(-i) + elif i <= 19: + return [_ones[i]] + elif i < 100: + return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) + elif i < 1_000: + return [_ones[i // 100], 'hundred'] + (_say(i % 100) if i % 100 else []) + elif i < 1_000_000: + return _say(i // 1_000) + ['thousand'] + (_say(i % 1_000) if i % 1_000 else []) + elif i < 1_000_000_000: + return _say(i // 1_000_000) + ['million'] + (_say(i % 1_000_000) if i % 1_000_000 else []) + else: + raise ValueError('Integer must be less than a billion, but got: {i}') + + +def _generate_number() -> int: + """Generate a random integer to say. + + Returns: + int: The integer. + """ + sign = (np.random.uniform() < 0.8) * 2 - 1 + expt = np.random.uniform(0, 9) + mag = int(10**expt) + return sign * mag + + +def _generate_numbers(num_train: int, num_val: int, + show_progress: bool) -> Tuple[List[int], List[int]]: + """Get two non-overlapping splits of integers to say. + + Args: + num_train (int): Number of training samples. + num_val (int): Number of validation samples. + show_progress (bool): Whether to display a progress bar. + + Returns: + Tuple[List[int], List[int]]: The two generated splits. + """ + total = num_train + num_val + nums = set() + pbar = tqdm(total=total, leave=False) if show_progress else None + while len(nums) < total: + num = _generate_number() + if num in nums: + continue + nums.add(num) + if pbar: + pbar.update(1) + if pbar: + pbar.close() + nums = sorted(nums) + np.random.shuffle(nums) + train_nums = nums[:num_train] + val_nums = nums[num_train:] + return train_nums, val_nums + + +_split_type = Tuple[str, List[int], List[str]] + + +def generate_dataset(num_train: int, num_val: int, show_progress: bool) -> List[_split_type]: + """Generate the dataset, which will be saved in different forms for comparison. + + Args: + num_train (int): Number of train samples. + num_val (int): Number of val samples. + show_progress (bool): Whether to show a progress bar. + + Returns: + List[Tuple[str, List[int], List[str]]]: List of dataset splits. + """ + train_nums, val_nums = _generate_numbers(num_train, num_val, show_progress) + + train_txts = [' '.join(_say(num)) for num in train_nums] + val_txts = [' '.join(_say(num)) for num in val_nums] + + return [ + ('train', train_nums, train_txts), + ('val', val_nums, val_txts), + ] From 2fb1b09879f671ef9e4a8c86829a7f190703c139 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 00:10:16 -0700 Subject: [PATCH 05/45] generate_datasets.py. --- benchmarks/backends-and-formats/__init__.py | 4 + .../backends-and-formats/generate_datasets.py | 293 ++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 benchmarks/backends-and-formats/__init__.py create mode 100644 benchmarks/backends-and-formats/generate_datasets.py diff --git a/benchmarks/backends-and-formats/__init__.py b/benchmarks/backends-and-formats/__init__.py new file mode 100644 index 000000000..7ce06d32c --- /dev/null +++ b/benchmarks/backends-and-formats/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmarking generating/iterating datasets of different backends and formats.""" diff --git a/benchmarks/backends-and-formats/generate_datasets.py b/benchmarks/backends-and-formats/generate_datasets.py new file mode 100644 index 000000000..1bb661863 --- /dev/null +++ b/benchmarks/backends-and-formats/generate_datasets.py @@ -0,0 +1,293 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a parquet dataset for testing.""" + +import os +from argparse import ArgumentParser, Namespace +from functools import partial +from shutil import rmtree +from time import time +from typing import List, Optional + +import lance +import pyarrow as pa +import pyspark +import pyspark.sql +from delta import configure_spark_with_delta_pip +from pyarrow import parquet as pq +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from task import generate_dataset +from tqdm import tqdm +from wurlitzer import pipes + +from streaming import CSVWriter, JSONWriter, MDSWriter + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--show_progress', type=int, default=1) + + args.add_argument('--seed', type=int, default=1337) + args.add_argument('--num_train', type=int, default=1 << 21) + args.add_argument('--num_val', type=int, default=1 << 17) + + args.add_argument('--data_root', type=str, default='data/compare-backends/') + args.add_argument('--csv', type=str, default='csv') + args.add_argument('--jsonl', type=str, default='jsonl') + args.add_argument('--lance', type=str, default='lance') + args.add_argument('--mds', type=str, default='mds') + args.add_argument('--parquet', type=str, default='parquet') + args.add_argument('--delta', type=str, default='delta') + + args.add_argument('--size_limit', type=int, default=1 << 23) + args.add_argument('--samples_per_shard', type=int, default=1 << 17) + args.add_argument('--quiet_delta', type=int, default=1) + return args.parse_args() + + +def _save_csv(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset in Streaming CSV form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = {'num': 'int', 'txt': 'str'} + with CSVWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = {'num': num, 'txt': txt} + out.write(sample) + + +def _save_jsonl(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset Streaming JSONL form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = {'num': 'int', 'txt': 'str'} + with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = {'num': num, 'txt': txt} + out.write(sample) + + +def _save_mds(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = {'num': 'int', 'txt': 'str'} + with MDSWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = {'num': num, 'txt': txt} + out.write(sample) + + +def _save_parquet(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + show_progress: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + if not os.path.exists(root): + os.makedirs(root) + num_samples = len(nums) + num_shards = (num_samples + samples_per_shard - 1) // samples_per_shard + each_shard = range(num_shards) + if show_progress: + each_shard = tqdm(each_shard, total=num_shards, leave=False) + for i in each_shard: + begin = i * samples_per_shard + end = min(begin + samples_per_shard, num_samples) + shard_nums = nums[begin:end] + shard_txts = txts[begin:end] + path = os.path.join(root, f'{i:05}.parquet') + obj = { + 'num': shard_nums, + 'txt': shard_txts, + } + table = pa.Table.from_pydict(obj) + pq.write_table(table, path) + + +def _wrapped_save_delta(nums: List[int], txts: List[str], root: str, + samples_per_shard: int) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + builder = pyspark.sql.SparkSession.builder.appName('deltatorch-example') # pyright: ignore + builder = builder.config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') + builder = builder.config('spark.sql.catalog.spark_catalog', + 'org.apache.spark.sql.delta.catalog.DeltaCatalog') + spark = configure_spark_with_delta_pip(builder).getOrCreate() + schema = StructType([ + StructField('num', IntegerType(), False), + StructField('txt', StringType(), False), + ]) + samples = list(zip(nums, txts)) + df = spark.createDataFrame(samples, schema) + df.write.format('delta').option('maxRecordsPerFile', samples_per_shard).save(root) + + +def _save_delta(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + quiet: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + quiet (bool): Whether to capture the Delta logging. Defaults to ``True``. + """ + bang_on_pipes = lambda: _wrapped_save_delta(nums, txts, root, samples_per_shard) + if quiet: + with pipes(): + bang_on_pipes() + else: + bang_on_pipes() + + +def _save_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: + """Save the dataset in Lance form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + column_names = 'num', 'txt' + column_values = nums, txts + table = pa.Table.from_arrays(column_values, column_names) + lance.write_dataset(table, root, mode='create', max_rows_per_file=samples_per_shard) + + +def _stat(root: str): + """Inventory what was written, collecting total files and total bytes. + + Args: + root (str): Dataset root. + + Returns: + Tuple[int, int]: Total files and total bytes written. + """ + rf = 0 + rz = 0 + for p, _, ff in os.walk(root): + rf += len(ff) + for f in ff: + g = os.path.join(p, f) + rz += os.stat(g).st_size + return rf, rz + + +def main(args: Namespace) -> None: + """Generate identical datasets in various formats for performance comparison. + + Args: + args (Namespace): Command-line arguments. + """ + if os.path.exists(args.data_root): + rmtree(args.data_root) + + kinds = 'csv', 'jsonl', 'lance', 'mds', 'parquet', 'delta' + + show_progress = bool(args.show_progress) + quiet_delta = bool(args.quiet_delta) + + kind2save = { + 'csv': + partial(_save_csv, size_limit=args.size_limit, show_progress=show_progress), + 'delta': + partial(_save_delta, samples_per_shard=args.samples_per_shard, quiet=quiet_delta), + 'jsonl': + partial(_save_jsonl, size_limit=args.size_limit, show_progress=show_progress), + 'lance': + partial(_save_lance, samples_per_shard=args.samples_per_shard), + 'mds': + partial(_save_mds, size_limit=args.size_limit, show_progress=show_progress), + 'parquet': + partial(_save_parquet, + samples_per_shard=args.samples_per_shard, + show_progress=show_progress), + } + + start = time() + dataset = generate_dataset(args.num_train, args.num_val, show_progress) + elapsed = time() - start + print(f'Dataset generation: {elapsed:.3f} sec.') + + for split, nums, txts in dataset: + print(f'Split {split}:') + for kind in kinds: + kind_subdir = getattr(args, kind) + split_root = os.path.join(args.data_root, 'gold', kind_subdir, split) + save = kind2save[kind] + start = time() + save(nums, txts, split_root) + elapsed = time() - start + num_files, num_bytes = _stat(split_root) + bytes_per_file = num_bytes // num_files + print(f'* Saving dataset as {kind:8}: {elapsed:8.3f} sec; {num_files:3,} files; ' + + f'{num_bytes:12,} bytes; {bytes_per_file:12,} bytes/file.') + + +if __name__ == '__main__': + main(parse_args()) From 11dd6737a09e00f74b773a7d27babc49a51f1b08 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 00:16:47 -0700 Subject: [PATCH 06/45] Fix. --- streaming/base/format/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/format/__init__.py b/streaming/base/format/__init__.py index 7b21a8cb3..82b506562 100644 --- a/streaming/base/format/__init__.py +++ b/streaming/base/format/__init__.py @@ -53,7 +53,7 @@ def index_backend(backend: str, version: Optional[int] = None, num_procs: Optional[int] = 0, download_timeout: Union[float, str] = '1m', - max_file_bytes: Union[int, str] = '200mb', + max_file_bytes: Optional[Union[int, str]] = '200mb', same_schema: bool = True, columns: Optional[Dict[str, Any]] = None, show_progress: bool = True) -> Dict[str, Any]: From 82737e002fdeb1722e1f81fa4ebb8da7e2b73b1e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 00:38:15 -0700 Subject: [PATCH 07/45] Organize/divide streaming/base/util.py: Into: - importing - merging, - pretty - retrying - shared - storage. --- streaming/base/util/__init__.py | 16 ++ streaming/base/util/importing.py | 20 ++ streaming/base/{util.py => util/merging.py} | 297 +------------------- streaming/base/util/pretty.py | 115 ++++++++ streaming/base/util/retrying.py | 109 +++++++ streaming/base/util/shared.py | 48 ++++ streaming/base/util/storage.py | 33 +++ 7 files changed, 345 insertions(+), 293 deletions(-) create mode 100644 streaming/base/util/__init__.py create mode 100644 streaming/base/util/importing.py rename streaming/base/{util.py => util/merging.py} (51%) create mode 100644 streaming/base/util/pretty.py create mode 100644 streaming/base/util/retrying.py create mode 100644 streaming/base/util/shared.py create mode 100644 streaming/base/util/storage.py diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py new file mode 100644 index 000000000..0ab5e2492 --- /dev/null +++ b/streaming/base/util/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities and helkper methods needed by Streaming.""" + +from streaming.base.util.pretty import bytes_to_int, get_list_arg, number_abbrev_to_int +from streaming.base.util.importing import get_import_exception_message +from streaming.base.util.merging import merge_index +from streaming.base.util.retrying import retry +from streaming.base.util.shared import clean_stale_shared_memory +from streaming.base.util.storage import wait_for_file_to_exist + +__all__ = [ + 'bytes_to_int', 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', + 'merge_index', 'number_abbrev_to_int', 'retry', 'wait_for_file_to_exist' +] diff --git a/streaming/base/util/importing.py b/streaming/base/util/importing.py new file mode 100644 index 000000000..7cf62d6f6 --- /dev/null +++ b/streaming/base/util/importing.py @@ -0,0 +1,20 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""User-friendly import exception message.""" + +__all__ = ['get_import_exception_message'] + + +def get_import_exception_message(package_name: str, extra_deps: str) -> str: + """Get import exception message. + + Args: + package_name (str): Package name. + + Returns: + str: Exception message. + """ + return f'Streaming was installed without {package_name} support. ' + \ + f'To use {package_name} related packages with Streaming, run ' + \ + f'`pip install \'mosaicml-streaming[{package_name}]\'`.' diff --git a/streaming/base/util.py b/streaming/base/util/merging.py similarity index 51% rename from streaming/base/util.py rename to streaming/base/util/merging.py index e86876ee1..8411d5cec 100644 --- a/streaming/base/util.py +++ b/streaming/base/util/merging.py @@ -1,219 +1,23 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Utility and helper functions for datasets.""" +"""Merging serialized streaming datasets.""" -import collections.abc -import functools import json import logging import os -import random import shutil import tempfile import urllib.parse from collections import OrderedDict -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from pathlib import Path -from time import sleep, time -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +from typing import Any, List, Tuple, Union -import torch.distributed as dist - -from streaming.base.constant import SHM_TO_CLEAN -from streaming.base.distributed import get_local_rank, maybe_init_dist from streaming.base.format.index import get_index_basename -from streaming.base.shared.prefix import _get_path - -logger = logging.getLogger(__name__) - -TCallable = TypeVar('TCallable', bound=Callable) - -__all__ = [ - 'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int', - 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry' -] - - -def get_list_arg(text: str) -> List[str]: - """Pass a list as a command-line flag. - - Args: - text (str): Text to split. - - Returns: - List[str]: Splits, if any. - """ - return text.split(',') if text else [] - - -def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, - err_msg: str) -> None: - """Wait for the file to exist till timeout seconds. Raise an Exception after that. - - Args: - filename (str): A file name - poll_interval (float): Number of seconds to wait before next polling - timeout (float): Number of seconds to wait for a file to exist before raising an exception - err_msg (str): Error message description for an exception - - Raises: - RuntimeError: Raise an Exception if file does not exist after timeout - """ - start_time = time() - while True: - sleep(poll_interval) - if os.path.exists(filename): - sleep(poll_interval) - break - dt = time() - start_time - if dt > timeout: - raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') - - -def bytes_to_int(bytes_str: Union[int, str]) -> int: - """Convert human readable byte format to an integer. - - Args: - bytes_str (Union[int, str]): Value to convert. - - Raises: - ValueError: Invalid byte suffix. - - Returns: - int: Integer value of bytes. - """ - #input is already an int - if isinstance(bytes_str, int) or isinstance(bytes_str, float): - return int(bytes_str) - - units = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, - 'pb': 1024**5, - 'eb': 1024**6, - 'zb': 1024**7, - 'yb': 1024**8, - } - # Convert a various byte types to an integer - for suffix in units: - bytes_str = bytes_str.lower().strip() - if bytes_str.lower().endswith(suffix): - try: - return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - else: - # Convert bytes to an integer - if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): - return int(bytes_str[0:-1]) - # Convert string representation of a number to an integer - elif bytes_str.isdigit(): - return int(bytes_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - - -def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: - """Convert human readable number abbreviations to an integer. - - Args: - abbrev_str (Union[int, str]): Value to convert. - Raises: - ValueError: Invalid number suffix. +__all__ = ['merge_index'] - Returns: - int: Integer value of number abbreviation. - """ - #input is already an int - if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): - return int(abbrev_str) - - units = { - 'k': 10**3, - 'm': 10**6, - 'b': 10**9, - 't': 10**12, - } - # Convert a various abbreviation types to an integer - for suffix in units: - abbrev_str = abbrev_str.lower().strip() - if abbrev_str.lower().endswith(suffix): - try: - return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - else: - # Convert string representation of a number to an integer - if abbrev_str.isdigit(): - return int(abbrev_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - - -def clean_stale_shared_memory() -> None: - """Clean up all the leaked shared memory. - - In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for - the local rank 0 to finish. - """ - # Initialize torch.distributed ourselves, if necessary. - destroy_dist = maybe_init_dist() - - # Perform clean up on local rank 0 - if get_local_rank() == 0: - for prefix_int in range(1000000): - leaked_shm = False - for shm_name in SHM_TO_CLEAN: - name = _get_path(prefix_int, shm_name) - try: - shm = BuiltinSharedMemory(name, True, 4) - except FileExistsError: - shm = BuiltinSharedMemory(name, False, 4) - leaked_shm = True - finally: - shm.close() # pyright: ignore - shm.unlink() - # Come out of loop if no leaked shared memory - if not leaked_shm: - break - - # Sync all ranks - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - # Delete the process group if Streaming initialized it. - if destroy_dist: - dist.destroy_process_group() - - -def get_import_exception_message(package_name: str, extra_deps: str) -> str: - """Get import exception message. - - Args: - package_name (str): Package name. - - Returns: - str: Exception message. - """ - return f'Streaming was installed without {package_name} support. ' + \ - f'To use {package_name} related packages with Streaming, run ' + \ - f'`pip install \'mosaicml-streaming[{package_name}]\'`.' +logger = logging.getLogger(__name__) def merge_index(*args: Any, **kwargs: Any): @@ -430,96 +234,3 @@ def not_merged_index(index_file_path: str, out: str): out, keep_local=keep_local, download_timeout=download_timeout) - - -@overload -def retry( - exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., - num_attempts: int = ..., - initial_backoff: float = ..., - max_jitter: float = ..., -) -> Callable[[TCallable], TCallable]: - ... - - -@overload -def retry(exc_class: TCallable) -> TCallable: - # Use the decorator without parenthesis - ... - - -# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type -# "(func: Never) -> Never" -def retry( # type: ignore - exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, - num_attempts: int = 3, - initial_backoff: float = 1.0, - max_jitter: float = 0.5, -): - """Decorator to retry a function with backoff and jitter. - - Attempts are spaced out with - ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. - - Example: - .. testcode:: - - from streaming.base.util import retry - - num_tries = 0 - - @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) - def flaky_function(): - global num_tries - if num_tries < 2: - num_tries += 1 - raise RuntimeError("Called too soon!") - return "Third time's a charm." - - print(flaky_function()) - - .. testoutput:: - - Third time's a charm. - - Args: - exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or - classes to retry. Defaults to Exception. - num_attempts (int, optional): The total number of attempts to make. Defaults to 3. - initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. - max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. - - Increasing the ``max_jitter`` can help prevent overloading a resource when multiple - processes in parallel are calling the same underlying function. - """ - if num_attempts < 1: - raise ValueError('num_attempts must be at-least 1') - - def wrapped_func(func: TCallable) -> TCallable: - - @functools.wraps(func) - def new_func(*args: Any, **kwargs: Any): - i = 0 - while True: - try: - return func(*args, **kwargs) - except exc_class as e: - if i + 1 == num_attempts: - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - raise e - else: - sleep(initial_backoff * 2**i + random.random() * max_jitter) - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - i += 1 - - return cast(TCallable, new_func) - - if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( - exc_class, type) and issubclass(exc_class, Exception)): - # Using the decorator without (), like @retry_with_backoff - func = cast(TCallable, exc_class) - exc_class = Exception - - return wrapped_func(func) - - return wrapped_func diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py new file mode 100644 index 000000000..fc63b699f --- /dev/null +++ b/streaming/base/util/pretty.py @@ -0,0 +1,115 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Conversions between human-friendly string forms and int/float.""" + +from typing import List, Union + +__all__ = ['bytes_to_int', 'get_list_arg', 'number_abbrev_to_int'] + + +def get_list_arg(text: str) -> List[str]: + """Pass a list as a comma-delimted string. + + Args: + text (str): Text to split. + + Returns: + List[str]: Splits, if any. + """ + return text.split(',') if text else [] + + +def bytes_to_int(bytes_str: Union[int, str]) -> int: + """Convert human readable byte format to an integer. + + Args: + bytes_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid byte suffix. + + Returns: + int: Integer value of bytes. + """ + #input is already an int + if isinstance(bytes_str, int) or isinstance(bytes_str, float): + return int(bytes_str) + + units = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + 'pb': 1024**5, + 'eb': 1024**6, + 'zb': 1024**7, + 'yb': 1024**8, + } + # Convert a various byte types to an integer + for suffix in units: + bytes_str = bytes_str.lower().strip() + if bytes_str.lower().endswith(suffix): + try: + return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + else: + # Convert bytes to an integer + if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): + return int(bytes_str[0:-1]) + # Convert string representation of a number to an integer + elif bytes_str.isdigit(): + return int(bytes_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + + +def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: + """Convert human readable number abbreviations to an integer. + + Args: + abbrev_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid number suffix. + + Returns: + int: Integer value of number abbreviation. + """ + #input is already an int + if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): + return int(abbrev_str) + + units = { + 'k': 10**3, + 'm': 10**6, + 'b': 10**9, + 't': 10**12, + } + # Convert a various abbreviation types to an integer + for suffix in units: + abbrev_str = abbrev_str.lower().strip() + if abbrev_str.lower().endswith(suffix): + try: + return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) + else: + # Convert string representation of a number to an integer + if abbrev_str.isdigit(): + return int(abbrev_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) diff --git a/streaming/base/util/retrying.py b/streaming/base/util/retrying.py new file mode 100644 index 000000000..3d006655b --- /dev/null +++ b/streaming/base/util/retrying.py @@ -0,0 +1,109 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Decorator that retries the wrapped function with backoff.""" + +import collections.abc +import functools +import logging +import random +from time import sleep +from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast, overload + +__all__ = ['retry'] + +logger = logging.getLogger(__name__) +TCallable = TypeVar('TCallable', bound=Callable) + + +@overload +def retry( + exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., + num_attempts: int = ..., + initial_backoff: float = ..., + max_jitter: float = ..., +) -> Callable[[TCallable], TCallable]: + ... + + +@overload +def retry(exc_class: TCallable) -> TCallable: + # Use the decorator without parenthesis + ... + + +# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type +# "(func: Never) -> Never" +def retry( # type: ignore + exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, + num_attempts: int = 3, + initial_backoff: float = 1.0, + max_jitter: float = 0.5, +): + """Decorator to retry a function with backoff and jitter. + + Attempts are spaced out with + ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. + + Example: + .. testcode:: + + from streaming.base.util import retry + + num_tries = 0 + + @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) + def flaky_function(): + global num_tries + if num_tries < 2: + num_tries += 1 + raise RuntimeError("Called too soon!") + return "Third time's a charm." + + print(flaky_function()) + + .. testoutput:: + + Third time's a charm. + + Args: + exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or + classes to retry. Defaults to Exception. + num_attempts (int, optional): The total number of attempts to make. Defaults to 3. + initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. + max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. + + Increasing the ``max_jitter`` can help prevent overloading a resource when multiple + processes in parallel are calling the same underlying function. + """ + if num_attempts < 1: + raise ValueError('num_attempts must be at-least 1') + + def wrapped_func(func: TCallable) -> TCallable: + + @functools.wraps(func) + def new_func(*args: Any, **kwargs: Any): + i = 0 + while True: + try: + return func(*args, **kwargs) + except exc_class as e: + if i + 1 == num_attempts: + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + raise e + else: + sleep(initial_backoff * 2**i + random.random() * max_jitter) + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + i += 1 + + return cast(TCallable, new_func) + + if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( + exc_class, type) and issubclass(exc_class, Exception)): + # Using the decorator without (), like @retry_with_backoff + func = cast(TCallable, exc_class) + exc_class = Exception + + return wrapped_func(func) + + return wrapped_func diff --git a/streaming/base/util/shared.py b/streaming/base/util/shared.py new file mode 100644 index 000000000..956d3427c --- /dev/null +++ b/streaming/base/util/shared.py @@ -0,0 +1,48 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Shared memory utilities.""" + +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory + +import torch.distributed as dist + +from streaming.base.constant import SHM_TO_CLEAN +from streaming.base.distributed import get_local_rank, maybe_init_dist +from streaming.base.shared.prefix import _get_path + + +def clean_stale_shared_memory() -> None: + """Clean up all the leaked shared memory. + + In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for + the local rank 0 to finish. + """ + # Initialize torch.distributed ourselves, if necessary. + destroy_dist = maybe_init_dist() + + # Perform clean up on local rank 0 + if get_local_rank() == 0: + for prefix_int in range(1000000): + leaked_shm = False + for shm_name in SHM_TO_CLEAN: + name = _get_path(prefix_int, shm_name) + try: + shm = BuiltinSharedMemory(name, True, 4) + except FileExistsError: + shm = BuiltinSharedMemory(name, False, 4) + leaked_shm = True + finally: + shm.close() # pyright: ignore + shm.unlink() + # Come out of loop if no leaked shared memory + if not leaked_shm: + break + + # Sync all ranks + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + # Delete the process group if Streaming initialized it. + if destroy_dist: + dist.destroy_process_group() diff --git a/streaming/base/util/storage.py b/streaming/base/util/storage.py new file mode 100644 index 000000000..470b02869 --- /dev/null +++ b/streaming/base/util/storage.py @@ -0,0 +1,33 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Storage utilities and helpers.""" + +import os +from time import sleep, time + +__all__ = ['wait_for_file_to_exist'] + + +def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for the file to exist till timeout seconds. Raise an Exception after that. + + Args: + filename (str): A file name + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') From 3212f6605ba33ca25997b1da144d1b6287375eb4 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 00:49:39 -0700 Subject: [PATCH 08/45] Completely rip out and rewrite pretty args handling: Was: - bytes_to_int - number_abbrev_to_int Now: - normalize_dec_bytes - normalize_bin_bytes - normalize_bytes - normalize_count - normalize_duration --- streaming/base/dataset.py | 7 +- streaming/base/format/base/reader.py | 5 +- streaming/base/format/base/writer.py | 4 +- streaming/base/util/__init__.py | 8 +- streaming/base/util/pretty.py | 359 +++++++++++++++++++++------ tests/test_util.py | 20 +- 6 files changed, 301 insertions(+), 102 deletions(-) diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 40cc10b8e..0faae33fd 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,7 @@ _get_path, get_shm_prefix) from streaming.base.spanner import Spanner from streaming.base.stream import Stream -from streaming.base.util import bytes_to_int, number_abbrev_to_int +from streaming.base.util import normalize_bytes, normalize_count from streaming.base.world import World # An arbitrary time in the future, used for cold shard eviction. @@ -394,7 +394,7 @@ def __init__(self, # Convert epoch size from string to int, if needed. Cannot be negative. epoch_size_value = None if epoch_size: - epoch_size_value = number_abbrev_to_int(epoch_size) + epoch_size_value = normalize_count(epoch_size) if epoch_size_value < 0: raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.') @@ -465,8 +465,7 @@ def __init__(self, # Check that cache limit is possible. if self.cache_limit: - if isinstance(self.cache_limit, str): - self.cache_limit = bytes_to_int(self.cache_limit) + self.cache_limit = normalize_bytes(self.cache_limit) min_cache_usage = sum((stream.get_index_size() for stream in streams)) if self.cache_limit <= min_cache_usage: raise ValueError(f'Minimum cache usage ({min_cache_usage} bytes) is larger than ' + diff --git a/streaming/base/format/base/reader.py b/streaming/base/format/base/reader.py index 80ec45231..7db3521cc 100644 --- a/streaming/base/format/base/reader.py +++ b/streaming/base/format/base/reader.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Union from streaming.base.array import Array -from streaming.base.util import bytes_to_int +from streaming.base.util import normalize_bytes __all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] @@ -54,8 +54,7 @@ def __init__( ) -> None: if size_limit: - if (isinstance(size_limit, str)): - size_limit = bytes_to_int(size_limit) + size_limit = normalize_bytes(size_limit) if size_limit < 0: raise ValueError(f'`size_limit` must be greater than zero, instead, ' + f'found as {size_limit}.') diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index 7cc3add3d..7b182c539 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -22,7 +22,7 @@ from streaming.base.format.index import get_index_basename from streaming.base.hashing import get_hash, is_hash from streaming.base.storage.upload import CloudUploader -from streaming.base.util import bytes_to_int +from streaming.base.util import normalize_bytes __all__ = ['JointWriter', 'SplitWriter'] @@ -93,7 +93,7 @@ def __init__(self, size_limit_value = None if size_limit: - size_limit_value = bytes_to_int(size_limit) + size_limit_value = normalize_bytes(size_limit) if size_limit_value < 0: raise ValueError(f'`size_limit` must be greater than zero, instead, ' + f'found as {size_limit_value}.') diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py index 0ab5e2492..db581ab1b 100644 --- a/streaming/base/util/__init__.py +++ b/streaming/base/util/__init__.py @@ -3,14 +3,16 @@ """Utilities and helkper methods needed by Streaming.""" -from streaming.base.util.pretty import bytes_to_int, get_list_arg, number_abbrev_to_int from streaming.base.util.importing import get_import_exception_message from streaming.base.util.merging import merge_index +from streaming.base.util.pretty import (get_list_arg, normalize_bin_bytes, normalize_bytes, + normalize_count, normalize_dec_bytes, normalize_duration) from streaming.base.util.retrying import retry from streaming.base.util.shared import clean_stale_shared_memory from streaming.base.util.storage import wait_for_file_to_exist __all__ = [ - 'bytes_to_int', 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', - 'merge_index', 'number_abbrev_to_int', 'retry', 'wait_for_file_to_exist' + 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', 'merge_index', + 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_dec_bytes', + 'normalize_duration', 'retry', 'wait_for_file_to_exist' ] diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py index fc63b699f..47ca28687 100644 --- a/streaming/base/util/pretty.py +++ b/streaming/base/util/pretty.py @@ -3,9 +3,13 @@ """Conversions between human-friendly string forms and int/float.""" -from typing import List, Union +from collections import defaultdict +from typing import Dict, List, Union -__all__ = ['bytes_to_int', 'get_list_arg', 'number_abbrev_to_int'] +__all__ = [ + 'get_list_arg', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', + 'normalize_count', 'normalize_duration' +] def get_list_arg(text: str) -> List[str]: @@ -20,96 +24,291 @@ def get_list_arg(text: str) -> List[str]: return text.split(',') if text else [] -def bytes_to_int(bytes_str: Union[int, str]) -> int: - """Convert human readable byte format to an integer. +def _normalize_arg(text: str, units: Dict[str, int], to_type: type) -> Union[int, float]: + """Normalize a human-friendly unit string to number. Args: - bytes_str (Union[int, str]): Value to convert. + text (str): Human-friendly string. + units (Dict[str, Any]): Mapping of unit name to value. + to_type (Union[int, float]): The return type. - Raises: - ValueError: Invalid byte suffix. + Returns: + type: Computer-friendly number. + """ + # Must be non-empty. + if not text: + raise ValueError(f'Text is empty.') + + # Drop commas and underscores (useful to demarcate thousands '1,337' or '1_337'). + text = text.replace(',', '') + text = text.replace('_', '') + + # Must start with a digit. + char = text[0] + if not char.isdigit(): + raise ValueError(f'Text must start with a digit, but got {text[0]} instead (input: ' + + f'{text}).') + + # Must alternative between numbers and units, starting with a number. + in_num = True + part = [] + parts = [] + for char in text: + is_digit = char.isdigit() or char == '.' + if in_num: + if is_digit: + part.append(char) + else: + part = ''.join(part) + parts.append(part) + part = [char] + in_num = False + else: + if is_digit: + part = ''.join(part) + parts.append(part) + part = [char] + in_num = True + else: + part.append(char) + part = ''.join(part) + parts.append(part) + + # If just a number, that's it. + if len(parts) == 1: + part, = parts + try: + return to_type(part) + except: + raise ValueError(f'Simple text must be numeric, but got {part} instead (input: ' + + f'{text}).') + + # Pair up numbers and units. + if len(parts) % 2: + if '' in units: + # Special case where the implied unit is the empty string, i.e. the smallest unit. + parts.append('') + else: + # If not just a number, each number must be paired with a corresponding unit. + raise ValueError(f'Text must contain pairs of number and unit, but got an odd ' + + f'number of parts instead: {parts} (input: {text}).') + + # Assign parts as numbers and units. + part_nums = [] + part_units = [] + for i, part in enumerate(parts): + if i % 2: + part_units.append(part) + else: + part_nums.append(part) + + # Each number before the last one must be integral + for i, num in enumerate(part_nums[:-0]): + try: + part_nums[i] = int(num) + except: + raise ValueError(f'Non-final numbers must be integral, but got part {i} as {num} ' + + f'instead (input: {text}).') + + # The last number may be fractional. + try: + part_nums[-1] = to_type(part_nums[-1]) + except: + raise ValueError(f'Final number must be numeric, but got {part_nums[-1]} instead ' + + f'(input: {text}.') + + # Each unit must be known to us. + part_muls = [] + for i, unit in enumerate(part_units): + mul = units.get(unit) + if mul is None: + raise ValueError(f'Unit is unknown: {unit} in part {i} (input: {text}).') + part_muls.append(mul) + + # Each unit must be used at most once. + unit2count = defaultdict(int) + for i, unit in enumerate(part_units): + unit2count[unit] += 1 + for unit in sorted(unit2count): + count = unit2count[unit] + if count != 1: + raise ValueError(f'Unit is reused: {unit} is used {count} times (input: {text}).') + + # Units must be listed in descending order of size. + prev_mul = part_muls[0] + for i in range(1, len(part_muls)): + mul = part_muls[i] + if mul < prev_mul: + prev_mul = mul + else: + unit = part_units[i] + raise ValueError(f'Units are out of order: {unit} in part {i} (input: {text}).') + + # The number of any given part must not exceed the size of the next biggest part's unit. + # + # (Otherwise you would just roll its overage into the next biggest part.) + for i in range(1, len(part_muls)): + parent_mul = part_muls[i - 1] + mul = part_muls[i] + num = part_nums[i] + if parent_mul < mul * num: + parent_unit = part_units[i - 1] + unit = part_units[i] + raise ValueError(f'The number of any non-initial part must not exceed the ratio of ' + + f'the unit of the next biggest part to its own unit (otherwise it ' + + f'should have been rolled into the bigger part): part {i} having ' + + f'{num} of {unit} ({mul}x) vs parent part {i - 1} in units of ' + + f'{parent_unit} ({parent_mul}x) (input: {text}).') + + # Collect parts. + ret = 0 + for num, mul in zip(part_nums, part_muls): + ret += num * mul + return ret + + +def _normalize_num(arg: Union[int, float, str], units: Dict[str, int], + to_type: type) -> Union[int, float]: + """Normalize from human-friendly argument to number. + + Args: + arg (Union[int, float, str]): Human-friendly argument. + units (Dict[str, Any]): Mapping of unit name to value. + to_type (type): The return type. Returns: - int: Integer value of bytes. + Union[int, float]: Numeric argument. """ - #input is already an int - if isinstance(bytes_str, int) or isinstance(bytes_str, float): - return int(bytes_str) - - units = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, - 'pb': 1024**5, - 'eb': 1024**6, - 'zb': 1024**7, - 'yb': 1024**8, - } - # Convert a various byte types to an integer - for suffix in units: - bytes_str = bytes_str.lower().strip() - if bytes_str.lower().endswith(suffix): - try: - return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) + if isinstance(arg, (int, float)): + return to_type(arg) else: - # Convert bytes to an integer - if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): - return int(bytes_str[0:-1]) - # Convert string representation of a number to an integer - elif bytes_str.isdigit(): - return int(bytes_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) + return _normalize_arg(arg, units, to_type) -def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: - """Convert human readable number abbreviations to an integer. +def _normalize_int(arg: Union[int, str], units: Dict[str, int]) -> int: + """Normalize from human-friendly argument to int. Args: - abbrev_str (Union[int, str]): Value to convert. + arg (Union[int, str]): Human-friendly argument. + units (Dict[str, int]): Mapping of unit name to value. + + Returns: + int: Integral argument. + """ + return _normalize_num(arg, units, int) # pyright: ignore + - Raises: - ValueError: Invalid number suffix. +def _normalize_float(arg: Union[int, float, str], units: Dict[str, int]) -> int: + """Normalize from human-friendly argument to float. + + Args: + arg (Union[int, float, str]): Human-friendly argument. + units (Dict[str, int]): Mapping of unit name to value. Returns: - int: Integer value of number abbreviation. + float: Floating argument. """ - #input is already an int - if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): - return int(abbrev_str) - - units = { - 'k': 10**3, - 'm': 10**6, - 'b': 10**9, - 't': 10**12, - } - # Convert a various abbreviation types to an integer - for suffix in units: - abbrev_str = abbrev_str.lower().strip() - if abbrev_str.lower().endswith(suffix): - try: - return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - else: - # Convert string representation of a number to an integer - if abbrev_str.isdigit(): - return int(abbrev_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) + return _normalize_num(arg, units, float) # pyright: ignore + + +def _get_units(base: int, names: List[str]) -> Dict[str, int]: + """Generate units mapping given a base and names of powers of that base. + + Args: + base (int): Base to exponentiate. + names (List[str]): Name of each power of base. + + Returns: + Dic[str, int]: Mapping of unit name to value. + """ + units = {} + for i, name in enumerate(names): + if name in units: + raise ValueError(f'Reused unit name: {name}.') + units[name] = base**i + return units + + +_dec_bytes_units = _get_units(1000, 'b kb mb tb pb eb zb yb rb qb'.split()) + + +def normalize_dec_bytes(bytes: Union[int, str]) -> int: + """Normalize from human-friendly base-1000 bytes to int. + + Args: + bytes (Union[int, str]): Human-friendly base-1000 bytes. + + Returns: + int: Integral bytes. + """ + return _normalize_int(bytes, _dec_bytes_units) + + +_bin_bytes_units = _get_units(1024, 'ib kib mib tib pib eib zib yib rib qib'.split()) + + +def normalize_bin_bytes(bytes: Union[int, str]) -> int: + """Normalize from human-friendly base-1024 bytes to int. + + Args: + bytes (Union[int, str]): Human-friendly base-1024 bytes. + + Returns: + int: Integral bytes. + """ + return _normalize_int(bytes, _bin_bytes_units) + + +def normalize_bytes(bytes: Union[int, str]) -> int: + """Normalize from human-friendly base-1000 or base-1024 bytes to int. + + Args: + bytes (Union[int, str]): Human-friendly base-1000 or base-1024 bytes. + + Returns: + int: Integral bytes. + """ + for norm in [normalize_dec_bytes, normalize_bin_bytes]: + try: + return norm(bytes) + except: + pass + raise ValueError('Invalid bytes: {bytes}.') + + +_count_units = _get_units(1000, ' k m b t'.split(' ')) + + +def normalize_count(count: Union[int, str]) -> int: + """Normalize from human-friendly count to int. + + Args: + count (Union[int, str]): Human-friendly count. + + Returns: + int: Integral count. + """ + ret = _normalize_int(count, _count_units) + if ret < 0: + raise ValueError(f'Counts cannot be negative, but got {ret} (input: {count}).') + return ret + + +_duration_units = { + 's': 1, + 'm': 60, + 'h': 60 * 60, + 'd': 24 * 60 * 60, +} + + +def normalize_duration(duration: Union[int, float, str]) -> float: + """Normalize from human-friendly duration to float. + + Args: + duration (Union[int, float, str]): Human-friendly duration. + + Returns: + float: Float duration. + """ + return _normalize_float(duration, _duration_units) diff --git a/tests/test_util.py b/tests/test_util.py index aa107cc17..bba0d2724 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,8 +15,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - merge_index, number_abbrev_to_int, retry) +from streaming.base.util import (clean_stale_shared_memory, get_list_arg, merge_index, + normalize_bytes, normalize_count, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -35,7 +35,7 @@ def test_get_list_arg(text: str, expected_output: List[Optional[str]]): assert output == expected_output -def test_bytes_to_int(): +def test_normalize_bytes(): input_to_expected = [ ('1234', 1234), ('1b', 1), @@ -63,18 +63,18 @@ def test_bytes_to_int(): (325388903.203984, 325388903), ] for size_pair in input_to_expected: - output = bytes_to_int(size_pair[0]) + output = normalize_bytes(size_pair[0]) assert output == size_pair[1] -def test_bytes_to_int_Exception(): +def test_normalize_bytes_Exception(): input_data = ['', '12kbb', '27mxb', '79kkb'] for value in input_data: with pytest.raises(ValueError, match=f'Unsupported value/suffix.*'): - _ = bytes_to_int(value) + _ = normalize_bytes(value) -def test_number_abbrev_to_int(): +def test_normalize_count(): input_to_expected = [ ('1234', 1234), ('1k', 1000), @@ -99,15 +99,15 @@ def test_number_abbrev_to_int(): (325388903.203984, 325388903), ] for size_pair in input_to_expected: - output = number_abbrev_to_int(size_pair[0]) + output = normalize_count(size_pair[0]) assert output == size_pair[1] -def test_number_abbrev_to_int_Exception(): +def test_normalize_count_Exception(): input_data = ['', '12kbb', '27mxb', '79bk', '79bb', '79 b m', 'p 64', '64p'] for value in input_data: with pytest.raises(ValueError, match=f'Unsupported value/suffix.*'): - _ = number_abbrev_to_int(value) + _ = normalize_count(value) def test_clean_stale_shared_memory(): From eb93bea3fef38ce59da45a9ffdaadda0d1759840 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 00:59:07 -0700 Subject: [PATCH 09/45] Layer several new storage APIs wrapping/complementing streaming/base/storage/. Let's properly integrate these later. - walk_dir() - Very Fancy list_dataset_files() - smart_download_file() --- streaming/base/storage/extra.py | 251 ++++++++++++++++++++++++++++++++ streaming/base/stream.py | 3 +- streaming/base/util/__init__.py | 3 +- streaming/base/util/storage.py | 33 ----- 4 files changed, 254 insertions(+), 36 deletions(-) create mode 100644 streaming/base/storage/extra.py delete mode 100644 streaming/base/util/storage.py diff --git a/streaming/base/storage/extra.py b/streaming/base/storage/extra.py new file mode 100644 index 000000000..286c05f8a --- /dev/null +++ b/streaming/base/storage/extra.py @@ -0,0 +1,251 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Some extras which wrap and/or complement the Streaming storage API. + +TODO: deliberately design the storage API, in a future PR. +""" + +import os +import re +from re import Pattern +from time import sleep, time +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union +from urllib.parse import urlparse + +from streaming.base.hashing import get_hash +from streaming.base.storage import CloudUploader, download_file +from streaming.base.util.pretty import normalize_bytes, normalize_duration + +__all__ = ['wait_for_file_to_exist', 'walk_dir', 'list_dataset_files', 'smart_download_file'] + + +def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for the file to exist till timeout seconds. Raise an Exception after that. + + Args: + filename (str): A file name + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') + + +def walk_dir(root: str) -> List[str]: + """Recursively list the given directory in sorted order. + + Notes: + * Supported across various storage backends, including local filesystem. + * Root must be a directory, not a generic path prefix, to make the local case nicer. + * There seems to be inconsistency in list_objects() about what the returned paths are + relative to: cwd, the given root, some local... let's just wrap it for our purposes. + + Args: + root (str): Root directory to walk. + + Returns: + List[str]: File paths, which are relative to the given root. + """ + obj = urlparse(root) + if obj.scheme == '': + is_local = True + elif obj.scheme == 'file': + is_local = True + root = obj.path + else: + is_local = False + + if is_local: + if not os.path.isdir(root): + raise ValueError(f'Path is not a directory: {root}.') + paths = [] + for sub_root, _, file_basenames in os.walk(root): + sub_path = sub_root.lstrip(root) + paths += [os.path.join(sub_path, name) for name in file_basenames] + else: + neither = CloudUploader.get(root, exist_ok=True) + paths = neither.list_objects(root) + + return sorted(paths) + + +def _filter(keep: Optional[Union[str, Pattern, Callable[[str], bool]]], + paths: Optional[Iterable[str]]) -> Iterable[str]: + """Filter the given paths according to the pattern or predicate. + + Args: + keep (Union[str, Pattern, Callable[[str], bool]], optional): A regex or Callable which is + applied to each path, keeping or dropping it. If not provided, do no filtering. + paths (Iterable[str], optional): Iterable of paths to filter. If empty, is the empty list. + """ + paths = paths or [] + if keep is None: + pass + elif isinstance(keep, str): + keep_regex = re.compile(keep) + paths = filter(keep_regex.match, paths) + elif isinstance(keep, Pattern): + paths = filter(keep.match, paths) + elif isinstance(keep, Callable): + paths = filter(keep, paths) + else: + raise ValueError(f'Unsupported type of keep: {keep}.') + yield from paths + + +def _get_overlap(want: Set[str], have: Set[str]) -> Dict[str, Any]: + """Get the overlap between two sets for informational/debugging purposes. + + Args: + want (Set[str]): What we want. + have (Set[str]): What we have. + + Returns: + Dict[str, Any]: Information about overlaps. + """ + return { + 'present': len(want & have), + 'missing': len(want.difference(have)), + 'ignored': len(have.difference(want)), + } + + +def list_dataset_files( + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + paths: Optional[Iterable[str]] = None, + keep: Optional[Union[str, Pattern, Callable[[str], bool]]] = None) -> List[str]: + """Collect all/certain local/remote dataset files, which are then filtered. + + Args: + local (str): Local dataset root. + remote (str, optional): Remote dataset root, if we have a remote. + split (str, optional): Split subdir, if used. + paths (Iterable[str], optional): Iterable of paths relative to dataset root (i.e., + local/remote + split). These are then filtered by the keep predicate, if any. If not + provided, defaults to a sorted, recursive listing of all dataset files. Such a listing + treats remote as authoritative if provided, else uses local. Defaults to ``None``. + keep (Union[str, Pattern, Callable[[str], bool]], optional): A regex or Callable which is + applied to each path in order to keep or drop it from the listing. If not provided, no + filtering is performed to paths. Defaults to ``None``. + + Returns: + List[str]: List of paths, relative to dataset root, ordered by ``paths``. + """ + # Tack on the split dir, if any. + if split: + local = os.path.join(local, split) + if remote: + remote = os.path.join(remote, split) + + # If no paths Iterable was not provided, list all the files, filter, and we're done. + if paths is None: + root = remote if remote else local + paths = walk_dir(root) + return list(_filter(keep, paths)) + + # If we were indeed provided explicit paths, cross-check those against a listing of local + # before we start assuming everything is fine. + want_paths = list(_filter(keep, paths)) + want_paths_set = set(want_paths) + have_local_paths_set = set(walk_dir(local)) + if want_paths_set.issubset(have_local_paths_set): # All exist in local? + return want_paths + + # If local is incomplete, and there is no remote, give up. + if not remote: + obj = _get_overlap(want_paths_set, have_local_paths_set) + raise ValueError(f'Local does not contain all listed shards, and no remote was ' + + f'provided. Overlap of listed vs local: {obj["present"]} present, ' + + f'{obj["missing"]} missing, {obj["ignored"]} ignored.') + + # Explicit paths, incomplete local, but we do have a remote to fall back to. Let's cross-check + # against that. + have_remote_paths_set = set(walk_dir(remote)) + if want_paths_set.issubset(have_remote_paths_set): + return want_paths + + # Both local and remote do not contain all the needed files, so give up. + l_obj = _get_overlap(want_paths_set, have_local_paths_set) + r_obj = _get_overlap(want_paths_set, have_remote_paths_set) + raise ValueError(f'Neither local nor remote contains all shards listed. Overlap of listed ' + + f'vs local: {l_obj["present"]} present, {l_obj["missing"]} missing, ' + + f'{l_obj["ignored"]} ignored. Overlap of listed vs remote: ' + + f'{r_obj["present"]} present, {r_obj["missing"]} missing, ' + + f'{r_obj["ignored"]} ignored.') + + +def smart_download_file(*, + remote: str, + local: str, + timeout: Union[float, str] = 60, + size: Optional[Union[int, str]] = None, + max_size: Optional[Union[int, str]] = None, + hashes: Optional[Dict[str, str]] = None) -> None: + """Download a file from the remote path to the local path, with size/hash checks. + + Args: + remote (str): Remote path. + local (str): Local path. + timeout (Union[float, str]): Maximum time to download, in seconds. Defaults to ``60``. + size (Union[int, str], optional): Expected file size. This check is a weak but fast/cheap + way to detect overwrites, truncation, tampering, and corruption. Defaults to ``None``. + max_size (Union[int, str], optional): Maximum file size. This check is a fast/cheap way to + prevent the user from inadvertently using shards that are far too large for Streaming + purposes, which is non-obvious and would result in a terrible user experience. Defaults + to ``None``. + hashes (Dict[str, str], optional): Hashes to check, as a dict of hash algo name to expected + hex digest. These checks are a very strong but slow/expensive way to detect changes to + data. See our benchmarks for more details. Defaults to ``None``. + """ + # Download. + want_timeout = normalize_duration(timeout) + download_file(remote, local, want_timeout) + + # Size checks. + if size is not None or max_size is not None: + have_size = os.stat(local).st_size + + # Exact size check. + if size is not None: + want_size = normalize_bytes(size) + if want_size != have_size: + raise ValueError( + f'The file as downloaded does not match the expected size: remote path = ' + + f'{remote}, local path = {local}, expected size = {want_size}, got size = ' + + f'{have_size}.') + + # Size limit check. + if max_size is not None: + want_max_size = normalize_bytes(max_size) + if want_max_size < have_size: + raise ValueError( + f'The file is too large for efficient use by Streaming, please reduce shard ' + + f'size: remote path = {remote}, local path = {local}, maximum size = ' + + f'{want_max_size}, got size = {have_size}.') + + # Hash checks. + if hashes: + data = open(local, 'rb').read() + for hash_algo in sorted(hashes): + want_hex_digest = hashes[hash_algo] + have_hex_digest = get_hash(hash_algo, data) + if want_hex_digest != have_hex_digest: + raise ValueError( + f'The file as downloaded does not match the expected hash: remote path = ' + + f'{remote}, local path = {local}, hash algo = {hash_algo}, expected hex ' + + f'digest = {want_hex_digest}, got digest = {have_hex_digest}.') diff --git a/streaming/base/stream.py b/streaming/base/stream.py index d707f9a6b..c0a8f7f3b 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -13,13 +13,14 @@ from numpy.typing import NDArray from typing_extensions import Self +from streaming.baes.storage.extra import wait_for_file_to_exist from streaming.base.compression import decompress from streaming.base.constant import TICK from streaming.base.distributed import barrier, get_local_rank from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash from streaming.base.storage import download_file -from streaming.base.util import retry, wait_for_file_to_exist +from streaming.base.util.retrying import retry from streaming.base.world import World diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py index db581ab1b..bce2e06ab 100644 --- a/streaming/base/util/__init__.py +++ b/streaming/base/util/__init__.py @@ -9,10 +9,9 @@ normalize_count, normalize_dec_bytes, normalize_duration) from streaming.base.util.retrying import retry from streaming.base.util.shared import clean_stale_shared_memory -from streaming.base.util.storage import wait_for_file_to_exist __all__ = [ 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', 'merge_index', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_dec_bytes', - 'normalize_duration', 'retry', 'wait_for_file_to_exist' + 'normalize_duration', 'retry' ] diff --git a/streaming/base/util/storage.py b/streaming/base/util/storage.py deleted file mode 100644 index 470b02869..000000000 --- a/streaming/base/util/storage.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Storage utilities and helpers.""" - -import os -from time import sleep, time - -__all__ = ['wait_for_file_to_exist'] - - -def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, - err_msg: str) -> None: - """Wait for the file to exist till timeout seconds. Raise an Exception after that. - - Args: - filename (str): A file name - poll_interval (float): Number of seconds to wait before next polling - timeout (float): Number of seconds to wait for a file to exist before raising an exception - err_msg (str): Error message description for an exception - - Raises: - RuntimeError: Raise an Exception if file does not exist after timeout - """ - start_time = time() - while True: - sleep(poll_interval) - if os.path.exists(filename): - sleep(poll_interval) - break - dt = time() - start_time - if dt > timeout: - raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') From 23554aca06b8b9404b280f05163e7a66f7cff190 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 01:16:05 -0700 Subject: [PATCH 10/45] Use those APIs to index a Parquet dataset (single-threaded). --- streaming/base/format/delta/indexing.py | 2 + streaming/base/format/lance/indexing.py | 2 + streaming/base/format/parquet/indexing.py | 146 +++++++++++++++++++++- 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/streaming/base/format/delta/indexing.py b/streaming/base/format/delta/indexing.py index 75656d8a4..3a407387f 100644 --- a/streaming/base/format/delta/indexing.py +++ b/streaming/base/format/delta/indexing.py @@ -5,6 +5,8 @@ from typing import Any, Dict, Optional, Union +__all__ = ['index_delta'] + def index_delta(*, local: str, diff --git a/streaming/base/format/lance/indexing.py b/streaming/base/format/lance/indexing.py index 6ad3d0cd5..a38adc125 100644 --- a/streaming/base/format/lance/indexing.py +++ b/streaming/base/format/lance/indexing.py @@ -5,6 +5,8 @@ from typing import Any, Dict, Optional, Union +__all__ = ['index_lance'] + def index_lance(*, local: str, diff --git a/streaming/base/format/parquet/indexing.py b/streaming/base/format/parquet/indexing.py index c68b1b3be..db9e2b2ab 100644 --- a/streaming/base/format/parquet/indexing.py +++ b/streaming/base/format/parquet/indexing.py @@ -3,9 +3,134 @@ """Indexing a Parquet dataset for use by Streaming.""" +import os from re import Pattern from typing import Any, Callable, Dict, Iterable, Optional, Union +from pyarrow import parquet as pq +from tqdm import tqdm + +from streaming.base.format.mds.encodings import get_mds_encoded_size +from streaming.base.storage.extra import list_dataset_files, smart_download_file + +__all__ = ['index_parquet'] + + +def _get_mds_column(val: Any) -> str: + """Get the MDS column encoding of one field. + + Args: + val (Any): The field. + + Returns: + str: Its corresponding MDS encoding. + """ + if isinstance(val, int): + return 'int' + elif isinstance(val, str): + return 'str' + else: + raise ValueError('Unsupported column type: {type(val)}.') + + +def _sample_to_schema(sample: Dict[str, Any]) -> Dict[str, Any]: + """Get column names, encodings, and sizes. + + Args: + sample (Dict[str, Any]): A sample to derive column info from. + + Returns: + Dict[str, Any]: MDS column names, encodings, and sizes. + """ + col_names = sorted(sample) + col_encs = [] + for name in col_names: + val = sample[name] + enc = _get_mds_column(val) + col_encs.append(enc) + col_sizes = list(map(get_mds_encoded_size, col_encs)) + return { + 'column_names': col_names, + 'column_encodings': col_encs, + 'column_sizes': col_sizes, + } + + +def _index_file(local: str, + remote: Optional[str], + split: Optional[str], + rel_path: str, + download_timeout: Union[float, str] = '2m', + max_file_bytes: Optional[Union[int, str]] = '200mb', + want_mds_schema: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Get info a Streaming index needs about a Parquet shard. + + Args: + local (str): Local dataset root. + remote (str, optional): Remote dataset root, if remote is provided. + split (str, optional): Split, if used. + rel_path (str): Path to file, relative to serialized dataset root. + download_timeout (Union[float, str]): Maximum download time. Defaults to ``2m``. + max_file_bytes (Union[int, str], optional): Maximum file size. This is to catch people + trying to stream gigantic Parquet shards. Defaults to ``200mb``. + want_mds_schema (Dict[str, Any], optional): If provided, MDS schemna that this Parquet + shard must match upon conversion to MDS. + + Returns: + Dict[str, Any]: Shard info, or None upon failure. + """ + local_path = os.path.join(local, split or '', rel_path) + if not os.path.exists(local): + if not remote: + raise ValueError('Remote was needed, but not provided.') + + remote_path = os.path.join(remote, split or '', rel_path) + smart_download_file(remote=remote_path, + local=local_path, + timeout=download_timeout, + max_size=max_file_bytes) + + num_bytes = os.stat(local).st_size + + table = pq.read_table(local_path) + samples = table.to_pylist() + num_samples = len(samples) + mds_schema = _sample_to_schema(samples[0]) + if want_mds_schema and want_mds_schema != mds_schema: + raise ValueError(f'MDS schema mismatch: required {want_mds_schema}, but got ' + + f'{mds_schema}.') + + ret = { + 'version': 2, + 'format': 'parquet', + 'raw_parquet': { + 'basename': rel_path, + 'bytes': num_bytes, + }, + 'raw_data': { + 'basename': rel_path + '.mds', + }, + 'samples': num_samples, + } + ret.update(mds_schema) + return ret + + +def _shard_info_to_schema(info: Dict[str, Any]) -> Dict[str, Any]: + """Extract MDS schema information from the info for a shard. + + Args: + info (Dict[str, Any]): Shard info. + + Returns: + Dict[str, Any]: MDS schema. + """ + ret = {} + for key in ['column_names', 'column_encoding', 'column_sizes']: + ret[key] = info[key] + return ret + + Filter = Union[str, Pattern, Callable[[str], bool]] @@ -69,4 +194,23 @@ def index_parquet(*, Returns: Dict[str, Any]: StreamingDataset index configuration to stream this Parquet dataset. """ - raise NotImplementedError # TODO + rel_paths = list_dataset_files(local, remote, split, files, keep) + if show_progress: + rel_paths = tqdm(rel_paths, leave=False) + + want_mds_schema = None + infos = [] + for rel_path in rel_paths: + info = _index_file(local, remote, split, rel_path, download_timeout, max_file_bytes, + want_mds_schema) + infos.append(info) + + if same_schema and not want_mds_schema: + want_mds_schema = _shard_info_to_schema(info) + + obj = { + 'version': 2, + 'shards': infos, + } + + return obj From c71156719f47bbd83b902c473872aa51ad5185e0 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 06:10:27 -0700 Subject: [PATCH 11/45] Add cli/index_parquet.py. --- streaming/base/cli/index_parquet.py | 76 +++++++++++++++++++++++ streaming/base/format/parquet/indexing.py | 5 +- streaming/base/storage/extra.py | 2 +- streaming/base/stream.py | 2 +- streaming/base/util/pretty.py | 37 +++++++++-- 5 files changed, 114 insertions(+), 8 deletions(-) create mode 100644 streaming/base/cli/index_parquet.py diff --git a/streaming/base/cli/index_parquet.py b/streaming/base/cli/index_parquet.py new file mode 100644 index 000000000..310fd2980 --- /dev/null +++ b/streaming/base/cli/index_parquet.py @@ -0,0 +1,76 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a Streaming index file for the given Parquet dataset.""" + +import json +from argparse import ArgumentParser, Namespace + +from streaming.base.format import index_parquet +from streaming.base.util.pretty import parse_str2str + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--local', type=str, required=True, help='Path to dataset cache.') + args.add_argument('--remote', type=str, default='', help='Path to gold copy of dataset.') + args.add_argument('--split', type=str, default='', help='Dataset split subdir.') + args.add_argument('--keep', type=str, default='', help='Optional regex for filtering shards.') + args.add_argument('--num_procs', + type=int, + default=0, + help='Process parallelism. Set to -1 for single process, 0 for processes, and positive int for that many processes.') + args.add_argument('--download_timeout', + type=str, + default='2m', + help='Download timeout per Parquet file.') + args.add_argument('--max_file_bytes', + type=str, + default='200m', + help='Maximum file size in bytes, or 0 to disable..') + args.add_argument('--same_schema', + type=int, + default=1, + help='Whether all shards must be of the same MDS schema.') + args.add_argument('--columns', + type=str, + default='', + help='Override hte inferred schema to set any field names and types ' + + 'specified here.') + args.add_argument('--show_progress', type=int, default=1, help='Show progress bar.') + args.add_argument('--sort_keys', type=int, default=1, help='Whether to sort JSON keys.') + args.add_argument('--indent', type=int, default=-1, help='JSON indent level (0 to disable).') + return args.parse_args() + + +def main(args: Namespace) -> None: + """Generate a Streaming index for the given Parquet dataset. + + Args: + args (Namespace): Command-line arguments. + """ + columns = parse_str2str(args.columns) + obj = index_parquet(local=args.local, + remote=args.remote, + split=args.split, + keep=args.keep, + num_procs=args.num_procs, + download_timeout=args.download_timeout, + max_file_bytes=args.max_file_bytes, + same_schema=args.same_schema, + columns=columns, + show_progress=args.show_progress) + + indent = None if args.indent < 0 else args.indent + text = json.dumps(obj, sort_keys=args.sort_keys, indent=indent) + print(text) + + +if __name__ == '__main__': + main(parse_args()) diff --git a/streaming/base/format/parquet/indexing.py b/streaming/base/format/parquet/indexing.py index db9e2b2ab..cacdbc583 100644 --- a/streaming/base/format/parquet/indexing.py +++ b/streaming/base/format/parquet/indexing.py @@ -126,7 +126,7 @@ def _shard_info_to_schema(info: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: MDS schema. """ ret = {} - for key in ['column_names', 'column_encoding', 'column_sizes']: + for key in ['column_names', 'column_encodings', 'column_sizes']: ret[key] = info[key] return ret @@ -165,6 +165,9 @@ def index_parquet(*, complete. * If we are listing files, and remote is provided, the remote must be authoritative. + TODO: use num_procs. + TODO: use columns. + Args: local (str): Where the dataset is cached on the local filesystem. remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. diff --git a/streaming/base/storage/extra.py b/streaming/base/storage/extra.py index 286c05f8a..18ce1c87f 100644 --- a/streaming/base/storage/extra.py +++ b/streaming/base/storage/extra.py @@ -92,7 +92,7 @@ def _filter(keep: Optional[Union[str, Pattern, Callable[[str], bool]]], paths (Iterable[str], optional): Iterable of paths to filter. If empty, is the empty list. """ paths = paths or [] - if keep is None: + if not keep: pass elif isinstance(keep, str): keep_regex = re.compile(keep) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c0a8f7f3b..d32bb63a5 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -13,13 +13,13 @@ from numpy.typing import NDArray from typing_extensions import Self -from streaming.baes.storage.extra import wait_for_file_to_exist from streaming.base.compression import decompress from streaming.base.constant import TICK from streaming.base.distributed import barrier, get_local_rank from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.base.hashing import get_hash from streaming.base.storage import download_file +from streaming.base.storage.extra import wait_for_file_to_exist from streaming.base.util.retrying import retry from streaming.base.world import World diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py index 47ca28687..0153f15f7 100644 --- a/streaming/base/util/pretty.py +++ b/streaming/base/util/pretty.py @@ -12,16 +12,43 @@ ] -def get_list_arg(text: str) -> List[str]: - """Pass a list as a comma-delimted string. +def get_list_arg(text: str, sep: str = ',') -> List[str]: + """Pass a list as a comma-delimited string. Args: - text (str): Text to split. + text (str): Text to parse. Returns: - List[str]: Splits, if any. + List[str]: List of items. """ - return text.split(',') if text else [] + if not text: + return [] + + return text.split(sep) + + +def parse_str2str(text: str, sep: str = ',', eq: str = '=') -> Dict[str, str]: + """Pass a dict as a comma- and equals-delimited string. + + Args: + text (str): Text to parse. + sep (str): Separator text. Defaults to ``,``. + eq (str): Assignment text. Deffaults to ``=``. + + Returns: + Dict[str, str]: Mapping of str to str. + """ + if not text: + return {} + + ret = {} + parts = text.split(sep) + for part in parts: + key, val = part.split(eq) + if key in ret: + raise ValueError(f'Repeated key: {key} (text: {text}).') + ret[key] = val + return ret def _normalize_arg(text: str, units: Dict[str, int], to_type: type) -> Union[int, float]: From 4ea01b294e139b36722de26ba0e5020e430a7b8c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 06:19:40 -0700 Subject: [PATCH 12/45] Rename get_list_arg() -> parse_strs() in keeping with parse_str2str(), etc. --- scripts/serialization/survey_fixed_decimals.py | 5 +++-- streaming/base/util/__init__.py | 4 ++-- streaming/base/util/pretty.py | 4 ++-- streaming/text/convert/c4.py | 4 ++-- streaming/text/convert/enwiki_text.py | 4 ++-- streaming/text/convert/pile.py | 4 ++-- streaming/vision/convert/ade20k.py | 4 ++-- streaming/vision/convert/cifar10.py | 6 +++--- streaming/vision/convert/coco.py | 4 ++-- streaming/vision/convert/imagenet.py | 8 ++++---- tests/test_util.py | 6 +++--- 11 files changed, 27 insertions(+), 26 deletions(-) diff --git a/scripts/serialization/survey_fixed_decimals.py b/scripts/serialization/survey_fixed_decimals.py index 52c9a38cc..3cb8bc8db 100644 --- a/scripts/serialization/survey_fixed_decimals.py +++ b/scripts/serialization/survey_fixed_decimals.py @@ -7,6 +7,8 @@ import numpy as np +from streaming.base.util.pretty import parse_strs + def parse_args() -> Namespace: """Parse command-line arguments. @@ -107,8 +109,7 @@ def main(args: Namespace) -> None: print('- dec range: Range of decimal places (half left, half right).') print() - get_list_arg = lambda x: x.split(',') if x else [] - byte_widths = list(map(int, get_list_arg(args.byte_widths))) + byte_widths = list(map(int, parse_strs(args.byte_widths))) for byte_width in byte_widths: for is_signed in [False, True]: survey(args.min_exp_range, args.max_exp_range, byte_width, is_signed) diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py index bce2e06ab..dcf747f73 100644 --- a/streaming/base/util/__init__.py +++ b/streaming/base/util/__init__.py @@ -5,7 +5,7 @@ from streaming.base.util.importing import get_import_exception_message from streaming.base.util.merging import merge_index -from streaming.base.util.pretty import (get_list_arg, normalize_bin_bytes, normalize_bytes, +from streaming.base.util.pretty import (parse_strs, parse_str2str, normalize_bin_bytes, normalize_bytes, normalize_count, normalize_dec_bytes, normalize_duration) from streaming.base.util.retrying import retry from streaming.base.util.shared import clean_stale_shared_memory @@ -13,5 +13,5 @@ __all__ = [ 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', 'merge_index', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_dec_bytes', - 'normalize_duration', 'retry' + 'normalize_duration', 'parsee_strs', 'parse_str2str', 'retry' ] diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py index 0153f15f7..d2114d901 100644 --- a/streaming/base/util/pretty.py +++ b/streaming/base/util/pretty.py @@ -7,12 +7,12 @@ from typing import Dict, List, Union __all__ = [ - 'get_list_arg', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', + 'parsea_strs', 'parse_str2str', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration' ] -def get_list_arg(text: str, sep: str = ',') -> List[str]: +def parse_strs(text: str, sep: str = ',') -> List[str]: """Pass a list as a comma-delimited string. Args: diff --git a/streaming/text/convert/c4.py b/streaming/text/convert/c4.py index 5dc186c52..4c4efe78d 100644 --- a/streaming/text/convert/c4.py +++ b/streaming/text/convert/c4.py @@ -13,7 +13,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util.pretty import parse_strs def parse_args() -> Namespace: @@ -141,7 +141,7 @@ def main(args: Namespace) -> None: ('validation', 'val', 364608, 8), ] columns = {'text': 'str', 'timestamp': 'str', 'url': 'str'} - hashes = get_list_arg(args.hashes) + hashes = parse_strs(args.hashes) for old_split, new_split, num_samples, num_workers in splits: dataset = get(old_split) split_dir = os.path.join(args.out_root, new_split) diff --git a/streaming/text/convert/enwiki_text.py b/streaming/text/convert/enwiki_text.py index 97f428d11..345c6f15b 100644 --- a/streaming/text/convert/enwiki_text.py +++ b/streaming/text/convert/enwiki_text.py @@ -10,7 +10,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util import parse_strs def parse_args() -> Namespace: @@ -108,7 +108,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - hashes = get_list_arg(args.hashes) + hashes = parse_strs(args.hashes) basenames = [f'part-{i:05}-of-00500' for i in range(500)] split = 'train' diff --git a/streaming/text/convert/pile.py b/streaming/text/convert/pile.py index b01fb8027..26f4a9581 100644 --- a/streaming/text/convert/pile.py +++ b/streaming/text/convert/pile.py @@ -12,7 +12,7 @@ from typing import Dict, Iterator, List, Tuple from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util import parse_strs def parse_args() -> Namespace: @@ -190,7 +190,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - hashes = get_list_arg(args.hashes) + hashes = parse_strs(args.hashes) # Find the original JSONL files to convert. pattern = os.path.join(args.in_root, 'train', '*.jsonl') diff --git a/streaming/vision/convert/ade20k.py b/streaming/vision/convert/ade20k.py index 8d0598666..1ceed7187 100644 --- a/streaming/vision/convert/ade20k.py +++ b/streaming/vision/convert/ade20k.py @@ -12,7 +12,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util.pretty import parse_strs def parse_args() -> Namespace: @@ -164,7 +164,7 @@ def main(args: Namespace) -> None: out_dir = os.path.join(args.out_root, split) - hashes = get_list_arg(args.hashes) + hashes = parse_strs(args.hashes) if args.progress_bar: samples = tqdm(samples, leave=args.leave) diff --git a/streaming/vision/convert/cifar10.py b/streaming/vision/convert/cifar10.py index 1251338b1..669673c2c 100644 --- a/streaming/vision/convert/cifar10.py +++ b/streaming/vision/convert/cifar10.py @@ -7,7 +7,7 @@ from torchvision.datasets import CIFAR10 -from streaming.base.util import get_list_arg +from streaming.base.util import parse_strs from streaming.vision.convert.base import convert_image_class_dataset @@ -76,8 +76,8 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - splits = get_list_arg(args.splits) - hashes = get_list_arg(args.hashes) + splits = parse_strs(args.splits) + hashes = parse_strs(args.hashes) for split in splits: dataset = CIFAR10(root=args.in_root, train=(split == 'train'), download=True) convert_image_class_dataset(dataset, args.out_root, split, args.compression, hashes, diff --git a/streaming/vision/convert/coco.py b/streaming/vision/convert/coco.py index 2456fc953..c8cd16279 100644 --- a/streaming/vision/convert/coco.py +++ b/streaming/vision/convert/coco.py @@ -15,7 +15,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util import parse_strs def parse_args() -> Namespace: @@ -226,7 +226,7 @@ def main(args: Namespace) -> None: raise ValueError(f'Number of samples in a dataset doesn\'t match. Expected ' + f'{expected_num_samples}, but got {len(dataset)}') - hashes = get_list_arg(args.hashes) + hashes = parse_strs(args.hashes) if args.progress_bar: dataset = tqdm(each(dataset, shuffle), leave=args.leave, total=len(dataset)) diff --git a/streaming/vision/convert/imagenet.py b/streaming/vision/convert/imagenet.py index d350e9029..6ec07966a 100644 --- a/streaming/vision/convert/imagenet.py +++ b/streaming/vision/convert/imagenet.py @@ -13,7 +13,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming.base.util import parse_strs def parse_args() -> Namespace: @@ -133,10 +133,10 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - splits = get_list_arg(args.splits) + splits = parse_strs(args.splits) columns = {'i': 'int', 'x': 'jpeg', 'y': 'int'} - hashes = get_list_arg(args.hashes) - extensions = set(get_list_arg(args.extensions)) + hashes = parse_strs(args.hashes) + extensions = set(parse_strs(args.extensions)) class_names = None for split in splits: pattern = os.path.join(args.in_root, split, '*', '*') diff --git a/tests/test_util.py b/tests/test_util.py index bba0d2724..fa8703099 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,7 +15,7 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (clean_stale_shared_memory, get_list_arg, merge_index, +from streaming.base.util import (clean_stale_shared_memory, parse_strs, merge_index, normalize_bytes, normalize_count, retry) MY_PREFIX = 'train_' + str(time.time()) @@ -30,8 +30,8 @@ @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) -def test_get_list_arg(text: str, expected_output: List[Optional[str]]): - output = get_list_arg(text) +def test_parse_strs(text: str, expected_output: List[Optional[str]]): + output = parse_strs(text) assert output == expected_output From 157381a136c873e1362f56cd3a490630bd2fe423 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 06:26:43 -0700 Subject: [PATCH 13/45] Rename parse_(args stuff) -> unpack_(args stuff). --- scripts/serialization/survey_fixed_decimals.py | 4 ++-- streaming/base/cli/index_parquet.py | 4 ++-- streaming/base/util/__init__.py | 4 ++-- streaming/base/util/pretty.py | 6 +++--- streaming/text/convert/c4.py | 4 ++-- streaming/text/convert/enwiki_text.py | 4 ++-- streaming/text/convert/pile.py | 4 ++-- streaming/vision/convert/ade20k.py | 4 ++-- streaming/vision/convert/cifar10.py | 6 +++--- streaming/vision/convert/coco.py | 4 ++-- streaming/vision/convert/imagenet.py | 8 ++++---- tests/test_util.py | 6 +++--- 12 files changed, 29 insertions(+), 29 deletions(-) diff --git a/scripts/serialization/survey_fixed_decimals.py b/scripts/serialization/survey_fixed_decimals.py index 3cb8bc8db..d0ebebe83 100644 --- a/scripts/serialization/survey_fixed_decimals.py +++ b/scripts/serialization/survey_fixed_decimals.py @@ -7,7 +7,7 @@ import numpy as np -from streaming.base.util.pretty import parse_strs +from streaming.base.util.pretty import unpack_strs def parse_args() -> Namespace: @@ -109,7 +109,7 @@ def main(args: Namespace) -> None: print('- dec range: Range of decimal places (half left, half right).') print() - byte_widths = list(map(int, parse_strs(args.byte_widths))) + byte_widths = list(map(int, unpack_strs(args.byte_widths))) for byte_width in byte_widths: for is_signed in [False, True]: survey(args.min_exp_range, args.max_exp_range, byte_width, is_signed) diff --git a/streaming/base/cli/index_parquet.py b/streaming/base/cli/index_parquet.py index 310fd2980..1d0b0377f 100644 --- a/streaming/base/cli/index_parquet.py +++ b/streaming/base/cli/index_parquet.py @@ -7,7 +7,7 @@ from argparse import ArgumentParser, Namespace from streaming.base.format import index_parquet -from streaming.base.util.pretty import parse_str2str +from streaming.base.util.pretty import unpack_str2str def parse_args() -> Namespace: @@ -55,7 +55,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - columns = parse_str2str(args.columns) + columns = unpack_str2str(args.columns) obj = index_parquet(local=args.local, remote=args.remote, split=args.split, diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py index dcf747f73..8fabba8e3 100644 --- a/streaming/base/util/__init__.py +++ b/streaming/base/util/__init__.py @@ -5,7 +5,7 @@ from streaming.base.util.importing import get_import_exception_message from streaming.base.util.merging import merge_index -from streaming.base.util.pretty import (parse_strs, parse_str2str, normalize_bin_bytes, normalize_bytes, +from streaming.base.util.pretty import (unpack_strs, unpack_str2str, normalize_bin_bytes, normalize_bytes, normalize_count, normalize_dec_bytes, normalize_duration) from streaming.base.util.retrying import retry from streaming.base.util.shared import clean_stale_shared_memory @@ -13,5 +13,5 @@ __all__ = [ 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', 'merge_index', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_dec_bytes', - 'normalize_duration', 'parsee_strs', 'parse_str2str', 'retry' + 'normalize_duration', 'parsee_strs', 'unpack_str2str', 'retry' ] diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py index d2114d901..f588ce207 100644 --- a/streaming/base/util/pretty.py +++ b/streaming/base/util/pretty.py @@ -7,12 +7,12 @@ from typing import Dict, List, Union __all__ = [ - 'parsea_strs', 'parse_str2str', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', + 'unpack_strs', 'unpack_str2str', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration' ] -def parse_strs(text: str, sep: str = ',') -> List[str]: +def unpack_strs(text: str, sep: str = ',') -> List[str]: """Pass a list as a comma-delimited string. Args: @@ -27,7 +27,7 @@ def parse_strs(text: str, sep: str = ',') -> List[str]: return text.split(sep) -def parse_str2str(text: str, sep: str = ',', eq: str = '=') -> Dict[str, str]: +def unpack_str2str(text: str, sep: str = ',', eq: str = '=') -> Dict[str, str]: """Pass a dict as a comma- and equals-delimited string. Args: diff --git a/streaming/text/convert/c4.py b/streaming/text/convert/c4.py index 4c4efe78d..395cf1120 100644 --- a/streaming/text/convert/c4.py +++ b/streaming/text/convert/c4.py @@ -13,7 +13,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util.pretty import parse_strs +from streaming.base.util.pretty import unpack_strs def parse_args() -> Namespace: @@ -141,7 +141,7 @@ def main(args: Namespace) -> None: ('validation', 'val', 364608, 8), ] columns = {'text': 'str', 'timestamp': 'str', 'url': 'str'} - hashes = parse_strs(args.hashes) + hashes = unpack_strs(args.hashes) for old_split, new_split, num_samples, num_workers in splits: dataset = get(old_split) split_dir = os.path.join(args.out_root, new_split) diff --git a/streaming/text/convert/enwiki_text.py b/streaming/text/convert/enwiki_text.py index 345c6f15b..2a60043fe 100644 --- a/streaming/text/convert/enwiki_text.py +++ b/streaming/text/convert/enwiki_text.py @@ -10,7 +10,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import parse_strs +from streaming.base.util import unpack_strs def parse_args() -> Namespace: @@ -108,7 +108,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - hashes = parse_strs(args.hashes) + hashes = unpack_strs(args.hashes) basenames = [f'part-{i:05}-of-00500' for i in range(500)] split = 'train' diff --git a/streaming/text/convert/pile.py b/streaming/text/convert/pile.py index 26f4a9581..ce24d26ae 100644 --- a/streaming/text/convert/pile.py +++ b/streaming/text/convert/pile.py @@ -12,7 +12,7 @@ from typing import Dict, Iterator, List, Tuple from streaming.base import MDSWriter -from streaming.base.util import parse_strs +from streaming.base.util import unpack_strs def parse_args() -> Namespace: @@ -190,7 +190,7 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - hashes = parse_strs(args.hashes) + hashes = unpack_strs(args.hashes) # Find the original JSONL files to convert. pattern = os.path.join(args.in_root, 'train', '*.jsonl') diff --git a/streaming/vision/convert/ade20k.py b/streaming/vision/convert/ade20k.py index 1ceed7187..5043bd9c2 100644 --- a/streaming/vision/convert/ade20k.py +++ b/streaming/vision/convert/ade20k.py @@ -12,7 +12,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util.pretty import parse_strs +from streaming.base.util.pretty import unpack_strs def parse_args() -> Namespace: @@ -164,7 +164,7 @@ def main(args: Namespace) -> None: out_dir = os.path.join(args.out_root, split) - hashes = parse_strs(args.hashes) + hashes = unpack_strs(args.hashes) if args.progress_bar: samples = tqdm(samples, leave=args.leave) diff --git a/streaming/vision/convert/cifar10.py b/streaming/vision/convert/cifar10.py index 669673c2c..7cab9207e 100644 --- a/streaming/vision/convert/cifar10.py +++ b/streaming/vision/convert/cifar10.py @@ -7,7 +7,7 @@ from torchvision.datasets import CIFAR10 -from streaming.base.util import parse_strs +from streaming.base.util import unpack_strs from streaming.vision.convert.base import convert_image_class_dataset @@ -76,8 +76,8 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - splits = parse_strs(args.splits) - hashes = parse_strs(args.hashes) + splits = unpack_strs(args.splits) + hashes = unpack_strs(args.hashes) for split in splits: dataset = CIFAR10(root=args.in_root, train=(split == 'train'), download=True) convert_image_class_dataset(dataset, args.out_root, split, args.compression, hashes, diff --git a/streaming/vision/convert/coco.py b/streaming/vision/convert/coco.py index c8cd16279..65ba7da8c 100644 --- a/streaming/vision/convert/coco.py +++ b/streaming/vision/convert/coco.py @@ -15,7 +15,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import parse_strs +from streaming.base.util import unpack_strs def parse_args() -> Namespace: @@ -226,7 +226,7 @@ def main(args: Namespace) -> None: raise ValueError(f'Number of samples in a dataset doesn\'t match. Expected ' + f'{expected_num_samples}, but got {len(dataset)}') - hashes = parse_strs(args.hashes) + hashes = unpack_strs(args.hashes) if args.progress_bar: dataset = tqdm(each(dataset, shuffle), leave=args.leave, total=len(dataset)) diff --git a/streaming/vision/convert/imagenet.py b/streaming/vision/convert/imagenet.py index 6ec07966a..6b2efd670 100644 --- a/streaming/vision/convert/imagenet.py +++ b/streaming/vision/convert/imagenet.py @@ -13,7 +13,7 @@ from tqdm import tqdm from streaming.base import MDSWriter -from streaming.base.util import parse_strs +from streaming.base.util import unpack_strs def parse_args() -> Namespace: @@ -133,10 +133,10 @@ def main(args: Namespace) -> None: Args: args (Namespace): command-line arguments. """ - splits = parse_strs(args.splits) + splits = unpack_strs(args.splits) columns = {'i': 'int', 'x': 'jpeg', 'y': 'int'} - hashes = parse_strs(args.hashes) - extensions = set(parse_strs(args.extensions)) + hashes = unpack_strs(args.hashes) + extensions = set(unpack_strs(args.extensions)) class_names = None for split in splits: pattern = os.path.join(args.in_root, split, '*', '*') diff --git a/tests/test_util.py b/tests/test_util.py index fa8703099..6210bb422 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,7 +15,7 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (clean_stale_shared_memory, parse_strs, merge_index, +from streaming.base.util import (clean_stale_shared_memory, unpack_strs, merge_index, normalize_bytes, normalize_count, retry) MY_PREFIX = 'train_' + str(time.time()) @@ -30,8 +30,8 @@ @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) -def test_parse_strs(text: str, expected_output: List[Optional[str]]): - output = parse_strs(text) +def test_unpack_strs(text: str, expected_output: List[Optional[str]]): + output = unpack_strs(text) assert output == expected_output From c72127f1bae8e715c3230f3ff63562500fb86204 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 06:36:01 -0700 Subject: [PATCH 14/45] Long lines. --- streaming/base/storage/upload.py | 17 +++++++++-------- streaming/base/util/__init__.py | 9 +++++---- tests/base/converters/test_dataframe_to_mds.py | 10 +++++----- tests/test_util.py | 14 ++++++++------ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index dab805bf5..2c89c08de 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -75,14 +75,15 @@ def get(cls, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` + already exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. """ cls._validate(cls, out) - obj = urllib.parse.urlparse(out) if isinstance(out, str) else urllib.parse.urlparse(out[1]) + obj = urllib.parse.urlparse(out) if isinstance(out, str) else \ + urllib.parse.urlparse(out[1]) provider_prefix = obj.scheme if obj.scheme == 'dbfs': path = pathlib.Path(out) if isinstance(out, str) else pathlib.Path(out[1]) @@ -142,8 +143,8 @@ def __init__(self, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` + already exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -171,8 +172,8 @@ def __init__(self, raise FileExistsError(f'Directory is not empty: {self.local}') else: logger.warning( - f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.' - ) + f'Directory {self.local} exists and not empty. But continue to mkdir since ' + + f'exist_ok is set to be True.') os.makedirs(self.local, exist_ok=True) @@ -774,7 +775,7 @@ def check_container_exists(self, remote: str): error: Container does not exist. """ container_name = urllib.parse.urlparse(remote).netloc - if self.azure_service.get_file_system_client(file_system=container_name).exists() is False: + if not self.azure_service.get_file_system_client(file_system=container_name).exists(): raise FileNotFoundError( f'Either container `{container_name}` does not exist! ' + f'or check the container permission.',) diff --git a/streaming/base/util/__init__.py b/streaming/base/util/__init__.py index 8fabba8e3..b5352dd47 100644 --- a/streaming/base/util/__init__.py +++ b/streaming/base/util/__init__.py @@ -5,13 +5,14 @@ from streaming.base.util.importing import get_import_exception_message from streaming.base.util.merging import merge_index -from streaming.base.util.pretty import (unpack_strs, unpack_str2str, normalize_bin_bytes, normalize_bytes, - normalize_count, normalize_dec_bytes, normalize_duration) +from streaming.base.util.pretty import (normalize_bin_bytes, normalize_bytes, normalize_count, + normalize_dec_bytes, normalize_duration, unpack_str2str, + unpack_strs) from streaming.base.util.retrying import retry from streaming.base.util.shared import clean_stale_shared_memory __all__ = [ - 'clean_stale_shared_memory', 'get_import_exception_message', 'get_list_arg', 'merge_index', + 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_dec_bytes', - 'normalize_duration', 'parsee_strs', 'unpack_str2str', 'retry' + 'normalize_duration', 'unpack_strs', 'unpack_str2str', 'retry' ] diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a99ea973a..dc19d219b 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -13,8 +13,8 @@ from streaming.base.converters import dataframe_to_mds -os.environ[ - 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls +# set to yes to all fork process in spark calls +os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' class TestDataFrameToMDS: @@ -178,9 +178,9 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer nsamples += shards[0]['samples'] assert nsamples == sum([a['samples'] for a in mgi['shards']]) else: - assert os.path.exists( - os.path.join(out, 'index.json') - ), 'merged index.json was not found when keep_local is False but no remote part exists' + assert os.path.exists(os.path.join(out, 'index.json')), ( + 'merged index.json was not found when keep_local is False but no remote ' + + 'part exists') else: assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' diff --git a/tests/test_util.py b/tests/test_util.py index 6210bb422..65bf54529 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,8 +15,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (clean_stale_shared_memory, unpack_strs, merge_index, - normalize_bytes, normalize_count, retry) +from streaming.base.util import (clean_stale_shared_memory, merge_index, normalize_bytes, + normalize_count, retry, unpack_strs) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -24,8 +24,8 @@ 's3://': 'testing-bucket', 'oci://': 'testing-bucket', } -os.environ[ - 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls +# set to yes to all fork process in spark calls +os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), @@ -167,7 +167,8 @@ def get_expected(mds_root: str): ), f'{local_merged_index_path} does not exist when keep_local is {keep_local}' merged_index = json.load(open(local_merged_index_path, 'r')) n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) - assert n_shard_files == expected_n_shard_files, f'expected {expected_n_shard_files} shard files but got {n_shard_files}' + assert n_shard_files == expected_n_shard_files, \ + f'expected {expected_n_shard_files} shard files but got {n_shard_files}' @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) @@ -179,7 +180,8 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc 1. All URLs are str (local). All URLs are accessible locally -> no download 2. All URLs are str (local). At least one url is unaccessible locally -> Error 3. All URLs are tuple (local, remote). All URLs are accessible locally -> no download - 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all + 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> \ + download all 5. All URLs are str (remote) -> download all """ from decimal import Decimal From d2be6a01c6f679e8f2e00893937095964823c324 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 07:28:25 -0700 Subject: [PATCH 15/45] Populate streaming/examples/ with SD subclasses, also streaming/benchmarks/. --- {examples => notebooks}/cifar10.ipynb | 0 {examples => notebooks}/facesynthetics.ipynb | 0 .../multiprocess_dataset_conversion.ipynb | 0 .../spark_dataframe_to_MDS.ipynb | 0 {examples => notebooks}/synthetic_nlp.ipynb | 0 streaming/{vision/base.py => base/vision.py} | 63 +++++++++- .../benchmarks/backends}/__init__.py | 0 .../benchmarks/backends}/generate_datasets.py | 2 +- .../benchmarks/backends}/task.py | 0 .../benchmarks}/compression/bench.py | 0 .../benchmarks}/compression/plot.py | 0 .../benchmarks}/epoch/bench.py | 0 .../benchmarks}/hashing/bench.py | 0 .../benchmarks}/hashing/plot.py | 0 .../benchmarks/partitioning}/bench.py | 0 .../benchmarks/partitioning}/diff.py | 0 .../benchmarks/partitioning}/plot.py | 0 .../benchmarks/partitioning}/txt.py | 0 .../benchmarks/partitioning}/web.py | 0 .../benchmarks}/samples/bench_and_plot.py | 0 .../benchmarks}/serialization/compare.py | 0 .../serialization/survey_fixed_decimals.py | 0 .../benchmarks/shuffling}/bench.py | 0 .../benchmarks/shuffling}/plot.py | 0 .../benchmarks/shuffling}/vis.py | 0 .../__init__.py => examples/__init__py} | 0 .../multimodal/convert/__init__.py | 0 .../multimodal}/laion400m/README.md | 0 .../multimodal}/laion400m/__init__.py | 0 .../laion400m/convert_and_upload.py | 0 .../laion400m/convert_and_upload.sh | 0 .../multimodal}/laion400m/download_data.sh | 0 .../multimodal}/laion400m/download_meta.sh | 0 .../multimodal/webvid/read.py} | 0 .../multimodal/webvid}/webvid/bench_inside.py | 0 .../webvid}/webvid/bench_outside_dt.py | 0 .../webvid}/webvid/bench_outside_gi.py | 0 .../multimodal/webvid}/webvid/plot.py | 0 .../multimodal/webvid/write}/README.md | 0 .../multimodal/webvid/write}/__init__.py | 0 .../multimodal/webvid/write}/crawl_webvid.py | 0 .../webvid/write}/crawl_webvid_subsets.py | 0 .../webvid/write}/extract_webvid_videos.py | 0 streaming/examples/text/c4/README.md | 7 ++ .../{text/c4.py => examples/text/c4/read.py} | 0 .../c4.py => examples/text/c4/write.py} | 0 .../text/enwiki_tok}/__init__.py | 0 .../text/enwiki_tok}/mds/README.md | 0 .../text/enwiki_tok/mds}/__init__.py | 0 .../mds/create_pretraining_data.py | 0 .../text/enwiki_tok}/mds/make_eval.sh | 0 .../enwiki_tok}/mds/make_train_parallel.py | 0 .../enwiki_tok}/mds/merge_shard_groups.py | 0 .../text/enwiki_tok}/mds/pick_eval_samples.py | 0 .../text/enwiki_tok}/mds/tokenization.py | 0 .../text/enwiki_tok}/mds/vocab.txt | 0 .../text/enwiki_tok/tfrecord/__init__.py | 0 .../enwiki_tok}/tfrecord/count_samples.py | 0 .../tfrecord/create_pretraining_data.py | 0 .../text/enwiki_tok}/tfrecord/make_eval.sh | 0 .../text/enwiki_tok}/tfrecord/make_train.sh | 0 .../tfrecord/make_train_parallel.py | 0 .../enwiki_tok}/tfrecord/pick_eval_samples.py | 0 .../text/enwiki_tok}/tfrecord/tokenization.py | 0 .../text/enwiki_tok}/tfrecord/vocab.txt | 0 streaming/examples/text/enwiki_txt/README.md | 26 ++++ .../text/enwiki_txt}/enwiki.py | 0 .../text/enwiki_txt/write.py} | 0 streaming/examples/text/pile/README.md | 19 +++ .../pile.py => examples/text/pile/read.py} | 0 .../pile.py => examples/text/pile/write.py} | 0 streaming/examples/vision/ade20k/README.md | 19 +++ .../vision/ade20k/read.py} | 0 .../vision/ade20k/write.py} | 0 streaming/examples/vision/cifar10/README.md | 7 ++ .../vision/cifar10/read.py} | 0 .../vision/cifar10/write.py} | 0 .../vision/cifar10/write_fake.py} | 0 streaming/examples/vision/coco/README.md | 38 ++++++ .../coco.py => examples/vision/coco/read.py} | 0 .../coco.py => examples/vision/coco/write.py} | 0 streaming/examples/vision/imagenet/README.md | 38 ++++++ .../vision/imagenet/read.py} | 0 .../vision/imagenet/write.py} | 0 streaming/multimodal/__init__.py | 8 -- .../multimodal/convert/laion/__init__.py | 4 - streaming/text/__init__.py | 10 -- streaming/text/convert/README.md | 69 ----------- streaming/text/convert/__init__.py | 4 - streaming/vision/__init__.py | 11 -- streaming/vision/convert/README.md | 113 ------------------ streaming/vision/convert/__init__.py | 4 - streaming/vision/convert/base.py | 68 ----------- tests/test_streaming_remote.py | 51 -------- 94 files changed, 216 insertions(+), 345 deletions(-) rename {examples => notebooks}/cifar10.ipynb (100%) rename {examples => notebooks}/facesynthetics.ipynb (100%) rename {examples => notebooks}/multiprocess_dataset_conversion.ipynb (100%) rename {examples => notebooks}/spark_dataframe_to_MDS.ipynb (100%) rename {examples => notebooks}/synthetic_nlp.ipynb (100%) rename streaming/{vision/base.py => base/vision.py} (77%) rename {benchmarks/backends-and-formats => streaming/benchmarks/backends}/__init__.py (100%) rename {benchmarks/backends-and-formats => streaming/benchmarks/backends}/generate_datasets.py (99%) rename {benchmarks/backends-and-formats => streaming/benchmarks/backends}/task.py (100%) rename {scripts => streaming/benchmarks}/compression/bench.py (100%) rename {scripts => streaming/benchmarks}/compression/plot.py (100%) rename {scripts => streaming/benchmarks}/epoch/bench.py (100%) rename {scripts => streaming/benchmarks}/hashing/bench.py (100%) rename {scripts => streaming/benchmarks}/hashing/plot.py (100%) rename {scripts/partition => streaming/benchmarks/partitioning}/bench.py (100%) rename {scripts/partition => streaming/benchmarks/partitioning}/diff.py (100%) rename {scripts/partition => streaming/benchmarks/partitioning}/plot.py (100%) rename {scripts/partition => streaming/benchmarks/partitioning}/txt.py (100%) rename {scripts/partition => streaming/benchmarks/partitioning}/web.py (100%) rename {scripts => streaming/benchmarks}/samples/bench_and_plot.py (100%) rename {scripts => streaming/benchmarks}/serialization/compare.py (100%) rename {scripts => streaming/benchmarks}/serialization/survey_fixed_decimals.py (100%) rename {scripts/shuffle => streaming/benchmarks/shuffling}/bench.py (100%) rename {scripts/shuffle => streaming/benchmarks/shuffling}/plot.py (100%) rename {scripts/shuffle => streaming/benchmarks/shuffling}/vis.py (100%) rename streaming/{text/convert/enwiki/__init__.py => examples/__init__py} (100%) rename streaming/{ => examples}/multimodal/convert/__init__.py (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/README.md (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/__init__.py (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/convert_and_upload.py (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/convert_and_upload.sh (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/download_data.sh (100%) rename streaming/{multimodal/convert/laion => examples/multimodal}/laion400m/download_meta.sh (100%) rename streaming/{multimodal/webvid.py => examples/multimodal/webvid/read.py} (100%) rename {scripts => streaming/examples/multimodal/webvid}/webvid/bench_inside.py (100%) rename {scripts => streaming/examples/multimodal/webvid}/webvid/bench_outside_dt.py (100%) rename {scripts => streaming/examples/multimodal/webvid}/webvid/bench_outside_gi.py (100%) rename {scripts => streaming/examples/multimodal/webvid}/webvid/plot.py (100%) rename streaming/{multimodal/convert/webvid => examples/multimodal/webvid/write}/README.md (100%) rename streaming/{multimodal/convert/webvid => examples/multimodal/webvid/write}/__init__.py (100%) rename streaming/{multimodal/convert/webvid => examples/multimodal/webvid/write}/crawl_webvid.py (100%) rename streaming/{multimodal/convert/webvid => examples/multimodal/webvid/write}/crawl_webvid_subsets.py (100%) rename streaming/{multimodal/convert/webvid => examples/multimodal/webvid/write}/extract_webvid_videos.py (100%) create mode 100644 streaming/examples/text/c4/README.md rename streaming/{text/c4.py => examples/text/c4/read.py} (100%) rename streaming/{text/convert/c4.py => examples/text/c4/write.py} (100%) rename streaming/{text/convert/enwiki/mds => examples/text/enwiki_tok}/__init__.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/README.md (100%) rename streaming/{text/convert/enwiki/tfrecord => examples/text/enwiki_tok/mds}/__init__.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/create_pretraining_data.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/make_eval.sh (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/make_train_parallel.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/merge_shard_groups.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/pick_eval_samples.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/tokenization.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/mds/vocab.txt (100%) create mode 100644 streaming/examples/text/enwiki_tok/tfrecord/__init__.py rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/count_samples.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/create_pretraining_data.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_eval.sh (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_train.sh (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_train_parallel.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/pick_eval_samples.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/tokenization.py (100%) rename streaming/{text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/vocab.txt (100%) create mode 100644 streaming/examples/text/enwiki_txt/README.md rename streaming/{text => examples/text/enwiki_txt}/enwiki.py (100%) rename streaming/{text/convert/enwiki_text.py => examples/text/enwiki_txt/write.py} (100%) create mode 100644 streaming/examples/text/pile/README.md rename streaming/{text/pile.py => examples/text/pile/read.py} (100%) rename streaming/{text/convert/pile.py => examples/text/pile/write.py} (100%) create mode 100644 streaming/examples/vision/ade20k/README.md rename streaming/{vision/ade20k.py => examples/vision/ade20k/read.py} (100%) rename streaming/{vision/convert/ade20k.py => examples/vision/ade20k/write.py} (100%) create mode 100644 streaming/examples/vision/cifar10/README.md rename streaming/{vision/cifar10.py => examples/vision/cifar10/read.py} (100%) rename streaming/{vision/convert/cifar10.py => examples/vision/cifar10/write.py} (100%) rename streaming/{vision/convert/fake_cifar10.py => examples/vision/cifar10/write_fake.py} (100%) create mode 100644 streaming/examples/vision/coco/README.md rename streaming/{vision/coco.py => examples/vision/coco/read.py} (100%) rename streaming/{vision/convert/coco.py => examples/vision/coco/write.py} (100%) create mode 100644 streaming/examples/vision/imagenet/README.md rename streaming/{vision/imagenet.py => examples/vision/imagenet/read.py} (100%) rename streaming/{vision/convert/imagenet.py => examples/vision/imagenet/write.py} (100%) delete mode 100644 streaming/multimodal/__init__.py delete mode 100644 streaming/multimodal/convert/laion/__init__.py delete mode 100644 streaming/text/__init__.py delete mode 100644 streaming/text/convert/README.md delete mode 100644 streaming/text/convert/__init__.py delete mode 100644 streaming/vision/__init__.py delete mode 100644 streaming/vision/convert/README.md delete mode 100644 streaming/vision/convert/__init__.py delete mode 100644 streaming/vision/convert/base.py diff --git a/examples/cifar10.ipynb b/notebooks/cifar10.ipynb similarity index 100% rename from examples/cifar10.ipynb rename to notebooks/cifar10.ipynb diff --git a/examples/facesynthetics.ipynb b/notebooks/facesynthetics.ipynb similarity index 100% rename from examples/facesynthetics.ipynb rename to notebooks/facesynthetics.ipynb diff --git a/examples/multiprocess_dataset_conversion.ipynb b/notebooks/multiprocess_dataset_conversion.ipynb similarity index 100% rename from examples/multiprocess_dataset_conversion.ipynb rename to notebooks/multiprocess_dataset_conversion.ipynb diff --git a/examples/spark_dataframe_to_MDS.ipynb b/notebooks/spark_dataframe_to_MDS.ipynb similarity index 100% rename from examples/spark_dataframe_to_MDS.ipynb rename to notebooks/spark_dataframe_to_MDS.ipynb diff --git a/examples/synthetic_nlp.ipynb b/notebooks/synthetic_nlp.ipynb similarity index 100% rename from examples/synthetic_nlp.ipynb rename to notebooks/synthetic_nlp.ipynb diff --git a/streaming/vision/base.py b/streaming/base/vision.py similarity index 77% rename from streaming/vision/base.py rename to streaming/base/vision.py index 564305849..b3fc20790 100644 --- a/streaming/vision/base.py +++ b/streaming/base/vision.py @@ -3,12 +3,16 @@ """Base classes for computer vision :class:`StreamingDataset`s.""" -from typing import Any, Callable, Optional, Tuple +import os +from typing import Any, Callable, List, Optional, Tuple +import numpy as np +from torch.utils.data import Dataset from torchvision.datasets import VisionDataset from torchvision.transforms.functional import to_tensor +from tqdm import tqdm -from streaming.base import StreamingDataset +from streaming.base import MDSWriter, StreamingDataset __all__ = ['StreamingVisionDataset'] @@ -174,3 +178,58 @@ def get_item(self, idx: int) -> Any: x = obj['x'] y = obj['y'] return self.transforms(x, y) + + +def convert_image_class_dataset(dataset: Dataset, + out_root: str, + split: Optional[str] = None, + compression: Optional[str] = None, + hashes: Optional[List[str]] = None, + size_limit: int = 1 << 24, + progress_bar: bool = True, + leave: bool = False, + encoding: str = 'pil') -> None: + """Convert an image classification Dataset. + + Args: + dataset (Dataset): The dataset object to convert. + out_root (str): Output directory where shards are cached by split. + remote (str, optional): Remote dataset directory where shards are uploaded by split. + split (str, optional): Which dataset split to use, if any. Defaults to ``None``. + compression (str, optional): Optional compression. Defaults to ``None``. + hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. + Defaults to ``None``. + size_limit (int): Uncompressed shard size limit, at which point it flushes the shard and + starts a new one. Defaults to ``1 << 26``. + progress_bar (bool): Whether to display a progress bar while converting. + Defaults to ``True``. + leave (bool): Whether to leave the progress bar in the console when done. Defaults to + ``False``. + encoding (str): MDS encoding to use for the image data. Defaults to ``pil``. + """ + split = split or '' + columns = { + 'i': 'int', + 'x': encoding, + 'y': 'int', + } + hashes = hashes or [] + indices = np.random.permutation(len(dataset)).tolist() # pyright: ignore + if progress_bar: + indices = tqdm(indices, leave=leave) + + out_split_dir = os.path.join(out_root, split) + + with MDSWriter(out=out_split_dir, + columns=columns, + compression=compression, + hashes=hashes, + size_limit=size_limit, + progress_bar=progress_bar) as out: + for i in indices: + x, y = dataset[i] + out.write({ + 'i': i, + 'x': x, + 'y': y, + }) diff --git a/benchmarks/backends-and-formats/__init__.py b/streaming/benchmarks/backends/__init__.py similarity index 100% rename from benchmarks/backends-and-formats/__init__.py rename to streaming/benchmarks/backends/__init__.py diff --git a/benchmarks/backends-and-formats/generate_datasets.py b/streaming/benchmarks/backends/generate_datasets.py similarity index 99% rename from benchmarks/backends-and-formats/generate_datasets.py rename to streaming/benchmarks/backends/generate_datasets.py index 1bb661863..ad658860d 100644 --- a/benchmarks/backends-and-formats/generate_datasets.py +++ b/streaming/benchmarks/backends/generate_datasets.py @@ -37,7 +37,7 @@ def parse_args() -> Namespace: args.add_argument('--num_train', type=int, default=1 << 21) args.add_argument('--num_val', type=int, default=1 << 17) - args.add_argument('--data_root', type=str, default='data/compare-backends/') + args.add_argument('--data_root', type=str, default='data/backendss/') args.add_argument('--csv', type=str, default='csv') args.add_argument('--jsonl', type=str, default='jsonl') args.add_argument('--lance', type=str, default='lance') diff --git a/benchmarks/backends-and-formats/task.py b/streaming/benchmarks/backends/task.py similarity index 100% rename from benchmarks/backends-and-formats/task.py rename to streaming/benchmarks/backends/task.py diff --git a/scripts/compression/bench.py b/streaming/benchmarks/compression/bench.py similarity index 100% rename from scripts/compression/bench.py rename to streaming/benchmarks/compression/bench.py diff --git a/scripts/compression/plot.py b/streaming/benchmarks/compression/plot.py similarity index 100% rename from scripts/compression/plot.py rename to streaming/benchmarks/compression/plot.py diff --git a/scripts/epoch/bench.py b/streaming/benchmarks/epoch/bench.py similarity index 100% rename from scripts/epoch/bench.py rename to streaming/benchmarks/epoch/bench.py diff --git a/scripts/hashing/bench.py b/streaming/benchmarks/hashing/bench.py similarity index 100% rename from scripts/hashing/bench.py rename to streaming/benchmarks/hashing/bench.py diff --git a/scripts/hashing/plot.py b/streaming/benchmarks/hashing/plot.py similarity index 100% rename from scripts/hashing/plot.py rename to streaming/benchmarks/hashing/plot.py diff --git a/scripts/partition/bench.py b/streaming/benchmarks/partitioning/bench.py similarity index 100% rename from scripts/partition/bench.py rename to streaming/benchmarks/partitioning/bench.py diff --git a/scripts/partition/diff.py b/streaming/benchmarks/partitioning/diff.py similarity index 100% rename from scripts/partition/diff.py rename to streaming/benchmarks/partitioning/diff.py diff --git a/scripts/partition/plot.py b/streaming/benchmarks/partitioning/plot.py similarity index 100% rename from scripts/partition/plot.py rename to streaming/benchmarks/partitioning/plot.py diff --git a/scripts/partition/txt.py b/streaming/benchmarks/partitioning/txt.py similarity index 100% rename from scripts/partition/txt.py rename to streaming/benchmarks/partitioning/txt.py diff --git a/scripts/partition/web.py b/streaming/benchmarks/partitioning/web.py similarity index 100% rename from scripts/partition/web.py rename to streaming/benchmarks/partitioning/web.py diff --git a/scripts/samples/bench_and_plot.py b/streaming/benchmarks/samples/bench_and_plot.py similarity index 100% rename from scripts/samples/bench_and_plot.py rename to streaming/benchmarks/samples/bench_and_plot.py diff --git a/scripts/serialization/compare.py b/streaming/benchmarks/serialization/compare.py similarity index 100% rename from scripts/serialization/compare.py rename to streaming/benchmarks/serialization/compare.py diff --git a/scripts/serialization/survey_fixed_decimals.py b/streaming/benchmarks/serialization/survey_fixed_decimals.py similarity index 100% rename from scripts/serialization/survey_fixed_decimals.py rename to streaming/benchmarks/serialization/survey_fixed_decimals.py diff --git a/scripts/shuffle/bench.py b/streaming/benchmarks/shuffling/bench.py similarity index 100% rename from scripts/shuffle/bench.py rename to streaming/benchmarks/shuffling/bench.py diff --git a/scripts/shuffle/plot.py b/streaming/benchmarks/shuffling/plot.py similarity index 100% rename from scripts/shuffle/plot.py rename to streaming/benchmarks/shuffling/plot.py diff --git a/scripts/shuffle/vis.py b/streaming/benchmarks/shuffling/vis.py similarity index 100% rename from scripts/shuffle/vis.py rename to streaming/benchmarks/shuffling/vis.py diff --git a/streaming/text/convert/enwiki/__init__.py b/streaming/examples/__init__py similarity index 100% rename from streaming/text/convert/enwiki/__init__.py rename to streaming/examples/__init__py diff --git a/streaming/multimodal/convert/__init__.py b/streaming/examples/multimodal/convert/__init__.py similarity index 100% rename from streaming/multimodal/convert/__init__.py rename to streaming/examples/multimodal/convert/__init__.py diff --git a/streaming/multimodal/convert/laion/laion400m/README.md b/streaming/examples/multimodal/laion400m/README.md similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/README.md rename to streaming/examples/multimodal/laion400m/README.md diff --git a/streaming/multimodal/convert/laion/laion400m/__init__.py b/streaming/examples/multimodal/laion400m/__init__.py similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/__init__.py rename to streaming/examples/multimodal/laion400m/__init__.py diff --git a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.py b/streaming/examples/multimodal/laion400m/convert_and_upload.py similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/convert_and_upload.py rename to streaming/examples/multimodal/laion400m/convert_and_upload.py diff --git a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh b/streaming/examples/multimodal/laion400m/convert_and_upload.sh similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh rename to streaming/examples/multimodal/laion400m/convert_and_upload.sh diff --git a/streaming/multimodal/convert/laion/laion400m/download_data.sh b/streaming/examples/multimodal/laion400m/download_data.sh similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/download_data.sh rename to streaming/examples/multimodal/laion400m/download_data.sh diff --git a/streaming/multimodal/convert/laion/laion400m/download_meta.sh b/streaming/examples/multimodal/laion400m/download_meta.sh similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/download_meta.sh rename to streaming/examples/multimodal/laion400m/download_meta.sh diff --git a/streaming/multimodal/webvid.py b/streaming/examples/multimodal/webvid/read.py similarity index 100% rename from streaming/multimodal/webvid.py rename to streaming/examples/multimodal/webvid/read.py diff --git a/scripts/webvid/bench_inside.py b/streaming/examples/multimodal/webvid/webvid/bench_inside.py similarity index 100% rename from scripts/webvid/bench_inside.py rename to streaming/examples/multimodal/webvid/webvid/bench_inside.py diff --git a/scripts/webvid/bench_outside_dt.py b/streaming/examples/multimodal/webvid/webvid/bench_outside_dt.py similarity index 100% rename from scripts/webvid/bench_outside_dt.py rename to streaming/examples/multimodal/webvid/webvid/bench_outside_dt.py diff --git a/scripts/webvid/bench_outside_gi.py b/streaming/examples/multimodal/webvid/webvid/bench_outside_gi.py similarity index 100% rename from scripts/webvid/bench_outside_gi.py rename to streaming/examples/multimodal/webvid/webvid/bench_outside_gi.py diff --git a/scripts/webvid/plot.py b/streaming/examples/multimodal/webvid/webvid/plot.py similarity index 100% rename from scripts/webvid/plot.py rename to streaming/examples/multimodal/webvid/webvid/plot.py diff --git a/streaming/multimodal/convert/webvid/README.md b/streaming/examples/multimodal/webvid/write/README.md similarity index 100% rename from streaming/multimodal/convert/webvid/README.md rename to streaming/examples/multimodal/webvid/write/README.md diff --git a/streaming/multimodal/convert/webvid/__init__.py b/streaming/examples/multimodal/webvid/write/__init__.py similarity index 100% rename from streaming/multimodal/convert/webvid/__init__.py rename to streaming/examples/multimodal/webvid/write/__init__.py diff --git a/streaming/multimodal/convert/webvid/crawl_webvid.py b/streaming/examples/multimodal/webvid/write/crawl_webvid.py similarity index 100% rename from streaming/multimodal/convert/webvid/crawl_webvid.py rename to streaming/examples/multimodal/webvid/write/crawl_webvid.py diff --git a/streaming/multimodal/convert/webvid/crawl_webvid_subsets.py b/streaming/examples/multimodal/webvid/write/crawl_webvid_subsets.py similarity index 100% rename from streaming/multimodal/convert/webvid/crawl_webvid_subsets.py rename to streaming/examples/multimodal/webvid/write/crawl_webvid_subsets.py diff --git a/streaming/multimodal/convert/webvid/extract_webvid_videos.py b/streaming/examples/multimodal/webvid/write/extract_webvid_videos.py similarity index 100% rename from streaming/multimodal/convert/webvid/extract_webvid_videos.py rename to streaming/examples/multimodal/webvid/write/extract_webvid_videos.py diff --git a/streaming/examples/text/c4/README.md b/streaming/examples/text/c4/README.md new file mode 100644 index 000000000..e819daf8d --- /dev/null +++ b/streaming/examples/text/c4/README.md @@ -0,0 +1,7 @@ +### [C4: Colossal, Cleaned, Common Crawl dataset](https://huggingface.co/datasets/c4) + +1. Run the [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) script as shown below. The script downloads the raw format with `train` and `val` splits from HuggingFace hub and converts to StreamingDataset MDS format into their own split directories. For more advanced use cases, please see the supported arguments for [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) and modify as necessary. + + ``` + python c4.py --out_root + ``` diff --git a/streaming/text/c4.py b/streaming/examples/text/c4/read.py similarity index 100% rename from streaming/text/c4.py rename to streaming/examples/text/c4/read.py diff --git a/streaming/text/convert/c4.py b/streaming/examples/text/c4/write.py similarity index 100% rename from streaming/text/convert/c4.py rename to streaming/examples/text/c4/write.py diff --git a/streaming/text/convert/enwiki/mds/__init__.py b/streaming/examples/text/enwiki_tok/__init__.py similarity index 100% rename from streaming/text/convert/enwiki/mds/__init__.py rename to streaming/examples/text/enwiki_tok/__init__.py diff --git a/streaming/text/convert/enwiki/mds/README.md b/streaming/examples/text/enwiki_tok/mds/README.md similarity index 100% rename from streaming/text/convert/enwiki/mds/README.md rename to streaming/examples/text/enwiki_tok/mds/README.md diff --git a/streaming/text/convert/enwiki/tfrecord/__init__.py b/streaming/examples/text/enwiki_tok/mds/__init__.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/__init__.py rename to streaming/examples/text/enwiki_tok/mds/__init__.py diff --git a/streaming/text/convert/enwiki/mds/create_pretraining_data.py b/streaming/examples/text/enwiki_tok/mds/create_pretraining_data.py similarity index 100% rename from streaming/text/convert/enwiki/mds/create_pretraining_data.py rename to streaming/examples/text/enwiki_tok/mds/create_pretraining_data.py diff --git a/streaming/text/convert/enwiki/mds/make_eval.sh b/streaming/examples/text/enwiki_tok/mds/make_eval.sh similarity index 100% rename from streaming/text/convert/enwiki/mds/make_eval.sh rename to streaming/examples/text/enwiki_tok/mds/make_eval.sh diff --git a/streaming/text/convert/enwiki/mds/make_train_parallel.py b/streaming/examples/text/enwiki_tok/mds/make_train_parallel.py similarity index 100% rename from streaming/text/convert/enwiki/mds/make_train_parallel.py rename to streaming/examples/text/enwiki_tok/mds/make_train_parallel.py diff --git a/streaming/text/convert/enwiki/mds/merge_shard_groups.py b/streaming/examples/text/enwiki_tok/mds/merge_shard_groups.py similarity index 100% rename from streaming/text/convert/enwiki/mds/merge_shard_groups.py rename to streaming/examples/text/enwiki_tok/mds/merge_shard_groups.py diff --git a/streaming/text/convert/enwiki/mds/pick_eval_samples.py b/streaming/examples/text/enwiki_tok/mds/pick_eval_samples.py similarity index 100% rename from streaming/text/convert/enwiki/mds/pick_eval_samples.py rename to streaming/examples/text/enwiki_tok/mds/pick_eval_samples.py diff --git a/streaming/text/convert/enwiki/mds/tokenization.py b/streaming/examples/text/enwiki_tok/mds/tokenization.py similarity index 100% rename from streaming/text/convert/enwiki/mds/tokenization.py rename to streaming/examples/text/enwiki_tok/mds/tokenization.py diff --git a/streaming/text/convert/enwiki/mds/vocab.txt b/streaming/examples/text/enwiki_tok/mds/vocab.txt similarity index 100% rename from streaming/text/convert/enwiki/mds/vocab.txt rename to streaming/examples/text/enwiki_tok/mds/vocab.txt diff --git a/streaming/examples/text/enwiki_tok/tfrecord/__init__.py b/streaming/examples/text/enwiki_tok/tfrecord/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/streaming/text/convert/enwiki/tfrecord/count_samples.py b/streaming/examples/text/enwiki_tok/tfrecord/count_samples.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/count_samples.py rename to streaming/examples/text/enwiki_tok/tfrecord/count_samples.py diff --git a/streaming/text/convert/enwiki/tfrecord/create_pretraining_data.py b/streaming/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/create_pretraining_data.py rename to streaming/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py diff --git a/streaming/text/convert/enwiki/tfrecord/make_eval.sh b/streaming/examples/text/enwiki_tok/tfrecord/make_eval.sh similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_eval.sh rename to streaming/examples/text/enwiki_tok/tfrecord/make_eval.sh diff --git a/streaming/text/convert/enwiki/tfrecord/make_train.sh b/streaming/examples/text/enwiki_tok/tfrecord/make_train.sh similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_train.sh rename to streaming/examples/text/enwiki_tok/tfrecord/make_train.sh diff --git a/streaming/text/convert/enwiki/tfrecord/make_train_parallel.py b/streaming/examples/text/enwiki_tok/tfrecord/make_train_parallel.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_train_parallel.py rename to streaming/examples/text/enwiki_tok/tfrecord/make_train_parallel.py diff --git a/streaming/text/convert/enwiki/tfrecord/pick_eval_samples.py b/streaming/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/pick_eval_samples.py rename to streaming/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py diff --git a/streaming/text/convert/enwiki/tfrecord/tokenization.py b/streaming/examples/text/enwiki_tok/tfrecord/tokenization.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/tokenization.py rename to streaming/examples/text/enwiki_tok/tfrecord/tokenization.py diff --git a/streaming/text/convert/enwiki/tfrecord/vocab.txt b/streaming/examples/text/enwiki_tok/tfrecord/vocab.txt similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/vocab.txt rename to streaming/examples/text/enwiki_tok/tfrecord/vocab.txt diff --git a/streaming/examples/text/enwiki_txt/README.md b/streaming/examples/text/enwiki_txt/README.md new file mode 100644 index 000000000..cf0721e2e --- /dev/null +++ b/streaming/examples/text/enwiki_txt/README.md @@ -0,0 +1,26 @@ +### [Wikipedia](https://huggingface.co/datasets/wikipedia) + +1. Download English Wikipedia 2020-01-01 from [here](https://drive.google.com/drive/folders/1cywmDnAsrP5-2vsr8GDc6QUc7VWe-M3v). +2. Unzip the file `results_text.zip` as shown below. + + ```bash + unzip results_text.zip + ``` + + Listing the output should show the following directory structure: + + ```bash + ├── eval.txt + ├── part-00000-of-00500 + ├── part-00001-of-00500 + ├── part-00002-of-00500 + ├── ..... + ├── part-00498-of-00500 + └── part-00499-of-00500 + ``` + +3. Run the [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) script. The script converts the `train` and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) and modify as necessary. + + ``` + python enwiki_text.py --in_root --out_root + ``` diff --git a/streaming/text/enwiki.py b/streaming/examples/text/enwiki_txt/enwiki.py similarity index 100% rename from streaming/text/enwiki.py rename to streaming/examples/text/enwiki_txt/enwiki.py diff --git a/streaming/text/convert/enwiki_text.py b/streaming/examples/text/enwiki_txt/write.py similarity index 100% rename from streaming/text/convert/enwiki_text.py rename to streaming/examples/text/enwiki_txt/write.py diff --git a/streaming/examples/text/pile/README.md b/streaming/examples/text/pile/README.md new file mode 100644 index 000000000..ec2301e61 --- /dev/null +++ b/streaming/examples/text/pile/README.md @@ -0,0 +1,19 @@ +### [Pile](https://pile.eleuther.ai/) + +1. Download the Pile dataset from [here](https://the-eye.eu/public/AI/pile/). + + Listing the output should show the following directory structure: + + ```bash + ├── SHA256SUMS.txt + ├── test.jsonl.zst + ├── train + │   ├── 00.jsonl.zst + │   ├── 01.jsonl.zst + │   ├── 02.jsonl.zst + │   ├── 03.jsonl.zst + │   ├── ..... + │   ├── 28.jsonl.zst + │   └── 29.jsonl.zst + └── val.jsonl.zst + ``` diff --git a/streaming/text/pile.py b/streaming/examples/text/pile/read.py similarity index 100% rename from streaming/text/pile.py rename to streaming/examples/text/pile/read.py diff --git a/streaming/text/convert/pile.py b/streaming/examples/text/pile/write.py similarity index 100% rename from streaming/text/convert/pile.py rename to streaming/examples/text/pile/write.py diff --git a/streaming/examples/vision/ade20k/README.md b/streaming/examples/vision/ade20k/README.md new file mode 100644 index 000000000..07af2b898 --- /dev/null +++ b/streaming/examples/vision/ade20k/README.md @@ -0,0 +1,19 @@ +### [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) + +1. Download the ADE20K dataset from [here](https://groups.csail.mit.edu/vision/datasets/ADE20K/). +2. Listing the output should show the following directory structure: + + ```bash + ├── annotations + │ ├── training + │ └── validation + └── images + ├── training + └── validation + ``` + +3. Run the [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please see the supported arguments for [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) and modify according as necessary. + + ``` + python ade20k.py --in_root --out_root + ``` diff --git a/streaming/vision/ade20k.py b/streaming/examples/vision/ade20k/read.py similarity index 100% rename from streaming/vision/ade20k.py rename to streaming/examples/vision/ade20k/read.py diff --git a/streaming/vision/convert/ade20k.py b/streaming/examples/vision/ade20k/write.py similarity index 100% rename from streaming/vision/convert/ade20k.py rename to streaming/examples/vision/ade20k/write.py diff --git a/streaming/examples/vision/cifar10/README.md b/streaming/examples/vision/cifar10/README.md new file mode 100644 index 000000000..7f12df567 --- /dev/null +++ b/streaming/examples/vision/cifar10/README.md @@ -0,0 +1,7 @@ +### [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) + +1. Run the [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) script as shown below. The CIFAR10 dataset will be automatically downloaded if it doesn't exist locally. For advanced use cases, please see the supported arguments for [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) and modify as necessary. + + ``` + python cifar10.py --in_root --out_root + ``` diff --git a/streaming/vision/cifar10.py b/streaming/examples/vision/cifar10/read.py similarity index 100% rename from streaming/vision/cifar10.py rename to streaming/examples/vision/cifar10/read.py diff --git a/streaming/vision/convert/cifar10.py b/streaming/examples/vision/cifar10/write.py similarity index 100% rename from streaming/vision/convert/cifar10.py rename to streaming/examples/vision/cifar10/write.py diff --git a/streaming/vision/convert/fake_cifar10.py b/streaming/examples/vision/cifar10/write_fake.py similarity index 100% rename from streaming/vision/convert/fake_cifar10.py rename to streaming/examples/vision/cifar10/write_fake.py diff --git a/streaming/examples/vision/coco/README.md b/streaming/examples/vision/coco/README.md new file mode 100644 index 000000000..3e98d3880 --- /dev/null +++ b/streaming/examples/vision/coco/README.md @@ -0,0 +1,38 @@ +### [MS-COCO](https://cocodataset.org/#home) + +1. Download the COCO 2017 dataset from [here](https://cocodataset.org/#download). Please download both the COCO images and annotations and unzip the files as shown below. + + ```bash + mkdir coco + wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip + wget -c http://images.cocodataset.org/zips/train2017.zip + wget -c http://images.cocodataset.org/zips/val2017.zip + + unzip annotations_trainval2017.zip + unzip train2017.zip + unzip val2017.zip + + rm annotations_trainval2017.zip + rm train2017.zip + rm val2017.zip + ``` + + Listing the output should show the following directory structure: + + ```bash + ├── annotations + │ ├── instances_train2017.json + │ └── instances_val2017.json + ├── train2017 + │ ├── 000000391895.jpg + | |── ... + └── val2017 + │ ├── 000000000139.jpg + | |── ... + ``` + +2. Run the [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please seet the supported arguments for [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) and modify as necessary. + + ``` + python coco.py --in_root --out_root + ``` diff --git a/streaming/vision/coco.py b/streaming/examples/vision/coco/read.py similarity index 100% rename from streaming/vision/coco.py rename to streaming/examples/vision/coco/read.py diff --git a/streaming/vision/convert/coco.py b/streaming/examples/vision/coco/write.py similarity index 100% rename from streaming/vision/convert/coco.py rename to streaming/examples/vision/coco/write.py diff --git a/streaming/examples/vision/imagenet/README.md b/streaming/examples/vision/imagenet/README.md new file mode 100644 index 000000000..ce27b258d --- /dev/null +++ b/streaming/examples/vision/imagenet/README.md @@ -0,0 +1,38 @@ +### [ImageNet](https://www.image-net.org/) + +1. Download the ImageNet dataset from [here](https://image-net.org/download.php). Two files are needed, `ILSVRC2012_img_train.tar` for training and `ILSVRC2012_img_val.tar` for validation. Next untar both the files as shown below. + + ```bash + mkdir val + mv ILSVRC2012_img_val.tar val/ + tar -xvf ILSVRC2012_img_val.tar -C val/ + rm ILSVRC2012_img_val.tar + + mkdir train + mv ILSVRC2012_img_train.tar train/ + tar -xvf ILSVRC2012_img_train.tar -C train/ + rm ILSVRC2012_img_train.tar + ``` + + Listing the output should show the following directory structure: + + ```bash + ├── train/ + ├── n01440764 + │ ├── n01440764_10026.JPEG + │ ├── n01440764_10027.JPEG + │ ├── ...... + ├── ...... + ├── val/ + ├── n01440764 + │ ├── ILSVRC2012_val_00000293.JPEG + │ ├── ILSVRC2012_val_00002138.JPEG + │ ├── ...... + ├── ...... + ``` + +2. Run the [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced uses cases, please see the supported arguments for [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) and modify as needed. + + ``` + python imagenet.py --in_root --out_root + ``` diff --git a/streaming/vision/imagenet.py b/streaming/examples/vision/imagenet/read.py similarity index 100% rename from streaming/vision/imagenet.py rename to streaming/examples/vision/imagenet/read.py diff --git a/streaming/vision/convert/imagenet.py b/streaming/examples/vision/imagenet/write.py similarity index 100% rename from streaming/vision/convert/imagenet.py rename to streaming/examples/vision/imagenet/write.py diff --git a/streaming/multimodal/__init__.py b/streaming/multimodal/__init__.py deleted file mode 100644 index cac23533f..000000000 --- a/streaming/multimodal/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported multimodal datasets.""" - -from streaming.multimodal.webvid import StreamingInsideWebVid as StreamingInsideWebVid -from streaming.multimodal.webvid import StreamingOutsideDTWebVid as StreamingOutsideDTWebVid -from streaming.multimodal.webvid import StreamingOutsideGIWebVid as StreamingOutsideGIWebVid diff --git a/streaming/multimodal/convert/laion/__init__.py b/streaming/multimodal/convert/laion/__init__.py deleted file mode 100644 index dc40547ef..000000000 --- a/streaming/multimodal/convert/laion/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""LAION dataset creation.""" diff --git a/streaming/text/__init__.py b/streaming/text/__init__.py deleted file mode 100644 index 0452f4430..000000000 --- a/streaming/text/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported NLP datasets.""" - -from streaming.text.c4 import StreamingC4 as StreamingC4 -from streaming.text.enwiki import StreamingEnWiki as StreamingEnWiki -from streaming.text.pile import StreamingPile as StreamingPile - -__all__ = ['StreamingPile', 'StreamingC4', 'StreamingEnWiki'] diff --git a/streaming/text/convert/README.md b/streaming/text/convert/README.md deleted file mode 100644 index 029ddae09..000000000 --- a/streaming/text/convert/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# Dataset preparation - -To use Streaming Dataset we must first convert the dataset from its native format to MosaicML's Streaming Dataset format called Mosaic Dataset Shard (MDS). Once in MDS format, we can access the dataset from the local file system (disk network attached storage, etc.) or object store (GCS, OCS, S3, etc.). From object store, data can be streamed to train deep learning models and it all just works efficiently. - -Check out steps below for information on converting common NLP datasets to MDS format. Please see [MDSWriter()](https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html) parameters for details on advanced usage. - -## NLP Dataset Conversion Examples - -### [C4: Colossal, Cleaned, Common Crawl dataset](https://huggingface.co/datasets/c4) - -1. Run the [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) script as shown below. The script downloads the raw format with `train` and `val` splits from HuggingFace hub and converts to StreamingDataset MDS format into their own split directories. For more advanced use cases, please see the supported arguments for [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) and modify as necessary. - - ``` - python c4.py --out_root - ``` - -### [Wikipedia](https://huggingface.co/datasets/wikipedia) - -1. Download English Wikipedia 2020-01-01 from [here](https://drive.google.com/drive/folders/1cywmDnAsrP5-2vsr8GDc6QUc7VWe-M3v). -2. Unzip the file `results_text.zip` as shown below. - - ```bash - unzip results_text.zip - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── eval.txt - ├── part-00000-of-00500 - ├── part-00001-of-00500 - ├── part-00002-of-00500 - ├── ..... - ├── part-00498-of-00500 - └── part-00499-of-00500 - ``` - -3. Run the [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) script. The script converts the `train` and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) and modify as necessary. - - ``` - python enwiki_text.py --in_root --out_root - ``` - -### [Pile](https://pile.eleuther.ai/) - -1. Download the Pile dataset from [here](https://the-eye.eu/public/AI/pile/). - - Listing the output should show the following directory structure: - - ```bash - ├── SHA256SUMS.txt - ├── test.jsonl.zst - ├── train - │   ├── 00.jsonl.zst - │   ├── 01.jsonl.zst - │   ├── 02.jsonl.zst - │   ├── 03.jsonl.zst - │   ├── ..... - │   ├── 28.jsonl.zst - │   └── 29.jsonl.zst - └── val.jsonl.zst - ``` - -2. Run the [pile.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) script. The script converts the `train`, `test`, and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [pile.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) and modify as necessary. - - - ```bash - python pile.py --in_root --out_root - ``` diff --git a/streaming/text/convert/__init__.py b/streaming/text/convert/__init__.py deleted file mode 100644 index a807b9660..000000000 --- a/streaming/text/convert/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Data conversion scripts for Natural Language Processing.""" diff --git a/streaming/vision/__init__.py b/streaming/vision/__init__.py deleted file mode 100644 index f7ceab7b7..000000000 --- a/streaming/vision/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported CV datasets.""" - -from streaming.vision.ade20k import StreamingADE20K as StreamingADE20K -from streaming.vision.cifar10 import StreamingCIFAR10 as StreamingCIFAR10 -from streaming.vision.coco import StreamingCOCO as StreamingCOCO -from streaming.vision.imagenet import StreamingImageNet as StreamingImageNet - -__all__ = ['StreamingADE20K', 'StreamingCIFAR10', 'StreamingCOCO', 'StreamingImageNet'] diff --git a/streaming/vision/convert/README.md b/streaming/vision/convert/README.md deleted file mode 100644 index 58eda5148..000000000 --- a/streaming/vision/convert/README.md +++ /dev/null @@ -1,113 +0,0 @@ -# Dataset preparation - -To use Streaming Dataset we must first convert the dataset from its native format to MosaicML's Streaming Dataset format called Mosaic Dataset Shard (MDS). Once in MDS format, we can access the dataset from the local file system (disk network attached storage, etc.) or object store (GCS, OCS, S3, etc.). From object store, data can be streamed to train deep learning models and it all just works efficiently. - -Check out steps below for information on converting common Computer Vision datasets to MDS format. Please see [MDSWriter()](https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html) parameters for details on advanced usage. - -## Vision Datasets Conversion Examples - -### [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) - -1. Download the ADE20K dataset from [here](https://groups.csail.mit.edu/vision/datasets/ADE20K/). -2. Listing the output should show the following directory structure: - - ```bash - ├── annotations - │ ├── training - │ └── validation - └── images - ├── training - └── validation - ``` - -3. Run the [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please see the supported arguments for [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) and modify according as necessary. - - ``` - python ade20k.py --in_root --out_root - ``` - -### [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) - -1. Run the [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) script as shown below. The CIFAR10 dataset will be automatically downloaded if it doesn't exist locally. For advanced use cases, please see the supported arguments for [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) and modify as necessary. - - ``` - python cifar10.py --in_root --out_root - ``` - -### [MS-COCO](https://cocodataset.org/#home) - -1. Download the COCO 2017 dataset from [here](https://cocodataset.org/#download). Please download both the COCO images and annotations and unzip the files as shown below. - - ```bash - mkdir coco - wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip - wget -c http://images.cocodataset.org/zips/train2017.zip - wget -c http://images.cocodataset.org/zips/val2017.zip - - unzip annotations_trainval2017.zip - unzip train2017.zip - unzip val2017.zip - - rm annotations_trainval2017.zip - rm train2017.zip - rm val2017.zip - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── annotations - │ ├── instances_train2017.json - │ └── instances_val2017.json - ├── train2017 - │ ├── 000000391895.jpg - | |── ... - └── val2017 - │ ├── 000000000139.jpg - | |── ... - ``` - -2. Run the [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please seet the supported arguments for [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) and modify as necessary. - - ``` - python coco.py --in_root --out_root - ``` - -### [ImageNet](https://www.image-net.org/) - -1. Download the ImageNet dataset from [here](https://image-net.org/download.php). Two files are needed, `ILSVRC2012_img_train.tar` for training and `ILSVRC2012_img_val.tar` for validation. Next untar both the files as shown below. - - ```bash - mkdir val - mv ILSVRC2012_img_val.tar val/ - tar -xvf ILSVRC2012_img_val.tar -C val/ - rm ILSVRC2012_img_val.tar - - mkdir train - mv ILSVRC2012_img_train.tar train/ - tar -xvf ILSVRC2012_img_train.tar -C train/ - rm ILSVRC2012_img_train.tar - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── train/ - ├── n01440764 - │ ├── n01440764_10026.JPEG - │ ├── n01440764_10027.JPEG - │ ├── ...... - ├── ...... - ├── val/ - ├── n01440764 - │ ├── ILSVRC2012_val_00000293.JPEG - │ ├── ILSVRC2012_val_00002138.JPEG - │ ├── ...... - ├── ...... - ``` - -2. Run the [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced uses cases, please see the supported arguments for [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) and modify as needed. - - ``` - python imagenet.py --in_root --out_root - ``` diff --git a/streaming/vision/convert/__init__.py b/streaming/vision/convert/__init__.py deleted file mode 100644 index fcea5a2a2..000000000 --- a/streaming/vision/convert/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Data conversion scripts for Computer Vision.""" diff --git a/streaming/vision/convert/base.py b/streaming/vision/convert/base.py deleted file mode 100644 index 5194816fd..000000000 --- a/streaming/vision/convert/base.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Utility and helper functions to convert CV datasets.""" - -import os -from typing import List, Optional - -import numpy as np -from torch.utils.data import Dataset -from tqdm import tqdm - -from streaming.base import MDSWriter - - -def convert_image_class_dataset(dataset: Dataset, - out_root: str, - split: Optional[str] = None, - compression: Optional[str] = None, - hashes: Optional[List[str]] = None, - size_limit: int = 1 << 24, - progress_bar: bool = True, - leave: bool = False, - encoding: str = 'pil') -> None: - """Convert an image classification Dataset. - - Args: - dataset (Dataset): The dataset object to convert. - out_root (str): Output directory where shards are cached by split. - remote (str, optional): Remote dataset directory where shards are uploaded by split. - split (str, optional): Which dataset split to use, if any. Defaults to ``None``. - compression (str, optional): Optional compression. Defaults to ``None``. - hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. - Defaults to ``None``. - size_limit (int): Uncompressed shard size limit, at which point it flushes the shard and - starts a new one. Defaults to ``1 << 26``. - progress_bar (bool): Whether to display a progress bar while converting. - Defaults to ``True``. - leave (bool): Whether to leave the progress bar in the console when done. Defaults to - ``False``. - encoding (str): MDS encoding to use for the image data. Defaults to ``pil``. - """ - split = split or '' - columns = { - 'i': 'int', - 'x': encoding, - 'y': 'int', - } - hashes = hashes or [] - indices = np.random.permutation(len(dataset)).tolist() # pyright: ignore - if progress_bar: - indices = tqdm(indices, leave=leave) - - out_split_dir = os.path.join(out_root, split) - - with MDSWriter(out=out_split_dir, - columns=columns, - compression=compression, - hashes=hashes, - size_limit=size_limit, - progress_bar=progress_bar) as out: - for i in indices: - x, y = dataset[i] - out.write({ - 'i': i, - 'x': x, - 'y': y, - }) diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 7e3dd7fc9..206dd10cd 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -8,8 +8,6 @@ import pytest from streaming.base import StreamingDataset -from streaming.text import StreamingC4 -from streaming.vision import StreamingADE20K, StreamingCIFAR10, StreamingCOCO, StreamingImageNet def get_dataset(name: str, @@ -20,55 +18,6 @@ def get_dataset(name: str, other_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[int, StreamingDataset]: other_kwargs = {} if other_kwargs is None else other_kwargs dataset_map = { - 'ade20k': { - 'remote': 's3://mosaicml-internal-dataset-ade20k/mds/2/', - 'num_samples': { - 'train': 20206, - 'val': 2000, - }, - 'class': StreamingADE20K, - 'kwargs': {}, - }, - 'imagenet1k': { - 'remote': 's3://mosaicml-internal-dataset-imagenet1k/mds/2/', - 'num_samples': { - 'train': 1281167, - 'val': 50000, - }, - 'class': StreamingImageNet, - 'kwargs': {}, - }, - 'coco': { - 'remote': 's3://mosaicml-internal-dataset-coco/mds/2/', - 'num_samples': { - 'train': 117266, - 'val': 4952, - }, - 'class': StreamingCOCO, - 'kwargs': {}, - }, - 'c4': { - 'remote': 's3://mosaicml-internal-dataset-c4/mds/2/', - 'num_samples': { - 'train': 364868892, - 'val': 364608, - }, - 'class': StreamingC4, - 'kwargs': { - 'tokenizer_name': 'bert-base-uncased', - 'max_seq_len': 512, - 'group_method': 'truncate' - }, - }, - 'cifar10': { - 'remote': 's3://mosaicml-internal-dataset-cifar10/mds/2/', - 'num_samples': { - 'train': 50000, - 'val': 10000, - }, - 'class': StreamingCIFAR10, - 'kwargs': {}, - }, 'test_streaming_upload': { 'remote': 's3://streaming-upload-test-bucket/', 'num_samples': { From da6f4afdcaeab1d2aecec412b75251119bb56a70 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 07:32:18 -0700 Subject: [PATCH 16/45] Fix. --- streaming/examples/multimodal/convert/__init__.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 streaming/examples/multimodal/convert/__init__.py diff --git a/streaming/examples/multimodal/convert/__init__.py b/streaming/examples/multimodal/convert/__init__.py deleted file mode 100644 index 36f008387..000000000 --- a/streaming/examples/multimodal/convert/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Dataset conversion for natively supported multimodal datasets.""" From b0fa3d7bc6d2d04abe48ce22a887a5d84ac683af Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 07:35:06 -0700 Subject: [PATCH 17/45] Move benchmarks up and out. --- {streaming/benchmarks => benchmarks}/backends/__init__.py | 0 .../benchmarks => benchmarks}/backends/generate_datasets.py | 0 {streaming/benchmarks => benchmarks}/backends/task.py | 0 {streaming/benchmarks => benchmarks}/compression/bench.py | 0 {streaming/benchmarks => benchmarks}/compression/plot.py | 0 {streaming/benchmarks => benchmarks}/epoch/bench.py | 0 {streaming/benchmarks => benchmarks}/hashing/bench.py | 0 {streaming/benchmarks => benchmarks}/hashing/plot.py | 0 {streaming/benchmarks => benchmarks}/partitioning/bench.py | 0 {streaming/benchmarks => benchmarks}/partitioning/diff.py | 0 {streaming/benchmarks => benchmarks}/partitioning/plot.py | 0 {streaming/benchmarks => benchmarks}/partitioning/txt.py | 0 {streaming/benchmarks => benchmarks}/partitioning/web.py | 0 {streaming/benchmarks => benchmarks}/samples/bench_and_plot.py | 0 {streaming/benchmarks => benchmarks}/serialization/compare.py | 0 .../serialization/survey_fixed_decimals.py | 0 {streaming/benchmarks => benchmarks}/shuffling/bench.py | 0 {streaming/benchmarks => benchmarks}/shuffling/plot.py | 0 {streaming/benchmarks => benchmarks}/shuffling/vis.py | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename {streaming/benchmarks => benchmarks}/backends/__init__.py (100%) rename {streaming/benchmarks => benchmarks}/backends/generate_datasets.py (100%) rename {streaming/benchmarks => benchmarks}/backends/task.py (100%) rename {streaming/benchmarks => benchmarks}/compression/bench.py (100%) rename {streaming/benchmarks => benchmarks}/compression/plot.py (100%) rename {streaming/benchmarks => benchmarks}/epoch/bench.py (100%) rename {streaming/benchmarks => benchmarks}/hashing/bench.py (100%) rename {streaming/benchmarks => benchmarks}/hashing/plot.py (100%) rename {streaming/benchmarks => benchmarks}/partitioning/bench.py (100%) rename {streaming/benchmarks => benchmarks}/partitioning/diff.py (100%) rename {streaming/benchmarks => benchmarks}/partitioning/plot.py (100%) rename {streaming/benchmarks => benchmarks}/partitioning/txt.py (100%) rename {streaming/benchmarks => benchmarks}/partitioning/web.py (100%) rename {streaming/benchmarks => benchmarks}/samples/bench_and_plot.py (100%) rename {streaming/benchmarks => benchmarks}/serialization/compare.py (100%) rename {streaming/benchmarks => benchmarks}/serialization/survey_fixed_decimals.py (100%) rename {streaming/benchmarks => benchmarks}/shuffling/bench.py (100%) rename {streaming/benchmarks => benchmarks}/shuffling/plot.py (100%) rename {streaming/benchmarks => benchmarks}/shuffling/vis.py (100%) diff --git a/streaming/benchmarks/backends/__init__.py b/benchmarks/backends/__init__.py similarity index 100% rename from streaming/benchmarks/backends/__init__.py rename to benchmarks/backends/__init__.py diff --git a/streaming/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py similarity index 100% rename from streaming/benchmarks/backends/generate_datasets.py rename to benchmarks/backends/generate_datasets.py diff --git a/streaming/benchmarks/backends/task.py b/benchmarks/backends/task.py similarity index 100% rename from streaming/benchmarks/backends/task.py rename to benchmarks/backends/task.py diff --git a/streaming/benchmarks/compression/bench.py b/benchmarks/compression/bench.py similarity index 100% rename from streaming/benchmarks/compression/bench.py rename to benchmarks/compression/bench.py diff --git a/streaming/benchmarks/compression/plot.py b/benchmarks/compression/plot.py similarity index 100% rename from streaming/benchmarks/compression/plot.py rename to benchmarks/compression/plot.py diff --git a/streaming/benchmarks/epoch/bench.py b/benchmarks/epoch/bench.py similarity index 100% rename from streaming/benchmarks/epoch/bench.py rename to benchmarks/epoch/bench.py diff --git a/streaming/benchmarks/hashing/bench.py b/benchmarks/hashing/bench.py similarity index 100% rename from streaming/benchmarks/hashing/bench.py rename to benchmarks/hashing/bench.py diff --git a/streaming/benchmarks/hashing/plot.py b/benchmarks/hashing/plot.py similarity index 100% rename from streaming/benchmarks/hashing/plot.py rename to benchmarks/hashing/plot.py diff --git a/streaming/benchmarks/partitioning/bench.py b/benchmarks/partitioning/bench.py similarity index 100% rename from streaming/benchmarks/partitioning/bench.py rename to benchmarks/partitioning/bench.py diff --git a/streaming/benchmarks/partitioning/diff.py b/benchmarks/partitioning/diff.py similarity index 100% rename from streaming/benchmarks/partitioning/diff.py rename to benchmarks/partitioning/diff.py diff --git a/streaming/benchmarks/partitioning/plot.py b/benchmarks/partitioning/plot.py similarity index 100% rename from streaming/benchmarks/partitioning/plot.py rename to benchmarks/partitioning/plot.py diff --git a/streaming/benchmarks/partitioning/txt.py b/benchmarks/partitioning/txt.py similarity index 100% rename from streaming/benchmarks/partitioning/txt.py rename to benchmarks/partitioning/txt.py diff --git a/streaming/benchmarks/partitioning/web.py b/benchmarks/partitioning/web.py similarity index 100% rename from streaming/benchmarks/partitioning/web.py rename to benchmarks/partitioning/web.py diff --git a/streaming/benchmarks/samples/bench_and_plot.py b/benchmarks/samples/bench_and_plot.py similarity index 100% rename from streaming/benchmarks/samples/bench_and_plot.py rename to benchmarks/samples/bench_and_plot.py diff --git a/streaming/benchmarks/serialization/compare.py b/benchmarks/serialization/compare.py similarity index 100% rename from streaming/benchmarks/serialization/compare.py rename to benchmarks/serialization/compare.py diff --git a/streaming/benchmarks/serialization/survey_fixed_decimals.py b/benchmarks/serialization/survey_fixed_decimals.py similarity index 100% rename from streaming/benchmarks/serialization/survey_fixed_decimals.py rename to benchmarks/serialization/survey_fixed_decimals.py diff --git a/streaming/benchmarks/shuffling/bench.py b/benchmarks/shuffling/bench.py similarity index 100% rename from streaming/benchmarks/shuffling/bench.py rename to benchmarks/shuffling/bench.py diff --git a/streaming/benchmarks/shuffling/plot.py b/benchmarks/shuffling/plot.py similarity index 100% rename from streaming/benchmarks/shuffling/plot.py rename to benchmarks/shuffling/plot.py diff --git a/streaming/benchmarks/shuffling/vis.py b/benchmarks/shuffling/vis.py similarity index 100% rename from streaming/benchmarks/shuffling/vis.py rename to benchmarks/shuffling/vis.py From 4a226389eb8c3b8262671ff2dadab6282a76cd8c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 07:39:27 -0700 Subject: [PATCH 18/45] Fix. --- pyproject.toml | 2 +- streaming/base/util/pretty.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 77e597e05..a32957f39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ include = [ exclude = [ "build/**", "node_modules/**", - "streaming/text/convert/enwiki/**", + "streaming/examples/text/enwiki_tok/**", "docs/source/conf.py" ] diff --git a/streaming/base/util/pretty.py b/streaming/base/util/pretty.py index f588ce207..afc4631c1 100644 --- a/streaming/base/util/pretty.py +++ b/streaming/base/util/pretty.py @@ -7,8 +7,8 @@ from typing import Dict, List, Union __all__ = [ - 'unpack_strs', 'unpack_str2str', 'normalize_dec_bytes', 'normalize_bin_bytes', 'normalize_bytes', - 'normalize_count', 'normalize_duration' + 'unpack_strs', 'unpack_str2str', 'normalize_dec_bytes', 'normalize_bin_bytes', + 'normalize_bytes', 'normalize_count', 'normalize_duration' ] From cb808652489b002e24268cfeda310d2d6f9a4c72 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 08:07:02 -0700 Subject: [PATCH 19/45] Now, rename streaming/base/... -> streaming/.... --- streaming/__init__.py | 25 ++++++------------- streaming/{base => }/array.py | 0 streaming/base/__init__.py | 15 ----------- streaming/{base => }/batching/__init__.py | 0 streaming/{base => }/batching/per_stream.py | 0 streaming/{base => }/batching/random.py | 0 streaming/{base => }/batching/stratified.py | 0 streaming/{base => }/cli/index_parquet.py | 0 streaming/{base => }/compression.py | 0 streaming/{base => }/constant.py | 0 streaming/{base => }/converters/README.md | 0 streaming/{base => }/converters/__init__.py | 0 .../{base => }/converters/dataframe_to_mds.py | 0 streaming/{base => }/dataloader.py | 0 streaming/{base => }/dataset.py | 0 streaming/{base => }/distributed.py | 0 streaming/{base => }/format/__init__.py | 0 streaming/{base => }/format/base/__init__.py | 0 streaming/{base => }/format/base/reader.py | 0 streaming/{base => }/format/base/writer.py | 0 streaming/{base => }/format/delta/__init__.py | 0 streaming/{base => }/format/delta/indexing.py | 0 streaming/{base => }/format/index.py | 0 streaming/{base => }/format/json/README.md | 0 streaming/{base => }/format/json/__init__.py | 0 streaming/{base => }/format/json/encodings.py | 0 streaming/{base => }/format/json/reader.py | 0 streaming/{base => }/format/json/writer.py | 0 streaming/{base => }/format/lance/__init__.py | 0 streaming/{base => }/format/lance/indexing.py | 0 streaming/{base => }/format/mds/README.md | 0 streaming/{base => }/format/mds/__init__.py | 0 streaming/{base => }/format/mds/encodings.py | 0 streaming/{base => }/format/mds/reader.py | 0 streaming/{base => }/format/mds/writer.py | 0 .../{base => }/format/parquet/__init__.py | 0 .../{base => }/format/parquet/indexing.py | 0 streaming/{base => }/format/xsv/README.md | 0 streaming/{base => }/format/xsv/__init__.py | 0 streaming/{base => }/format/xsv/encodings.py | 0 streaming/{base => }/format/xsv/reader.py | 0 streaming/{base => }/format/xsv/writer.py | 0 streaming/{base => }/hashing.py | 0 streaming/{base => }/local.py | 0 streaming/{base => }/partition/__init__.py | 0 streaming/{base => }/partition/orig.py | 0 streaming/{base => }/partition/relaxed.py | 0 streaming/{base => }/sampling.py | 0 streaming/{base => }/shared/__init__.py | 0 streaming/{base => }/shared/array.py | 0 streaming/{base => }/shared/barrier.py | 0 streaming/{base => }/shared/memory.py | 0 streaming/{base => }/shared/prefix.py | 0 streaming/{base => }/shared/scalar.py | 0 streaming/{base => }/shuffle/__init__.py | 0 streaming/{base => }/shuffle/naive.py | 0 streaming/{base => }/shuffle/py1b.py | 0 streaming/{base => }/shuffle/py1br.py | 0 streaming/{base => }/shuffle/py1e.py | 0 streaming/{base => }/shuffle/py1s.py | 0 streaming/{base => }/shuffle/py2s.py | 0 streaming/{base => }/spanner.py | 0 streaming/{base => }/storage/__init__.py | 0 streaming/{base => }/storage/download.py | 0 streaming/{base => }/storage/extra.py | 0 streaming/{base => }/storage/upload.py | 0 streaming/{base => }/stream.py | 0 streaming/{base => }/util/__init__.py | 0 streaming/{base => }/util/importing.py | 0 streaming/{base => }/util/merging.py | 0 streaming/{base => }/util/pretty.py | 0 streaming/{base => }/util/retrying.py | 0 streaming/{base => }/util/shared.py | 0 streaming/{base => }/vision.py | 0 streaming/{base => }/world.py | 0 75 files changed, 7 insertions(+), 33 deletions(-) rename streaming/{base => }/array.py (100%) delete mode 100644 streaming/base/__init__.py rename streaming/{base => }/batching/__init__.py (100%) rename streaming/{base => }/batching/per_stream.py (100%) rename streaming/{base => }/batching/random.py (100%) rename streaming/{base => }/batching/stratified.py (100%) rename streaming/{base => }/cli/index_parquet.py (100%) rename streaming/{base => }/compression.py (100%) rename streaming/{base => }/constant.py (100%) rename streaming/{base => }/converters/README.md (100%) rename streaming/{base => }/converters/__init__.py (100%) rename streaming/{base => }/converters/dataframe_to_mds.py (100%) rename streaming/{base => }/dataloader.py (100%) rename streaming/{base => }/dataset.py (100%) rename streaming/{base => }/distributed.py (100%) rename streaming/{base => }/format/__init__.py (100%) rename streaming/{base => }/format/base/__init__.py (100%) rename streaming/{base => }/format/base/reader.py (100%) rename streaming/{base => }/format/base/writer.py (100%) rename streaming/{base => }/format/delta/__init__.py (100%) rename streaming/{base => }/format/delta/indexing.py (100%) rename streaming/{base => }/format/index.py (100%) rename streaming/{base => }/format/json/README.md (100%) rename streaming/{base => }/format/json/__init__.py (100%) rename streaming/{base => }/format/json/encodings.py (100%) rename streaming/{base => }/format/json/reader.py (100%) rename streaming/{base => }/format/json/writer.py (100%) rename streaming/{base => }/format/lance/__init__.py (100%) rename streaming/{base => }/format/lance/indexing.py (100%) rename streaming/{base => }/format/mds/README.md (100%) rename streaming/{base => }/format/mds/__init__.py (100%) rename streaming/{base => }/format/mds/encodings.py (100%) rename streaming/{base => }/format/mds/reader.py (100%) rename streaming/{base => }/format/mds/writer.py (100%) rename streaming/{base => }/format/parquet/__init__.py (100%) rename streaming/{base => }/format/parquet/indexing.py (100%) rename streaming/{base => }/format/xsv/README.md (100%) rename streaming/{base => }/format/xsv/__init__.py (100%) rename streaming/{base => }/format/xsv/encodings.py (100%) rename streaming/{base => }/format/xsv/reader.py (100%) rename streaming/{base => }/format/xsv/writer.py (100%) rename streaming/{base => }/hashing.py (100%) rename streaming/{base => }/local.py (100%) rename streaming/{base => }/partition/__init__.py (100%) rename streaming/{base => }/partition/orig.py (100%) rename streaming/{base => }/partition/relaxed.py (100%) rename streaming/{base => }/sampling.py (100%) rename streaming/{base => }/shared/__init__.py (100%) rename streaming/{base => }/shared/array.py (100%) rename streaming/{base => }/shared/barrier.py (100%) rename streaming/{base => }/shared/memory.py (100%) rename streaming/{base => }/shared/prefix.py (100%) rename streaming/{base => }/shared/scalar.py (100%) rename streaming/{base => }/shuffle/__init__.py (100%) rename streaming/{base => }/shuffle/naive.py (100%) rename streaming/{base => }/shuffle/py1b.py (100%) rename streaming/{base => }/shuffle/py1br.py (100%) rename streaming/{base => }/shuffle/py1e.py (100%) rename streaming/{base => }/shuffle/py1s.py (100%) rename streaming/{base => }/shuffle/py2s.py (100%) rename streaming/{base => }/spanner.py (100%) rename streaming/{base => }/storage/__init__.py (100%) rename streaming/{base => }/storage/download.py (100%) rename streaming/{base => }/storage/extra.py (100%) rename streaming/{base => }/storage/upload.py (100%) rename streaming/{base => }/stream.py (100%) rename streaming/{base => }/util/__init__.py (100%) rename streaming/{base => }/util/importing.py (100%) rename streaming/{base => }/util/merging.py (100%) rename streaming/{base => }/util/pretty.py (100%) rename streaming/{base => }/util/retrying.py (100%) rename streaming/{base => }/util/shared.py (100%) rename streaming/{base => }/vision.py (100%) rename streaming/{base => }/world.py (100%) diff --git a/streaming/__init__.py b/streaming/__init__.py index 7023580ce..37f0c0277 100644 --- a/streaming/__init__.py +++ b/streaming/__init__.py @@ -3,24 +3,13 @@ """MosaicML Streaming Datasets for cloud-native model training.""" -import streaming.multimodal as multimodal -import streaming.text as text -import streaming.vision as vision -from streaming._version import __version__ -from streaming.base import (CSVWriter, JSONWriter, LocalDataset, MDSWriter, Stream, - StreamingDataLoader, StreamingDataset, TSVWriter, XSVWriter) +from streaming.dataloader import StreamingDataLoader +from streaming.dataset import StreamingDataset +from streaming.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter +from streaming.local import LocalDataset +from streaming.stream import Stream __all__ = [ - 'StreamingDataLoader', - 'Stream', - 'StreamingDataset', - 'CSVWriter', - 'JSONWriter', - 'MDSWriter', - 'TSVWriter', - 'XSVWriter', - 'LocalDataset', - 'multimodal', - 'vision', - 'text', + 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', + 'MDSWriter', 'TSVWriter', 'XSVWriter' ] diff --git a/streaming/base/array.py b/streaming/array.py similarity index 100% rename from streaming/base/array.py rename to streaming/array.py diff --git a/streaming/base/__init__.py b/streaming/base/__init__.py deleted file mode 100644 index 8834b9bea..000000000 --- a/streaming/base/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""MosaicML Streaming Datasets for cloud-native model training.""" - -from streaming.base.dataloader import StreamingDataLoader -from streaming.base.dataset import StreamingDataset -from streaming.base.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter -from streaming.base.local import LocalDataset -from streaming.base.stream import Stream - -__all__ = [ - 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', - 'MDSWriter', 'TSVWriter', 'XSVWriter' -] diff --git a/streaming/base/batching/__init__.py b/streaming/batching/__init__.py similarity index 100% rename from streaming/base/batching/__init__.py rename to streaming/batching/__init__.py diff --git a/streaming/base/batching/per_stream.py b/streaming/batching/per_stream.py similarity index 100% rename from streaming/base/batching/per_stream.py rename to streaming/batching/per_stream.py diff --git a/streaming/base/batching/random.py b/streaming/batching/random.py similarity index 100% rename from streaming/base/batching/random.py rename to streaming/batching/random.py diff --git a/streaming/base/batching/stratified.py b/streaming/batching/stratified.py similarity index 100% rename from streaming/base/batching/stratified.py rename to streaming/batching/stratified.py diff --git a/streaming/base/cli/index_parquet.py b/streaming/cli/index_parquet.py similarity index 100% rename from streaming/base/cli/index_parquet.py rename to streaming/cli/index_parquet.py diff --git a/streaming/base/compression.py b/streaming/compression.py similarity index 100% rename from streaming/base/compression.py rename to streaming/compression.py diff --git a/streaming/base/constant.py b/streaming/constant.py similarity index 100% rename from streaming/base/constant.py rename to streaming/constant.py diff --git a/streaming/base/converters/README.md b/streaming/converters/README.md similarity index 100% rename from streaming/base/converters/README.md rename to streaming/converters/README.md diff --git a/streaming/base/converters/__init__.py b/streaming/converters/__init__.py similarity index 100% rename from streaming/base/converters/__init__.py rename to streaming/converters/__init__.py diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/converters/dataframe_to_mds.py similarity index 100% rename from streaming/base/converters/dataframe_to_mds.py rename to streaming/converters/dataframe_to_mds.py diff --git a/streaming/base/dataloader.py b/streaming/dataloader.py similarity index 100% rename from streaming/base/dataloader.py rename to streaming/dataloader.py diff --git a/streaming/base/dataset.py b/streaming/dataset.py similarity index 100% rename from streaming/base/dataset.py rename to streaming/dataset.py diff --git a/streaming/base/distributed.py b/streaming/distributed.py similarity index 100% rename from streaming/base/distributed.py rename to streaming/distributed.py diff --git a/streaming/base/format/__init__.py b/streaming/format/__init__.py similarity index 100% rename from streaming/base/format/__init__.py rename to streaming/format/__init__.py diff --git a/streaming/base/format/base/__init__.py b/streaming/format/base/__init__.py similarity index 100% rename from streaming/base/format/base/__init__.py rename to streaming/format/base/__init__.py diff --git a/streaming/base/format/base/reader.py b/streaming/format/base/reader.py similarity index 100% rename from streaming/base/format/base/reader.py rename to streaming/format/base/reader.py diff --git a/streaming/base/format/base/writer.py b/streaming/format/base/writer.py similarity index 100% rename from streaming/base/format/base/writer.py rename to streaming/format/base/writer.py diff --git a/streaming/base/format/delta/__init__.py b/streaming/format/delta/__init__.py similarity index 100% rename from streaming/base/format/delta/__init__.py rename to streaming/format/delta/__init__.py diff --git a/streaming/base/format/delta/indexing.py b/streaming/format/delta/indexing.py similarity index 100% rename from streaming/base/format/delta/indexing.py rename to streaming/format/delta/indexing.py diff --git a/streaming/base/format/index.py b/streaming/format/index.py similarity index 100% rename from streaming/base/format/index.py rename to streaming/format/index.py diff --git a/streaming/base/format/json/README.md b/streaming/format/json/README.md similarity index 100% rename from streaming/base/format/json/README.md rename to streaming/format/json/README.md diff --git a/streaming/base/format/json/__init__.py b/streaming/format/json/__init__.py similarity index 100% rename from streaming/base/format/json/__init__.py rename to streaming/format/json/__init__.py diff --git a/streaming/base/format/json/encodings.py b/streaming/format/json/encodings.py similarity index 100% rename from streaming/base/format/json/encodings.py rename to streaming/format/json/encodings.py diff --git a/streaming/base/format/json/reader.py b/streaming/format/json/reader.py similarity index 100% rename from streaming/base/format/json/reader.py rename to streaming/format/json/reader.py diff --git a/streaming/base/format/json/writer.py b/streaming/format/json/writer.py similarity index 100% rename from streaming/base/format/json/writer.py rename to streaming/format/json/writer.py diff --git a/streaming/base/format/lance/__init__.py b/streaming/format/lance/__init__.py similarity index 100% rename from streaming/base/format/lance/__init__.py rename to streaming/format/lance/__init__.py diff --git a/streaming/base/format/lance/indexing.py b/streaming/format/lance/indexing.py similarity index 100% rename from streaming/base/format/lance/indexing.py rename to streaming/format/lance/indexing.py diff --git a/streaming/base/format/mds/README.md b/streaming/format/mds/README.md similarity index 100% rename from streaming/base/format/mds/README.md rename to streaming/format/mds/README.md diff --git a/streaming/base/format/mds/__init__.py b/streaming/format/mds/__init__.py similarity index 100% rename from streaming/base/format/mds/__init__.py rename to streaming/format/mds/__init__.py diff --git a/streaming/base/format/mds/encodings.py b/streaming/format/mds/encodings.py similarity index 100% rename from streaming/base/format/mds/encodings.py rename to streaming/format/mds/encodings.py diff --git a/streaming/base/format/mds/reader.py b/streaming/format/mds/reader.py similarity index 100% rename from streaming/base/format/mds/reader.py rename to streaming/format/mds/reader.py diff --git a/streaming/base/format/mds/writer.py b/streaming/format/mds/writer.py similarity index 100% rename from streaming/base/format/mds/writer.py rename to streaming/format/mds/writer.py diff --git a/streaming/base/format/parquet/__init__.py b/streaming/format/parquet/__init__.py similarity index 100% rename from streaming/base/format/parquet/__init__.py rename to streaming/format/parquet/__init__.py diff --git a/streaming/base/format/parquet/indexing.py b/streaming/format/parquet/indexing.py similarity index 100% rename from streaming/base/format/parquet/indexing.py rename to streaming/format/parquet/indexing.py diff --git a/streaming/base/format/xsv/README.md b/streaming/format/xsv/README.md similarity index 100% rename from streaming/base/format/xsv/README.md rename to streaming/format/xsv/README.md diff --git a/streaming/base/format/xsv/__init__.py b/streaming/format/xsv/__init__.py similarity index 100% rename from streaming/base/format/xsv/__init__.py rename to streaming/format/xsv/__init__.py diff --git a/streaming/base/format/xsv/encodings.py b/streaming/format/xsv/encodings.py similarity index 100% rename from streaming/base/format/xsv/encodings.py rename to streaming/format/xsv/encodings.py diff --git a/streaming/base/format/xsv/reader.py b/streaming/format/xsv/reader.py similarity index 100% rename from streaming/base/format/xsv/reader.py rename to streaming/format/xsv/reader.py diff --git a/streaming/base/format/xsv/writer.py b/streaming/format/xsv/writer.py similarity index 100% rename from streaming/base/format/xsv/writer.py rename to streaming/format/xsv/writer.py diff --git a/streaming/base/hashing.py b/streaming/hashing.py similarity index 100% rename from streaming/base/hashing.py rename to streaming/hashing.py diff --git a/streaming/base/local.py b/streaming/local.py similarity index 100% rename from streaming/base/local.py rename to streaming/local.py diff --git a/streaming/base/partition/__init__.py b/streaming/partition/__init__.py similarity index 100% rename from streaming/base/partition/__init__.py rename to streaming/partition/__init__.py diff --git a/streaming/base/partition/orig.py b/streaming/partition/orig.py similarity index 100% rename from streaming/base/partition/orig.py rename to streaming/partition/orig.py diff --git a/streaming/base/partition/relaxed.py b/streaming/partition/relaxed.py similarity index 100% rename from streaming/base/partition/relaxed.py rename to streaming/partition/relaxed.py diff --git a/streaming/base/sampling.py b/streaming/sampling.py similarity index 100% rename from streaming/base/sampling.py rename to streaming/sampling.py diff --git a/streaming/base/shared/__init__.py b/streaming/shared/__init__.py similarity index 100% rename from streaming/base/shared/__init__.py rename to streaming/shared/__init__.py diff --git a/streaming/base/shared/array.py b/streaming/shared/array.py similarity index 100% rename from streaming/base/shared/array.py rename to streaming/shared/array.py diff --git a/streaming/base/shared/barrier.py b/streaming/shared/barrier.py similarity index 100% rename from streaming/base/shared/barrier.py rename to streaming/shared/barrier.py diff --git a/streaming/base/shared/memory.py b/streaming/shared/memory.py similarity index 100% rename from streaming/base/shared/memory.py rename to streaming/shared/memory.py diff --git a/streaming/base/shared/prefix.py b/streaming/shared/prefix.py similarity index 100% rename from streaming/base/shared/prefix.py rename to streaming/shared/prefix.py diff --git a/streaming/base/shared/scalar.py b/streaming/shared/scalar.py similarity index 100% rename from streaming/base/shared/scalar.py rename to streaming/shared/scalar.py diff --git a/streaming/base/shuffle/__init__.py b/streaming/shuffle/__init__.py similarity index 100% rename from streaming/base/shuffle/__init__.py rename to streaming/shuffle/__init__.py diff --git a/streaming/base/shuffle/naive.py b/streaming/shuffle/naive.py similarity index 100% rename from streaming/base/shuffle/naive.py rename to streaming/shuffle/naive.py diff --git a/streaming/base/shuffle/py1b.py b/streaming/shuffle/py1b.py similarity index 100% rename from streaming/base/shuffle/py1b.py rename to streaming/shuffle/py1b.py diff --git a/streaming/base/shuffle/py1br.py b/streaming/shuffle/py1br.py similarity index 100% rename from streaming/base/shuffle/py1br.py rename to streaming/shuffle/py1br.py diff --git a/streaming/base/shuffle/py1e.py b/streaming/shuffle/py1e.py similarity index 100% rename from streaming/base/shuffle/py1e.py rename to streaming/shuffle/py1e.py diff --git a/streaming/base/shuffle/py1s.py b/streaming/shuffle/py1s.py similarity index 100% rename from streaming/base/shuffle/py1s.py rename to streaming/shuffle/py1s.py diff --git a/streaming/base/shuffle/py2s.py b/streaming/shuffle/py2s.py similarity index 100% rename from streaming/base/shuffle/py2s.py rename to streaming/shuffle/py2s.py diff --git a/streaming/base/spanner.py b/streaming/spanner.py similarity index 100% rename from streaming/base/spanner.py rename to streaming/spanner.py diff --git a/streaming/base/storage/__init__.py b/streaming/storage/__init__.py similarity index 100% rename from streaming/base/storage/__init__.py rename to streaming/storage/__init__.py diff --git a/streaming/base/storage/download.py b/streaming/storage/download.py similarity index 100% rename from streaming/base/storage/download.py rename to streaming/storage/download.py diff --git a/streaming/base/storage/extra.py b/streaming/storage/extra.py similarity index 100% rename from streaming/base/storage/extra.py rename to streaming/storage/extra.py diff --git a/streaming/base/storage/upload.py b/streaming/storage/upload.py similarity index 100% rename from streaming/base/storage/upload.py rename to streaming/storage/upload.py diff --git a/streaming/base/stream.py b/streaming/stream.py similarity index 100% rename from streaming/base/stream.py rename to streaming/stream.py diff --git a/streaming/base/util/__init__.py b/streaming/util/__init__.py similarity index 100% rename from streaming/base/util/__init__.py rename to streaming/util/__init__.py diff --git a/streaming/base/util/importing.py b/streaming/util/importing.py similarity index 100% rename from streaming/base/util/importing.py rename to streaming/util/importing.py diff --git a/streaming/base/util/merging.py b/streaming/util/merging.py similarity index 100% rename from streaming/base/util/merging.py rename to streaming/util/merging.py diff --git a/streaming/base/util/pretty.py b/streaming/util/pretty.py similarity index 100% rename from streaming/base/util/pretty.py rename to streaming/util/pretty.py diff --git a/streaming/base/util/retrying.py b/streaming/util/retrying.py similarity index 100% rename from streaming/base/util/retrying.py rename to streaming/util/retrying.py diff --git a/streaming/base/util/shared.py b/streaming/util/shared.py similarity index 100% rename from streaming/base/util/shared.py rename to streaming/util/shared.py diff --git a/streaming/base/vision.py b/streaming/vision.py similarity index 100% rename from streaming/base/vision.py rename to streaming/vision.py diff --git a/streaming/base/world.py b/streaming/world.py similarity index 100% rename from streaming/base/world.py rename to streaming/world.py From 4851888d3b0c038f3592d2453b8a2d20fa87710f Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 08:19:03 -0700 Subject: [PATCH 20/45] Update paths accordingly. --- STYLE_GUIDE.md | 14 ++++---- benchmarks/compression/bench.py | 2 +- benchmarks/epoch/bench.py | 4 +-- benchmarks/hashing/bench.py | 2 +- benchmarks/partitioning/bench.py | 2 +- benchmarks/partitioning/diff.py | 2 +- benchmarks/partitioning/txt.py | 2 +- benchmarks/partitioning/web.py | 2 +- .../serialization/survey_fixed_decimals.py | 2 +- benchmarks/shuffling/bench.py | 2 +- benchmarks/shuffling/vis.py | 2 +- docs/source/conf.py | 20 ++++++------ .../fundamentals/dataset_conversion_guide.md | 2 +- docs/source/getting_started/user_guide.md | 2 +- .../multiprocess_dataset_conversion.ipynb | 2 +- notebooks/spark_dataframe_to_MDS.ipynb | 6 ++-- regression/iterate_data.py | 2 +- regression/synthetic_dataset.py | 2 +- streaming/batching/__init__.py | 10 +++--- streaming/batching/per_stream.py | 8 ++--- streaming/batching/random.py | 8 ++--- streaming/batching/stratified.py | 8 ++--- streaming/cli/index_parquet.py | 4 +-- streaming/converters/__init__.py | 2 +- streaming/converters/dataframe_to_mds.py | 10 +++--- streaming/dataloader.py | 4 +-- streaming/dataset.py | 22 ++++++------- .../laion400m/convert_and_upload.py | 2 +- streaming/examples/multimodal/webvid/read.py | 6 ++-- streaming/examples/text/c4/read.py | 2 +- streaming/examples/text/c4/write.py | 4 +-- streaming/examples/text/enwiki_txt/enwiki.py | 2 +- streaming/examples/text/enwiki_txt/write.py | 4 +-- streaming/examples/text/pile/read.py | 2 +- streaming/examples/text/pile/write.py | 4 +-- streaming/examples/vision/ade20k/read.py | 2 +- streaming/examples/vision/ade20k/write.py | 4 +-- streaming/examples/vision/cifar10/write.py | 2 +- streaming/examples/vision/coco/read.py | 2 +- streaming/examples/vision/coco/write.py | 4 +-- streaming/examples/vision/imagenet/write.py | 4 +-- streaming/format/__init__.py | 16 +++++----- streaming/format/base/__init__.py | 8 ----- streaming/format/delta/__init__.py | 2 +- streaming/format/json/__init__.py | 4 +-- streaming/format/json/reader.py | 2 +- streaming/format/json/writer.py | 4 +-- streaming/format/lance/__init__.py | 2 +- streaming/format/mds/__init__.py | 4 +-- streaming/format/mds/reader.py | 4 +-- streaming/format/mds/writer.py | 4 +-- streaming/format/parquet/__init__.py | 2 +- streaming/format/parquet/indexing.py | 4 +-- streaming/format/{base => }/reader.py | 4 +-- streaming/format/{base => }/writer.py | 10 +++--- streaming/format/xsv/__init__.py | 4 +-- streaming/format/xsv/reader.py | 4 +-- streaming/format/xsv/writer.py | 4 +-- streaming/local.py | 6 ++-- streaming/partition/__init__.py | 4 +-- streaming/partition/relaxed.py | 2 +- streaming/shared/__init__.py | 12 +++---- streaming/shared/array.py | 2 +- streaming/shared/barrier.py | 4 +-- streaming/shared/memory.py | 2 +- streaming/shared/prefix.py | 8 ++--- streaming/shared/scalar.py | 2 +- streaming/shuffle/__init__.py | 12 +++---- streaming/shuffle/py1b.py | 2 +- streaming/shuffle/py1br.py | 2 +- streaming/shuffle/py1e.py | 2 +- streaming/storage/__init__.py | 4 +-- streaming/storage/download.py | 2 +- streaming/storage/extra.py | 6 ++-- streaming/storage/upload.py | 4 +-- streaming/stream.py | 18 +++++------ streaming/util/__init__.py | 10 +++--- streaming/util/merging.py | 8 ++--- streaming/util/retrying.py | 2 +- streaming/util/shared.py | 6 ++-- streaming/vision.py | 2 +- streaming/world.py | 2 +- .../base/converters/test_dataframe_to_mds.py | 2 +- tests/common/datasets.py | 2 +- tests/test_array.py | 2 +- tests/test_barrier.py | 2 +- tests/test_compression.py | 4 +-- tests/test_distributed.py | 4 +-- tests/test_download.py | 18 +++++------ tests/test_encodings.py | 6 ++-- tests/test_hashing.py | 4 +-- tests/test_local.py | 2 +- tests/test_partition.py | 2 +- tests/test_reader.py | 2 +- tests/test_sampling.py | 2 +- tests/test_shared.py | 4 +-- tests/test_shuffle.py | 2 +- tests/test_spanner.py | 2 +- tests/test_stream.py | 2 +- tests/test_streaming.py | 4 +-- tests/test_streaming_remote.py | 2 +- tests/test_upload.py | 32 +++++++++---------- tests/test_util.py | 14 ++++---- 103 files changed, 252 insertions(+), 260 deletions(-) delete mode 100644 streaming/format/base/__init__.py rename streaming/format/{base => }/reader.py (99%) rename streaming/format/{base => }/writer.py (98%) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 265ea7d24..67156e2a0 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -142,10 +142,10 @@ so other contributors will know why this error was silenced. A public API, generally speaking, can be invoked by a user without a leading underscore in any portion of the path. The following are examples of public APIs: -* Standalone functions in public modules (e.g. `streaming.base.distributed.get_world_size`) -* Classes in public modules (e.g. `streaming.base.format.MDSWriter`) -* Public methods in public classes (e.g. `streaming.base.format.MDSWriter.write`) -* Public modules (e.g. `streaming.base.dataset`) +* Standalone functions in public modules (e.g. `streaming.distributed.get_world_size`) +* Classes in public modules (e.g. `streaming.format.MDSWriter`) +* Public methods in public classes (e.g. `streaming.format.MDSWriter.write`) +* Public modules (e.g. `streaming.dataset`) The following rules apply to public APIs: 1. All public APIs must have a docstring (see the Documentation section below) @@ -201,14 +201,14 @@ All public modules must define `__all__` to be the list of members that should b The variable is necessary to 1) limit what `from XXX import *` imports, and 2) ensure that the documentation only includes exported members, not unrelated re-imports. -For example, from [streaming/base/dataset.py](streaming/base/dataset.py) +For example, from [streaming/dataset.py](streaming/dataset.py) ```python """The :class:`Dataset` class, used for building streaming iterable datasets.""" from torch.utils.data import IterableDataset -from streaming.base.format import reader_from_json -from streaming.base.spanner import Spanner +from streaming.format import reader_from_json +from streaming.spanner import Spanner __all__ = ["Dataset"] # export only the Dataset, not other imports like `Spanner` or `reader_from_json` diff --git a/benchmarks/compression/bench.py b/benchmarks/compression/bench.py index 7fff5149b..d3740e335 100644 --- a/benchmarks/compression/bench.py +++ b/benchmarks/compression/bench.py @@ -9,7 +9,7 @@ import numpy as np -from streaming.base.compression import compress, decompress, get_compressions +from streaming.compression import compress, decompress, get_compressions def parse_args() -> Namespace: diff --git a/benchmarks/epoch/bench.py b/benchmarks/epoch/bench.py index 393ea66af..a1c8b73e0 100644 --- a/benchmarks/epoch/bench.py +++ b/benchmarks/epoch/bench.py @@ -9,8 +9,8 @@ import numpy as np -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle def parse_args() -> Namespace: diff --git a/benchmarks/hashing/bench.py b/benchmarks/hashing/bench.py index 6be145006..45e4d4855 100644 --- a/benchmarks/hashing/bench.py +++ b/benchmarks/hashing/bench.py @@ -9,7 +9,7 @@ import numpy as np -from streaming.base.hashing import get_hash, get_hashes +from streaming.hashing import get_hash, get_hashes def parse_args() -> Namespace: diff --git a/benchmarks/partitioning/bench.py b/benchmarks/partitioning/bench.py index 3d83d3b63..d52629d25 100644 --- a/benchmarks/partitioning/bench.py +++ b/benchmarks/partitioning/bench.py @@ -6,7 +6,7 @@ from argparse import ArgumentParser, Namespace from time import time -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/benchmarks/partitioning/diff.py b/benchmarks/partitioning/diff.py index 43c10224b..0c6f68171 100644 --- a/benchmarks/partitioning/diff.py +++ b/benchmarks/partitioning/diff.py @@ -10,7 +10,7 @@ import numpy as np from tqdm import tqdm -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/benchmarks/partitioning/txt.py b/benchmarks/partitioning/txt.py index 4f6793825..8d71f6294 100644 --- a/benchmarks/partitioning/txt.py +++ b/benchmarks/partitioning/txt.py @@ -6,7 +6,7 @@ import math from argparse import ArgumentParser, Namespace -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/benchmarks/partitioning/web.py b/benchmarks/partitioning/web.py index c37a849f2..f961b06ba 100644 --- a/benchmarks/partitioning/web.py +++ b/benchmarks/partitioning/web.py @@ -16,7 +16,7 @@ from fastapi.responses import HTMLResponse from pydantic import BaseModel -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions INDEX = ''' diff --git a/benchmarks/serialization/survey_fixed_decimals.py b/benchmarks/serialization/survey_fixed_decimals.py index d0ebebe83..6b28c80d0 100644 --- a/benchmarks/serialization/survey_fixed_decimals.py +++ b/benchmarks/serialization/survey_fixed_decimals.py @@ -7,7 +7,7 @@ import numpy as np -from streaming.base.util.pretty import unpack_strs +from streaming.util.pretty import unpack_strs def parse_args() -> Namespace: diff --git a/benchmarks/shuffling/bench.py b/benchmarks/shuffling/bench.py index 74ec02021..774906fea 100644 --- a/benchmarks/shuffling/bench.py +++ b/benchmarks/shuffling/bench.py @@ -11,7 +11,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle import (get_shuffle_naive, get_shuffle_py1b, get_shuffle_py1s, +from streaming.shuffle import (get_shuffle_naive, get_shuffle_py1b, get_shuffle_py1s, get_shuffle_py2s) diff --git a/benchmarks/shuffling/vis.py b/benchmarks/shuffling/vis.py index 7819e6b2a..1b7f387d5 100644 --- a/benchmarks/shuffling/vis.py +++ b/benchmarks/shuffling/vis.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.base.shuffle import algos, get_shuffle +from streaming.shuffle import algos, get_shuffle def parse_args() -> Namespace: diff --git a/docs/source/conf.py b/docs/source/conf.py index e25dc24ba..3d9efc843 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -364,17 +364,17 @@ def _modules_to_rst() -> List[types.ModuleType]: """Return the list of modules for which to generate API reference rst files.""" document_modules: List[types.Module] = [ streaming, - streaming.base.compression, - streaming.base.format, - streaming.base.hashing, - streaming.base.partition, - streaming.base.shared, - streaming.base.shuffle, - streaming.base.storage, - streaming.base.util, - streaming.base.world, + streaming.compression, + streaming.format, + streaming.hashing, + streaming.partition, + streaming.shared, + streaming.shuffle, + streaming.storage, + streaming.util, + streaming.world, ] - exclude_modules: List[types.Module] = [streaming.base, streaming._version] + exclude_modules: List[types.Module] = [streaming, streaming._version] for name in streaming.__dict__: obj = streaming.__dict__[name] if isinstance(obj, types.ModuleType) and obj not in exclude_modules: diff --git a/docs/source/fundamentals/dataset_conversion_guide.md b/docs/source/fundamentals/dataset_conversion_guide.md index c2f750ac6..e480f2724 100644 --- a/docs/source/fundamentals/dataset_conversion_guide.md +++ b/docs/source/fundamentals/dataset_conversion_guide.md @@ -42,7 +42,7 @@ column = { import numpy as np from typing import Any -from streaming.base.format.mds.encodings import Encoding, _encodings +from streaming.format.mds.encodings import Encoding, _encodings class Int32(Encoding): def encode(self, obj: Any) -> bytes: diff --git a/docs/source/getting_started/user_guide.md b/docs/source/getting_started/user_guide.md index 0226b94a6..e8025cbe3 100644 --- a/docs/source/getting_started/user_guide.md +++ b/docs/source/getting_started/user_guide.md @@ -106,7 +106,7 @@ def each(samples): It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset. ```python -from streaming.base import MDSWriter +from streaming import MDSWriter dataset = RandomClassificationDataset() with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes=hashes, size_limit=limit) as out: diff --git a/notebooks/multiprocess_dataset_conversion.ipynb b/notebooks/multiprocess_dataset_conversion.ipynb index d0ce9f134..c3591ea3c 100644 --- a/notebooks/multiprocess_dataset_conversion.ipynb +++ b/notebooks/multiprocess_dataset_conversion.ipynb @@ -424,7 +424,7 @@ }, "outputs": [], "source": [ - "from streaming.base.util import merge_index\n", + "from streaming.util import merge_index\n", "merge_index(out_root, keep_local=True)" ] }, diff --git a/notebooks/spark_dataframe_to_MDS.ipynb b/notebooks/spark_dataframe_to_MDS.ipynb index c5617d464..72c72961b 100644 --- a/notebooks/spark_dataframe_to_MDS.ipynb +++ b/notebooks/spark_dataframe_to_MDS.ipynb @@ -137,7 +137,7 @@ { "cell_type": "code", "source": [ - "from streaming.base.converters import dataframeToMDS" + "from streaming.converters import dataframeToMDS" ], "metadata": { "id": "uzYHe6yYRzyV" @@ -500,7 +500,7 @@ "from streaming import StreamingDataset\n", "\n", "# clean stale shared memory if any\n", - "streaming.base.util.clean_stale_shared_memory()\n", + "streaming.util.clean_stale_shared_memory()\n", "\n", "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n", "\n", @@ -773,7 +773,7 @@ "from streaming import StreamingDataset\n", "\n", "# clean stale shared memory if any\n", - "streaming.base.util.clean_stale_shared_memory()\n", + "streaming.util.clean_stale_shared_memory()\n", "\n", "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n", "\n", diff --git a/regression/iterate_data.py b/regression/iterate_data.py index eab9131b5..ffa9f2c81 100644 --- a/regression/iterate_data.py +++ b/regression/iterate_data.py @@ -17,7 +17,7 @@ get_streaming_dataset_params) from streaming import StreamingDataset -from streaming.base.distributed import (all_gather, barrier, get_rank, get_world_size, +from streaming.distributed import (all_gather, barrier, get_rank, get_world_size, maybe_init_dist) logger = logging.getLogger(__name__) diff --git a/regression/synthetic_dataset.py b/regression/synthetic_dataset.py index c90cbb888..3e8f44d78 100644 --- a/regression/synthetic_dataset.py +++ b/regression/synthetic_dataset.py @@ -14,7 +14,7 @@ import torch from utils import delete_gcs, delete_oci, delete_s3, get_kwargs, get_writer_params -from streaming.base import MDSWriter +from streaming import MDSWriter _DATASET_MAP = { 'sequencedataset': 'SequenceDataset', diff --git a/streaming/batching/__init__.py b/streaming/batching/__init__.py index f4fd7f788..a95de0147 100644 --- a/streaming/batching/__init__.py +++ b/streaming/batching/__init__.py @@ -9,13 +9,13 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.batching.per_stream import generate_work_per_stream_batching -from streaming.base.batching.random import generate_work_random_batching -from streaming.base.batching.stratified import generate_work_stratified_batching -from streaming.base.world import World +from streaming.batching.per_stream import generate_work_per_stream_batching +from streaming.batching.random import generate_work_random_batching +from streaming.batching.stratified import generate_work_stratified_batching +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset batching_methods = { 'random': generate_work_random_batching, diff --git a/streaming/batching/per_stream.py b/streaming/batching/per_stream.py index 1686720b9..e955b0114 100644 --- a/streaming/batching/per_stream.py +++ b/streaming/batching/per_stream.py @@ -10,12 +10,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/batching/random.py b/streaming/batching/random.py index 48e803acb..a716e0515 100644 --- a/streaming/batching/random.py +++ b/streaming/batching/random.py @@ -10,12 +10,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/batching/stratified.py b/streaming/batching/stratified.py index 2eef06fd5..aff18eba6 100644 --- a/streaming/batching/stratified.py +++ b/streaming/batching/stratified.py @@ -11,12 +11,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/cli/index_parquet.py b/streaming/cli/index_parquet.py index 1d0b0377f..90a08c1fb 100644 --- a/streaming/cli/index_parquet.py +++ b/streaming/cli/index_parquet.py @@ -6,8 +6,8 @@ import json from argparse import ArgumentParser, Namespace -from streaming.base.format import index_parquet -from streaming.base.util.pretty import unpack_str2str +from streaming.format import index_parquet +from streaming.util.pretty import unpack_str2str def parse_args() -> Namespace: diff --git a/streaming/converters/__init__.py b/streaming/converters/__init__.py index 8fbbed094..dec288307 100644 --- a/streaming/converters/__init__.py +++ b/streaming/converters/__init__.py @@ -3,7 +3,7 @@ """Utility function for converting spark dataframe to MDS dataset.""" -from streaming.base.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds, +from streaming.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds, dataframeToMDS) __all__ = ['dataframeToMDS', 'dataframe_to_mds', 'MAPPING_SPARK_TO_MDS'] diff --git a/streaming/converters/dataframe_to_mds.py b/streaming/converters/dataframe_to_mds.py index c74460b3f..5093a9d7f 100644 --- a/streaming/converters/dataframe_to_mds.py +++ b/streaming/converters/dataframe_to_mds.py @@ -11,8 +11,8 @@ import pandas as pd -from streaming.base.util import get_import_exception_message -from streaming.base.util import merge_index as do_merge_index +from streaming.util import get_import_exception_message +from streaming.util import merge_index as do_merge_index try: from pyspark import TaskContext @@ -26,9 +26,9 @@ raise e from streaming import MDSWriter -from streaming.base.format.index import get_index_basename -from streaming.base.format.mds.encodings import _encodings -from streaming.base.storage.upload import CloudUploader +from streaming.format.index import get_index_basename +from streaming.format.mds.encodings import _encodings +from streaming.storage.upload import CloudUploader logger = logging.getLogger(__name__) diff --git a/streaming/dataloader.py b/streaming/dataloader.py index 89cdb0026..a0c881f34 100644 --- a/streaming/dataloader.py +++ b/streaming/dataloader.py @@ -9,8 +9,8 @@ from torch.utils.data import DataLoader from transformers import BatchEncoding, BatchFeature -from streaming.base.dataset import StreamingDataset -from streaming.base.world import World +from streaming.dataset import StreamingDataset +from streaming.world import World class StreamingDataLoader(DataLoader): diff --git a/streaming/dataset.py b/streaming/dataset.py index 0faae33fd..c949841d9 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -22,20 +22,20 @@ from torch import distributed as dist from torch.utils.data import IterableDataset -from streaming.base.array import Array -from streaming.base.batching import generate_work -from streaming.base.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, +from streaming.array import Array +from streaming.batching import generate_work +from streaming.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES, TICK) -from streaming.base.distributed import maybe_init_dist -from streaming.base.format import get_index_basename -from streaming.base.sampling import get_sampling -from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, +from streaming.distributed import maybe_init_dist +from streaming.format import get_index_basename +from streaming.sampling import get_sampling +from streaming.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) -from streaming.base.spanner import Spanner -from streaming.base.stream import Stream -from streaming.base.util import normalize_bytes, normalize_count -from streaming.base.world import World +from streaming.spanner import Spanner +from streaming.stream import Stream +from streaming.util import normalize_bytes, normalize_count +from streaming.world import World # An arbitrary time in the future, used for cold shard eviction. NEVER = np.iinfo(np.uint64).max diff --git a/streaming/examples/multimodal/laion400m/convert_and_upload.py b/streaming/examples/multimodal/laion400m/convert_and_upload.py index 8af84a3d1..ddc1e4cb5 100644 --- a/streaming/examples/multimodal/laion400m/convert_and_upload.py +++ b/streaming/examples/multimodal/laion400m/convert_and_upload.py @@ -13,7 +13,7 @@ from pyarrow import parquet as pq from streaming import MDSWriter -from streaming.base.storage import CloudUploader +from streaming.storage import CloudUploader def parse_args() -> Namespace: diff --git a/streaming/examples/multimodal/webvid/read.py b/streaming/examples/multimodal/webvid/read.py index 260ae7bc3..d3f74c2d9 100644 --- a/streaming/examples/multimodal/webvid/read.py +++ b/streaming/examples/multimodal/webvid/read.py @@ -7,9 +7,9 @@ from time import sleep from typing import Any, Optional -from streaming.base import StreamingDataset -from streaming.base.dataset import TICK, _Iterator -from streaming.base.storage import download_file +from streaming import StreamingDataset +from streaming.dataset import TICK, _Iterator +from streaming.storage import download_file class StreamingInsideWebVid(StreamingDataset): diff --git a/streaming/examples/text/c4/read.py b/streaming/examples/text/c4/read.py index 82a24a255..d30340f97 100644 --- a/streaming/examples/text/c4/read.py +++ b/streaming/examples/text/c4/read.py @@ -11,7 +11,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingC4'] diff --git a/streaming/examples/text/c4/write.py b/streaming/examples/text/c4/write.py index 395cf1120..efe1aed32 100644 --- a/streaming/examples/text/c4/write.py +++ b/streaming/examples/text/c4/write.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader, IterableDataset, get_worker_info from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util.pretty import unpack_strs +from streaming import MDSWriter +from streaming.util.pretty import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/examples/text/enwiki_txt/enwiki.py b/streaming/examples/text/enwiki_txt/enwiki.py index 63c24a5a3..4385e7394 100644 --- a/streaming/examples/text/enwiki_txt/enwiki.py +++ b/streaming/examples/text/enwiki_txt/enwiki.py @@ -7,7 +7,7 @@ import numpy as np -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingEnWiki'] diff --git a/streaming/examples/text/enwiki_txt/write.py b/streaming/examples/text/enwiki_txt/write.py index 2a60043fe..e554952e2 100644 --- a/streaming/examples/text/enwiki_txt/write.py +++ b/streaming/examples/text/enwiki_txt/write.py @@ -9,8 +9,8 @@ from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import unpack_strs +from streaming import MDSWriter +from streaming.util import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/examples/text/pile/read.py b/streaming/examples/text/pile/read.py index f2f06113b..58c4afc68 100644 --- a/streaming/examples/text/pile/read.py +++ b/streaming/examples/text/pile/read.py @@ -11,7 +11,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingPile'] diff --git a/streaming/examples/text/pile/write.py b/streaming/examples/text/pile/write.py index ce24d26ae..fc1d4557f 100644 --- a/streaming/examples/text/pile/write.py +++ b/streaming/examples/text/pile/write.py @@ -11,8 +11,8 @@ from multiprocessing import Pool from typing import Dict, Iterator, List, Tuple -from streaming.base import MDSWriter -from streaming.base.util import unpack_strs +from streaming import MDSWriter +from streaming.util import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/examples/vision/ade20k/read.py b/streaming/examples/vision/ade20k/read.py index bba847115..f04fc423f 100644 --- a/streaming/examples/vision/ade20k/read.py +++ b/streaming/examples/vision/ade20k/read.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional, Tuple -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingADE20K'] diff --git a/streaming/examples/vision/ade20k/write.py b/streaming/examples/vision/ade20k/write.py index 5043bd9c2..5c5a613d7 100644 --- a/streaming/examples/vision/ade20k/write.py +++ b/streaming/examples/vision/ade20k/write.py @@ -11,8 +11,8 @@ from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util.pretty import unpack_strs +from streaming import MDSWriter +from streaming.util.pretty import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/examples/vision/cifar10/write.py b/streaming/examples/vision/cifar10/write.py index 7cab9207e..83bdce334 100644 --- a/streaming/examples/vision/cifar10/write.py +++ b/streaming/examples/vision/cifar10/write.py @@ -7,7 +7,7 @@ from torchvision.datasets import CIFAR10 -from streaming.base.util import unpack_strs +from streaming.util import unpack_strs from streaming.vision.convert.base import convert_image_class_dataset diff --git a/streaming/examples/vision/coco/read.py b/streaming/examples/vision/coco/read.py index 162b17581..a9622eab3 100644 --- a/streaming/examples/vision/coco/read.py +++ b/streaming/examples/vision/coco/read.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingCOCO'] diff --git a/streaming/examples/vision/coco/write.py b/streaming/examples/vision/coco/write.py index 65ba7da8c..29cf20865 100644 --- a/streaming/examples/vision/coco/write.py +++ b/streaming/examples/vision/coco/write.py @@ -14,8 +14,8 @@ from torch.utils.data import Dataset from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import unpack_strs +from streaming import MDSWriter +from streaming.util import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/examples/vision/imagenet/write.py b/streaming/examples/vision/imagenet/write.py index 6b2efd670..7da15d1dd 100644 --- a/streaming/examples/vision/imagenet/write.py +++ b/streaming/examples/vision/imagenet/write.py @@ -12,8 +12,8 @@ from PIL import Image from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import unpack_strs +from streaming import MDSWriter +from streaming.util import unpack_strs def parse_args() -> Namespace: diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index 82b506562..5cb9e234b 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -5,14 +5,14 @@ from typing import Any, Dict, Optional, Union -from streaming.base.format.base import FileInfo, Reader -from streaming.base.format.delta import index_delta -from streaming.base.format.index import get_index_basename -from streaming.base.format.json import JSONReader, JSONWriter -from streaming.base.format.lance import index_lance -from streaming.base.format.mds import MDSReader, MDSWriter -from streaming.base.format.parquet import index_parquet -from streaming.base.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, +from streaming.format.base import FileInfo, Reader +from streaming.format.delta import index_delta +from streaming.format.index import get_index_basename +from streaming.format.json import JSONReader, JSONWriter +from streaming.format.lance import index_lance +from streaming.format.mds import MDSReader, MDSWriter +from streaming.format.parquet import index_parquet +from streaming.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter) __all__ = [ diff --git a/streaming/format/base/__init__.py b/streaming/format/base/__init__.py deleted file mode 100644 index 46bf9f730..000000000 --- a/streaming/format/base/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Base module for dataset reader and writer.""" - -from streaming.base.format.base.reader import FileInfo, Reader - -__all__ = ['FileInfo', 'Reader'] diff --git a/streaming/format/delta/__init__.py b/streaming/format/delta/__init__.py index 248e928a0..c94680487 100644 --- a/streaming/format/delta/__init__.py +++ b/streaming/format/delta/__init__.py @@ -3,6 +3,6 @@ """Integration with Delta tables.""" -from streaming.base.format.delta.indexing import index_delta +from streaming.format.delta.indexing import index_delta __all__ = ['index_delta'] diff --git a/streaming/format/json/__init__.py b/streaming/format/json/__init__.py index fe37c8570..47e8be8f6 100644 --- a/streaming/format/json/__init__.py +++ b/streaming/format/json/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in JSON format.""" -from streaming.base.format.json.reader import JSONReader -from streaming.base.format.json.writer import JSONWriter +from streaming.format.json.reader import JSONReader +from streaming.format.json.writer import JSONWriter __all__ = ['JSONReader', 'JSONWriter'] diff --git a/streaming/format/json/reader.py b/streaming/format/json/reader.py index 4aaeb91cc..afcef25bd 100644 --- a/streaming/format/json/reader.py +++ b/streaming/format/json/reader.py @@ -11,7 +11,7 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, SplitReader +from streaming.format.base.reader import FileInfo, SplitReader __all__ = ['JSONReader'] diff --git a/streaming/format/json/writer.py b/streaming/format/json/writer.py index aae9d1d28..74c11c7ac 100644 --- a/streaming/format/json/writer.py +++ b/streaming/format/json/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.base.format.base.writer import SplitWriter -from streaming.base.format.json.encodings import is_json_encoded, is_json_encoding +from streaming.format.base.writer import SplitWriter +from streaming.format.json.encodings import is_json_encoded, is_json_encoding __all__ = ['JSONWriter'] diff --git a/streaming/format/lance/__init__.py b/streaming/format/lance/__init__.py index 3e3d3ac87..d3885d0b6 100644 --- a/streaming/format/lance/__init__.py +++ b/streaming/format/lance/__init__.py @@ -3,6 +3,6 @@ """Integration with Lance datasets.""" -from streaming.base.format.lance.indexing import index_lance +from streaming.format.lance.indexing import index_lance __all__ = ['index_lance'] diff --git a/streaming/format/mds/__init__.py b/streaming/format/mds/__init__.py index 2c18ca0e7..67a5be56f 100644 --- a/streaming/format/mds/__init__.py +++ b/streaming/format/mds/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in MDS format.""" -from streaming.base.format.mds.reader import MDSReader -from streaming.base.format.mds.writer import MDSWriter +from streaming.format.mds.reader import MDSReader +from streaming.format.mds.writer import MDSWriter __all__ = ['MDSReader', 'MDSWriter'] diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/reader.py index 275f01192..c779433ad 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/reader.py @@ -10,8 +10,8 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, JointReader -from streaming.base.format.mds.encodings import mds_decode +from streaming.format.base.reader import FileInfo, JointReader +from streaming.format.mds.encodings import mds_decode __all__ = ['MDSReader'] diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index e82fc02a8..00b9e8da4 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.base.format.base.writer import JointWriter -from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, +from streaming.format.base.writer import JointWriter +from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) __all__ = ['MDSWriter'] diff --git a/streaming/format/parquet/__init__.py b/streaming/format/parquet/__init__.py index c2847cee8..b98171f5d 100644 --- a/streaming/format/parquet/__init__.py +++ b/streaming/format/parquet/__init__.py @@ -3,6 +3,6 @@ """Integration with Parquet datasets.""" -from streaming.base.format.parquet.indexing import index_parquet +from streaming.format.parquet.indexing import index_parquet __all__ = ['index_parquet'] diff --git a/streaming/format/parquet/indexing.py b/streaming/format/parquet/indexing.py index cacdbc583..e615c7f89 100644 --- a/streaming/format/parquet/indexing.py +++ b/streaming/format/parquet/indexing.py @@ -10,8 +10,8 @@ from pyarrow import parquet as pq from tqdm import tqdm -from streaming.base.format.mds.encodings import get_mds_encoded_size -from streaming.base.storage.extra import list_dataset_files, smart_download_file +from streaming.format.mds.encodings import get_mds_encoded_size +from streaming.storage.extra import list_dataset_files, smart_download_file __all__ = ['index_parquet'] diff --git a/streaming/format/base/reader.py b/streaming/format/reader.py similarity index 99% rename from streaming/format/base/reader.py rename to streaming/format/reader.py index 7db3521cc..3fa9c6b3a 100644 --- a/streaming/format/base/reader.py +++ b/streaming/format/reader.py @@ -8,8 +8,8 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Set, Union -from streaming.base.array import Array -from streaming.base.util import normalize_bytes +from streaming.array import Array +from streaming.util import normalize_bytes __all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] diff --git a/streaming/format/base/writer.py b/streaming/format/writer.py similarity index 98% rename from streaming/format/base/writer.py rename to streaming/format/writer.py index 7b182c539..25a5ef28e 100644 --- a/streaming/format/base/writer.py +++ b/streaming/format/writer.py @@ -18,11 +18,11 @@ from typing_extensions import Self -from streaming.base.compression import compress, get_compression_extension, is_compression -from streaming.base.format.index import get_index_basename -from streaming.base.hashing import get_hash, is_hash -from streaming.base.storage.upload import CloudUploader -from streaming.base.util import normalize_bytes +from streaming.compression import compress, get_compression_extension, is_compression +from streaming.format.index import get_index_basename +from streaming.hashing import get_hash, is_hash +from streaming.storage.upload import CloudUploader +from streaming.util import normalize_bytes __all__ = ['JointWriter', 'SplitWriter'] diff --git a/streaming/format/xsv/__init__.py b/streaming/format/xsv/__init__.py index 6d5ca2489..985010a42 100644 --- a/streaming/format/xsv/__init__.py +++ b/streaming/format/xsv/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in Tabular format.""" -from streaming.base.format.xsv.reader import CSVReader, TSVReader, XSVReader -from streaming.base.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter +from streaming.format.xsv.reader import CSVReader, TSVReader, XSVReader +from streaming.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter __all__ = ['CSVReader', 'CSVWriter', 'TSVReader', 'TSVWriter', 'XSVReader', 'XSVWriter'] diff --git a/streaming/format/xsv/reader.py b/streaming/format/xsv/reader.py index 896d9cda9..7c0df2bc9 100644 --- a/streaming/format/xsv/reader.py +++ b/streaming/format/xsv/reader.py @@ -10,8 +10,8 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, SplitReader -from streaming.base.format.xsv.encodings import xsv_decode +from streaming.format.base.reader import FileInfo, SplitReader +from streaming.format.xsv.encodings import xsv_decode __all__ = ['XSVReader', 'CSVReader', 'TSVReader'] diff --git a/streaming/format/xsv/writer.py b/streaming/format/xsv/writer.py index 2888597b2..88d00fd54 100644 --- a/streaming/format/xsv/writer.py +++ b/streaming/format/xsv/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.base.format.base.writer import SplitWriter -from streaming.base.format.xsv.encodings import is_xsv_encoding, xsv_encode +from streaming.format.base.writer import SplitWriter +from streaming.format.xsv.encodings import is_xsv_encoding, xsv_encode __all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter'] diff --git a/streaming/local.py b/streaming/local.py index 48eea91a5..47dd8134f 100644 --- a/streaming/local.py +++ b/streaming/local.py @@ -10,9 +10,9 @@ import numpy as np from torch.utils.data import Dataset -from streaming.base.array import Array -from streaming.base.format import get_index_basename, reader_from_json -from streaming.base.spanner import Spanner +from streaming.array import Array +from streaming.format import get_index_basename, reader_from_json +from streaming.spanner import Spanner __all__ = ['LocalDataset'] diff --git a/streaming/partition/__init__.py b/streaming/partition/__init__.py index ad1edefa2..5e67e485c 100644 --- a/streaming/partition/__init__.py +++ b/streaming/partition/__init__.py @@ -8,8 +8,8 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition.orig import get_partitions_orig -from streaming.base.partition.relaxed import get_partitions_relaxed +from streaming.partition.orig import get_partitions_orig +from streaming.partition.relaxed import get_partitions_relaxed algos = { 'orig': get_partitions_orig, diff --git a/streaming/partition/relaxed.py b/streaming/partition/relaxed.py index c2f0d83a8..f57529874 100644 --- a/streaming/partition/relaxed.py +++ b/streaming/partition/relaxed.py @@ -9,7 +9,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition.orig import get_partitions_orig +from streaming.partition.orig import get_partitions_orig logger = logging.getLogger(__name__) diff --git a/streaming/shared/__init__.py b/streaming/shared/__init__.py index cf507c4fe..8d599d4fe 100644 --- a/streaming/shared/__init__.py +++ b/streaming/shared/__init__.py @@ -7,11 +7,11 @@ we are coordinating separately instantiated pytorch worker processes. """ -from streaming.base.shared.array import SharedArray as SharedArray -from streaming.base.shared.barrier import SharedBarrier as SharedBarrier -from streaming.base.shared.memory import SharedMemory as SharedMemory -from streaming.base.shared.prefix import _get_path as _get_path -from streaming.base.shared.prefix import get_shm_prefix as get_shm_prefix -from streaming.base.shared.scalar import SharedScalar as SharedScalar +from streaming.shared.array import SharedArray as SharedArray +from streaming.shared.barrier import SharedBarrier as SharedBarrier +from streaming.shared.memory import SharedMemory as SharedMemory +from streaming.shared.prefix import _get_path as _get_path +from streaming.shared.prefix import get_shm_prefix as get_shm_prefix +from streaming.shared.scalar import SharedScalar as SharedScalar __all__ = ['SharedArray', 'SharedBarrier', 'SharedMemory', 'get_shm_prefix', 'SharedScalar'] diff --git a/streaming/shared/array.py b/streaming/shared/array.py index 20689d125..cd69db85f 100644 --- a/streaming/shared/array.py +++ b/streaming/shared/array.py @@ -8,7 +8,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shared.memory import SharedMemory +from streaming.shared.memory import SharedMemory class SharedArray: diff --git a/streaming/shared/barrier.py b/streaming/shared/barrier.py index ceeb3ec43..b4adda46e 100644 --- a/streaming/shared/barrier.py +++ b/streaming/shared/barrier.py @@ -11,8 +11,8 @@ import numpy as np from filelock import FileLock -from streaming.base.constant import TICK -from streaming.base.shared.array import SharedArray +from streaming.constant import TICK +from streaming.shared.array import SharedArray # Time out to wait before raising exception TIMEOUT = 60 diff --git a/streaming/shared/memory.py b/streaming/shared/memory.py index b5b70f55e..b235b7e32 100644 --- a/streaming/shared/memory.py +++ b/streaming/shared/memory.py @@ -9,7 +9,7 @@ from time import sleep from typing import Any, Optional -from streaming.base.constant import TICK +from streaming.constant import TICK class SharedMemory: diff --git a/streaming/shared/prefix.py b/streaming/shared/prefix.py index 48d2aaa6c..f51f9f1a6 100644 --- a/streaming/shared/prefix.py +++ b/streaming/shared/prefix.py @@ -14,9 +14,9 @@ import numpy as np from torch import distributed as dist -from streaming.base.constant import LOCALS, TICK -from streaming.base.shared import SharedMemory -from streaming.base.world import World +from streaming.constant import LOCALS, TICK +from streaming.shared import SharedMemory +from streaming.world import World def _each_prefix_int() -> Iterator[int]: @@ -128,7 +128,7 @@ def _check_and_find(streams_local: List[str], streams_remote: List[Union[str, No f'Reused local directory: {streams_local} vs ' + f'{their_locals}. Provide a different one. If using ' + f'a unique local directory, try deleting the local directory and ' + - f'call `streaming.base.util.clean_stale_shared_memory()` only once ' + + f'call `streaming.util.clean_stale_shared_memory()` only once ' + f'in your script to clean up the stale shared memory before ' + f'instantiation of `StreamingDataset`.') return prefix_int diff --git a/streaming/shared/scalar.py b/streaming/shared/scalar.py index 14cd5e7fa..c9714befc 100644 --- a/streaming/shared/scalar.py +++ b/streaming/shared/scalar.py @@ -5,7 +5,7 @@ from typing import Any -from streaming.base.shared.array import SharedArray +from streaming.shared.array import SharedArray class SharedScalar: diff --git a/streaming/shuffle/__init__.py b/streaming/shuffle/__init__.py index e5e529c42..d34eb71fd 100644 --- a/streaming/shuffle/__init__.py +++ b/streaming/shuffle/__init__.py @@ -6,12 +6,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.naive import get_shuffle_naive -from streaming.base.shuffle.py1b import get_shuffle_py1b -from streaming.base.shuffle.py1br import get_shuffle_py1br -from streaming.base.shuffle.py1e import get_shuffle_py1e -from streaming.base.shuffle.py1s import get_shuffle_py1s -from streaming.base.shuffle.py2s import get_shuffle_py2s +from streaming.shuffle.naive import get_shuffle_naive +from streaming.shuffle.py1b import get_shuffle_py1b +from streaming.shuffle.py1br import get_shuffle_py1br +from streaming.shuffle.py1e import get_shuffle_py1e +from streaming.shuffle.py1s import get_shuffle_py1s +from streaming.shuffle.py2s import get_shuffle_py2s algos = { 'py1b': get_shuffle_py1b, diff --git a/streaming/shuffle/py1b.py b/streaming/shuffle/py1b.py index bb59f0c73..fdfaf9dd0 100644 --- a/streaming/shuffle/py1b.py +++ b/streaming/shuffle/py1b.py @@ -10,7 +10,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1b(shard_sizes: NDArray[np.int64], diff --git a/streaming/shuffle/py1br.py b/streaming/shuffle/py1br.py index eff32210c..bc4c5053a 100644 --- a/streaming/shuffle/py1br.py +++ b/streaming/shuffle/py1br.py @@ -10,7 +10,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1br(shard_sizes: NDArray[np.int64], diff --git a/streaming/shuffle/py1e.py b/streaming/shuffle/py1e.py index 3583caa22..e5dfc6291 100644 --- a/streaming/shuffle/py1e.py +++ b/streaming/shuffle/py1e.py @@ -13,7 +13,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1e(shard_sizes: NDArray[np.int64], diff --git a/streaming/storage/__init__.py b/streaming/storage/__init__.py index d3658656b..3a312a669 100644 --- a/streaming/storage/__init__.py +++ b/streaming/storage/__init__.py @@ -3,13 +3,13 @@ """Base module for downloading/uploading files from/to cloud storage.""" -from streaming.base.storage.download import (download_file, download_from_azure, +from streaming.storage.download import (download_file, download_from_azure, download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, download_from_local, download_from_oci, download_from_s3, download_from_sftp) -from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, +from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, GCSUploader, LocalUploader, OCIUploader, S3Uploader) __all__ = [ diff --git a/streaming/storage/download.py b/streaming/storage/download.py index 9db4af328..edb88943c 100644 --- a/streaming/storage/download.py +++ b/streaming/storage/download.py @@ -10,7 +10,7 @@ from time import sleep, time from typing import Any, Dict, Optional -from streaming.base.util import get_import_exception_message +from streaming.util import get_import_exception_message __all__ = [ 'download_from_s3', diff --git a/streaming/storage/extra.py b/streaming/storage/extra.py index 18ce1c87f..7f993edfd 100644 --- a/streaming/storage/extra.py +++ b/streaming/storage/extra.py @@ -13,9 +13,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union from urllib.parse import urlparse -from streaming.base.hashing import get_hash -from streaming.base.storage import CloudUploader, download_file -from streaming.base.util.pretty import normalize_bytes, normalize_duration +from streaming.hashing import get_hash +from streaming.storage import CloudUploader, download_file +from streaming.util.pretty import normalize_bytes, normalize_duration __all__ = ['wait_for_file_to_exist', 'walk_dir', 'list_dataset_files', 'smart_download_file'] diff --git a/streaming/storage/upload.py b/streaming/storage/upload.py index 2c89c08de..66f0b32fe 100644 --- a/streaming/storage/upload.py +++ b/streaming/storage/upload.py @@ -15,9 +15,9 @@ import tqdm -from streaming.base.storage.download import (BOTOCORE_CLIENT_ERROR_CODES, +from streaming.storage.download import (BOTOCORE_CLIENT_ERROR_CODES, GCS_ERROR_NO_AUTHENTICATION) -from streaming.base.util import get_import_exception_message, retry +from streaming.util import get_import_exception_message, retry __all__ = [ 'CloudUploader', diff --git a/streaming/stream.py b/streaming/stream.py index d32bb63a5..6d970cde3 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -13,15 +13,15 @@ from numpy.typing import NDArray from typing_extensions import Self -from streaming.base.compression import decompress -from streaming.base.constant import TICK -from streaming.base.distributed import barrier, get_local_rank -from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json -from streaming.base.hashing import get_hash -from streaming.base.storage import download_file -from streaming.base.storage.extra import wait_for_file_to_exist -from streaming.base.util.retrying import retry -from streaming.base.world import World +from streaming.compression import decompress +from streaming.constant import TICK +from streaming.distributed import barrier, get_local_rank +from streaming.format import FileInfo, Reader, get_index_basename, reader_from_json +from streaming.hashing import get_hash +from streaming.storage import download_file +from streaming.storage.extra import wait_for_file_to_exist +from streaming.util.retrying import retry +from streaming.world import World class Stream: diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py index b5352dd47..f96a2fb26 100644 --- a/streaming/util/__init__.py +++ b/streaming/util/__init__.py @@ -3,13 +3,13 @@ """Utilities and helkper methods needed by Streaming.""" -from streaming.base.util.importing import get_import_exception_message -from streaming.base.util.merging import merge_index -from streaming.base.util.pretty import (normalize_bin_bytes, normalize_bytes, normalize_count, +from streaming.util.importing import get_import_exception_message +from streaming.util.merging import merge_index +from streaming.util.pretty import (normalize_bin_bytes, normalize_bytes, normalize_count, normalize_dec_bytes, normalize_duration, unpack_str2str, unpack_strs) -from streaming.base.util.retrying import retry -from streaming.base.util.shared import clean_stale_shared_memory +from streaming.util.retrying import retry +from streaming.util.shared import clean_stale_shared_memory __all__ = [ 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', diff --git a/streaming/util/merging.py b/streaming/util/merging.py index 8411d5cec..a7243b32f 100644 --- a/streaming/util/merging.py +++ b/streaming/util/merging.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import Any, List, Tuple, Union -from streaming.base.format.index import get_index_basename +from streaming.format.index import get_index_basename __all__ = ['merge_index'] @@ -77,8 +77,8 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ - from streaming.base.storage.download import download_file - from streaming.base.storage.upload import CloudUploader + from streaming.storage.download import download_file + from streaming.storage.upload import CloudUploader if not index_file_urls or not out: logger.warning('Either index_file_urls or out are None. ' + @@ -180,7 +180,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ - from streaming.base.storage.upload import CloudUploader + from streaming.storage.upload import CloudUploader def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out. diff --git a/streaming/util/retrying.py b/streaming/util/retrying.py index 3d006655b..e2c78a8c6 100644 --- a/streaming/util/retrying.py +++ b/streaming/util/retrying.py @@ -48,7 +48,7 @@ def retry( # type: ignore Example: .. testcode:: - from streaming.base.util import retry + from streaming.util import retry num_tries = 0 diff --git a/streaming/util/shared.py b/streaming/util/shared.py index 956d3427c..d68ad8d8a 100644 --- a/streaming/util/shared.py +++ b/streaming/util/shared.py @@ -7,9 +7,9 @@ import torch.distributed as dist -from streaming.base.constant import SHM_TO_CLEAN -from streaming.base.distributed import get_local_rank, maybe_init_dist -from streaming.base.shared.prefix import _get_path +from streaming.constant import SHM_TO_CLEAN +from streaming.distributed import get_local_rank, maybe_init_dist +from streaming.shared.prefix import _get_path def clean_stale_shared_memory() -> None: diff --git a/streaming/vision.py b/streaming/vision.py index b3fc20790..2cf3300b3 100644 --- a/streaming/vision.py +++ b/streaming/vision.py @@ -12,7 +12,7 @@ from torchvision.transforms.functional import to_tensor from tqdm import tqdm -from streaming.base import MDSWriter, StreamingDataset +from streaming import MDSWriter, StreamingDataset __all__ = ['StreamingVisionDataset'] diff --git a/streaming/world.py b/streaming/world.py index c787c2f97..b512b4132 100644 --- a/streaming/world.py +++ b/streaming/world.py @@ -5,7 +5,7 @@ from torch.utils.data import get_worker_info -from streaming.base import distributed as dist +from streaming import distributed as dist class World: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index dc19d219b..0a1a3107b 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -11,7 +11,7 @@ from pyspark.sql.functions import col from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType -from streaming.base.converters import dataframe_to_mds +from streaming.converters import dataframe_to_mds # set to yes to all fork process in spark calls os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' diff --git a/tests/common/datasets.py b/tests/common/datasets.py index dbaefcd38..ce76010c3 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -5,7 +5,7 @@ import numpy as np -from streaming.base import MDSWriter +from streaming import MDSWriter class SequenceDataset: diff --git a/tests/test_array.py b/tests/test_array.py index 30816665f..7cfeb3f42 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -7,7 +7,7 @@ import pytest from numpy.typing import NDArray -from streaming.base.array import Array +from streaming.array import Array class Range(Array): diff --git a/tests/test_barrier.py b/tests/test_barrier.py index fdc5eb87d..72fcb6d13 100644 --- a/tests/test_barrier.py +++ b/tests/test_barrier.py @@ -11,7 +11,7 @@ import pytest -from streaming.base.shared import SharedArray, SharedBarrier +from streaming.shared import SharedArray, SharedBarrier class TestSharedBarrier: diff --git a/tests/test_compression.py b/tests/test_compression.py index 24c72a158..b1f3b9e03 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -7,8 +7,8 @@ import numpy as np import pytest -from streaming.base import StreamingDataset -from streaming.base.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, +from streaming import StreamingDataset +from streaming.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, decompress, get_compression_extension, is_compression) from tests.common.datasets import SequenceDataset, write_mds_dataset diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 73fc1726b..8da2b8673 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -12,8 +12,8 @@ import torch.distributed as dist from torch.utils.data import DataLoader -import streaming.base.distributed as ms_dist -from streaming.base import StreamingDataset +import streaming.distributed as ms_dist +from streaming import StreamingDataset from tests.common.datasets import SequenceDataset, write_mds_dataset from tests.common.distributed import DistributedTest diff --git a/tests/test_download.py b/tests/test_download.py index 2d33cefb3..d5f169521 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -10,7 +10,7 @@ import pytest from botocore.exceptions import ClientError -from streaming.base.storage.download import (download_file, download_from_azure, +from streaming.storage.download import (download_file, download_from_azure, download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, @@ -167,7 +167,7 @@ def test_download_from_local(): class TestDownload: - @patch('streaming.base.storage.download.download_from_s3') + @patch('streaming.storage.download.download_from_s3') @pytest.mark.usefixtures('remote_local_file') def test_download_from_s3_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') @@ -175,7 +175,7 @@ def test_download_from_s3_gets_called(self, mocked_requests: Mock, remote_local_ mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath, 60) - @patch('streaming.base.storage.download.download_from_gcs') + @patch('streaming.storage.download.download_from_gcs') @pytest.mark.usefixtures('remote_local_file') def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') @@ -183,7 +183,7 @@ def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_azure') + @patch('streaming.storage.download.download_from_azure') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='azure://') @@ -191,7 +191,7 @@ def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_loc mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_azure_datalake') + @patch('streaming.storage.download.download_from_azure_datalake') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_datalake_gets_called(self, mocked_requests: Mock, remote_local_file: Any): @@ -200,7 +200,7 @@ def test_download_from_azure_datalake_gets_called(self, mocked_requests: Mock, mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_sftp') + @patch('streaming.storage.download.download_from_sftp') @pytest.mark.usefixtures('remote_local_file') def test_download_from_sftp_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='sftp://') @@ -208,7 +208,7 @@ def test_download_from_sftp_gets_called(self, mocked_requests: Mock, remote_loca mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_databricks_unity_catalog') + @patch('streaming.storage.download.download_from_databricks_unity_catalog') @pytest.mark.usefixtures('remote_local_file') def test_download_from_databricks_unity_catalog_gets_called(self, mocked_requests: Mock, remote_local_file: Any): @@ -217,7 +217,7 @@ def test_download_from_databricks_unity_catalog_gets_called(self, mocked_request mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_dbfs') + @patch('streaming.storage.download.download_from_dbfs') @pytest.mark.usefixtures('remote_local_file') def test_download_from_dbfs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='dbfs:/') @@ -225,7 +225,7 @@ def test_download_from_dbfs_gets_called(self, mocked_requests: Mock, remote_loca mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_local') + @patch('streaming.storage.download.download_from_local') @pytest.mark.usefixtures('remote_local_file') def test_download_from_local_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file() diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 88e6ba203..70d048647 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -10,9 +10,9 @@ import pytest from PIL import Image -import streaming.base.format.json.encodings as jsonEnc -import streaming.base.format.mds.encodings as mdsEnc -import streaming.base.format.xsv.encodings as xsvEnc +import streaming.format.json.encodings as jsonEnc +import streaming.format.mds.encodings as mdsEnc +import streaming.format.xsv.encodings as xsvEnc class TestMDSEncodings: diff --git a/tests/test_hashing.py b/tests/test_hashing.py index 225ce7458..a9558f493 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -6,8 +6,8 @@ import pytest -import streaming.base.hashing as shash -from streaming.base import StreamingDataset +import streaming.hashing as shash +from streaming import StreamingDataset from tests.common.utils import convert_to_mds logger = logging.getLogger(__name__) diff --git a/tests/test_local.py b/tests/test_local.py index df6fb5f05..4292f6c92 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from streaming import MDSWriter -from streaming.base.local import LocalDataset +from streaming.local import LocalDataset def test_local_dataset(): diff --git a/tests/test_partition.py b/tests/test_partition.py index 37da79ce1..2c4af319e 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions @pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) diff --git a/tests/test_reader.py b/tests/test_reader.py index fbe7ff723..3d43e7aad 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -12,7 +12,7 @@ import pytest from numpy.typing import NDArray -from streaming.base import StreamingDataset +from streaming import StreamingDataset from tests.common.utils import convert_to_mds, copy_all_files logger = logging.getLogger(__name__) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index b8b661d8a..e2be7484c 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -3,7 +3,7 @@ import numpy as np -from streaming.base.sampling import get_sampling +from streaming.sampling import get_sampling def test_choose_per_shard_adds_up(): diff --git a/tests/test_shared.py b/tests/test_shared.py index c28229472..02a4c531b 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -5,8 +5,8 @@ import pytest -from streaming.base.shared import get_shm_prefix -from streaming.base.world import World +from streaming.shared import get_shm_prefix +from streaming.world import World @pytest.mark.usefixtures('local_remote_dir') diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index 76eeb7dd9..c3885047f 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -5,7 +5,7 @@ import numpy as np -from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, +from streaming.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, get_shuffle_py1s, get_shuffle_py2s) diff --git a/tests/test_spanner.py b/tests/test_spanner.py index 340facd4d..46b802d5c 100644 --- a/tests/test_spanner.py +++ b/tests/test_spanner.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from streaming.base.spanner import Spanner +from streaming.spanner import Spanner def test_spanner_success(): diff --git a/tests/test_stream.py b/tests/test_stream.py index 9a7f64af4..818c19ae8 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -11,7 +11,7 @@ from _pytest.monkeypatch import MonkeyPatch from streaming import Stream, StreamingDataset -from streaming.base.distributed import barrier +from streaming.distributed import barrier from tests.common.utils import convert_to_mds diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ad55b3659..f0736b9da 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -10,8 +10,8 @@ import pytest from torch.utils.data import DataLoader -from streaming.base import Stream, StreamingDataLoader, StreamingDataset -from streaming.base.util import clean_stale_shared_memory +from streaming import Stream, StreamingDataLoader, StreamingDataset +from streaming.util import clean_stale_shared_memory from tests.common.utils import convert_to_mds diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 206dd10cd..d5ea3bb5d 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -7,7 +7,7 @@ import pytest -from streaming.base import StreamingDataset +from streaming import StreamingDataset def get_dataset(name: str, diff --git a/tests/test_upload.py b/tests/test_upload.py index 57c0046f0..fd8a83b32 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -10,7 +10,7 @@ import boto3 import pytest -from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, +from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, DatabricksUnityCatalogUploader, DBFSUploader, GCSAuthentication, GCSUploader, LocalUploader, S3Uploader) @@ -37,8 +37,8 @@ def _method(cloud_prefix: str = '') -> Tuple[str, str]: class TestCloudUploader: - @patch('streaming.base.storage.upload.S3Uploader.check_bucket_exists') - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.S3Uploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.parametrize( 'mapping', [ @@ -111,7 +111,7 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = CloudUploader.get(out=out) - @patch('streaming.base.storage.LocalUploader.list_objects') + @patch('streaming.storage.LocalUploader.list_objects') @pytest.mark.usefixtures('remote_local_dir') def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, remote_local_dir: Any): @@ -123,7 +123,7 @@ def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, class TestS3Uploader: - @patch('streaming.base.storage.upload.S3Uploader.check_bucket_exists') + @patch('streaming.storage.upload.S3Uploader.check_bucket_exists') @pytest.mark.parametrize('out', ['s3://bucket/dir', ('./dir1', 's3://bucket/dir/')]) def test_instantiation(self, mocked_requests: Mock, out: Any): mocked_requests.side_effect = None @@ -215,7 +215,7 @@ def test_invalid_cloud_prefix(self, remote_local_dir: Any): class TestGCSUploader: - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.parametrize('out', ['gs://bucket/dir', ('./dir1', 'gs://bucket/dir/')]) @pytest.mark.usefixtures('gcs_hmac_credentials') def test_instantiation(self, mocked_requests: Mock, out: Any): @@ -268,7 +268,7 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = GCSUploader(out=out) - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.usefixtures('gcs_hmac_credentials') @pytest.mark.parametrize('out', ['gs://bucket/dir']) def test_hmac_authentication(self, mocked_requests: Mock, out: str): @@ -284,7 +284,7 @@ def test_service_account_authentication(self, mock_client: Mock, mock_default: M uploader = GCSUploader(out=out) assert uploader.authentication == GCSAuthentication.SERVICE_ACCOUNT - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @patch('google.auth.default') @patch('google.cloud.storage.Client') @pytest.mark.usefixtures('gcs_service_account_credentials', 'gcs_hmac_credentials') @@ -324,7 +324,7 @@ def test_no_credentials_error(self, remote_local_dir: Any): class TestAzureUploader: - @patch('streaming.base.storage.upload.AzureUploader.check_bucket_exists') + @patch('streaming.storage.upload.AzureUploader.check_bucket_exists') @pytest.mark.usefixtures('azure_credentials') @pytest.mark.parametrize('out', ['azure://bucket/dir', ('./dir1', 'azure://bucket/dir/')]) def test_instantiation(self, mocked_requests: Mock, out: Any): @@ -356,7 +356,7 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): class TestAzureDataLakeUploader: - @patch('streaming.base.storage.upload.AzureDataLakeUploader.check_container_exists') + @patch('streaming.storage.upload.AzureDataLakeUploader.check_container_exists') @pytest.mark.usefixtures('azure_credentials') @pytest.mark.parametrize('out', ['azure://container/dir', ('./dir1', 'azure://container/dir/')]) @@ -389,7 +389,7 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): class TestDatabricksUnityCatalogUploader: - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize( 'out', ['dbfs:/Volumes/container/dir', ('./dir1', 'dbfs:/Volumes/container/dir/')]) def test_instantiation(self, mock_create_client: Mock, out: Any): @@ -398,14 +398,14 @@ def test_instantiation(self, mock_create_client: Mock, out: Any): if not isinstance(out, str): shutil.rmtree(out[0], ignore_errors=True) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['ss4://bucket/dir', ('./dir1', 'gcs://bucket/dir/')]) def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DatabricksUnityCatalogUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') def test_local_directory_is_empty(self, mock_create_client: Mock, local_remote_dir: Tuple[str, str]): mock_create_client.side_effect = None @@ -421,7 +421,7 @@ def test_local_directory_is_empty(self, mock_create_client: Mock, class TestDBFSUploader: - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['dbfs:/container/dir', ('./dir1', 'dbfs:/container/dir/')]) def test_instantiation(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None @@ -429,14 +429,14 @@ def test_instantiation(self, mock_create_client: Mock, out: Any): if not isinstance(out, str): shutil.rmtree(out[0], ignore_errors=True) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['ss4://bucket/dir', ('./dir1', 'gcs://bucket/dir/')]) def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DBFSUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') def test_local_directory_is_empty(self, mock_create_client: Mock, local_remote_dir: Tuple[str, str]): with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): diff --git a/tests/test_util.py b/tests/test_util.py index 65bf54529..1486836d6 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -11,11 +11,11 @@ import pytest -from streaming.base.constant import RESUME -from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file -from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (clean_stale_shared_memory, merge_index, normalize_bytes, +from streaming.constant import RESUME +from streaming.shared.prefix import _get_path +from streaming.storage.download import download_file +from streaming.storage.upload import CloudUploader +from streaming.util import (clean_stale_shared_memory, merge_index, normalize_bytes, normalize_count, retry, unpack_strs) MY_PREFIX = 'train_' + str(time.time()) @@ -189,7 +189,7 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - from streaming.base.converters import dataframeToMDS + from streaming.converters import dataframeToMDS def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out.""" @@ -256,7 +256,7 @@ def test_merge_index_from_root_local(local_remote_dir: Tuple[str, str], n_partit from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - from streaming.base.converters import dataframeToMDS + from streaming.converters import dataframeToMDS out, _ = local_remote_dir From 1051474a180bf047b02eec9257a23e80c78f0d22 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 08:20:41 -0700 Subject: [PATCH 21/45] Update more paths. --- streaming/format/__init__.py | 2 +- streaming/format/json/reader.py | 2 +- streaming/format/json/writer.py | 2 +- streaming/format/mds/reader.py | 2 +- streaming/format/mds/writer.py | 2 +- streaming/format/xsv/reader.py | 2 +- streaming/format/xsv/writer.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index 5cb9e234b..f84b6cbf2 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union -from streaming.format.base import FileInfo, Reader +from streaming.format import FileInfo, Reader from streaming.format.delta import index_delta from streaming.format.index import get_index_basename from streaming.format.json import JSONReader, JSONWriter diff --git a/streaming/format/json/reader.py b/streaming/format/json/reader.py index afcef25bd..698783d71 100644 --- a/streaming/format/json/reader.py +++ b/streaming/format/json/reader.py @@ -11,7 +11,7 @@ import numpy as np from typing_extensions import Self -from streaming.format.base.reader import FileInfo, SplitReader +from streaming.format.reader import FileInfo, SplitReader __all__ = ['JSONReader'] diff --git a/streaming/format/json/writer.py b/streaming/format/json/writer.py index 74c11c7ac..ff8a6e42f 100644 --- a/streaming/format/json/writer.py +++ b/streaming/format/json/writer.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.format.base.writer import SplitWriter +from streaming.format.writer import SplitWriter from streaming.format.json.encodings import is_json_encoded, is_json_encoding __all__ = ['JSONWriter'] diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/reader.py index c779433ad..847fe4368 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/reader.py @@ -10,7 +10,7 @@ import numpy as np from typing_extensions import Self -from streaming.format.base.reader import FileInfo, JointReader +from streaming.format.reader import FileInfo, JointReader from streaming.format.mds.encodings import mds_decode __all__ = ['MDSReader'] diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 00b9e8da4..babc18408 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.format.base.writer import JointWriter +from streaming.format.writer import JointWriter from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) diff --git a/streaming/format/xsv/reader.py b/streaming/format/xsv/reader.py index 7c0df2bc9..f43ee6f5d 100644 --- a/streaming/format/xsv/reader.py +++ b/streaming/format/xsv/reader.py @@ -10,7 +10,7 @@ import numpy as np from typing_extensions import Self -from streaming.format.base.reader import FileInfo, SplitReader +from streaming.format.reader import FileInfo, SplitReader from streaming.format.xsv.encodings import xsv_decode __all__ = ['XSVReader', 'CSVReader', 'TSVReader'] diff --git a/streaming/format/xsv/writer.py b/streaming/format/xsv/writer.py index 88d00fd54..b1ab720d3 100644 --- a/streaming/format/xsv/writer.py +++ b/streaming/format/xsv/writer.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.format.base.writer import SplitWriter +from streaming.format.writer import SplitWriter from streaming.format.xsv.encodings import is_xsv_encoding, xsv_encode __all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter'] From 65ef0de5aaee014b063c75451d04a82434779a67 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 08:25:43 -0700 Subject: [PATCH 22/45] Formatting. --- benchmarks/shuffling/bench.py | 2 +- regression/iterate_data.py | 3 +- streaming/converters/__init__.py | 2 +- streaming/dataset.py | 10 ++-- streaming/format/__init__.py | 3 +- streaming/format/json/writer.py | 2 +- streaming/format/mds/reader.py | 2 +- streaming/format/mds/writer.py | 4 +- streaming/storage/__init__.py | 11 ++-- streaming/storage/upload.py | 3 +- streaming/util/__init__.py | 4 +- streaming/vision.py | 100 ++----------------------------- tests/test_compression.py | 4 +- tests/test_download.py | 7 +-- tests/test_shuffle.py | 2 +- tests/test_upload.py | 5 +- tests/test_util.py | 2 +- 17 files changed, 34 insertions(+), 132 deletions(-) diff --git a/benchmarks/shuffling/bench.py b/benchmarks/shuffling/bench.py index 774906fea..ac15f641a 100644 --- a/benchmarks/shuffling/bench.py +++ b/benchmarks/shuffling/bench.py @@ -12,7 +12,7 @@ from numpy.typing import NDArray from streaming.shuffle import (get_shuffle_naive, get_shuffle_py1b, get_shuffle_py1s, - get_shuffle_py2s) + get_shuffle_py2s) def parse_args() -> Namespace: diff --git a/regression/iterate_data.py b/regression/iterate_data.py index ffa9f2c81..bdbc77a10 100644 --- a/regression/iterate_data.py +++ b/regression/iterate_data.py @@ -17,8 +17,7 @@ get_streaming_dataset_params) from streaming import StreamingDataset -from streaming.distributed import (all_gather, barrier, get_rank, get_world_size, - maybe_init_dist) +from streaming.distributed import all_gather, barrier, get_rank, get_world_size, maybe_init_dist logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/streaming/converters/__init__.py b/streaming/converters/__init__.py index dec288307..d12602e57 100644 --- a/streaming/converters/__init__.py +++ b/streaming/converters/__init__.py @@ -4,6 +4,6 @@ """Utility function for converting spark dataframe to MDS dataset.""" from streaming.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds, - dataframeToMDS) + dataframeToMDS) __all__ = ['dataframeToMDS', 'dataframe_to_mds', 'MAPPING_SPARK_TO_MDS'] diff --git a/streaming/dataset.py b/streaming/dataset.py index c949841d9..a5934c3bf 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -24,14 +24,14 @@ from streaming.array import Array from streaming.batching import generate_work -from streaming.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, - EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME, - SHARD_ACCESS_TIMES, SHARD_STATES, TICK) +from streaming.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA, + EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES, + TICK) from streaming.distributed import maybe_init_dist from streaming.format import get_index_basename from streaming.sampling import get_sampling -from streaming.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, - _get_path, get_shm_prefix) +from streaming.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, + get_shm_prefix) from streaming.spanner import Spanner from streaming.stream import Stream from streaming.util import normalize_bytes, normalize_count diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index f84b6cbf2..20acf5b53 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -12,8 +12,7 @@ from streaming.format.lance import index_lance from streaming.format.mds import MDSReader, MDSWriter from streaming.format.parquet import index_parquet -from streaming.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, - XSVWriter) +from streaming.format.xsv import CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter __all__ = [ 'CSVWriter', 'FileInfo', 'JSONWriter', 'MDSWriter', 'Reader', 'TSVWriter', 'XSVWriter', diff --git a/streaming/format/json/writer.py b/streaming/format/json/writer.py index ff8a6e42f..b0117a47f 100644 --- a/streaming/format/json/writer.py +++ b/streaming/format/json/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.format.writer import SplitWriter from streaming.format.json.encodings import is_json_encoded, is_json_encoding +from streaming.format.writer import SplitWriter __all__ = ['JSONWriter'] diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/reader.py index 847fe4368..245458bf4 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/reader.py @@ -10,8 +10,8 @@ import numpy as np from typing_extensions import Self -from streaming.format.reader import FileInfo, JointReader from streaming.format.mds.encodings import mds_decode +from streaming.format.reader import FileInfo, JointReader __all__ = ['MDSReader'] diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index babc18408..950c60f20 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -8,9 +8,9 @@ import numpy as np -from streaming.format.writer import JointWriter from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, - is_mds_encoding, mds_encode) + is_mds_encoding, mds_encode) +from streaming.format.writer import JointWriter __all__ = ['MDSWriter'] diff --git a/streaming/storage/__init__.py b/streaming/storage/__init__.py index 3a312a669..674d4fbad 100644 --- a/streaming/storage/__init__.py +++ b/streaming/storage/__init__.py @@ -4,13 +4,12 @@ """Base module for downloading/uploading files from/to cloud storage.""" from streaming.storage.download import (download_file, download_from_azure, - download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_oci, - download_from_s3, download_from_sftp) + download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, + download_from_gcs, download_from_local, download_from_oci, + download_from_s3, download_from_sftp) from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - GCSUploader, LocalUploader, OCIUploader, S3Uploader) + GCSUploader, LocalUploader, OCIUploader, S3Uploader) __all__ = [ 'download_file', diff --git a/streaming/storage/upload.py b/streaming/storage/upload.py index 66f0b32fe..c2a2f40bd 100644 --- a/streaming/storage/upload.py +++ b/streaming/storage/upload.py @@ -15,8 +15,7 @@ import tqdm -from streaming.storage.download import (BOTOCORE_CLIENT_ERROR_CODES, - GCS_ERROR_NO_AUTHENTICATION) +from streaming.storage.download import BOTOCORE_CLIENT_ERROR_CODES, GCS_ERROR_NO_AUTHENTICATION from streaming.util import get_import_exception_message, retry __all__ = [ diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py index f96a2fb26..c5110183c 100644 --- a/streaming/util/__init__.py +++ b/streaming/util/__init__.py @@ -6,8 +6,8 @@ from streaming.util.importing import get_import_exception_message from streaming.util.merging import merge_index from streaming.util.pretty import (normalize_bin_bytes, normalize_bytes, normalize_count, - normalize_dec_bytes, normalize_duration, unpack_str2str, - unpack_strs) + normalize_dec_bytes, normalize_duration, unpack_str2str, + unpack_strs) from streaming.util.retrying import retry from streaming.util.shared import clean_stale_shared_memory diff --git a/streaming/vision.py b/streaming/vision.py index 2cf3300b3..d0b8ca17f 100644 --- a/streaming/vision.py +++ b/streaming/vision.py @@ -51,107 +51,15 @@ def __call__(self, x: Any, y: Any) -> Tuple[Any, Any]: class StreamingVisionDataset(StreamingDataset, VisionDataset): - """A streaming, iterable, torchvision VisionDataset. - - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transforms (callable, optional): A function/transforms that takes in an image and a label - and returns the transformed versions of both. Defaults to ``None``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. - """ + """A streaming, iterable, torchvision VisionDataset.""" def __init__(self, *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - StreamingDataset.__init__(self, - remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + target_transform: Optional[Callable] = None, + **kwargs) -> None: + StreamingDataset.__init__(self, **kwargs) has_transforms = transforms is not None has_separate_transform = transform is not None or target_transform is not None diff --git a/tests/test_compression.py b/tests/test_compression.py index b1f3b9e03..ff62ecb9d 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -8,8 +8,8 @@ import pytest from streaming import StreamingDataset -from streaming.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, - decompress, get_compression_extension, is_compression) +from streaming.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, decompress, + get_compression_extension, is_compression) from tests.common.datasets import SequenceDataset, write_mds_dataset diff --git a/tests/test_download.py b/tests/test_download.py index d5f169521..26814641d 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -11,10 +11,9 @@ from botocore.exceptions import ClientError from streaming.storage.download import (download_file, download_from_azure, - download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_s3) + download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, + download_from_gcs, download_from_local, download_from_s3) from tests.conftest import GCS_URL, MY_BUCKET, R2_URL MY_PREFIX = 'train' diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index c3885047f..221a8793f 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -6,7 +6,7 @@ import numpy as np from streaming.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, - get_shuffle_py1s, get_shuffle_py2s) + get_shuffle_py1s, get_shuffle_py2s) def check(get_shuffle: Callable) -> None: diff --git a/tests/test_upload.py b/tests/test_upload.py index fd8a83b32..3e4056dbb 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -11,9 +11,8 @@ import pytest from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - DatabricksUnityCatalogUploader, DBFSUploader, - GCSAuthentication, GCSUploader, LocalUploader, - S3Uploader) + DatabricksUnityCatalogUploader, DBFSUploader, + GCSAuthentication, GCSUploader, LocalUploader, S3Uploader) from tests.conftest import MY_BUCKET, R2_URL MY_PREFIX = 'train' diff --git a/tests/test_util.py b/tests/test_util.py index 1486836d6..55e8f4a19 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,7 +16,7 @@ from streaming.storage.download import download_file from streaming.storage.upload import CloudUploader from streaming.util import (clean_stale_shared_memory, merge_index, normalize_bytes, - normalize_count, retry, unpack_strs) + normalize_count, retry, unpack_strs) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { From 408999a65a81a480ba7479ea60ecafe580f1753e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 08:29:21 -0700 Subject: [PATCH 23/45] Fix. --- streaming/format/__init__.py | 2 +- streaming/vision.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index 20acf5b53..6454c4a93 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -5,13 +5,13 @@ from typing import Any, Dict, Optional, Union -from streaming.format import FileInfo, Reader from streaming.format.delta import index_delta from streaming.format.index import get_index_basename from streaming.format.json import JSONReader, JSONWriter from streaming.format.lance import index_lance from streaming.format.mds import MDSReader, MDSWriter from streaming.format.parquet import index_parquet +from streaming.format.reader import FileInfo, Reader from streaming.format.xsv import CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter __all__ = [ diff --git a/streaming/vision.py b/streaming/vision.py index d0b8ca17f..972a8e004 100644 --- a/streaming/vision.py +++ b/streaming/vision.py @@ -4,7 +4,7 @@ """Base classes for computer vision :class:`StreamingDataset`s.""" import os -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np from torch.utils.data import Dataset @@ -58,7 +58,7 @@ def __init__(self, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - **kwargs) -> None: + **kwargs: Dict[str, Any]) -> None: StreamingDataset.__init__(self, **kwargs) has_transforms = transforms is not None From b38f8a3f9fca880404a72258f5142b5215c936ec Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 11:00:33 -0700 Subject: [PATCH 24/45] Move examples/ to top level. --- {streaming/examples => examples}/__init__py | 0 .../examples => examples}/multimodal/laion400m/README.md | 0 .../examples => examples}/multimodal/laion400m/__init__.py | 0 .../multimodal/laion400m/convert_and_upload.py | 0 .../multimodal/laion400m/convert_and_upload.sh | 0 .../multimodal/laion400m/download_data.sh | 0 .../multimodal/laion400m/download_meta.sh | 0 {streaming/examples => examples}/multimodal/webvid/read.py | 0 .../multimodal/webvid/webvid/bench_inside.py | 0 .../multimodal/webvid/webvid/bench_outside_dt.py | 0 .../multimodal/webvid/webvid/bench_outside_gi.py | 0 .../examples => examples}/multimodal/webvid/webvid/plot.py | 0 .../examples => examples}/multimodal/webvid/write/README.md | 0 .../examples => examples}/multimodal/webvid/write/__init__.py | 0 .../multimodal/webvid/write/crawl_webvid.py | 0 .../multimodal/webvid/write/crawl_webvid_subsets.py | 0 .../multimodal/webvid/write/extract_webvid_videos.py | 0 {streaming/examples => examples}/text/c4/README.md | 0 {streaming/examples => examples}/text/c4/read.py | 0 {streaming/examples => examples}/text/c4/write.py | 0 {streaming/examples => examples}/text/enwiki_tok/__init__.py | 0 .../examples => examples}/text/enwiki_tok/mds/README.md | 0 .../examples => examples}/text/enwiki_tok/mds/__init__.py | 0 .../text/enwiki_tok/mds/create_pretraining_data.py | 0 .../examples => examples}/text/enwiki_tok/mds/make_eval.sh | 0 .../text/enwiki_tok/mds/make_train_parallel.py | 0 .../text/enwiki_tok/mds/merge_shard_groups.py | 0 .../text/enwiki_tok/mds/pick_eval_samples.py | 0 .../examples => examples}/text/enwiki_tok/mds/tokenization.py | 0 .../examples => examples}/text/enwiki_tok/mds/vocab.txt | 0 .../text/enwiki_tok/tfrecord/__init__.py | 0 .../text/enwiki_tok/tfrecord/count_samples.py | 0 .../text/enwiki_tok/tfrecord/create_pretraining_data.py | 0 .../text/enwiki_tok/tfrecord/make_eval.sh | 0 .../text/enwiki_tok/tfrecord/make_train.sh | 0 .../text/enwiki_tok/tfrecord/make_train_parallel.py | 0 .../text/enwiki_tok/tfrecord/pick_eval_samples.py | 0 .../text/enwiki_tok/tfrecord/tokenization.py | 0 .../examples => examples}/text/enwiki_tok/tfrecord/vocab.txt | 0 {streaming/examples => examples}/text/enwiki_txt/README.md | 0 {streaming/examples => examples}/text/enwiki_txt/enwiki.py | 0 {streaming/examples => examples}/text/enwiki_txt/write.py | 0 {streaming/examples => examples}/text/pile/README.md | 0 {streaming/examples => examples}/text/pile/read.py | 0 {streaming/examples => examples}/text/pile/write.py | 0 {streaming/examples => examples}/vision/ade20k/README.md | 0 {streaming/examples => examples}/vision/ade20k/read.py | 0 {streaming/examples => examples}/vision/ade20k/write.py | 0 {streaming/examples => examples}/vision/cifar10/README.md | 0 {streaming/examples => examples}/vision/cifar10/read.py | 0 {streaming/examples => examples}/vision/cifar10/write.py | 0 {streaming/examples => examples}/vision/cifar10/write_fake.py | 0 {streaming/examples => examples}/vision/coco/README.md | 0 {streaming/examples => examples}/vision/coco/read.py | 0 {streaming/examples => examples}/vision/coco/write.py | 0 {streaming/examples => examples}/vision/imagenet/README.md | 0 {streaming/examples => examples}/vision/imagenet/read.py | 0 {streaming/examples => examples}/vision/imagenet/write.py | 0 pyproject.toml | 4 ++-- 59 files changed, 2 insertions(+), 2 deletions(-) rename {streaming/examples => examples}/__init__py (100%) rename {streaming/examples => examples}/multimodal/laion400m/README.md (100%) rename {streaming/examples => examples}/multimodal/laion400m/__init__.py (100%) rename {streaming/examples => examples}/multimodal/laion400m/convert_and_upload.py (100%) rename {streaming/examples => examples}/multimodal/laion400m/convert_and_upload.sh (100%) rename {streaming/examples => examples}/multimodal/laion400m/download_data.sh (100%) rename {streaming/examples => examples}/multimodal/laion400m/download_meta.sh (100%) rename {streaming/examples => examples}/multimodal/webvid/read.py (100%) rename {streaming/examples => examples}/multimodal/webvid/webvid/bench_inside.py (100%) rename {streaming/examples => examples}/multimodal/webvid/webvid/bench_outside_dt.py (100%) rename {streaming/examples => examples}/multimodal/webvid/webvid/bench_outside_gi.py (100%) rename {streaming/examples => examples}/multimodal/webvid/webvid/plot.py (100%) rename {streaming/examples => examples}/multimodal/webvid/write/README.md (100%) rename {streaming/examples => examples}/multimodal/webvid/write/__init__.py (100%) rename {streaming/examples => examples}/multimodal/webvid/write/crawl_webvid.py (100%) rename {streaming/examples => examples}/multimodal/webvid/write/crawl_webvid_subsets.py (100%) rename {streaming/examples => examples}/multimodal/webvid/write/extract_webvid_videos.py (100%) rename {streaming/examples => examples}/text/c4/README.md (100%) rename {streaming/examples => examples}/text/c4/read.py (100%) rename {streaming/examples => examples}/text/c4/write.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/__init__.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/README.md (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/__init__.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/create_pretraining_data.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/make_eval.sh (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/make_train_parallel.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/merge_shard_groups.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/pick_eval_samples.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/tokenization.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/mds/vocab.txt (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/__init__.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/count_samples.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/create_pretraining_data.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/make_eval.sh (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/make_train.sh (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/make_train_parallel.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/pick_eval_samples.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/tokenization.py (100%) rename {streaming/examples => examples}/text/enwiki_tok/tfrecord/vocab.txt (100%) rename {streaming/examples => examples}/text/enwiki_txt/README.md (100%) rename {streaming/examples => examples}/text/enwiki_txt/enwiki.py (100%) rename {streaming/examples => examples}/text/enwiki_txt/write.py (100%) rename {streaming/examples => examples}/text/pile/README.md (100%) rename {streaming/examples => examples}/text/pile/read.py (100%) rename {streaming/examples => examples}/text/pile/write.py (100%) rename {streaming/examples => examples}/vision/ade20k/README.md (100%) rename {streaming/examples => examples}/vision/ade20k/read.py (100%) rename {streaming/examples => examples}/vision/ade20k/write.py (100%) rename {streaming/examples => examples}/vision/cifar10/README.md (100%) rename {streaming/examples => examples}/vision/cifar10/read.py (100%) rename {streaming/examples => examples}/vision/cifar10/write.py (100%) rename {streaming/examples => examples}/vision/cifar10/write_fake.py (100%) rename {streaming/examples => examples}/vision/coco/README.md (100%) rename {streaming/examples => examples}/vision/coco/read.py (100%) rename {streaming/examples => examples}/vision/coco/write.py (100%) rename {streaming/examples => examples}/vision/imagenet/README.md (100%) rename {streaming/examples => examples}/vision/imagenet/read.py (100%) rename {streaming/examples => examples}/vision/imagenet/write.py (100%) diff --git a/streaming/examples/__init__py b/examples/__init__py similarity index 100% rename from streaming/examples/__init__py rename to examples/__init__py diff --git a/streaming/examples/multimodal/laion400m/README.md b/examples/multimodal/laion400m/README.md similarity index 100% rename from streaming/examples/multimodal/laion400m/README.md rename to examples/multimodal/laion400m/README.md diff --git a/streaming/examples/multimodal/laion400m/__init__.py b/examples/multimodal/laion400m/__init__.py similarity index 100% rename from streaming/examples/multimodal/laion400m/__init__.py rename to examples/multimodal/laion400m/__init__.py diff --git a/streaming/examples/multimodal/laion400m/convert_and_upload.py b/examples/multimodal/laion400m/convert_and_upload.py similarity index 100% rename from streaming/examples/multimodal/laion400m/convert_and_upload.py rename to examples/multimodal/laion400m/convert_and_upload.py diff --git a/streaming/examples/multimodal/laion400m/convert_and_upload.sh b/examples/multimodal/laion400m/convert_and_upload.sh similarity index 100% rename from streaming/examples/multimodal/laion400m/convert_and_upload.sh rename to examples/multimodal/laion400m/convert_and_upload.sh diff --git a/streaming/examples/multimodal/laion400m/download_data.sh b/examples/multimodal/laion400m/download_data.sh similarity index 100% rename from streaming/examples/multimodal/laion400m/download_data.sh rename to examples/multimodal/laion400m/download_data.sh diff --git a/streaming/examples/multimodal/laion400m/download_meta.sh b/examples/multimodal/laion400m/download_meta.sh similarity index 100% rename from streaming/examples/multimodal/laion400m/download_meta.sh rename to examples/multimodal/laion400m/download_meta.sh diff --git a/streaming/examples/multimodal/webvid/read.py b/examples/multimodal/webvid/read.py similarity index 100% rename from streaming/examples/multimodal/webvid/read.py rename to examples/multimodal/webvid/read.py diff --git a/streaming/examples/multimodal/webvid/webvid/bench_inside.py b/examples/multimodal/webvid/webvid/bench_inside.py similarity index 100% rename from streaming/examples/multimodal/webvid/webvid/bench_inside.py rename to examples/multimodal/webvid/webvid/bench_inside.py diff --git a/streaming/examples/multimodal/webvid/webvid/bench_outside_dt.py b/examples/multimodal/webvid/webvid/bench_outside_dt.py similarity index 100% rename from streaming/examples/multimodal/webvid/webvid/bench_outside_dt.py rename to examples/multimodal/webvid/webvid/bench_outside_dt.py diff --git a/streaming/examples/multimodal/webvid/webvid/bench_outside_gi.py b/examples/multimodal/webvid/webvid/bench_outside_gi.py similarity index 100% rename from streaming/examples/multimodal/webvid/webvid/bench_outside_gi.py rename to examples/multimodal/webvid/webvid/bench_outside_gi.py diff --git a/streaming/examples/multimodal/webvid/webvid/plot.py b/examples/multimodal/webvid/webvid/plot.py similarity index 100% rename from streaming/examples/multimodal/webvid/webvid/plot.py rename to examples/multimodal/webvid/webvid/plot.py diff --git a/streaming/examples/multimodal/webvid/write/README.md b/examples/multimodal/webvid/write/README.md similarity index 100% rename from streaming/examples/multimodal/webvid/write/README.md rename to examples/multimodal/webvid/write/README.md diff --git a/streaming/examples/multimodal/webvid/write/__init__.py b/examples/multimodal/webvid/write/__init__.py similarity index 100% rename from streaming/examples/multimodal/webvid/write/__init__.py rename to examples/multimodal/webvid/write/__init__.py diff --git a/streaming/examples/multimodal/webvid/write/crawl_webvid.py b/examples/multimodal/webvid/write/crawl_webvid.py similarity index 100% rename from streaming/examples/multimodal/webvid/write/crawl_webvid.py rename to examples/multimodal/webvid/write/crawl_webvid.py diff --git a/streaming/examples/multimodal/webvid/write/crawl_webvid_subsets.py b/examples/multimodal/webvid/write/crawl_webvid_subsets.py similarity index 100% rename from streaming/examples/multimodal/webvid/write/crawl_webvid_subsets.py rename to examples/multimodal/webvid/write/crawl_webvid_subsets.py diff --git a/streaming/examples/multimodal/webvid/write/extract_webvid_videos.py b/examples/multimodal/webvid/write/extract_webvid_videos.py similarity index 100% rename from streaming/examples/multimodal/webvid/write/extract_webvid_videos.py rename to examples/multimodal/webvid/write/extract_webvid_videos.py diff --git a/streaming/examples/text/c4/README.md b/examples/text/c4/README.md similarity index 100% rename from streaming/examples/text/c4/README.md rename to examples/text/c4/README.md diff --git a/streaming/examples/text/c4/read.py b/examples/text/c4/read.py similarity index 100% rename from streaming/examples/text/c4/read.py rename to examples/text/c4/read.py diff --git a/streaming/examples/text/c4/write.py b/examples/text/c4/write.py similarity index 100% rename from streaming/examples/text/c4/write.py rename to examples/text/c4/write.py diff --git a/streaming/examples/text/enwiki_tok/__init__.py b/examples/text/enwiki_tok/__init__.py similarity index 100% rename from streaming/examples/text/enwiki_tok/__init__.py rename to examples/text/enwiki_tok/__init__.py diff --git a/streaming/examples/text/enwiki_tok/mds/README.md b/examples/text/enwiki_tok/mds/README.md similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/README.md rename to examples/text/enwiki_tok/mds/README.md diff --git a/streaming/examples/text/enwiki_tok/mds/__init__.py b/examples/text/enwiki_tok/mds/__init__.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/__init__.py rename to examples/text/enwiki_tok/mds/__init__.py diff --git a/streaming/examples/text/enwiki_tok/mds/create_pretraining_data.py b/examples/text/enwiki_tok/mds/create_pretraining_data.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/create_pretraining_data.py rename to examples/text/enwiki_tok/mds/create_pretraining_data.py diff --git a/streaming/examples/text/enwiki_tok/mds/make_eval.sh b/examples/text/enwiki_tok/mds/make_eval.sh similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/make_eval.sh rename to examples/text/enwiki_tok/mds/make_eval.sh diff --git a/streaming/examples/text/enwiki_tok/mds/make_train_parallel.py b/examples/text/enwiki_tok/mds/make_train_parallel.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/make_train_parallel.py rename to examples/text/enwiki_tok/mds/make_train_parallel.py diff --git a/streaming/examples/text/enwiki_tok/mds/merge_shard_groups.py b/examples/text/enwiki_tok/mds/merge_shard_groups.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/merge_shard_groups.py rename to examples/text/enwiki_tok/mds/merge_shard_groups.py diff --git a/streaming/examples/text/enwiki_tok/mds/pick_eval_samples.py b/examples/text/enwiki_tok/mds/pick_eval_samples.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/pick_eval_samples.py rename to examples/text/enwiki_tok/mds/pick_eval_samples.py diff --git a/streaming/examples/text/enwiki_tok/mds/tokenization.py b/examples/text/enwiki_tok/mds/tokenization.py similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/tokenization.py rename to examples/text/enwiki_tok/mds/tokenization.py diff --git a/streaming/examples/text/enwiki_tok/mds/vocab.txt b/examples/text/enwiki_tok/mds/vocab.txt similarity index 100% rename from streaming/examples/text/enwiki_tok/mds/vocab.txt rename to examples/text/enwiki_tok/mds/vocab.txt diff --git a/streaming/examples/text/enwiki_tok/tfrecord/__init__.py b/examples/text/enwiki_tok/tfrecord/__init__.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/__init__.py rename to examples/text/enwiki_tok/tfrecord/__init__.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/count_samples.py b/examples/text/enwiki_tok/tfrecord/count_samples.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/count_samples.py rename to examples/text/enwiki_tok/tfrecord/count_samples.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py b/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py rename to examples/text/enwiki_tok/tfrecord/create_pretraining_data.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/make_eval.sh b/examples/text/enwiki_tok/tfrecord/make_eval.sh similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/make_eval.sh rename to examples/text/enwiki_tok/tfrecord/make_eval.sh diff --git a/streaming/examples/text/enwiki_tok/tfrecord/make_train.sh b/examples/text/enwiki_tok/tfrecord/make_train.sh similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/make_train.sh rename to examples/text/enwiki_tok/tfrecord/make_train.sh diff --git a/streaming/examples/text/enwiki_tok/tfrecord/make_train_parallel.py b/examples/text/enwiki_tok/tfrecord/make_train_parallel.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/make_train_parallel.py rename to examples/text/enwiki_tok/tfrecord/make_train_parallel.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py b/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py rename to examples/text/enwiki_tok/tfrecord/pick_eval_samples.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/tokenization.py b/examples/text/enwiki_tok/tfrecord/tokenization.py similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/tokenization.py rename to examples/text/enwiki_tok/tfrecord/tokenization.py diff --git a/streaming/examples/text/enwiki_tok/tfrecord/vocab.txt b/examples/text/enwiki_tok/tfrecord/vocab.txt similarity index 100% rename from streaming/examples/text/enwiki_tok/tfrecord/vocab.txt rename to examples/text/enwiki_tok/tfrecord/vocab.txt diff --git a/streaming/examples/text/enwiki_txt/README.md b/examples/text/enwiki_txt/README.md similarity index 100% rename from streaming/examples/text/enwiki_txt/README.md rename to examples/text/enwiki_txt/README.md diff --git a/streaming/examples/text/enwiki_txt/enwiki.py b/examples/text/enwiki_txt/enwiki.py similarity index 100% rename from streaming/examples/text/enwiki_txt/enwiki.py rename to examples/text/enwiki_txt/enwiki.py diff --git a/streaming/examples/text/enwiki_txt/write.py b/examples/text/enwiki_txt/write.py similarity index 100% rename from streaming/examples/text/enwiki_txt/write.py rename to examples/text/enwiki_txt/write.py diff --git a/streaming/examples/text/pile/README.md b/examples/text/pile/README.md similarity index 100% rename from streaming/examples/text/pile/README.md rename to examples/text/pile/README.md diff --git a/streaming/examples/text/pile/read.py b/examples/text/pile/read.py similarity index 100% rename from streaming/examples/text/pile/read.py rename to examples/text/pile/read.py diff --git a/streaming/examples/text/pile/write.py b/examples/text/pile/write.py similarity index 100% rename from streaming/examples/text/pile/write.py rename to examples/text/pile/write.py diff --git a/streaming/examples/vision/ade20k/README.md b/examples/vision/ade20k/README.md similarity index 100% rename from streaming/examples/vision/ade20k/README.md rename to examples/vision/ade20k/README.md diff --git a/streaming/examples/vision/ade20k/read.py b/examples/vision/ade20k/read.py similarity index 100% rename from streaming/examples/vision/ade20k/read.py rename to examples/vision/ade20k/read.py diff --git a/streaming/examples/vision/ade20k/write.py b/examples/vision/ade20k/write.py similarity index 100% rename from streaming/examples/vision/ade20k/write.py rename to examples/vision/ade20k/write.py diff --git a/streaming/examples/vision/cifar10/README.md b/examples/vision/cifar10/README.md similarity index 100% rename from streaming/examples/vision/cifar10/README.md rename to examples/vision/cifar10/README.md diff --git a/streaming/examples/vision/cifar10/read.py b/examples/vision/cifar10/read.py similarity index 100% rename from streaming/examples/vision/cifar10/read.py rename to examples/vision/cifar10/read.py diff --git a/streaming/examples/vision/cifar10/write.py b/examples/vision/cifar10/write.py similarity index 100% rename from streaming/examples/vision/cifar10/write.py rename to examples/vision/cifar10/write.py diff --git a/streaming/examples/vision/cifar10/write_fake.py b/examples/vision/cifar10/write_fake.py similarity index 100% rename from streaming/examples/vision/cifar10/write_fake.py rename to examples/vision/cifar10/write_fake.py diff --git a/streaming/examples/vision/coco/README.md b/examples/vision/coco/README.md similarity index 100% rename from streaming/examples/vision/coco/README.md rename to examples/vision/coco/README.md diff --git a/streaming/examples/vision/coco/read.py b/examples/vision/coco/read.py similarity index 100% rename from streaming/examples/vision/coco/read.py rename to examples/vision/coco/read.py diff --git a/streaming/examples/vision/coco/write.py b/examples/vision/coco/write.py similarity index 100% rename from streaming/examples/vision/coco/write.py rename to examples/vision/coco/write.py diff --git a/streaming/examples/vision/imagenet/README.md b/examples/vision/imagenet/README.md similarity index 100% rename from streaming/examples/vision/imagenet/README.md rename to examples/vision/imagenet/README.md diff --git a/streaming/examples/vision/imagenet/read.py b/examples/vision/imagenet/read.py similarity index 100% rename from streaming/examples/vision/imagenet/read.py rename to examples/vision/imagenet/read.py diff --git a/streaming/examples/vision/imagenet/write.py b/examples/vision/imagenet/write.py similarity index 100% rename from streaming/examples/vision/imagenet/write.py rename to examples/vision/imagenet/write.py diff --git a/pyproject.toml b/pyproject.toml index a32957f39..c99c91a84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,8 @@ include = [ exclude = [ "build/**", "node_modules/**", - "streaming/examples/text/enwiki_tok/**", - "docs/source/conf.py" + "docs/source/conf.py", + "examples/text/enwiki_tok/**", ] # Disable checks for missing imports, as a conditional install of streaming will not include them From ff90826edda5d9372062a482efcb9de458804f3c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 14:51:30 -0700 Subject: [PATCH 25/45] Update multimodal. --- examples/multimodal/webvid/read.py | 241 ++--------------------------- 1 file changed, 12 insertions(+), 229 deletions(-) diff --git a/examples/multimodal/webvid/read.py b/examples/multimodal/webvid/read.py index d3f74c2d9..6742d430d 100644 --- a/examples/multimodal/webvid/read.py +++ b/examples/multimodal/webvid/read.py @@ -5,7 +5,7 @@ import os from time import sleep -from typing import Any, Optional +from typing import Any, Dict, Optional from streaming import StreamingDataset from streaming.dataset import TICK, _Iterator @@ -16,58 +16,6 @@ class StreamingInsideWebVid(StreamingDataset): """Streaming WebVid dataset. Videos are stored "inside" the shards, as is typically done. - - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. """ def get_item(self, idx: int) -> Any: @@ -91,101 +39,19 @@ class StreamingOutsideGIWebVid(StreamingDataset): get_item ("GI"), when samples are requested by the dataloader. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. extra_local (str, optional): Base destination of extra local sample downloads. extra_remote (str, optional): Base source of extra remote sample downloads. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, extra_local: Optional[str] = None, - extra_remote: Optional[str] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + extra_remote: Optional[str] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) # Videos are stored outside of their shards here. - self.download_timeout = download_timeout self.extra_local = extra_local self.extra_remote = extra_remote @@ -205,7 +71,7 @@ def get_item(self, idx: int) -> Any: local = os.path.join(self.extra_local, rel_path) remote = os.path.join(self.extra_remote, rel_path) if not os.path.exists(local): - download_file(remote, local, self.download_timeout) + download_file(remote, local, self.streams[0].download_timeout) with open(local, 'rb') as fp: content = fp.read() obj['content'] = content @@ -222,101 +88,18 @@ class StreamingOutsideDTWebVid(StreamingDataset): _download_thread ("DT"), when the download thread prefetches the sample. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. extra_local (str, optional): Base destination of extra local sample downloads. extra_remote (str, optional): Base source of extra remote sample downloads. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, extra_local: Optional[str] = None, - extra_remote: Optional[str] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + extra_remote: Optional[str] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) # Videos are stored outside of their shards here. - self.download_timeout = download_timeout self.extra_local = extra_local self.extra_remote = extra_remote @@ -336,7 +119,7 @@ def get_item(self, idx: int) -> Any: local = os.path.join(self.extra_local, rel_path) remote = os.path.join(self.extra_remote, rel_path) if not os.path.exists(local): - download_file(remote, local, self.download_timeout) + download_file(remote, local, self.streams[0].download_timeout) with open(local, 'rb') as fp: content = fp.read() obj['content'] = content @@ -394,7 +177,7 @@ def _download_thread(self, it: _Iterator) -> None: local = os.path.join(self.extra_local, rel_path) remote = os.path.join(self.extra_remote, rel_path) if not os.path.exists(local): - download_file(remote, local, self.download_timeout) + download_file(remote, local, self.streams[0].download_timeout) # Step forward one sample. it.prepare_index += 1 From a7808aedd626490962a8a8197a789e5bdafc4cfe Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 15:03:38 -0700 Subject: [PATCH 26/45] Update vision dataset sexamples -> kwargs. --- examples/text/c4/read.py | 95 +------------------ examples/text/enwiki_txt/enwiki.py | 144 ----------------------------- examples/text/enwiki_txt/read.py | 62 +++++++++++++ examples/text/pile/read.py | 95 +------------------ 4 files changed, 72 insertions(+), 324 deletions(-) delete mode 100644 examples/text/enwiki_txt/enwiki.py create mode 100644 examples/text/enwiki_txt/read.py diff --git a/examples/text/c4/read.py b/examples/text/c4/read.py index d30340f97..1cb407307 100644 --- a/examples/text/c4/read.py +++ b/examples/text/c4/read.py @@ -7,7 +7,7 @@ the `Common Crawl `_ dataset. """ -from typing import Any, Dict, Optional +from typing import Any, Dict from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -20,104 +20,19 @@ class StreamingC4(StreamingDataset): """Implementation of the C4 (Colossal Cleaned Common Crawl) dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - tokenizer_name: str, - max_seq_len: int, - group_method: str) -> None: + def __init__(self, *, tokenizer_name: str, max_seq_len: int, group_method: str, + **kwargs: Dict[str, Any]) -> None: if group_method not in {'truncate'}: raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.") - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + super().__init__(**kwargs) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len diff --git a/examples/text/enwiki_txt/enwiki.py b/examples/text/enwiki_txt/enwiki.py deleted file mode 100644 index 4385e7394..000000000 --- a/examples/text/enwiki_txt/enwiki.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""English Wikipedia 2020-01-01 streaming dataset.""" - -from typing import Any, Optional - -import numpy as np - -from streaming import StreamingDataset - -__all__ = ['StreamingEnWiki'] - - -class StreamingEnWiki(StreamingDataset): - """Implementation of the English Wikipedia 2020-01-01 streaming dataset. - - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - """ - - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) - self.field_dtypes = { - 'input_ids': np.int32, - 'input_mask': np.int32, - 'attention_mask': np.int32, - 'segment_ids': np.int32, - 'token_type_ids': np.int32, - 'masked_lm_positions': np.int32, - 'masked_lm_ids': np.int32, - 'masked_lm_weights': np.float32, - 'next_sentence_labels': np.int32, - 'labels': np.int32, - } - - def get_item(self, idx: int) -> Any: - """Get sample by global index, blocking to load its shard if missing. - - Args: - idx (int): Sample index. - - Returns: - Any: Sample data. - """ - obj = super().get_item(idx) - - for key, value in obj.items(): - dtype = self.field_dtypes[key] - obj[key] = np.frombuffer(value, dtype) - - input_len = len(obj['input_ids']) - labels = np.full((input_len,), -100) - labels[obj['masked_lm_positions']] = obj['masked_lm_ids'] - - return { - 'input_ids': obj['input_ids'].copy(), - 'token_type_ids': obj['segment_ids'].copy(), - 'attention_mask': obj['input_mask'].copy(), - 'labels': labels, - } diff --git a/examples/text/enwiki_txt/read.py b/examples/text/enwiki_txt/read.py new file mode 100644 index 000000000..192516b52 --- /dev/null +++ b/examples/text/enwiki_txt/read.py @@ -0,0 +1,62 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""English Wikipedia 2020-01-01 streaming dataset.""" + +from typing import Any, Dict + +import numpy as np + +from streaming import StreamingDataset + +__all__ = ['StreamingEnWiki'] + + +class StreamingEnWiki(StreamingDataset): + """Implementation of the English Wikipedia 2020-01-01 streaming dataset. + + Args: + **kwargs (Dict[str, Any]): Keyword arguments. + """ + + def __init__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) + + self.field_dtypes = { + 'input_ids': np.int32, + 'input_mask': np.int32, + 'attention_mask': np.int32, + 'segment_ids': np.int32, + 'token_type_ids': np.int32, + 'masked_lm_positions': np.int32, + 'masked_lm_ids': np.int32, + 'masked_lm_weights': np.float32, + 'next_sentence_labels': np.int32, + 'labels': np.int32, + } + + def get_item(self, idx: int) -> Any: + """Get sample by global index, blocking to load its shard if missing. + + Args: + idx (int): Sample index. + + Returns: + Any: Sample data. + """ + obj = super().get_item(idx) + + for key, value in obj.items(): + dtype = self.field_dtypes[key] + obj[key] = np.frombuffer(value, dtype) + + input_len = len(obj['input_ids']) + labels = np.full((input_len,), -100) + labels[obj['masked_lm_positions']] = obj['masked_lm_ids'] + + return { + 'input_ids': obj['input_ids'].copy(), + 'token_type_ids': obj['segment_ids'].copy(), + 'attention_mask': obj['input_mask'].copy(), + 'labels': labels, + } diff --git a/examples/text/pile/read.py b/examples/text/pile/read.py index 58c4afc68..4b70a960c 100644 --- a/examples/text/pile/read.py +++ b/examples/text/pile/read.py @@ -7,7 +7,7 @@ high-quality datasets combined together. """ -from typing import Any, Dict, Optional +from typing import Any, Dict from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -20,104 +20,19 @@ class StreamingPile(StreamingDataset): """Implementation of the the Pile using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - tokenizer_name: str, - max_seq_len: int, - group_method: str) -> None: + def __init__(self, *, tokenizer_name: str, max_seq_len: int, group_method: str, + **kwargs: Dict[str, Any]) -> None: if group_method not in ['truncate']: raise ValueError(f'Only group_method="truncate" is supported at this time.') - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + super().__init__(**kwargs) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len From c857ed65e4d341cb09d394fbfe5444c8207a762e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 15:08:19 -0700 Subject: [PATCH 27/45] Update vision datasets to use kwargs (save us from bitrot, o kwargs). --- examples/vision/ade20k/read.py | 91 ++----------------------------- examples/vision/cifar10/read.py | 56 +------------------ examples/vision/coco/read.py | 92 ++------------------------------ examples/vision/imagenet/read.py | 56 +------------------ 4 files changed, 11 insertions(+), 284 deletions(-) diff --git a/examples/vision/ade20k/read.py b/examples/vision/ade20k/read.py index f04fc423f..47ac475b0 100644 --- a/examples/vision/ade20k/read.py +++ b/examples/vision/ade20k/read.py @@ -7,7 +7,7 @@ more details about this dataset. """ -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple from streaming import StreamingDataset @@ -18,103 +18,22 @@ class StreamingADE20K(StreamingDataset): """Implementation of the ADE20K dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. joint_transform (callable, optional): A function/transforms that takes in an image and a target and returns the transformed versions of both. Defaults to ``None``. transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to ``None``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - partition_algo: str = 'orig', - cache_limit: Optional[int] = None, - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, joint_transform: Optional[Callable] = None, transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + target_transform: Optional[Callable] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) self.joint_transform = joint_transform self.transform = transform self.target_transform = target_transform diff --git a/examples/vision/cifar10/read.py b/examples/vision/cifar10/read.py index 75c2c36ac..95b356f0c 100644 --- a/examples/vision/cifar10/read.py +++ b/examples/vision/cifar10/read.py @@ -15,59 +15,5 @@ class StreamingCIFAR10(StreamingVisionDataset): """Implementation of the CIFAR-10 dataset using StreamingDataset. - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. + No custom work is neeeded. """ diff --git a/examples/vision/coco/read.py b/examples/vision/coco/read.py index a9622eab3..c52ed7284 100644 --- a/examples/vision/coco/read.py +++ b/examples/vision/coco/read.py @@ -7,7 +7,7 @@ `COCO dataset `_ for more details. """ -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional from streaming import StreamingDataset @@ -18,97 +18,13 @@ class StreamingCOCO(StreamingDataset): """Implementation of the COCO dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. transform (callable, optional): A function/transform that takes in an image and bboxes and returns a transformed version. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - partition_algo: str = 'orig', - cache_limit: Optional[int] = None, - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - transform: Optional[Callable] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + def __init__(self, *, transform: Optional[Callable] = None, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) self.transform = transform def get_item(self, idx: int) -> Any: diff --git a/examples/vision/imagenet/read.py b/examples/vision/imagenet/read.py index 8ed47af54..53d796073 100644 --- a/examples/vision/imagenet/read.py +++ b/examples/vision/imagenet/read.py @@ -15,59 +15,5 @@ class StreamingImageNet(StreamingVisionDataset): """Implementation of the ImageNet dataset using StreamingDataset. - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. + No custom work is needed. """ From 89d5719e892a10556734659b069c3a90bc113684 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 15:12:18 -0700 Subject: [PATCH 28/45] Generalize `keep_zip` argument to `keep_packed`. --- regression/utils.py | 4 ++-- streaming/dataset.py | 17 ++++++++++------- streaming/format/reader.py | 19 ++++++++++--------- streaming/stream.py | 35 +++++++++++++++++++---------------- tests/test_eviction.py | 30 +++++++++++++++--------------- 5 files changed, 56 insertions(+), 49 deletions(-) diff --git a/regression/utils.py b/regression/utils.py index e66d5db8e..aeca53935 100644 --- a/regression/utils.py +++ b/regression/utils.py @@ -76,8 +76,8 @@ def get_streaming_dataset_params(kwargs: dict[str, str]) -> dict[str, Any]: dataset_params['download_timeout'] = float(kwargs['download_timeout']) if 'validate_hash' in kwargs: dataset_params['validate_hash'] = kwargs['validate_hash'] - if 'keep_zip' in kwargs: - dataset_params['keep_zip'] = kwargs['keep_zip'].lower().capitalize() == 'True' + if 'keep_packed' in kwargs: + dataset_params['keep_packed'] = kwargs['keep_packed'].lower().capitalize() == 'True' if 'epoch_size' in kwargs: dataset_params['epoch_size'] = kwargs['epoch_size'] if 'predownload' in kwargs: diff --git a/streaming/dataset.py b/streaming/dataset.py index a5934c3bf..b17ce99f7 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -3,6 +3,9 @@ """A mid-epoch-resumable streaming/caching pytorch IterableDataset.""" +# TODO: we have generalized `keep_zip` to `keep_packed` in StreamingDataaset arguments, but must +# still accept `keep_zip`. + import json import logging import os @@ -200,7 +203,7 @@ class StreamingDataset(Array, IterableDataset): * ``download_retry`` * ``download_timeout`` * ``validate_hash`` - * ``keep_zip`` + * ``keep_packed`` * Absolute dataset size, if streams were weighted relatively: @@ -253,9 +256,9 @@ class StreamingDataset(Array, IterableDataset): an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. + keep_packed (bool): Whether to keep or drop the packed form of shards after unpacking, e.g. + compressed shards after decompression or Parquet shards after conversion to MDS. If + ``False``, keep iff remote is local or no remote. Defaults to ``False``. epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger @@ -314,7 +317,7 @@ def __init__(self, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, - keep_zip: bool = False, + keep_packed: bool = False, epoch_size: Optional[Union[int, str]] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, @@ -410,7 +413,7 @@ def __init__(self, 'download_retry': download_retry, 'download_timeout': download_timeout, 'validate_hash': validate_hash, - 'keep_zip': keep_zip, + 'keep_packed': keep_packed, } for stream in streams: stream.apply_default(default) @@ -421,7 +424,7 @@ def __init__(self, download_retry=download_retry, download_timeout=download_timeout, validate_hash=validate_hash, - keep_zip=keep_zip) + keep_packed=keep_packed) streams = [default] # Validate the stream weighting scheme (relative or absolute) to catch errors before we go diff --git a/streaming/format/reader.py b/streaming/format/reader.py index 3fa9c6b3a..88c72efe3 100644 --- a/streaming/format/reader.py +++ b/streaming/format/reader.py @@ -122,14 +122,15 @@ def evict(self) -> int: """ return self._evict_raw() + self._evict_zip() - def set_up_local(self, listing: Set[str], safe_keep_zip: bool) -> int: + def set_up_local(self, listing: Set[str], safe_keep_packed: bool) -> int: """Bring what shard files are present to a consistent state, returning whether present. Args: listing (Set[str]): The listing of all files under dirname/[split/]. This is listed once and then saved because there could potentially be very many shard files. - safe_keep_zip (bool): Whether to keep zip files when decompressing. Possible when - compression was used. Necessary when local is the remote or there is no remote. + safe_keep_packed (bool): Whether to keep the packed form of shards after unpacking, + e.g. compressed shards after decompression or Parquet shards after conversion to + MDS, after taking into account whether local is remote. Returns: bool: Whether the shard is present. @@ -168,7 +169,7 @@ def set_up_local(self, listing: Set[str], safe_keep_zip: bool) -> int: # Enumerate cases of raw/zip presence. if self.compression: - if safe_keep_zip: + if safe_keep_packed: if has_raw: if has_zip: # Present (normalized). @@ -242,7 +243,7 @@ def get_max_size(self) -> int: "Max" in this case means both the raw (decompressed) and zip (compressed) versions are resident (assuming it has a zip form). This is the maximum disk usage the shard can reach. - When compressed was used, even if keep_zip is ``False``, the zip form must still be + When compressed was used, even if seep_zip is ``False``, the zip form must still be resident at the same time as the raw form during shard decompression. Returns: @@ -250,21 +251,21 @@ def get_max_size(self) -> int: """ return self.get_raw_size() + (self.get_zip_size() or 0) - def get_persistent_size(self, keep_zip: bool) -> int: + def get_persistent_size(self, keep_packed: bool) -> int: """Get the persistent size of this shard. "Persistent" in this case means whether both raw and zip are present is subject to - keep_zip. If we are not keeping zip files after decompression, they don't count to the + keep_packed. If we are not keeping zip files after decompression, they don't count to the shard's persistent size on disk. Args: - keep_zip (bool): Whether to keep zip files after decompressing. + keep_packed (bool): Whether to keep zip files after decompressing. Returns: int: Size in bytes. """ if self.compression: - if keep_zip: + if keep_packed: size = self.get_max_size() else: size = self.get_raw_size() diff --git a/streaming/stream.py b/streaming/stream.py index 6d970cde3..092a9dd4b 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -3,6 +3,9 @@ """A dataset, or sub-dataset if mixing, from which we stream/cache samples.""" +# TODO: we have generalized `keep_zip` to `keep_packed` in Stream arguments, but must still accept +# `keep_zip`. + import hashlib import json import os @@ -57,7 +60,7 @@ class Stream: * ``download_retry`` * ``download_timeout`` * ``validate_hash`` - * ``keep_zip`` + * ``keep_packed`` Args: remote (str, optional): Remote path or directory to download the dataset from. If ``None``, @@ -84,9 +87,9 @@ class Stream: before raising an exception. Defaults to ``None``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - keep_zip (bool, optional): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep if and only if remote is local or no remote. - Defaults to ``None``. + keep_packed (bool): Whether to keep or drop the packed form of shards after unpacking, e.g. + compressed shards after decompression or Parquet shards after conversion to MDS. If + ``False``, keep iff remote is local or no remote. Defaults to ``False``. """ def __init__(self, @@ -100,7 +103,7 @@ def __init__(self, download_retry: Optional[int] = None, download_timeout: Optional[float] = None, validate_hash: Optional[str] = None, - keep_zip: Optional[bool] = None) -> None: + keep_packed: Optional[bool] = None) -> None: self.remote = remote self._local = local self.split = split or '' @@ -157,10 +160,10 @@ def __init__(self, else: self.local = local - self._keep_zip = keep_zip - if keep_zip is not None: - self.keep_zip = keep_zip - self.safe_keep_zip = self.keep_zip or self.remote in {None, self.local} + self._keep_packed = keep_packed + if keep_packed is not None: + self.keep_packed = keep_packed + self.safe_keep_packed = self.keep_packed or self.remote in {None, self.local} def _get_temporary_directory(self) -> str: """Construct a path to a temporary directory based on remote and split.""" @@ -189,9 +192,9 @@ def apply_default(self, default: dict) -> None: self.download_timeout = default['download_timeout'] if self.validate_hash is None: self.validate_hash = default['validate_hash'] or None - if self._keep_zip is None: - self.keep_zip = default['keep_zip'] - self.safe_keep_zip = default['keep_zip'] or self.remote in {None, self.local} + if self._keep_packed is None: + self.keep_packed = default['keep_packed'] + self.safe_keep_packed = default['keep_packed'] or self.remote in {None, self.local} @classmethod def validate_weights(cls, streams: Sequence[Self]) -> Tuple[bool, bool]: @@ -344,7 +347,7 @@ def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_file os.rename(tmp_filename, raw_filename) # Maybe remove compressed to save space. - if not self.safe_keep_zip: + if not self.safe_keep_packed: os.remove(zip_filename) def _prepare_shard_part(self, @@ -371,7 +374,7 @@ def _prepare_shard_part(self, raw_filename = os.path.join(self.local, self.split, raw_info.basename) if os.path.isfile(raw_filename): # Has raw. - if zip_info and not self.safe_keep_zip: + if zip_info and not self.safe_keep_packed: zip_filename = os.path.join(self.local, self.split, zip_info.basename) if os.path.isfile(zip_filename): # If don't keep zip and it has a zip, drop the zip. @@ -389,7 +392,7 @@ def _prepare_shard_part(self, # Validate and decompress. self._decompress_shard_part(zip_info, zip_filename, raw_filename, compression) delta += raw_info.bytes - if not self.safe_keep_zip: + if not self.safe_keep_packed: delta -= zip_info.bytes else: # Download raw. @@ -491,7 +494,7 @@ def set_up_local(self, shards: List[Reader], cache_usage_per_shard: NDArray[np.i # Determine which shards are present, making local dir consistent. for i, shard in enumerate(shards): - cache_usage_per_shard[i] = shard.set_up_local(listing, self.safe_keep_zip) + cache_usage_per_shard[i] = shard.set_up_local(listing, self.safe_keep_packed) def get_index_size(self) -> int: """Get the size of the index file in bytes. diff --git a/tests/test_eviction.py b/tests/test_eviction.py index 5afb12473..ce0f7c23a 100644 --- a/tests/test_eviction.py +++ b/tests/test_eviction.py @@ -13,14 +13,14 @@ from tests.common.utils import convert_to_mds -def validate(remote: str, local: str, dataset: StreamingDataset, keep_zip: bool, +def validate(remote: str, local: str, dataset: StreamingDataset, keep_packed: bool, is_shard_evicted: bool): """Validate the number of files in a local directory in comparison to remote directory.""" if is_shard_evicted: ops = operator.lt else: ops = operator.eq - if keep_zip: + if keep_packed: if dataset.shards[0].compression: # Local has raw + zip, remote has zip. assert ops( @@ -40,57 +40,57 @@ def validate(remote: str, local: str, dataset: StreamingDataset, keep_zip: bool, assert ops(set(os.listdir(local)), set(os.listdir(remote))) -def shard_eviction_disabled(remote: str, local: str, keep_zip: bool): +def shard_eviction_disabled(remote: str, local: str, keep_packed: bool): """ With shard eviction disabled. """ - dataset = StreamingDataset(remote=remote, local=local, keep_zip=keep_zip) + dataset = StreamingDataset(remote=remote, local=local, keep_packed=keep_packed) for _ in range(2): for sample in dataset: # pyright: ignore pass - validate(remote, local, dataset, keep_zip, False) + validate(remote, local, dataset, keep_packed, False) rmtree(local, ignore_errors=False) -def shard_eviction_too_high(remote: str, local: str, keep_zip: bool): +def shard_eviction_too_high(remote: str, local: str, keep_packed: bool): """ With no shard evictions because cache_limit is bigger than the dataset. """ dataset = StreamingDataset(remote=remote, local=local, - keep_zip=keep_zip, + keep_packed=keep_packed, cache_limit=1_000_000) dataloader = DataLoader(dataset=dataset, num_workers=8) for _ in range(2): for _ in dataloader: pass - validate(remote, local, dataset, keep_zip, False) + validate(remote, local, dataset, keep_packed, False) rmtree(local, ignore_errors=False) -def shard_eviction(remote: str, local: str, keep_zip: bool): +def shard_eviction(remote: str, local: str, keep_packed: bool): """ With shard eviction because cache_limit is smaller than the whole dataset. """ - cache_limit = '120kb' if keep_zip else '100kb' + cache_limit = '120kb' if keep_packed else '100kb' dataset = StreamingDataset(remote=remote, local=local, - keep_zip=keep_zip, + keep_packed=keep_packed, cache_limit=cache_limit) dataloader = DataLoader(dataset=dataset, num_workers=8) for _ in range(2): for _ in dataloader: pass - validate(remote, local, dataset, keep_zip, True) + validate(remote, local, dataset, keep_packed, True) rmtree(local, ignore_errors=False) -def manual_shard_eviction(remote: str, local: str, keep_zip: bool): +def manual_shard_eviction(remote: str, local: str, keep_packed: bool): """ Manually downloading and evicting shards. """ - dataset = StreamingDataset(remote=remote, local=local, keep_zip=keep_zip) + dataset = StreamingDataset(remote=remote, local=local, keep_packed=keep_packed) for shard_id in range(dataset.num_shards): dataset.prepare_shard(shard_id) @@ -109,7 +109,7 @@ def manual_shard_eviction(remote: str, local: str, keep_zip: bool): rmtree(local, ignore_errors=False) -def cache_limit_too_low(remote: str, local: str, keep_zip: bool): +def cache_limit_too_low(remote: str, local: str, keep_packed: bool): """ With impossible shard eviction settings because cache_limit is set too low. """ From c09248c19d570649969bf9f780ba910c6de81829 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 15:49:47 -0700 Subject: [PATCH 29/45] Add graceful migration from keep_zip to keep_packed. --- streaming/dataset.py | 22 ++++++++++++++-------- streaming/stream.py | 19 ++++++++++++------- streaming/util/migration.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 15 deletions(-) create mode 100644 streaming/util/migration.py diff --git a/streaming/dataset.py b/streaming/dataset.py index b17ce99f7..c80e001af 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -3,9 +3,6 @@ """A mid-epoch-resumable streaming/caching pytorch IterableDataset.""" -# TODO: we have generalized `keep_zip` to `keep_packed` in StreamingDataaset arguments, but must -# still accept `keep_zip`. - import json import logging import os @@ -38,6 +35,7 @@ from streaming.spanner import Spanner from streaming.stream import Stream from streaming.util import normalize_bytes, normalize_count +from streaming.util.migration import get_keep_packed from streaming.world import World # An arbitrary time in the future, used for cold shard eviction. @@ -256,9 +254,11 @@ class StreamingDataset(Array, IterableDataset): an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - keep_packed (bool): Whether to keep or drop the packed form of shards after unpacking, e.g. - compressed shards after decompression or Parquet shards after conversion to MDS. If - ``False``, keep iff remote is local or no remote. Defaults to ``False``. + keep_packed (bool, optional): Whether to keep or drop the packed form of shards after + unpacking, e.g. compressed shards after decompression or Parquet shards after + conversion to MDS. If ``False``, keep iff remote is local or no remote. Defaults to + ``None``, which is normalized to ``False``, in order to distinguish setting it on + purpose from receiving the default. epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced across all streams. If ``None``, takes its value from the total number of underlying samples. Provide this field if you are weighting streams relatively to target a larger @@ -306,6 +306,9 @@ class StreamingDataset(Array, IterableDataset): ``None``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + keep_zip (bool, optional): This argument is deprecated. It has been replaced by + ``keep_packed``, which is more general, for which it serves as a fallback. Defaults to + ``None``. """ def __init__(self, @@ -317,7 +320,7 @@ def __init__(self, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, - keep_packed: bool = False, + keep_packed: Optional[bool] = None, epoch_size: Optional[Union[int, str]] = None, predownload: Optional[int] = None, cache_limit: Optional[Union[int, str]] = None, @@ -330,7 +333,8 @@ def __init__(self, shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, shuffle_block_size: Optional[int] = None, - batching_method: str = 'random') -> None: + batching_method: str = 'random', + keep_zip: Optional[bool] = None) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -345,6 +349,8 @@ def __init__(self, self.shuffle_block_size = shuffle_block_size self.batching_method = batching_method + keep_packed = get_keep_packed(keep_packed, keep_zip) + # Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the # number of physical nodes of the initial run in the _resume function. self.initial_physical_nodes = None diff --git a/streaming/stream.py b/streaming/stream.py index 092a9dd4b..edc1f3835 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -3,9 +3,6 @@ """A dataset, or sub-dataset if mixing, from which we stream/cache samples.""" -# TODO: we have generalized `keep_zip` to `keep_packed` in Stream arguments, but must still accept -# `keep_zip`. - import hashlib import json import os @@ -23,6 +20,7 @@ from streaming.hashing import get_hash from streaming.storage import download_file from streaming.storage.extra import wait_for_file_to_exist +from streaming.util.migration import get_keep_packed from streaming.util.retrying import retry from streaming.world import World @@ -87,9 +85,14 @@ class Stream: before raising an exception. Defaults to ``None``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. - keep_packed (bool): Whether to keep or drop the packed form of shards after unpacking, e.g. - compressed shards after decompression or Parquet shards after conversion to MDS. If - ``False``, keep iff remote is local or no remote. Defaults to ``False``. + keep_packed (bool, optional): Whether to keep or drop the packed form of shards after + unpacking, e.g. compressed shards after decompression or Parquet shards after + conversion to MDS. If ``False``, keep iff remote is local or no remote. Defaults to + ``None``, which is normalized to ``False``, in order to distinguish setting it on + purpose from receiving the default. + keep_zip (bool, optional): This argument is deprecated. It has been replaced by + ``keep_packed``, which is more general, for which it serves as a fallback. Defaults to + ``None``. """ def __init__(self, @@ -103,7 +106,8 @@ def __init__(self, download_retry: Optional[int] = None, download_timeout: Optional[float] = None, validate_hash: Optional[str] = None, - keep_packed: Optional[bool] = None) -> None: + keep_packed: Optional[bool] = None, + keep_zip: Optional[bool] = None) -> None: self.remote = remote self._local = local self.split = split or '' @@ -160,6 +164,7 @@ def __init__(self, else: self.local = local + keep_packed = get_keep_packed(keep_packed, keep_zip) self._keep_packed = keep_packed if keep_packed is not None: self.keep_packed = keep_packed diff --git a/streaming/util/migration.py b/streaming/util/migration.py new file mode 100644 index 000000000..9a9ab543e --- /dev/null +++ b/streaming/util/migration.py @@ -0,0 +1,36 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Graceful migration of StreamingDataset arguments.""" + +import logging +from typing import Optional + +__all__ = ['get_keep_packed'] + +logger = logging.getLogger(__name__) + + +def get_keep_packed(keep_packed: Optional[bool], keep_zip: Optional[bool]) -> bool: + """Get the value of ``keep_packed`` given both old aand new arguments. + + Warns if the deprecated argument ``keep_zip`` is used. + + Args: + keep_packed (bool, optinoal): New argument. + keep_zip (bool, optional): Old argument. + + Returns: + bool: Normalized argument. + """ + if keep_zip is not None: + logger.warning('StreamingDataset/Stream argument `keep_zip` is deprecated, please use ' + + 'the new `keep_packed` argument instead, which is more general.') + + if keep_packed is not None: + return keep_packed + + if keep_zip is not None: + return keep_zip + + return False From 9befaa6a3399327b68eb26d12e7821e7eca4131e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sat, 28 Oct 2023 23:18:37 -0700 Subject: [PATCH 30/45] First take on a MDS write_dataset(). --- streaming/format/mds/writer.py | 54 +++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 950c60f20..394d93ec8 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -4,9 +4,10 @@ """:class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`.""" import json -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np +from tqdm import tqdm from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) @@ -139,3 +140,54 @@ def encode_joint_shard(self) -> bytes: offsets += len(num_samples.tobytes()) + len(offsets.tobytes()) + len(self.config_data) sample_data = b''.join(self.new_samples) return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data + + +def write_dataset(samples: Iterable[Dict[str, Any]], + out: Union[str, Tuple[str, str]], + *, + num_samples: Optional[int] = None, + keep_local: bool = False, + columns: Optional[Dict[str, str]] = None, + compression: Optional[str] = None, + hashes: Optional[List[str]] = None, + max_file_bytes: Optional[Union[int, str]] = '32mib', + num_upload_threads: Optional[int] = None, + upload_retry: int = 2, + show_write_progress: bool = True, + show_upload_progress: bool = True) -> None: + """Write the samples as an MDS dataset. + + Args: + samples (Iterable[Dict[str, Any]]): Iterable of sample dicts. + out (str | Tuple[str, str]): Dataaset save directory, or pair of (local, remote). + num_samples ((int, optional): If ``samples`` is a generator, specify ``num_samples``to + still get a useful progress bar. Defaults to ``None``. + keep_local (bool): Whether to keep local files after upload. Defaults to ``False``. + columns (Dict[str, str], optional): Any column types to override, given by column name. + Defaults to ``None``. + compression (str, optional): What compression scheme to use, if any. Defaults to ``None``. + hashes (List[str], optional): List of hashes to apply to dataset files. + max_file_bytes (int | str, optional): Maximum shard size ,in bytes. Defaults to ``32mib``. + num_upload_threads (int, optional): Number of threads used to upload shards. Defaults to + ``None``, which means to take the default, which is scaled for CPU count, etc. + upload_retry (int): Number of upload reattempts before bailing. Defaults to ``2``. + show_write_progress (bool): Show a progress bar for write progress. Defaults to ``True``. + show_upload_progress (bool): Show a progress bar for upload progress. Defaults to ``True``. + """ + # TODO:borrow the first sample to derive any inferred columns, then return it to its Iterable.. + # TODO: Use the part.00000/ subdir trick to make datasets easily appendable to. + total = len(samples) if hasattr(samples, '__len__') else num_samples # pyright: ignore + if show_write_progress: + samples = tqdm(samples, total=total, leave=False) + columns = columns or {} + with MDSWriter(columns=columns, + out=out, + keep_local=keep_local, + compression=compression, + hashes=hashes, + size_limit=max_file_bytes, + progress_bar=show_upload_progress, + max_workers=num_upload_threads, + retry=upload_retry) as writer: + for sample in samples: + writer.write(sample) From c4a509416a4948ab127e8e4728b950fc5cd3be87 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 00:06:46 -0700 Subject: [PATCH 31/45] Add enough column inference to keep going. --- streaming/format/mds/writer.py | 58 +++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 394d93ec8..d76151806 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -4,6 +4,7 @@ """:class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`.""" import json +from itertools import chain from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -142,6 +143,41 @@ def encode_joint_shard(self) -> bytes: return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data +_type2enc = { + int: 'int', + str: 'str', +} + + +def infer_column(field: Any) -> str: + """Infer the best MDS encoding for a column, given an example field. + + Args: + field (Any): The example. + + Returns: + MDS encoding signature. + """ + ty = type(field) + return _type2enc[ty] + + +def infer_columns(sample: Dict[str, Any]) -> Dict[str, str]: + """Infer dataset columns given a sample. + + Args: + sample (Dict[str, Any]): Mapping of field name to value. + + Returns: + Dict[str, str]: Mapping of field name to type. + """ + ret = {} + for key in sorted(sample): + val = sample[key] + ret[key] = infer_column(val) + return ret + + def write_dataset(samples: Iterable[Dict[str, Any]], out: Union[str, Tuple[str, str]], *, @@ -163,8 +199,7 @@ def write_dataset(samples: Iterable[Dict[str, Any]], num_samples ((int, optional): If ``samples`` is a generator, specify ``num_samples``to still get a useful progress bar. Defaults to ``None``. keep_local (bool): Whether to keep local files after upload. Defaults to ``False``. - columns (Dict[str, str], optional): Any column types to override, given by column name. - Defaults to ``None``. + columns (Dict[str, str], optional): Inferred column overrides. Defaults to ``None``. compression (str, optional): What compression scheme to use, if any. Defaults to ``None``. hashes (List[str], optional): List of hashes to apply to dataset files. max_file_bytes (int | str, optional): Maximum shard size ,in bytes. Defaults to ``32mib``. @@ -174,12 +209,25 @@ def write_dataset(samples: Iterable[Dict[str, Any]], show_write_progress (bool): Show a progress bar for write progress. Defaults to ``True``. show_upload_progress (bool): Show a progress bar for upload progress. Defaults to ``True``. """ - # TODO:borrow the first sample to derive any inferred columns, then return it to its Iterable.. # TODO: Use the part.00000/ subdir trick to make datasets easily appendable to. + + # First, count the number of samples to write from the input Iterable, falling back to the + # user-provided hint if it has no size. total = len(samples) if hasattr(samples, '__len__') else num_samples # pyright: ignore + + # If user did not tell us the schema, pop a sample off the front of the iterator, infer + # columns, then put it back lol. + it = iter(samples) + if not columns: + head = next(it) + columns = infer_columns(head) + it = chain([head], it) + + # Now that we have an iteator for reals, wrap it with the "write" progress bar. if show_write_progress: - samples = tqdm(samples, total=total, leave=False) - columns = columns or {} + it = tqdm(it, total=total, leave=False) + + # Finally walk/write the samples. with MDSWriter(columns=columns, out=out, keep_local=keep_local, From 48dce5c3547d3d3a74f96ebdc808c98b097867f0 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 02:06:19 -0700 Subject: [PATCH 32/45] WWriting all given samples as one indexless MDS shard, returning its metadata. --- streaming/format/mds/writer.py | 75 +++++++++++++++++++++++++++++++--- 1 file changed, 70 insertions(+), 5 deletions(-) diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index d76151806..39d4d92c5 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -4,12 +4,17 @@ """:class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`.""" import json +import os from itertools import chain +from shutil import rmtree +from tempfile import mkdtemp from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from urllib.parse import urlparse import numpy as np from tqdm import tqdm +from streaming.format.index import get_index_basename from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) from streaming.format.writer import JointWriter @@ -202,7 +207,8 @@ def write_dataset(samples: Iterable[Dict[str, Any]], columns (Dict[str, str], optional): Inferred column overrides. Defaults to ``None``. compression (str, optional): What compression scheme to use, if any. Defaults to ``None``. hashes (List[str], optional): List of hashes to apply to dataset files. - max_file_bytes (int | str, optional): Maximum shard size ,in bytes. Defaults to ``32mib``. + max_file_bytes (int | str, optional): Optional maximum shard size, in bytes. If no limit, + we will write exactly one (potentially very large) shard. Defaults to ``32mib``. num_upload_threads (int, optional): Number of threads used to upload shards. Defaults to ``None``, which means to take the default, which is scaled for CPU count, etc. upload_retry (int): Number of upload reattempts before bailing. Defaults to ``2``. @@ -219,9 +225,9 @@ def write_dataset(samples: Iterable[Dict[str, Any]], # columns, then put it back lol. it = iter(samples) if not columns: - head = next(it) - columns = infer_columns(head) - it = chain([head], it) + sample = next(it) # If samples is empty, user goofed. + columns = infer_columns(sample) + it = chain([sample], it) # Now that we have an iteator for reals, wrap it with the "write" progress bar. if show_write_progress: @@ -237,5 +243,64 @@ def write_dataset(samples: Iterable[Dict[str, Any]], progress_bar=show_upload_progress, max_workers=num_upload_threads, retry=upload_retry) as writer: - for sample in samples: + for sample in it: writer.write(sample) + + +def write_shard(*args: Any, + tmp_dir: Optional[str] = None, + shard_basename: str = 'shard.00000.mds', + **kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Write the samples as a single MDS shard. + + Args: + *args (Any): Positional arguments for ``write_dataset()``. + tmp_dir (str, optional): Write the MDS dataset to this specific directory instead of + lettting python tempfile pick one for us. Empties and removes the diretory when done. + This argument is useful if your shard is very large and your system's standard temp + root is across a filesystem boundary from the local cache dir you are using. + shard_basename (str): Path to shard, relative to dataset. Defaults to ``shard.00000.mds``. + **kwargs (Dict[str, Any]): Keyword arguments for ``write_dataset()``. + + Returns: + Dict[str, Any]: JSON dict of the shard metadata. + """ + # We happen to only have a need for this restricted use case. + shard_dest = kwargs.get('out') + if not isinstance(shard_dest, str) or urlparse(shard_dest).scheme: + raise ValueError(f'Streaming is restricted to only writing MDS datasets of one unlimited' + + f'shard when the output is just a local abs/rel path with no file:// ' + + f'prefix, but got: {shard_dest}.') + + # Verify our actions are aligned with our goals, which is one shard of technically unlimited + # size because of specific weird reasons (i.e., mirroring Parquets to MDS). + if kwargs.get('max_file_bytes'): + raise ValueError('We question your values.') + kwargs.__dict__['max_file_bytes'] = None + + # Fall back to using python tempfile. + if not tmp_dir: + tmp_dir = mkdtemp() + + # Verify scratch dir not present. + if os.path.exists(tmp_dir): + raise ValueError(f'Scratch path already exists: {tmp_dir}.') + + # Serialize a uni-shard dataset to the temp directory. + kwargs.__dict__['out'] = tmp_dir + write_dataset(*args, **kwargs) + + # Move the shard from its dataset to the desired location. + shard_source = os.path.join(tmp_dir, shard_basename) + os.rename(shard_source, shard_dest) + + # Get the shard metadata from the index (could also get it from the MDS shard itself). + index_path = os.path.join(tmp_dir, get_index_basename()) + obj = json.load(open(index_path)) + info, = obj['shards'] + + # Cleanup. + rmtree(tmp_dir) + + # Return shard metadata. + return info From b0d1543d54e23fa7d668572555330201fcb9710a Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 02:22:42 -0700 Subject: [PATCH 33/45] Naming. --- streaming/format/mds/writer.py | 38 +++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 39d4d92c5..221b8d712 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -19,7 +19,7 @@ is_mds_encoding, mds_encode) from streaming.format.writer import JointWriter -__all__ = ['MDSWriter'] +__all__ = ['MDSWriter', 'write_mds_dataset', 'write_bare_mds_shard'] class MDSWriter(JointWriter): @@ -183,19 +183,19 @@ def infer_columns(sample: Dict[str, Any]) -> Dict[str, str]: return ret -def write_dataset(samples: Iterable[Dict[str, Any]], - out: Union[str, Tuple[str, str]], - *, - num_samples: Optional[int] = None, - keep_local: bool = False, - columns: Optional[Dict[str, str]] = None, - compression: Optional[str] = None, - hashes: Optional[List[str]] = None, - max_file_bytes: Optional[Union[int, str]] = '32mib', - num_upload_threads: Optional[int] = None, - upload_retry: int = 2, - show_write_progress: bool = True, - show_upload_progress: bool = True) -> None: +def write_mds_dataset(samples: Iterable[Dict[str, Any]], + out: Union[str, Tuple[str, str]], + *, + num_samples: Optional[int] = None, + keep_local: bool = False, + columns: Optional[Dict[str, str]] = None, + compression: Optional[str] = None, + hashes: Optional[List[str]] = None, + max_file_bytes: Optional[Union[int, str]] = '32mib', + num_upload_threads: Optional[int] = None, + upload_retry: int = 2, + show_write_progress: bool = True, + show_upload_progress: bool = True) -> None: """Write the samples as an MDS dataset. Args: @@ -247,11 +247,11 @@ def write_dataset(samples: Iterable[Dict[str, Any]], writer.write(sample) -def write_shard(*args: Any, - tmp_dir: Optional[str] = None, - shard_basename: str = 'shard.00000.mds', - **kwargs: Dict[str, Any]) -> Dict[str, Any]: - """Write the samples as a single MDS shard. +def write_bare_mds_shard(*args: Any, + tmp_dir: Optional[str] = None, + shard_basename: str = 'shard.00000.mds', + **kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Write the samples as a single MDS shard, returning shard metadata. Args: *args (Any): Positional arguments for ``write_dataset()``. From 99ad0c060362c487d96d63509e45da0d27141cc1 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 02:38:24 -0700 Subject: [PATCH 34/45] Fixes. --- .pre-commit-config.yaml | 2 +- benchmarks/serialization/compare.py | 3 ++- docs/source/examples | 1 - docs/source/index.md | 10 +++++----- docs/source/notebooks | 1 + streaming/format/mds/writer.py | 2 +- 6 files changed, 10 insertions(+), 9 deletions(-) delete mode 120000 docs/source/examples create mode 120000 docs/source/notebooks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c2f07b0c..2b1dd1bb1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 # Skip the pre-commit check for below directories to have # a consistency with the official tfrecord preprocessing scripts -exclude: "^(streaming/text/convert/enwiki/)" +exclude: "^(examples/text/enwiki_tok/)" repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/benchmarks/serialization/compare.py b/benchmarks/serialization/compare.py index d9341178b..62942d08d 100644 --- a/benchmarks/serialization/compare.py +++ b/benchmarks/serialization/compare.py @@ -29,7 +29,8 @@ import numpy as np import pandas as pd -from datasets import Dataset, disable_caching, load_dataset, load_from_disk # pyright: ignore +from datasets import disable_caching # pyright: ignore +from datasets import Dataset, load_dataset, load_from_disk from matplotlib import pyplot as plt from tqdm import tqdm diff --git a/docs/source/examples b/docs/source/examples deleted file mode 120000 index d15735c1d..000000000 --- a/docs/source/examples +++ /dev/null @@ -1 +0,0 @@ -../../examples \ No newline at end of file diff --git a/docs/source/index.md b/docs/source/index.md index 6728583ad..ea8716cf8 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -74,11 +74,11 @@ If you have any questions, please feel free to reach out to us on [Twitter](htt :maxdepth: 1 :caption: Examples - examples/cifar10.ipynb - examples/facesynthetics.ipynb - examples/synthetic_nlp.ipynb - examples/multiprocess_dataset_conversion.ipynb - examples/spark_dataframe_to_MDS.ipynb + notebooks/cifar10.ipynb + notebooks/facesynthetics.ipynb + notebooks/synthetic_nlp.ipynb + notebooks/multiprocess_dataset_conversion.ipynb + notebooks/spark_dataframe_to_MDS.ipynb .. toctree:: :hidden: diff --git a/docs/source/notebooks b/docs/source/notebooks new file mode 120000 index 000000000..d4082256d --- /dev/null +++ b/docs/source/notebooks @@ -0,0 +1 @@ +../../notebooks/ \ No newline at end of file diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 221b8d712..0c5655ff6 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -288,7 +288,7 @@ def write_bare_mds_shard(*args: Any, # Serialize a uni-shard dataset to the temp directory. kwargs.__dict__['out'] = tmp_dir - write_dataset(*args, **kwargs) + write_mds_dataset(*args, **kwargs) # Move the shard from its dataset to the desired location. shard_source = os.path.join(tmp_dir, shard_basename) From b38fce03977244a81fd9337480882fd4d4e5002c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 08:02:50 -0700 Subject: [PATCH 35/45] cli/hash.py. --- examples/__init__py | 0 streaming/cli/hash.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) delete mode 100644 examples/__init__py create mode 100644 streaming/cli/hash.py diff --git a/examples/__init__py b/examples/__init__py deleted file mode 100644 index e69de29bb..000000000 diff --git a/streaming/cli/hash.py b/streaming/cli/hash.py new file mode 100644 index 000000000..d6af04b76 --- /dev/null +++ b/streaming/cli/hash.py @@ -0,0 +1,42 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a Streaming index file for the given Parquet dataset.""" + +from argparse import ArgumentParser, Namespace + +from streaming.hashing import get_hash, get_hashes +from streaming.util.pretty import unpack_strs + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + supported = sorted(get_hashes()) + args = ArgumentParser() + args.add_argument('--file', type=str, required=True, help='Path to file to hash.') + args.add_argument('--hash', + type=str, + required=True, + help=f'Comma-delimted names of hash algortihms. Must be in this list: ' + + f'{supported}. Names and hex digests will be listed one per line.') + return args.parse_args() + + +def main(args: Namespace): + """Calculate one or more hashes of the data of the given file. + + Args: + args (Namespace): Command-line arguments. + """ + data = open(args.file, 'rb').read() + for algo in unpack_strs(args.hash): + hex_digest = get_hash(algo, data) + print(f'{algo} {hex_digest}') + + +if __name__ == '__main__': + main(parse_args()) From 7a9fc90e07a6127866a493803464634adc80860c Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 29 Oct 2023 14:56:54 -0700 Subject: [PATCH 36/45] walk_prefix() including local fs. --- streaming/storage/extra.py | 88 +++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/streaming/storage/extra.py b/streaming/storage/extra.py index 7f993edfd..01cd3ffe2 100644 --- a/streaming/storage/extra.py +++ b/streaming/storage/extra.py @@ -10,14 +10,17 @@ import re from re import Pattern from time import sleep, time -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union from urllib.parse import urlparse from streaming.hashing import get_hash from streaming.storage import CloudUploader, download_file from streaming.util.pretty import normalize_bytes, normalize_duration -__all__ = ['wait_for_file_to_exist', 'walk_dir', 'list_dataset_files', 'smart_download_file'] +__all__ = [ + 'wait_for_file_to_exist', 'walk_prefix', 'walk_dir', 'list_dataset_files', + 'smart_download_file' +] def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, @@ -44,6 +47,86 @@ def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') +def _normalize_path(path: str) -> Tuple[str, bool]: + """Analyze the path, returning normalized form and whether it is local. + + Args: + path (str): Path to analyze. + + Returns: + Tuple[str, bool]: Normalized path, and whether it is local. + """ + obj = urlparse(path) + if obj.scheme == '': + is_local = True + elif obj.scheme == 'file': + is_local = True + path = obj.path + else: + is_local = False + return path, is_local + + +def _normalize_dir(dirname: str) -> str: + """Normalize a dirname to contain one trailing slash. + + Args: + dirname (str): Directory path. + + Returns: + str: Normalized directory path. + """ + return dirname.rstrip(os.path.sep) + os.path.sep + + +def walk_prefix(prefix: str) -> List[str]: + """Walk all the files under a path prefix in sorted order. + + Notes: + * If you choose a non-directory as a prefix, returned paths will indeed be relative to your + non-directory, which may seem funky. + * There is some special case handling so that if your path is a local directory with or + without a trailing slash, returned paths will nevertheless never start with a slash, lest + they assume "absolute" power. + + Args: + prefix (str): Path prefix. + + Returns: + List[str]: All file paths under the prefix, which are all relative to the given prefix. + """ + prefix, is_local = _normalize_path(prefix) + + if is_local: + # Prefix points to local filesystem. + prefix_rel_files = [] + if os.path.isdir(prefix): + # Prefix is a directory, so include everything under the directory. + root = _normalize_dir(prefix) + for abs_dir, _, file_bases in os.walk(root): + root_rel_dir = abs_dir.lstrip(root) + for base in file_bases: + root_rel_file = os.path.join(root_rel_dir, base) + prefix_rel_files.append(root_rel_file) + else: + # Prefix has other stuff tacked onto it after the directory, so include everything + # under the prefix's parent directory which also matches the prefix's basename. + root = os.path.dirname(prefix) + for abs_dir, _, file_bases in os.walk(root): + for base in file_bases: + abs_file = os.path.join(abs_dir, base) + if abs_file.startswith(prefix): + prefix_rel_file = abs_file.lstrip(prefix) + prefix_rel_files.append(prefix_rel_file) + else: + # Prefix points to some non-local storage. + neither = CloudUploader.get(prefix, exist_ok=True) + prefix_rel_files = neither.list_objects(prefix) + + # TODO: verify all implementations do a global sort on returned paths, then remove this line. + return sorted(prefix_rel_files) + + def walk_dir(root: str) -> List[str]: """Recursively list the given directory in sorted order. @@ -124,6 +207,7 @@ def _get_overlap(want: Set[str], have: Set[str]) -> Dict[str, Any]: def list_dataset_files( + *, local: str, remote: Optional[str] = None, split: Optional[str] = None, From 5247bfe7b14102458a55a43563c59463ceb0fd76 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 01:39:48 -0800 Subject: [PATCH 37/45] generate_datasets.py: Tabulator. --- benchmarks/backends/generate_datasets.py | 320 +++++++++++++++++------ 1 file changed, 239 insertions(+), 81 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index ad658860d..eb9f974d0 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -8,7 +8,7 @@ from functools import partial from shutil import rmtree from time import time -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import lance import pyarrow as pa @@ -19,25 +19,36 @@ from pyspark.sql.types import IntegerType, StringType, StructField, StructType from task import generate_dataset from tqdm import tqdm +from typing_extensions import Self from wurlitzer import pipes from streaming import CSVWriter, JSONWriter, MDSWriter -def parse_args() -> Namespace: +def _parse_args() -> Namespace: """Parse command-line arguments. Returns: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--show_progress', type=int, default=1) + # Reproducibility. args.add_argument('--seed', type=int, default=1337) + + # Dataset and shard sizes. args.add_argument('--num_train', type=int, default=1 << 21) args.add_argument('--num_val', type=int, default=1 << 17) + args.add_argument('--size_limit', type=int, default=1 << 23) + args.add_argument('--samples_per_shard', type=int, default=1 << 17) - args.add_argument('--data_root', type=str, default='data/backendss/') + # Output root. + args.add_argument('--data_root', type=str, default='data/backends/') + + # Formats to output. + args.add_argument('--formats', type=str, default='csv,jsonl,lance,mds,parquet,delta') + + # Output subdir per format. args.add_argument('--csv', type=str, default='csv') args.add_argument('--jsonl', type=str, default='jsonl') args.add_argument('--lance', type=str, default='lance') @@ -45,17 +56,18 @@ def parse_args() -> Namespace: args.add_argument('--parquet', type=str, default='parquet') args.add_argument('--delta', type=str, default='delta') - args.add_argument('--size_limit', type=int, default=1 << 23) - args.add_argument('--samples_per_shard', type=int, default=1 << 17) + # Logging. + args.add_argument('--show_progress', type=int, default=1) args.add_argument('--quiet_delta', type=int, default=1) + return args.parse_args() -def _save_csv(nums: List[int], - txts: List[str], - root: str, - size_limit: Optional[int], - show_progress: bool = True) -> None: +def _write_csv(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: """Save the dataset in Streaming CSV form. Args: @@ -65,21 +77,27 @@ def _save_csv(nums: List[int], size_limit (int, optional): Maximum shard size in bytes, or no limit. show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. """ - columns = {'num': 'int', 'txt': 'str'} + columns = { + 'num': 'int', + 'txt': 'str', + } with CSVWriter(out=root, columns=columns, size_limit=size_limit) as out: each_sample = zip(nums, txts) if show_progress: each_sample = tqdm(each_sample, total=len(nums), leave=False) for num, txt in each_sample: - sample = {'num': num, 'txt': txt} + sample = { + 'num': num, + 'txt': txt, + } out.write(sample) -def _save_jsonl(nums: List[int], - txts: List[str], - root: str, - size_limit: Optional[int], - show_progress: bool = True) -> None: +def _write_jsonl(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: """Save the dataset Streaming JSONL form. Args: @@ -89,21 +107,27 @@ def _save_jsonl(nums: List[int], size_limit (int, optional): Maximum shard size in bytes, or no limit. show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. """ - columns = {'num': 'int', 'txt': 'str'} + columns = { + 'num': 'int', + 'txt': 'str', + } with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out: each_sample = zip(nums, txts) if show_progress: each_sample = tqdm(each_sample, total=len(nums), leave=False) for num, txt in each_sample: - sample = {'num': num, 'txt': txt} + sample = { + 'num': num, + 'txt': txt, + } out.write(sample) -def _save_mds(nums: List[int], - txts: List[str], - root: str, - size_limit: Optional[int], - show_progress: bool = True) -> None: +def _write_mds(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: """Save the dataset in Streaming MDS form. Args: @@ -113,21 +137,27 @@ def _save_mds(nums: List[int], size_limit (int, optional): Maximum shard size in bytes, or no limit. show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. """ - columns = {'num': 'int', 'txt': 'str'} + columns = { + 'num': 'int', + 'txt': 'str', + } with MDSWriter(out=root, columns=columns, size_limit=size_limit) as out: each_sample = zip(nums, txts) if show_progress: each_sample = tqdm(each_sample, total=len(nums), leave=False) for num, txt in each_sample: - sample = {'num': num, 'txt': txt} + sample = { + 'num': num, + 'txt': txt, + } out.write(sample) -def _save_parquet(nums: List[int], - txts: List[str], - root: str, - samples_per_shard: int, - show_progress: bool = True) -> None: +def _write_parquet(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + show_progress: bool = True) -> None: """Save the dataset in Streaming MDS form. Args: @@ -158,8 +188,7 @@ def _save_parquet(nums: List[int], pq.write_table(table, path) -def _wrapped_save_delta(nums: List[int], txts: List[str], root: str, - samples_per_shard: int) -> None: +def _write_delta(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: """Save the dataset in Streaming MDS form. Args: @@ -182,29 +211,29 @@ def _wrapped_save_delta(nums: List[int], txts: List[str], root: str, df.write.format('delta').option('maxRecordsPerFile', samples_per_shard).save(root) -def _save_delta(nums: List[int], - txts: List[str], - root: str, - samples_per_shard: int, - quiet: bool = True) -> None: - """Save the dataset in Streaming MDS form. +def _do_write_delta(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + quietly: bool = True) -> None: + """Save the dataset in Streaming MDS form, possibly capturing stdout/stderr. Args: nums (List[int]): The sample numbers. txts (List[str]): The sample texts. root (str): Root directory. samples_per_shard (int): Maximum numbero of samples per shard. - quiet (bool): Whether to capture the Delta logging. Defaults to ``True``. + quietly (bool): Whether to capture the Delta logging. Defaults to ``True``. """ - bang_on_pipes = lambda: _wrapped_save_delta(nums, txts, root, samples_per_shard) - if quiet: + write = lambda: _write_delta(nums, txts, root, samples_per_shard) + if quietly: with pipes(): - bang_on_pipes() + write() else: - bang_on_pipes() + write() -def _save_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: +def _write_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: """Save the dataset in Lance form. Args: @@ -219,7 +248,7 @@ def _save_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: lance.write_dataset(table, root, mode='create', max_rows_per_file=samples_per_shard) -def _stat(root: str): +def _get_file_sizes(root: str) -> List[int]: """Inventory what was written, collecting total files and total bytes. Args: @@ -228,14 +257,108 @@ def _stat(root: str): Returns: Tuple[int, int]: Total files and total bytes written. """ - rf = 0 - rz = 0 - for p, _, ff in os.walk(root): - rf += len(ff) - for f in ff: - g = os.path.join(p, f) - rz += os.stat(g).st_size - return rf, rz + sizes = [] + for parent, _, file_basenames in sorted(os.walk(root)): + for basename in sorted(file_basenames): + path = os.path.join(parent, basename) + size = os.stat(path).st_size + sizes.append(size) + return sizes + + +class Tabulator: + """Line by line text table printer. + + Example: + conf = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + left = 4 * ' ' + tab = Tabulator.from_conf(conf, left) + + Args: + cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). + left (str, optional): Optional string that is printed before each line (e.g., indents). + """ + + def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: + self.cols = cols + self.col_justs = [] + self.col_names = [] + self.col_widths = [] + for just, name, width in cols: + if just not in {'<', '>'}: + raise ValueError(f'Invalid justify (must be one of "<" or ">"): {just}.') + + if not name: + raise ValueError('Name must be non-empty.') + elif width < len(name): + raise ValueError(f'Name is too wide for its column width: {width} vs {name}.') + + if width <= 0: + raise ValueError(f'Width must be positive, but got: {width}.') + + self.col_justs.append(just) + self.col_names.append(name) + self.col_widths.append(width) + + self.left = left + + @classmethod + def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: + """Initialize a Tabulator from a text table defining its columns. + + Args: + conf (str): The table config. + left (str, optional): Optional string that is printed before each line (e.g., indents). + """ + cols = [] + for line in conf.strip().split('\n'): + words = line.split() + + if len(words) < 3: + raise ValueError(f'Invalid col config (must be "just name width"): {line}.') + + just = words[0] + name = ' '.join(words[1:-1]) + width = int(words[-1]) + cols.append((just, name, width)) + return cls(cols) + + def draw_row(self, info: Dict[str, str]) -> str: + fields = [] + for just, name, width in self.cols: + val = info[name] + txt = str(val) + if width < len(txt): + raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + + f'name: {name}, width: {width}) vs field {txt}.') + if just == '<': + txt = txt.ljust(width) + else: + txt = txt.rjust(width) + fields.append(txt) + + left_txt = self.left or '' + fields_txt = ' | '.join(fields) + return f'{left_txt} | {fields_txt} |' + + def draw_header(self) -> str: + info = dict(zip(self.col_names, self.col_names)) + return self.draw_row(info) + + def draw_divider(self) -> str: + seps = ('-' * width for width in self.col_widths) + info = dict(zip(self.col_names, seps)) + text = self.draw_row(info) + return text.replace('|', '+') def main(args: Namespace) -> None: @@ -244,50 +367,85 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ - if os.path.exists(args.data_root): - rmtree(args.data_root) - - kinds = 'csv', 'jsonl', 'lance', 'mds', 'parquet', 'delta' - + # Normalize arguments. + format_names = args.formats.split(',') if args.formats else [] show_progress = bool(args.show_progress) quiet_delta = bool(args.quiet_delta) - kind2save = { + # Wipe output directory if exists. + if os.path.exists(args.data_root): + rmtree(args.data_root) + + # Given args, now we know how to configure saving the dataset in each format. + format2write = { 'csv': - partial(_save_csv, size_limit=args.size_limit, show_progress=show_progress), + partial(_write_csv, size_limit=args.size_limit, show_progress=show_progress), 'delta': - partial(_save_delta, samples_per_shard=args.samples_per_shard, quiet=quiet_delta), + partial(_do_write_delta, quietly=quiet_delta, + samples_per_shard=args.samples_per_shard), 'jsonl': - partial(_save_jsonl, size_limit=args.size_limit, show_progress=show_progress), + partial(_write_jsonl, size_limit=args.size_limit, show_progress=show_progress), 'lance': - partial(_save_lance, samples_per_shard=args.samples_per_shard), + partial(_write_lance, samples_per_shard=args.samples_per_shard), 'mds': - partial(_save_mds, size_limit=args.size_limit, show_progress=show_progress), + partial(_write_mds, size_limit=args.size_limit, show_progress=show_progress), 'parquet': - partial(_save_parquet, + partial(_write_parquet, samples_per_shard=args.samples_per_shard, show_progress=show_progress), } - start = time() + # Now, generate the dataset. + t0 = time() dataset = generate_dataset(args.num_train, args.num_val, show_progress) - elapsed = time() - start + elapsed = time() - t0 print(f'Dataset generation: {elapsed:.3f} sec.') + # Confgure the text table printer for dataset writing info. + conf = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + left = 4 * ' ' + tab = Tabulator.from_conf(conf, left) + + # Write each split in each desired format. for split, nums, txts in dataset: + print() print(f'Split {split}:') - for kind in kinds: - kind_subdir = getattr(args, kind) - split_root = os.path.join(args.data_root, 'gold', kind_subdir, split) - save = kind2save[kind] - start = time() - save(nums, txts, split_root) - elapsed = time() - start - num_files, num_bytes = _stat(split_root) - bytes_per_file = num_bytes // num_files - print(f'* Saving dataset as {kind:8}: {elapsed:8.3f} sec; {num_files:3,} files; ' + - f'{num_bytes:12,} bytes; {bytes_per_file:12,} bytes/file.') + print(tab.draw_divider()) + print(tab.draw_header()) + print(tab.draw_divider()) + for format_name in format_names: + format_subdir = getattr(args, format_name) + split_root = os.path.join(args.data_root, 'gold', format_subdir, split) + write = format2write[format_name] + + t0 = time() + write(nums, txts, split_root) + elapsed = time() - t0 + + file_sizes = _get_file_sizes(split_root) + pretty_int = lambda num: f'{num:,}' + obj = { + 'format': format_name, + 'sec': f'{elapsed:.3f}', + 'samples': pretty_int(len(nums)), + 'usec/sp': f'{1e6 * elapsed / len(nums):.3f}', + 'bytes': pretty_int(sum(file_sizes)), + 'files': pretty_int(len(file_sizes)), + 'bytes/file': pretty_int(sum(file_sizes) // len(file_sizes)), + 'max bytes/file': pretty_int(max(file_sizes)), + } + print(tab.draw_row(obj)) + print(tab.draw_divider()) if __name__ == '__main__': - main(parse_args()) + main(_parse_args()) From 6dc5e227e8da4a6295ed6b87bfd17d8a8d7041bd Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 01:51:08 -0800 Subject: [PATCH 38/45] Fix (passing `left`, and spacing). --- benchmarks/backends/generate_datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index eb9f974d0..5448a94a2 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -330,7 +330,7 @@ def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: name = ' '.join(words[1:-1]) width = int(words[-1]) cols.append((just, name, width)) - return cls(cols) + return cls(cols, left) def draw_row(self, info: Dict[str, str]) -> str: fields = [] @@ -347,8 +347,8 @@ def draw_row(self, info: Dict[str, str]) -> str: fields.append(txt) left_txt = self.left or '' - fields_txt = ' | '.join(fields) - return f'{left_txt} | {fields_txt} |' + fields_txt = f' | '.join(fields) + return f'{left_txt}| {fields_txt} |' def draw_header(self) -> str: info = dict(zip(self.col_names, self.col_names)) From 18f64746200f13b4beb7fc80977f1da557be284d Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 01:52:17 -0800 Subject: [PATCH 39/45] Switch to box-drawing chars in Tabulator. Example: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` ─ ──────── ─ ────── ─ ──────────── ─ ──────── ─ ────────────── ─ ────── ─ ──────────── ─ ────────────── ─ │ format │ sec │ samples │ usec/sp │ bytes │ files │ bytes/file │ max bytes/file │ ─ ──────── ─ ────── ─ ──────────── ─ ──────── ─ ────────────── ─ ────── ─ ──────────── ─ ────────────── ─ │ csv │ 5.131 │ 2,097,152 │ 2.446 │ 171,899,840 │ 41 │ 4,192,679 │ 8,388,616 │ │ jsonl │ 12.535 │ 2,097,152 │ 5.977 │ 211,747,148 │ 51 │ 4,151,904 │ 8,388,607 │ │ lance │ 1.074 │ 2,097,152 │ 0.512 │ 176,961,928 │ 19 │ 9,313,785 │ 11,067,536 │ │ mds │ 8.649 │ 2,097,152 │ 4.124 │ 176,880,177 │ 23 │ 7,690,442 │ 8,388,604 │ │ parquet │ 1.323 │ 2,097,152 │ 0.631 │ 63,528,364 │ 16 │ 3,970,522 │ 3,973,860 │ │ delta │ 16.881 │ 2,097,152 │ 8.050 │ 55,106,514 │ 66 │ 834,947 │ 1,710,970 │ ─ ──────── ─ ────── ─ ──────────── ─ ──────── ─ ────────────── ─ ────── ─ ──────────── ─ ────────────── ─ ``` --- benchmarks/backends/generate_datasets.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index 5448a94a2..565a9c6d0 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -311,6 +311,9 @@ def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) self.left = left + self.box_horiz = chr(0x2500) + self.box_vert = chr(0x2502) + @classmethod def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: """Initialize a Tabulator from a text table defining its columns. @@ -347,18 +350,18 @@ def draw_row(self, info: Dict[str, str]) -> str: fields.append(txt) left_txt = self.left or '' - fields_txt = f' | '.join(fields) - return f'{left_txt}| {fields_txt} |' + fields_txt = f' {self.box_vert} '.join(fields) + return f'{left_txt}{self.box_vert} {fields_txt} {self.box_vert}' def draw_header(self) -> str: info = dict(zip(self.col_names, self.col_names)) return self.draw_row(info) def draw_divider(self) -> str: - seps = ('-' * width for width in self.col_widths) + seps = (self.box_horiz * width for width in self.col_widths) info = dict(zip(self.col_names, seps)) text = self.draw_row(info) - return text.replace('|', '+') + return text.replace(self.box_vert, self.box_horiz) def main(args: Namespace) -> None: From 52af2cbf83ac5aeec6311d306bfda9a15722ae37 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 07:53:30 -0800 Subject: [PATCH 40/45] Rewrite task.py. --- benchmarks/backends/task.py | 225 +++++++++++++++++++++++++----------- 1 file changed, 160 insertions(+), 65 deletions(-) diff --git a/benchmarks/backends/task.py b/benchmarks/backends/task.py index 0bf550f55..fbc805c69 100644 --- a/benchmarks/backends/task.py +++ b/benchmarks/backends/task.py @@ -3,105 +3,200 @@ """Generate infinite samples for a 'saying numbers as words' task.""" -from typing import List, Tuple +from typing import Dict, List, Tuple, TypeVar import numpy as np +from numpy.random import Generator from tqdm import tqdm + +def _generate_int(rng: Generator, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000) -> int: + """Pick a random integer to say in words. + + This is a synthetic dataset whose random numbers need to be distinct, deterministic given a + seed, and little else. We choose a distribution that seems the most pleasing to us. + + Properties: + * About 80% positive and 20% negative. + * Magnitude of up to a billion on either side of zero. + * Strongly skewed toward the origin, i.e. chosen uniformly across base-10 digit lengths (at + least until running out of integers of that length anyway). + + Args: + rng (Generator): NumPy random number generator. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + """ + if not 0 <= pos_prob <= 1: + raise ValueError(f'Invalid positive probability ``pos_prob``: 0 <= {pos_prob} <= 1.') + + if not low < 0 < high: + raise ValueError(f'Invalid sampling range ``low`` and/or ``high``: {low} < 0 < {high}.') + + is_pos = rng.uniform() < pos_prob + max_digits = np.log10(high) if is_pos else np.log10(-low) + power = rng.uniform(0, max_digits) + magnitude = int(10**power) + sign = is_pos * 2 - 1 + return sign * magnitude + + +def _generate_ints(count: int, + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> List[int]: + """Sample until we have the given number of distinct integers. + + Args: + count (int): How many samples to draw. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to display a progress bar. Defaults to ``True``. + + Returns: + List[int]: The integers that were drawn. + """ + rng = np.random.default_rng(seed) + nums = set() + progress_bar = tqdm(total=count, leave=False) if show_progress else None + while len(nums) < count: + num = _generate_int(rng) + if num in nums: + continue + + nums.add(num) + if progress_bar: + progress_bar.update(1) + if progress_bar: + progress_bar.close() + + nums = sorted(nums) + rng.shuffle(nums) + return nums + + _ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' 'fifteen sixteen seventeen eighteen nineteen').split() _tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() -def _say(i: int) -> List[str]: - """Get the word form of a number. +def _int_to_words(num: int) -> List[str]: + """Say an integer as a list of words. Args: - i (int): The number. + num (int): The integer. Returns: - List[str]: The number in word form. + List[str]: The integer as a list of words. """ - if i < 0: - return ['negative'] + _say(-i) - elif i <= 19: - return [_ones[i]] - elif i < 100: - return [_tens[i // 10 - 2]] + ([_ones[i % 10]] if i % 10 else []) - elif i < 1_000: - return [_ones[i // 100], 'hundred'] + (_say(i % 100) if i % 100 else []) - elif i < 1_000_000: - return _say(i // 1_000) + ['thousand'] + (_say(i % 1_000) if i % 1_000 else []) - elif i < 1_000_000_000: - return _say(i // 1_000_000) + ['million'] + (_say(i % 1_000_000) if i % 1_000_000 else []) + if num < 0: + return ['negative'] + _int_to_words(-num) + elif num <= 19: + return [_ones[num]] + elif num < 100: + tens = [_tens[num // 10 - 2]] + ones = [_ones[num % 10]] if num % 10 else [] + return tens + ones + elif num < 1_000: + hundreds = [_ones[num // 100], 'hundred'] + etc = _int_to_words(num % 100) if num % 100 else [] + return hundreds + etc + elif num < 1_000_000: + thousands = _int_to_words(num // 1_000) + ['thousand'] + etc = _int_to_words(num % 1_000) if num % 1_000 else [] + return thousands + etc + elif num < 1_000_000_000: + millions = _int_to_words(num // 1_000_000) + ['million'] + etc = _int_to_words(num % 1_000_000) if num % 1_000_000 else [] + return millions + etc else: - raise ValueError('Integer must be less than a billion, but got: {i}') + raise ValueError('Integer out of range: -1,000,000,000 < {num} < +1,000,000,000.') + +def _int_to_text(num: int) -> str: + """Say an integer as text. -def _generate_number() -> int: - """Generate a random integer to say. + Args: + num (int): The integer. Returns: - int: The integer. + str: The integer as text. """ - sign = (np.random.uniform() < 0.8) * 2 - 1 - expt = np.random.uniform(0, 9) - mag = int(10**expt) - return sign * mag + words = _int_to_words(num) + return ' '.join(words) + +T = TypeVar('T') -def _generate_numbers(num_train: int, num_val: int, - show_progress: bool) -> Tuple[List[int], List[int]]: - """Get two non-overlapping splits of integers to say. + +def _split(items: List[T], sizes: List[int]) -> List[List[T]]: + """Divide the given items across the splits given by their sizes. Args: - num_train (int): Number of training samples. - num_val (int): Number of validation samples. - show_progress (bool): Whether to display a progress bar. + items (List[Any]): The items to divide across the spans. + sizes (List[int]): Number of items per split. Returns: - Tuple[List[int], List[int]]: The two generated splits. + List[List[Any]]: Each split of items. """ - total = num_train + num_val - nums = set() - pbar = tqdm(total=total, leave=False) if show_progress else None - while len(nums) < total: - num = _generate_number() - if num in nums: - continue - nums.add(num) - if pbar: - pbar.update(1) - if pbar: - pbar.close() - nums = sorted(nums) - np.random.shuffle(nums) - train_nums = nums[:num_train] - val_nums = nums[num_train:] - return train_nums, val_nums + arr = np.asarray(sizes, np.int64) + ends = arr.cumsum() + begins = ends - arr[0] + + if len(items) != ends[-1]: + raise ValueError(f'Number of items must match the combined size of the splits: ' + + f'{len(items)} items vs splits of size {sizes} = {ends[-1]}.') + splits = [] + for begin, end in zip(begins, ends): + split = items[begin:end] + splits.append(split) -_split_type = Tuple[str, List[int], List[str]] + return splits -def generate_dataset(num_train: int, num_val: int, show_progress: bool) -> List[_split_type]: - """Generate the dataset, which will be saved in different forms for comparison. +def generate(splits: Dict[str, int], + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> Dict[str, Tuple[List[int], List[str]]]: + """Generate a dataset, made of splits, to be saved in different forms for comparison. Args: - num_train (int): Number of train samples. - num_val (int): Number of val samples. - show_progress (bool): Whether to show a progress bar. + splits (Dict[str, int]): Mapping of split name to size in samples. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to show a progress bar. Defaults to ``True``. Returns: - List[Tuple[str, List[int], List[str]]]: List of dataset splits. + Dict[str, Tuple[List[int], List[str]]]: Mapping of split name to nums and texts. """ - train_nums, val_nums = _generate_numbers(num_train, num_val, show_progress) - - train_txts = [' '.join(_say(num)) for num in train_nums] - val_txts = [' '.join(_say(num)) for num in val_nums] - - return [ - ('train', train_nums, train_txts), - ('val', val_nums, val_txts), - ] + split_sizes = [] + total = 0 + for name in sorted(splits): + size = splits[name] + split_sizes.append(size) + total += size + + nums = _generate_ints(total, seed, low, high, show_progress) + nums_per_split = _split(nums, split_sizes) + + texts = list(map(_int_to_text, nums)) + texts_per_split = _split(texts, split_sizes) + + dataset = {} + for index, name in enumerate(sorted(splits)): + dataset[name] = nums_per_split[index], texts_per_split[index] + return dataset From bc125b481c45fc6e23c5dda8e1fc7057452c9497 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 08:59:05 -0800 Subject: [PATCH 41/45] Fixes. --- benchmarks/backends/generate_datasets.py | 57 ++++++++++++++++++------ benchmarks/backends/task.py | 30 ++++++------- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index 565a9c6d0..a4685a785 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -5,10 +5,11 @@ import os from argparse import ArgumentParser, Namespace +from collections import defaultdict from functools import partial from shutil import rmtree from time import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import lance import pyarrow as pa @@ -17,7 +18,7 @@ from delta import configure_spark_with_delta_pip from pyarrow import parquet as pq from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from task import generate_dataset +from task import generate from tqdm import tqdm from typing_extensions import Self from wurlitzer import pipes @@ -37,10 +38,11 @@ def _parse_args() -> Namespace: args.add_argument('--seed', type=int, default=1337) # Dataset and shard sizes. - args.add_argument('--num_train', type=int, default=1 << 21) - args.add_argument('--num_val', type=int, default=1 << 17) + args.add_argument('--small', type=int, default=1 << 16) + args.add_argument('--medium', type=int, default=1 << 20) + args.add_argument('--large', type=int, default=1 << 24) args.add_argument('--size_limit', type=int, default=1 << 23) - args.add_argument('--samples_per_shard', type=int, default=1 << 17) + args.add_argument('--samples_per_shard', type=int, default=1 << 18) # Output root. args.add_argument('--data_root', type=str, default='data/backends/') @@ -197,7 +199,7 @@ def _write_delta(nums: List[int], txts: List[str], root: str, samples_per_shard: root (str): Root directory. samples_per_shard (int): Maximum numbero of samples per shard. """ - builder = pyspark.sql.SparkSession.builder.appName('deltatorch-example') # pyright: ignore + builder = pyspark.sql.SparkSession.builder.appName('prolix') # pyright: ignore builder = builder.config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') builder = builder.config('spark.sql.catalog.spark_catalog', 'org.apache.spark.sql.delta.catalog.DeltaCatalog') @@ -364,6 +366,27 @@ def draw_divider(self) -> str: return text.replace(self.box_vert, self.box_horiz) +def _splits_by_size(dataset: Dict[str, Tuple[List[int], List[str]]]) -> Iterable[str]: + """Order a dataset's splits by their size in samples, then by name. + + Argxs: + dataset (Dict[str, Tuple[List[int], List[str]]]): Mapping of split name to split data. + + Returns: + Iterable[str]: Ordered split names. + """ + size2splits = defaultdict(list) + for split, (nums, _) in dataset.items(): + size2splits[len(nums)].append(split) + + splits_by_size = [] + for size in sorted(size2splits): + for split in sorted(size2splits[size]): + splits_by_size.append(split) + + return splits_by_size + + def main(args: Namespace) -> None: """Generate identical datasets in various formats for performance comparison. @@ -374,6 +397,11 @@ def main(args: Namespace) -> None: format_names = args.formats.split(',') if args.formats else [] show_progress = bool(args.show_progress) quiet_delta = bool(args.quiet_delta) + split2size = { + 'small': args.small, + 'medium': args.medium, + 'large': args.large, + } # Wipe output directory if exists. if os.path.exists(args.data_root): @@ -400,9 +428,9 @@ def main(args: Namespace) -> None: # Now, generate the dataset. t0 = time() - dataset = generate_dataset(args.num_train, args.num_val, show_progress) + dataset = generate(split2size, show_progress) elapsed = time() - t0 - print(f'Dataset generation: {elapsed:.3f} sec.') + print(f'Generate: {elapsed:.3f} sec.') # Confgure the text table printer for dataset writing info. conf = ''' @@ -418,13 +446,14 @@ def main(args: Namespace) -> None: left = 4 * ' ' tab = Tabulator.from_conf(conf, left) - # Write each split in each desired format. - for split, nums, txts in dataset: + # Write each split in each desired formats, in order of size. + for split in _splits_by_size(dataset): print() - print(f'Split {split}:') - print(tab.draw_divider()) + print(f'Write split: {split}') + print(tab.draw_line()) print(tab.draw_header()) - print(tab.draw_divider()) + print(tab.draw_line()) + nums, txts = dataset[split] for format_name in format_names: format_subdir = getattr(args, format_name) split_root = os.path.join(args.data_root, 'gold', format_subdir, split) @@ -447,7 +476,7 @@ def main(args: Namespace) -> None: 'max bytes/file': pretty_int(max(file_sizes)), } print(tab.draw_row(obj)) - print(tab.draw_divider()) + print(tab.draw_line()) if __name__ == '__main__': diff --git a/benchmarks/backends/task.py b/benchmarks/backends/task.py index fbc805c69..ab83aab5a 100644 --- a/benchmarks/backends/task.py +++ b/benchmarks/backends/task.py @@ -39,8 +39,8 @@ def _generate_int(rng: Generator, is_pos = rng.uniform() < pos_prob max_digits = np.log10(high) if is_pos else np.log10(-low) - power = rng.uniform(0, max_digits) - magnitude = int(10**power) + exponent = rng.uniform(0, max_digits) + magnitude = int(10**exponent) sign = is_pos * 2 - 1 return sign * magnitude @@ -148,23 +148,22 @@ def _split(items: List[T], sizes: List[int]) -> List[List[T]]: Returns: List[List[Any]]: Each split of items. """ - arr = np.asarray(sizes, np.int64) - ends = arr.cumsum() - begins = ends - arr[0] - - if len(items) != ends[-1]: + total = sum(sizes) + if len(items) != total: raise ValueError(f'Number of items must match the combined size of the splits: ' + - f'{len(items)} items vs splits of size {sizes} = {ends[-1]}.') + f'{len(items)} items vs splits of size {sizes} = {total}.') splits = [] - for begin, end in zip(begins, ends): - split = items[begin:end] + begin = 0 + for size in sizes: + split = items[begin:begin + size] splits.append(split) + begin += size return splits -def generate(splits: Dict[str, int], +def generate(split2size: Dict[str, int], seed: int = 0x1337, pos_prob: float = 0.75, low: int = -1_000_000_000, @@ -173,7 +172,7 @@ def generate(splits: Dict[str, int], """Generate a dataset, made of splits, to be saved in different forms for comparison. Args: - splits (Dict[str, int]): Mapping of split name to size in samples. + split2size (Dict[str, int]): Mapping of split name to size in samples. seed (int): Seed for the random number generator. Defaults to ``0x1337``. pos_prob (float): Probability of output being positive. Defaults to ``0.75``. low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. @@ -185,8 +184,8 @@ def generate(splits: Dict[str, int], """ split_sizes = [] total = 0 - for name in sorted(splits): - size = splits[name] + for split in sorted(split2size): + size = split2size[split] split_sizes.append(size) total += size @@ -197,6 +196,7 @@ def generate(splits: Dict[str, int], texts_per_split = _split(texts, split_sizes) dataset = {} - for index, name in enumerate(sorted(splits)): + for index, name in enumerate(sorted(split2size)): dataset[name] = nums_per_split[index], texts_per_split[index] + return dataset From a2ff86f572abc511b511201d6a8344c083be34d1 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 09:10:58 -0800 Subject: [PATCH 42/45] Fix. --- benchmarks/backends/generate_datasets.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index a4685a785..7b3b0d264 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -341,14 +341,13 @@ def draw_row(self, info: Dict[str, str]) -> str: fields = [] for just, name, width in self.cols: val = info[name] - txt = str(val) + + txt = val if isinstance(val, str) else str(val) if width < len(txt): raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + f'name: {name}, width: {width}) vs field {txt}.') - if just == '<': - txt = txt.ljust(width) - else: - txt = txt.rjust(width) + + txt = txt.ljust(width) if just == '<' else txt.rjust(width) fields.append(txt) left_txt = self.left or '' @@ -359,7 +358,7 @@ def draw_header(self) -> str: info = dict(zip(self.col_names, self.col_names)) return self.draw_row(info) - def draw_divider(self) -> str: + def draw_line(self) -> str: seps = (self.box_horiz * width for width in self.col_widths) info = dict(zip(self.col_names, seps)) text = self.draw_row(info) From 57e7571208c7dbda5fe29aae7be3fbf65c4dceb2 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 09:41:01 -0800 Subject: [PATCH 43/45] Misc. --- benchmarks/backends/generate_datasets.py | 54 ++++++++++++++++-------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index 7b3b0d264..9758bad6c 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -9,7 +9,7 @@ from functools import partial from shutil import rmtree from time import time -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import lance import pyarrow as pa @@ -48,7 +48,7 @@ def _parse_args() -> Namespace: args.add_argument('--data_root', type=str, default='data/backends/') # Formats to output. - args.add_argument('--formats', type=str, default='csv,jsonl,lance,mds,parquet,delta') + args.add_argument('--formats', type=str, default='csv,delta,jsonl,lance,mds,parquet') # Output subdir per format. args.add_argument('--csv', type=str, default='csv') @@ -287,7 +287,7 @@ class Tabulator: Args: cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). - left (str, optional): Optional string that is printed before each line (e.g., indents). + left (str, optional): Print this before each line (e.g., indenting). Defaults to ``None``. """ def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: @@ -313,8 +313,8 @@ def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) self.left = left - self.box_horiz = chr(0x2500) - self.box_vert = chr(0x2502) + self.box_chr_horiz = chr(0x2500) + self.box_chr_vert = chr(0x2502) @classmethod def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: @@ -337,10 +337,18 @@ def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: cols.append((just, name, width)) return cls(cols, left) - def draw_row(self, info: Dict[str, str]) -> str: + def draw_row(self, row: Dict[str, Any]) -> str: + """Draw a row, given a mapping of column name to field value. + + Args: + row (Dict[str, Any]): Mapping of column name to field value. + + Returns: + str: Text line. + """ fields = [] for just, name, width in self.cols: - val = info[name] + val = row[name] txt = val if isinstance(val, str) else str(val) if width < len(txt): @@ -351,18 +359,28 @@ def draw_row(self, info: Dict[str, str]) -> str: fields.append(txt) left_txt = self.left or '' - fields_txt = f' {self.box_vert} '.join(fields) - return f'{left_txt}{self.box_vert} {fields_txt} {self.box_vert}' + fields_txt = f' {self.box_chr_vert} '.join(fields) + return f'{left_txt}{self.box_chr_vert} {fields_txt} {self.box_chr_vert}' def draw_header(self) -> str: - info = dict(zip(self.col_names, self.col_names)) - return self.draw_row(info) + """Draw a header row. + + Returns: + str: Text line. + """ + row = dict(zip(self.col_names, self.col_names)) + return self.draw_row(row) def draw_line(self) -> str: - seps = (self.box_horiz * width for width in self.col_widths) - info = dict(zip(self.col_names, seps)) - text = self.draw_row(info) - return text.replace(self.box_vert, self.box_horiz) + """Draw a divider row. + + Returns: + str: Text line. + """ + seps = (self.box_chr_horiz * width for width in self.col_widths) + row = dict(zip(self.col_names, seps)) + line = self.draw_row(row) + return line.replace(self.box_chr_vert, self.box_chr_horiz) def _splits_by_size(dataset: Dict[str, Tuple[List[int], List[str]]]) -> Iterable[str]: @@ -446,6 +464,7 @@ def main(args: Namespace) -> None: tab = Tabulator.from_conf(conf, left) # Write each split in each desired formats, in order of size. + pretty_int = lambda num: f'{num:,}' for split in _splits_by_size(dataset): print() print(f'Write split: {split}') @@ -463,8 +482,7 @@ def main(args: Namespace) -> None: elapsed = time() - t0 file_sizes = _get_file_sizes(split_root) - pretty_int = lambda num: f'{num:,}' - obj = { + row = { 'format': format_name, 'sec': f'{elapsed:.3f}', 'samples': pretty_int(len(nums)), @@ -474,7 +492,7 @@ def main(args: Namespace) -> None: 'bytes/file': pretty_int(sum(file_sizes) // len(file_sizes)), 'max bytes/file': pretty_int(max(file_sizes)), } - print(tab.draw_row(obj)) + print(tab.draw_row(row)) print(tab.draw_line()) From f1e10bb62912ed2c17fe9fe18608f430b6f6ab2e Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 11:56:52 -0800 Subject: [PATCH 44/45] Split out Tabulator. --- benchmarks/backends/generate_datasets.py | 119 +--------------------- benchmarks/backends/task.py | 4 +- streaming/util/tabulator.py | 123 +++++++++++++++++++++++ 3 files changed, 127 insertions(+), 119 deletions(-) create mode 100644 streaming/util/tabulator.py diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate_datasets.py index 9758bad6c..9cdb511e6 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate_datasets.py @@ -9,7 +9,7 @@ from functools import partial from shutil import rmtree from time import time -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple import lance import pyarrow as pa @@ -20,10 +20,10 @@ from pyspark.sql.types import IntegerType, StringType, StructField, StructType from task import generate from tqdm import tqdm -from typing_extensions import Self from wurlitzer import pipes from streaming import CSVWriter, JSONWriter, MDSWriter +from streaming.base.util.tabulator import Tabulator def _parse_args() -> Namespace: @@ -268,121 +268,6 @@ def _get_file_sizes(root: str) -> List[int]: return sizes -class Tabulator: - """Line by line text table printer. - - Example: - conf = ''' - < format 8 - > sec 6 - > samples 12 - > usec/sp 8 - > bytes 14 - > files 6 - > bytes/file 12 - > max bytes/file 14 - ''' - left = 4 * ' ' - tab = Tabulator.from_conf(conf, left) - - Args: - cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). - left (str, optional): Print this before each line (e.g., indenting). Defaults to ``None``. - """ - - def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: - self.cols = cols - self.col_justs = [] - self.col_names = [] - self.col_widths = [] - for just, name, width in cols: - if just not in {'<', '>'}: - raise ValueError(f'Invalid justify (must be one of "<" or ">"): {just}.') - - if not name: - raise ValueError('Name must be non-empty.') - elif width < len(name): - raise ValueError(f'Name is too wide for its column width: {width} vs {name}.') - - if width <= 0: - raise ValueError(f'Width must be positive, but got: {width}.') - - self.col_justs.append(just) - self.col_names.append(name) - self.col_widths.append(width) - - self.left = left - - self.box_chr_horiz = chr(0x2500) - self.box_chr_vert = chr(0x2502) - - @classmethod - def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: - """Initialize a Tabulator from a text table defining its columns. - - Args: - conf (str): The table config. - left (str, optional): Optional string that is printed before each line (e.g., indents). - """ - cols = [] - for line in conf.strip().split('\n'): - words = line.split() - - if len(words) < 3: - raise ValueError(f'Invalid col config (must be "just name width"): {line}.') - - just = words[0] - name = ' '.join(words[1:-1]) - width = int(words[-1]) - cols.append((just, name, width)) - return cls(cols, left) - - def draw_row(self, row: Dict[str, Any]) -> str: - """Draw a row, given a mapping of column name to field value. - - Args: - row (Dict[str, Any]): Mapping of column name to field value. - - Returns: - str: Text line. - """ - fields = [] - for just, name, width in self.cols: - val = row[name] - - txt = val if isinstance(val, str) else str(val) - if width < len(txt): - raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + - f'name: {name}, width: {width}) vs field {txt}.') - - txt = txt.ljust(width) if just == '<' else txt.rjust(width) - fields.append(txt) - - left_txt = self.left or '' - fields_txt = f' {self.box_chr_vert} '.join(fields) - return f'{left_txt}{self.box_chr_vert} {fields_txt} {self.box_chr_vert}' - - def draw_header(self) -> str: - """Draw a header row. - - Returns: - str: Text line. - """ - row = dict(zip(self.col_names, self.col_names)) - return self.draw_row(row) - - def draw_line(self) -> str: - """Draw a divider row. - - Returns: - str: Text line. - """ - seps = (self.box_chr_horiz * width for width in self.col_widths) - row = dict(zip(self.col_names, seps)) - line = self.draw_row(row) - return line.replace(self.box_chr_vert, self.box_chr_horiz) - - def _splits_by_size(dataset: Dict[str, Tuple[List[int], List[str]]]) -> Iterable[str]: """Order a dataset's splits by their size in samples, then by name. diff --git a/benchmarks/backends/task.py b/benchmarks/backends/task.py index ab83aab5a..d354f144c 100644 --- a/benchmarks/backends/task.py +++ b/benchmarks/backends/task.py @@ -196,7 +196,7 @@ def generate(split2size: Dict[str, int], texts_per_split = _split(texts, split_sizes) dataset = {} - for index, name in enumerate(sorted(split2size)): - dataset[name] = nums_per_split[index], texts_per_split[index] + for index, split in enumerate(sorted(split2size)): + dataset[split] = nums_per_split[index], texts_per_split[index] return dataset diff --git a/streaming/util/tabulator.py b/streaming/util/tabulator.py new file mode 100644 index 000000000..a6cd9454a --- /dev/null +++ b/streaming/util/tabulator.py @@ -0,0 +1,123 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Line by line text table printer.""" + +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Self + + +class Tabulator: + """Line by line text table printer. + + Example: + conf = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + left = 4 * ' ' + tab = Tabulator.from_conf(conf, left) + + Args: + cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). + left (str, optional): Print this before each line (e.g., indenting). Defaults to ``None``. + """ + + def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: + self.cols = cols + self.col_justs = [] + self.col_names = [] + self.col_widths = [] + for just, name, width in cols: + if just not in {'<', '>'}: + raise ValueError(f'Invalid justify (must be one of "<" or ">"): {just}.') + + if not name: + raise ValueError('Name must be non-empty.') + elif width < len(name): + raise ValueError(f'Name is too wide for its column width: {width} vs {name}.') + + if width <= 0: + raise ValueError(f'Width must be positive, but got: {width}.') + + self.col_justs.append(just) + self.col_names.append(name) + self.col_widths.append(width) + + self.left = left + + self.box_chr_horiz = chr(0x2500) + self.box_chr_vert = chr(0x2502) + + @classmethod + def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: + """Initialize a Tabulator from a text table defining its columns. + + Args: + conf (str): The table config. + left (str, optional): Optional string that is printed before each line (e.g., indents). + """ + cols = [] + for line in conf.strip().split('\n'): + words = line.split() + + if len(words) < 3: + raise ValueError(f'Invalid col config (must be "just name width"): {line}.') + + just = words[0] + name = ' '.join(words[1:-1]) + width = int(words[-1]) + cols.append((just, name, width)) + return cls(cols, left) + + def draw_row(self, row: Dict[str, Any]) -> str: + """Draw a row, given a mapping of column name to field value. + + Args: + row (Dict[str, Any]): Mapping of column name to field value. + + Returns: + str: Text line. + """ + fields = [] + for just, name, width in self.cols: + val = row[name] + + txt = val if isinstance(val, str) else str(val) + if width < len(txt): + raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + + f'name: {name}, width: {width}) vs field {txt}.') + + txt = txt.ljust(width) if just == '<' else txt.rjust(width) + fields.append(txt) + + left_txt = self.left or '' + fields_txt = f' {self.box_chr_vert} '.join(fields) + return f'{left_txt}{self.box_chr_vert} {fields_txt} {self.box_chr_vert}' + + def draw_header(self) -> str: + """Draw a header row. + + Returns: + str: Text line. + """ + row = dict(zip(self.col_names, self.col_names)) + return self.draw_row(row) + + def draw_line(self) -> str: + """Draw a divider row. + + Returns: + str: Text line. + """ + seps = (self.box_chr_horiz * width for width in self.col_widths) + row = dict(zip(self.col_names, seps)) + line = self.draw_row(row) + return line.replace(self.box_chr_vert, self.box_chr_horiz) From cbfcab39e1dea9eefd161aa944105fb9201f3202 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Sun, 5 Nov 2023 13:46:58 -0800 Subject: [PATCH 45/45] Refactor. --- benchmarks/__init__.py | 4 + .../{generate_datasets.py => generate.py} | 292 +++++++++++++++--- benchmarks/backends/task.py | 202 ------------ 3 files changed, 248 insertions(+), 250 deletions(-) create mode 100644 benchmarks/__init__.py rename benchmarks/backends/{generate_datasets.py => generate.py} (59%) delete mode 100644 benchmarks/backends/task.py diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..62e9d1f0e --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming benchmarks.""" diff --git a/benchmarks/backends/generate_datasets.py b/benchmarks/backends/generate.py similarity index 59% rename from benchmarks/backends/generate_datasets.py rename to benchmarks/backends/generate.py index 9cdb511e6..e1f18b015 100644 --- a/benchmarks/backends/generate_datasets.py +++ b/benchmarks/backends/generate.py @@ -1,29 +1,29 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Generate a parquet dataset for testing.""" - +"""Generate copies of the same dataset in different Streaming formats.""" import os from argparse import ArgumentParser, Namespace from collections import defaultdict from functools import partial from shutil import rmtree from time import time -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar import lance +import numpy as np import pyarrow as pa import pyspark import pyspark.sql from delta import configure_spark_with_delta_pip +from numpy.random import Generator from pyarrow import parquet as pq from pyspark.sql.types import IntegerType, StringType, StructField, StructType -from task import generate from tqdm import tqdm from wurlitzer import pipes from streaming import CSVWriter, JSONWriter, MDSWriter -from streaming.base.util.tabulator import Tabulator +from streaming.util.tabulator import Tabulator def _parse_args() -> Namespace: @@ -37,10 +37,15 @@ def _parse_args() -> Namespace: # Reproducibility. args.add_argument('--seed', type=int, default=1337) - # Dataset and shard sizes. - args.add_argument('--small', type=int, default=1 << 16) + # Dataset properties. + args.add_argument('--data_pos_prob', type=float, default=0.75) + args.add_argument('--data_low', type=int, default=-1_000_000_000) + args.add_argument('--data_high', type=int, default=1_000_000_000) + + # Sizes of dataset splits and shards. + args.add_argument('--small', type=int, default=1 << 15) args.add_argument('--medium', type=int, default=1 << 20) - args.add_argument('--large', type=int, default=1 << 24) + args.add_argument('--large', type=int, default=1 << 25) args.add_argument('--size_limit', type=int, default=1 << 23) args.add_argument('--samples_per_shard', type=int, default=1 << 18) @@ -50,14 +55,6 @@ def _parse_args() -> Namespace: # Formats to output. args.add_argument('--formats', type=str, default='csv,delta,jsonl,lance,mds,parquet') - # Output subdir per format. - args.add_argument('--csv', type=str, default='csv') - args.add_argument('--jsonl', type=str, default='jsonl') - args.add_argument('--lance', type=str, default='lance') - args.add_argument('--mds', type=str, default='mds') - args.add_argument('--parquet', type=str, default='parquet') - args.add_argument('--delta', type=str, default='delta') - # Logging. args.add_argument('--show_progress', type=int, default=1) args.add_argument('--quiet_delta', type=int, default=1) @@ -65,6 +62,198 @@ def _parse_args() -> Namespace: return args.parse_args() +def _generate_int(rng: Generator, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000) -> int: + """Pick a random integer to say in words. + + This is a synthetic dataset whose random numbers need to be distinct, deterministic given a + seed, and little else. We choose a distribution that seems the most pleasing to us. + + Properties: + * About 80% positive and 20% negative. + * Magnitude of up to a billion on either side of zero. + * Strongly skewed toward the origin, i.e. chosen uniformly across base-10 digit lengths (at + least until running out of integers of that length anyway). + + Args: + rng (Generator): NumPy random number generator. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + """ + if not 0 <= pos_prob <= 1: + raise ValueError(f'Invalid positive probability ``pos_prob``: 0 <= {pos_prob} <= 1.') + + if not low < 0 < high: + raise ValueError(f'Invalid sampling range ``low`` and/or ``high``: {low} < 0 < {high}.') + + is_pos = rng.uniform() < pos_prob + max_digits = np.log10(high) if is_pos else np.log10(-low) + exponent = rng.uniform(0, max_digits) + magnitude = int(10**exponent) + sign = is_pos * 2 - 1 + return sign * magnitude + + +def _generate_ints(count: int, + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> List[int]: + """Sample until we have the given number of distinct integers. + + Args: + count (int): How many samples to draw. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to display a progress bar. Defaults to ``True``. + + Returns: + List[int]: The integers that were drawn. + """ + rng = np.random.default_rng(seed) + nums = set() + progress_bar = tqdm(total=count, leave=False) if show_progress else None + while len(nums) < count: + num = _generate_int(rng) + if num in nums: + continue + + nums.add(num) + if progress_bar: + progress_bar.update(1) + if progress_bar: + progress_bar.close() + + nums = sorted(nums) + rng.shuffle(nums) + return nums + + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def _int_to_words(num: int) -> List[str]: + """Say an integer as a list of words. + + Args: + num (int): The integer. + + Returns: + List[str]: The integer as a list of words. + """ + if num < 0: + return ['negative'] + _int_to_words(-num) + elif num <= 19: + return [_ones[num]] + elif num < 100: + tens = [_tens[num // 10 - 2]] + ones = [_ones[num % 10]] if num % 10 else [] + return tens + ones + elif num < 1_000: + hundreds = [_ones[num // 100], 'hundred'] + etc = _int_to_words(num % 100) if num % 100 else [] + return hundreds + etc + elif num < 1_000_000: + thousands = _int_to_words(num // 1_000) + ['thousand'] + etc = _int_to_words(num % 1_000) if num % 1_000 else [] + return thousands + etc + elif num < 1_000_000_000: + millions = _int_to_words(num // 1_000_000) + ['million'] + etc = _int_to_words(num % 1_000_000) if num % 1_000_000 else [] + return millions + etc + else: + raise ValueError('Integer out of range: -1,000,000,000 < {num} < +1,000,000,000.') + + +def _int_to_text(num: int) -> str: + """Say an integer as text. + + Args: + num (int): The integer. + + Returns: + str: The integer as text. + """ + words = _int_to_words(num) + return ' '.join(words) + + +T = TypeVar('T') + + +def _split(items: List[T], sizes: List[int]) -> List[List[T]]: + """Divide the given items across the splits given by their sizes. + + Args: + items (List[Any]): The items to divide across the spans. + sizes (List[int]): Number of items per split. + + Returns: + List[List[Any]]: Each split of items. + """ + total = sum(sizes) + if len(items) != total: + raise ValueError(f'Number of items must match the combined size of the splits: ' + + f'{len(items)} items vs splits of size {sizes} = {total}.') + + splits = [] + begin = 0 + for size in sizes: + split = items[begin:begin + size] + splits.append(split) + begin += size + + return splits + + +def _generate_dataset(split2size: Dict[str, int], + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> Dict[str, Tuple[List[int], List[str]]]: + """Generate a dataset, made of splits, to be saved in different forms for comparison. + + Args: + split2size (Dict[str, int]): Mapping of split name to size in samples. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to show a progress bar. Defaults to ``True``. + + Returns: + Dict[str, Tuple[List[int], List[str]]]: Mapping of split name to nums and texts. + """ + split_sizes = [] + total = 0 + for split in sorted(split2size): + size = split2size[split] + split_sizes.append(size) + total += size + + nums = _generate_ints(total, seed, low, high, show_progress) + nums_per_split = _split(nums, split_sizes) + + texts = list(map(_int_to_text, nums)) + texts_per_split = _split(texts, split_sizes) + + dataset = {} + for index, split in enumerate(sorted(split2size)): + dataset[split] = nums_per_split[index], texts_per_split[index] + + return dataset + + def _write_csv(nums: List[int], txts: List[str], root: str, @@ -295,19 +484,24 @@ def main(args: Namespace) -> None: Args: args (Namespace): Command-line arguments. """ + # Confgure the dataset writing statistics table printer. + table_columns = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + table_indent = 4 + table = Tabulator.from_conf(table_columns, table_indent * ' ') + # Normalize arguments. format_names = args.formats.split(',') if args.formats else [] show_progress = bool(args.show_progress) quiet_delta = bool(args.quiet_delta) - split2size = { - 'small': args.small, - 'medium': args.medium, - 'large': args.large, - } - - # Wipe output directory if exists. - if os.path.exists(args.data_root): - rmtree(args.data_root) # Given args, now we know how to configure saving the dataset in each format. format2write = { @@ -328,42 +522,44 @@ def main(args: Namespace) -> None: show_progress=show_progress), } - # Now, generate the dataset. + # Collect sizes of the splits to generate. + split2size = { + 'small': args.small, + 'medium': args.medium, + 'large': args.large, + } + + # Generate the dataset samples. t0 = time() - dataset = generate(split2size, show_progress) + dataset = _generate_dataset(split2size, args.seed, args.data_pos_prob, args.data_low, + args.data_high, show_progress) elapsed = time() - t0 print(f'Generate: {elapsed:.3f} sec.') - # Confgure the text table printer for dataset writing info. - conf = ''' - < format 8 - > sec 6 - > samples 12 - > usec/sp 8 - > bytes 14 - > files 6 - > bytes/file 12 - > max bytes/file 14 - ''' - left = 4 * ' ' - tab = Tabulator.from_conf(conf, left) + # Wipe output directory if exists. + if os.path.exists(args.data_root): + print(f'Found directory at {args.data_root}, wiping it for reuse') + rmtree(args.data_root) # Write each split in each desired formats, in order of size. pretty_int = lambda num: f'{num:,}' for split in _splits_by_size(dataset): print() print(f'Write split: {split}') - print(tab.draw_line()) - print(tab.draw_header()) - print(tab.draw_line()) + print(table.draw_line()) + print(table.draw_header()) + print(table.draw_line()) + nums, txts = dataset[split] for format_name in format_names: - format_subdir = getattr(args, format_name) - split_root = os.path.join(args.data_root, 'gold', format_subdir, split) + split_root = os.path.join(args.data_root, 'gold', format_name, split) write = format2write[format_name] t0 = time() - write(nums, txts, split_root) + try: + write(nums, txts, split_root) + except: + continue # Getting Delta Java OOMs at gigabyte size. elapsed = time() - t0 file_sizes = _get_file_sizes(split_root) @@ -377,8 +573,8 @@ def main(args: Namespace) -> None: 'bytes/file': pretty_int(sum(file_sizes) // len(file_sizes)), 'max bytes/file': pretty_int(max(file_sizes)), } - print(tab.draw_row(row)) - print(tab.draw_line()) + print(table.draw_row(row)) + print(table.draw_line()) if __name__ == '__main__': diff --git a/benchmarks/backends/task.py b/benchmarks/backends/task.py deleted file mode 100644 index d354f144c..000000000 --- a/benchmarks/backends/task.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Generate infinite samples for a 'saying numbers as words' task.""" - -from typing import Dict, List, Tuple, TypeVar - -import numpy as np -from numpy.random import Generator -from tqdm import tqdm - - -def _generate_int(rng: Generator, - pos_prob: float = 0.75, - low: int = -1_000_000_000, - high: int = 1_000_000_000) -> int: - """Pick a random integer to say in words. - - This is a synthetic dataset whose random numbers need to be distinct, deterministic given a - seed, and little else. We choose a distribution that seems the most pleasing to us. - - Properties: - * About 80% positive and 20% negative. - * Magnitude of up to a billion on either side of zero. - * Strongly skewed toward the origin, i.e. chosen uniformly across base-10 digit lengths (at - least until running out of integers of that length anyway). - - Args: - rng (Generator): NumPy random number generator. - pos_prob (float): Probability of output being positive. Defaults to ``0.75``. - low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. - high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. - """ - if not 0 <= pos_prob <= 1: - raise ValueError(f'Invalid positive probability ``pos_prob``: 0 <= {pos_prob} <= 1.') - - if not low < 0 < high: - raise ValueError(f'Invalid sampling range ``low`` and/or ``high``: {low} < 0 < {high}.') - - is_pos = rng.uniform() < pos_prob - max_digits = np.log10(high) if is_pos else np.log10(-low) - exponent = rng.uniform(0, max_digits) - magnitude = int(10**exponent) - sign = is_pos * 2 - 1 - return sign * magnitude - - -def _generate_ints(count: int, - seed: int = 0x1337, - pos_prob: float = 0.75, - low: int = -1_000_000_000, - high: int = 1_000_000_000, - show_progress: bool = True) -> List[int]: - """Sample until we have the given number of distinct integers. - - Args: - count (int): How many samples to draw. - seed (int): Seed for the random number generator. Defaults to ``0x1337``. - pos_prob (float): Probability of output being positive. Defaults to ``0.75``. - low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. - high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. - show_progress (bool): Whether to display a progress bar. Defaults to ``True``. - - Returns: - List[int]: The integers that were drawn. - """ - rng = np.random.default_rng(seed) - nums = set() - progress_bar = tqdm(total=count, leave=False) if show_progress else None - while len(nums) < count: - num = _generate_int(rng) - if num in nums: - continue - - nums.add(num) - if progress_bar: - progress_bar.update(1) - if progress_bar: - progress_bar.close() - - nums = sorted(nums) - rng.shuffle(nums) - return nums - - -_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' - 'fifteen sixteen seventeen eighteen nineteen').split() - -_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() - - -def _int_to_words(num: int) -> List[str]: - """Say an integer as a list of words. - - Args: - num (int): The integer. - - Returns: - List[str]: The integer as a list of words. - """ - if num < 0: - return ['negative'] + _int_to_words(-num) - elif num <= 19: - return [_ones[num]] - elif num < 100: - tens = [_tens[num // 10 - 2]] - ones = [_ones[num % 10]] if num % 10 else [] - return tens + ones - elif num < 1_000: - hundreds = [_ones[num // 100], 'hundred'] - etc = _int_to_words(num % 100) if num % 100 else [] - return hundreds + etc - elif num < 1_000_000: - thousands = _int_to_words(num // 1_000) + ['thousand'] - etc = _int_to_words(num % 1_000) if num % 1_000 else [] - return thousands + etc - elif num < 1_000_000_000: - millions = _int_to_words(num // 1_000_000) + ['million'] - etc = _int_to_words(num % 1_000_000) if num % 1_000_000 else [] - return millions + etc - else: - raise ValueError('Integer out of range: -1,000,000,000 < {num} < +1,000,000,000.') - - -def _int_to_text(num: int) -> str: - """Say an integer as text. - - Args: - num (int): The integer. - - Returns: - str: The integer as text. - """ - words = _int_to_words(num) - return ' '.join(words) - - -T = TypeVar('T') - - -def _split(items: List[T], sizes: List[int]) -> List[List[T]]: - """Divide the given items across the splits given by their sizes. - - Args: - items (List[Any]): The items to divide across the spans. - sizes (List[int]): Number of items per split. - - Returns: - List[List[Any]]: Each split of items. - """ - total = sum(sizes) - if len(items) != total: - raise ValueError(f'Number of items must match the combined size of the splits: ' + - f'{len(items)} items vs splits of size {sizes} = {total}.') - - splits = [] - begin = 0 - for size in sizes: - split = items[begin:begin + size] - splits.append(split) - begin += size - - return splits - - -def generate(split2size: Dict[str, int], - seed: int = 0x1337, - pos_prob: float = 0.75, - low: int = -1_000_000_000, - high: int = 1_000_000_000, - show_progress: bool = True) -> Dict[str, Tuple[List[int], List[str]]]: - """Generate a dataset, made of splits, to be saved in different forms for comparison. - - Args: - split2size (Dict[str, int]): Mapping of split name to size in samples. - seed (int): Seed for the random number generator. Defaults to ``0x1337``. - pos_prob (float): Probability of output being positive. Defaults to ``0.75``. - low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. - high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. - show_progress (bool): Whether to show a progress bar. Defaults to ``True``. - - Returns: - Dict[str, Tuple[List[int], List[str]]]: Mapping of split name to nums and texts. - """ - split_sizes = [] - total = 0 - for split in sorted(split2size): - size = split2size[split] - split_sizes.append(size) - total += size - - nums = _generate_ints(total, seed, low, high, show_progress) - nums_per_split = _split(nums, split_sizes) - - texts = list(map(_int_to_text, nums)) - texts_per_split = _split(texts, split_sizes) - - dataset = {} - for index, split in enumerate(sorted(split2size)): - dataset[split] = nums_per_split[index], texts_per_split[index] - - return dataset