From 7f5d160a19f13d2b1b6b5a5652cbd8a2cec3f021 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 17 Nov 2023 17:51:33 -0800 Subject: [PATCH 01/12] Move examples out, merge base/ upward (#494) * scripts/ -> benchmarks/. * examples/ -> notebooks/. * streaming/multimodal/ -> examples/multimodal/ (reorganized). * streaming/text/ -> examples/text (reorganized). * streaming/vision/base.py -> streaming/base/vision.py. * Switch streaming/base/vision.py to kwargs. * streaming/vision/ -> examples/vision/. * Update pyproject.toml. * And .pre-commit-config.yaml. * Fix headers. * Collapse "base/": streaming/base/ -> streaming/. * Fil imports re: collapsing the `base/` dirs upward. * Fixes (imports and indentation). * Update test_streaming_remote.py to not rely on any specific SD example subclasses * Fix pypyroject config. * Update paths. * Fix. * More examples/ moves. * Comma-tailing args. * Fix links. * More fixes. * Fix missing license. * How about this for import redirects... * Or this... * Improve redirect deprecation warning. * examples/ tree: __init__ imports and __all__'s.^ * benchmarks/ tree: __init__ imports and __all__'s * notebooks/ tree: __init__ imports and __all__'s. * Add notebooks/ symlink to docs/source. * Add benchmarks, examples, and notebooks trees to document_modules. * Also add benchmarks symlink. Or should we only symlink to notebooks/? --- .pre-commit-config.yaml | 2 +- README.md | 35 +- STYLE_GUIDE.md | 14 +- benchmarks/__init__.py | 14 + {scripts => benchmarks}/compression/bench.py | 2 +- {scripts => benchmarks}/compression/plot.py | 0 {scripts => benchmarks}/epoch/bench.py | 4 +- {scripts => benchmarks}/hashing/bench.py | 2 +- {scripts => benchmarks}/hashing/plot.py | 0 {scripts => benchmarks}/partition/bench.py | 2 +- {scripts => benchmarks}/partition/diff.py | 2 +- {scripts => benchmarks}/partition/plot.py | 0 {scripts => benchmarks}/partition/txt.py | 2 +- {scripts => benchmarks}/partition/web.py | 2 +- .../samples/bench_and_plot.py | 0 .../serialization/compare.py | 0 .../serialization/survey_fixed_decimals.py | 0 {scripts => benchmarks}/shuffle/bench.py | 4 +- {scripts => benchmarks}/shuffle/plot.py | 0 {scripts => benchmarks}/shuffle/vis.py | 2 +- docs/source/benchmarks | 1 + docs/source/conf.py | 46 +- .../fundamentals/dataset_conversion_guide.md | 2 +- docs/source/getting_started/quick_start.md | 2 +- docs/source/getting_started/user_guide.md | 4 +- .../dataset_conversion_to_mds_format.md | 16 +- docs/source/index.md | 10 +- docs/source/notebooks | 1 + examples/__init__.py | 10 + examples/multimodal/__init__.py | 9 + .../multimodal}/laion400m/README.md | 6 +- .../multimodal/laion400m}/__init__.py | 2 +- .../laion400m/convert_and_upload.py | 2 +- .../laion400m/convert_and_upload.sh | 2 +- .../multimodal}/laion400m/download_data.sh | 0 .../multimodal}/laion400m/download_meta.sh | 0 .../multimodal/webvid}/__init__.py | 2 +- .../multimodal/webvid/read.py | 6 +- .../webvid/scripts}/bench_inside.py | 2 +- .../webvid/scripts}/bench_outside_dt.py | 16 +- .../webvid/scripts}/bench_outside_gi.py | 16 +- .../multimodal/webvid/scripts}/plot.py | 0 .../multimodal/webvid/write}/README.md | 6 +- .../multimodal/webvid/write}/__init__.py | 0 .../multimodal/webvid/write}/crawl_webvid.py | 0 .../webvid/write}/crawl_webvid_subsets.py | 0 .../webvid/write}/extract_webvid_videos.py | 0 examples/text/__init__.py | 11 + examples/text/c4/README.md | 7 + .../text/c4}/__init__.py | 2 +- .../text/c4.py => examples/text/c4/read.py | 2 +- .../c4.py => examples/text/c4/write.py | 4 +- examples/text/enwiki_tok/__init__.py | 1 + .../text/enwiki_tok}/mds/README.md | 0 .../text/enwiki_tok/mds}/__init__.py | 0 .../mds/create_pretraining_data.py | 0 .../text/enwiki_tok}/mds/make_eval.sh | 0 .../enwiki_tok}/mds/make_train_parallel.py | 0 .../enwiki_tok}/mds/merge_shard_groups.py | 0 .../text/enwiki_tok}/mds/pick_eval_samples.py | 0 .../text/enwiki_tok}/mds/tokenization.py | 0 .../text/enwiki_tok}/mds/vocab.txt | 0 .../text/enwiki_tok/tfrecord}/__init__.py | 0 .../enwiki_tok}/tfrecord/count_samples.py | 0 .../tfrecord/create_pretraining_data.py | 0 .../text/enwiki_tok}/tfrecord/make_eval.sh | 0 .../text/enwiki_tok}/tfrecord/make_train.sh | 0 .../tfrecord/make_train_parallel.py | 0 .../enwiki_tok}/tfrecord/pick_eval_samples.py | 0 .../text/enwiki_tok}/tfrecord/tokenization.py | 0 .../text/enwiki_tok}/tfrecord/vocab.txt | 0 examples/text/enwiki_txt/README.txt | 26 + examples/text/enwiki_txt/__init__.py | 4 + .../text/enwiki_txt/read.py | 2 +- .../text/enwiki_txt/write.py | 4 +- examples/text/pile/README.md | 26 + examples/text/pile/__init__.py | 4 + .../pile.py => examples/text/pile/read.py | 2 +- .../pile.py => examples/text/pile/write.py | 4 +- examples/vision/__init__.py | 11 + examples/vision/ade20k/__init__.py | 4 + .../vision/ade20k/read.py | 2 +- .../vision/ade20k/write.py | 4 +- examples/vision/cifar10/__init__.py | 4 + .../vision/cifar10/read.py | 2 +- .../vision/cifar10/write.py | 4 +- .../vision/cifar10/write_fake.py | 2 +- examples/vision/coco/__init__.py | 4 + .../coco.py => examples/vision/coco/read.py | 2 +- .../coco.py => examples/vision/coco/write.py | 4 +- examples/vision/imagenet/__init__.py | 4 + .../vision/imagenet/read.py | 2 +- .../vision/imagenet/write.py | 4 +- .../convert/laion => notebooks}/__init__.py | 2 +- {examples => notebooks}/cifar10.ipynb | 0 {examples => notebooks}/facesynthetics.ipynb | 0 .../multiprocess_dataset_conversion.ipynb | 4 +- .../spark_dataframe_to_MDS.ipynb | 6 +- {examples => notebooks}/synthetic_nlp.ipynb | 0 pyproject.toml | 5 +- regression/iterate_data.py | 3 +- regression/synthetic_dataset.py | 2 +- simulation/core/create_index.py | 2 +- simulation/core/node_tracker.py | 2 +- simulation/core/shuffle_quality.py | 4 +- simulation/core/sim_dataset.py | 10 +- simulation/core/sim_spanner.py | 2 +- simulation/core/sim_world.py | 2 +- simulation/core/yaml_processing.py | 2 +- simulation/interfaces/interface_utils.py | 2 +- simulation/interfaces/sim_cli.py | 2 +- simulation/interfaces/sim_script.py | 2 +- simulation/interfaces/sim_ui.py | 2 +- simulation/interfaces/widgets.py | 2 +- simulation/testing/wandb_testing.py | 2 +- streaming/__init__.py | 25 +- streaming/{base => }/array.py | 0 streaming/base/__init__.py | 20 +- streaming/base/converters/README.md | 7 - streaming/base/format/base/__init__.py | 8 - streaming/base/shared/__init__.py | 17 - streaming/base/storage/__init__.py | 33 -- streaming/base/util.py | 524 +---------------- streaming/{base => }/batching/__init__.py | 10 +- streaming/{base => }/batching/per_stream.py | 8 +- streaming/{base => }/batching/random.py | 8 +- streaming/{base => }/batching/stratified.py | 8 +- streaming/{base => }/compression.py | 0 streaming/{base => }/constant.py | 0 streaming/converters/README.md | 7 + streaming/{base => }/converters/__init__.py | 4 +- .../{base => }/converters/dataframe_to_mds.py | 10 +- streaming/{base => }/dataloader.py | 4 +- streaming/{base => }/dataset.py | 28 +- streaming/{base => }/distributed.py | 0 streaming/{base => }/format/__init__.py | 11 +- streaming/{base => }/format/index.py | 0 streaming/{base => }/format/json/README.md | 0 streaming/{base => }/format/json/__init__.py | 4 +- streaming/{base => }/format/json/encodings.py | 0 streaming/{base => }/format/json/reader.py | 2 +- streaming/{base => }/format/json/writer.py | 4 +- streaming/{base => }/format/mds/README.md | 0 streaming/{base => }/format/mds/__init__.py | 4 +- streaming/{base => }/format/mds/encodings.py | 0 streaming/{base => }/format/mds/reader.py | 4 +- streaming/{base => }/format/mds/writer.py | 6 +- .../{base/format/base => format}/reader.py | 4 +- .../{base/format/base => format}/writer.py | 10 +- streaming/{base => }/format/xsv/README.md | 0 streaming/{base => }/format/xsv/__init__.py | 4 +- streaming/{base => }/format/xsv/encodings.py | 0 streaming/{base => }/format/xsv/reader.py | 4 +- streaming/{base => }/format/xsv/writer.py | 4 +- streaming/{base => }/hashing.py | 0 streaming/{base => }/local.py | 6 +- streaming/multimodal/__init__.py | 8 - streaming/multimodal/convert/__init__.py | 4 - streaming/{base => }/partition/__init__.py | 4 +- streaming/{base => }/partition/orig.py | 0 streaming/{base => }/partition/relaxed.py | 2 +- streaming/{base => }/sampling.py | 0 streaming/shared/__init__.py | 17 + streaming/{base => }/shared/array.py | 2 +- streaming/{base => }/shared/barrier.py | 4 +- streaming/{base => }/shared/memory.py | 2 +- streaming/{base => }/shared/prefix.py | 8 +- streaming/{base => }/shared/scalar.py | 2 +- streaming/{base => }/shuffle/__init__.py | 12 +- streaming/{base => }/shuffle/naive.py | 0 streaming/{base => }/shuffle/py1b.py | 2 +- streaming/{base => }/shuffle/py1br.py | 2 +- streaming/{base => }/shuffle/py1e.py | 2 +- streaming/{base => }/shuffle/py1s.py | 0 streaming/{base => }/shuffle/py2s.py | 0 streaming/{base => }/spanner.py | 0 streaming/storage/__init__.py | 32 + streaming/{base => }/storage/download.py | 2 +- streaming/{base => }/storage/upload.py | 5 +- streaming/{base => }/stream.py | 16 +- streaming/text/__init__.py | 10 - streaming/text/convert/README.md | 69 --- .../text/convert/enwiki/tfrecord/__init__.py | 0 streaming/util.py | 551 ++++++++++++++++++ streaming/vision.py | 154 +++++ streaming/vision/__init__.py | 11 - streaming/vision/base.py | 176 ------ streaming/vision/convert/README.md | 113 ---- streaming/vision/convert/base.py | 68 --- streaming/{base => }/world.py | 2 +- .../base/converters/test_dataframe_to_mds.py | 2 +- tests/common/datasets.py | 2 +- tests/test_array.py | 2 +- tests/test_barrier.py | 2 +- tests/test_compression.py | 6 +- tests/test_distributed.py | 4 +- tests/test_download.py | 25 +- tests/test_encodings.py | 6 +- tests/test_hashing.py | 4 +- tests/test_local.py | 2 +- tests/test_partition.py | 2 +- tests/test_reader.py | 2 +- tests/test_sampling.py | 2 +- tests/test_shared.py | 4 +- tests/test_shuffle.py | 4 +- tests/test_spanner.py | 2 +- tests/test_stream.py | 2 +- tests/test_streaming.py | 4 +- tests/test_streaming_remote.py | 21 +- tests/test_upload.py | 37 +- tests/test_util.py | 16 +- 211 files changed, 1283 insertions(+), 1394 deletions(-) create mode 100644 benchmarks/__init__.py rename {scripts => benchmarks}/compression/bench.py (96%) rename {scripts => benchmarks}/compression/plot.py (100%) rename {scripts => benchmarks}/epoch/bench.py (97%) rename {scripts => benchmarks}/hashing/bench.py (97%) rename {scripts => benchmarks}/hashing/plot.py (100%) rename {scripts => benchmarks}/partition/bench.py (98%) rename {scripts => benchmarks}/partition/diff.py (98%) rename {scripts => benchmarks}/partition/plot.py (100%) rename {scripts => benchmarks}/partition/txt.py (98%) rename {scripts => benchmarks}/partition/web.py (99%) rename {scripts => benchmarks}/samples/bench_and_plot.py (100%) rename {scripts => benchmarks}/serialization/compare.py (100%) rename {scripts => benchmarks}/serialization/survey_fixed_decimals.py (100%) rename {scripts => benchmarks}/shuffle/bench.py (97%) rename {scripts => benchmarks}/shuffle/plot.py (100%) rename {scripts => benchmarks}/shuffle/vis.py (98%) create mode 120000 docs/source/benchmarks create mode 120000 docs/source/notebooks create mode 100644 examples/__init__.py create mode 100644 examples/multimodal/__init__.py rename {streaming/multimodal/convert/laion => examples/multimodal}/laion400m/README.md (90%) rename {streaming/vision/convert => examples/multimodal/laion400m}/__init__.py (61%) rename {streaming/multimodal/convert/laion => examples/multimodal}/laion400m/convert_and_upload.py (99%) rename {streaming/multimodal/convert/laion => examples/multimodal}/laion400m/convert_and_upload.sh (59%) rename {streaming/multimodal/convert/laion => examples/multimodal}/laion400m/download_data.sh (100%) rename {streaming/multimodal/convert/laion => examples/multimodal}/laion400m/download_meta.sh (100%) rename {streaming/text/convert => examples/multimodal/webvid}/__init__.py (56%) rename streaming/multimodal/webvid.py => examples/multimodal/webvid/read.py (99%) rename {scripts/webvid => examples/multimodal/webvid/scripts}/bench_inside.py (94%) rename {scripts/webvid => examples/multimodal/webvid/scripts}/bench_outside_dt.py (81%) rename {scripts/webvid => examples/multimodal/webvid/scripts}/bench_outside_gi.py (81%) rename {scripts/webvid => examples/multimodal/webvid/scripts}/plot.py (100%) rename {streaming/multimodal/convert/webvid => examples/multimodal/webvid/write}/README.md (83%) rename {streaming/multimodal/convert/webvid => examples/multimodal/webvid/write}/__init__.py (100%) rename {streaming/multimodal/convert/webvid => examples/multimodal/webvid/write}/crawl_webvid.py (100%) rename {streaming/multimodal/convert/webvid => examples/multimodal/webvid/write}/crawl_webvid_subsets.py (100%) rename {streaming/multimodal/convert/webvid => examples/multimodal/webvid/write}/extract_webvid_videos.py (100%) create mode 100644 examples/text/__init__.py create mode 100644 examples/text/c4/README.md rename {streaming/multimodal/convert/laion/laion400m => examples/text/c4}/__init__.py (69%) rename streaming/text/c4.py => examples/text/c4/read.py (99%) rename streaming/text/convert/c4.py => examples/text/c4/write.py (98%) create mode 100644 examples/text/enwiki_tok/__init__.py rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/README.md (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok/mds}/__init__.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/create_pretraining_data.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/make_eval.sh (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/make_train_parallel.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/merge_shard_groups.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/pick_eval_samples.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/tokenization.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/mds/vocab.txt (100%) rename {streaming/text/convert/enwiki/mds => examples/text/enwiki_tok/tfrecord}/__init__.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/count_samples.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/create_pretraining_data.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_eval.sh (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_train.sh (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/make_train_parallel.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/pick_eval_samples.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/tokenization.py (100%) rename {streaming/text/convert/enwiki => examples/text/enwiki_tok}/tfrecord/vocab.txt (100%) create mode 100644 examples/text/enwiki_txt/README.txt create mode 100644 examples/text/enwiki_txt/__init__.py rename streaming/text/enwiki.py => examples/text/enwiki_txt/read.py (99%) rename streaming/text/convert/enwiki_text.py => examples/text/enwiki_txt/write.py (97%) create mode 100644 examples/text/pile/README.md create mode 100644 examples/text/pile/__init__.py rename streaming/text/pile.py => examples/text/pile/read.py (99%) rename streaming/text/convert/pile.py => examples/text/pile/write.py (98%) create mode 100644 examples/vision/__init__.py create mode 100644 examples/vision/ade20k/__init__.py rename streaming/vision/ade20k.py => examples/vision/ade20k/read.py (99%) rename streaming/vision/convert/ade20k.py => examples/vision/ade20k/write.py (98%) create mode 100644 examples/vision/cifar10/__init__.py rename streaming/vision/cifar10.py => examples/vision/cifar10/read.py (98%) rename streaming/vision/convert/cifar10.py => examples/vision/cifar10/write.py (95%) rename streaming/vision/convert/fake_cifar10.py => examples/vision/cifar10/write_fake.py (95%) create mode 100644 examples/vision/coco/__init__.py rename streaming/vision/coco.py => examples/vision/coco/read.py (99%) rename streaming/vision/convert/coco.py => examples/vision/coco/write.py (98%) create mode 100644 examples/vision/imagenet/__init__.py rename streaming/vision/imagenet.py => examples/vision/imagenet/read.py (98%) rename streaming/vision/convert/imagenet.py => examples/vision/imagenet/write.py (98%) rename {streaming/multimodal/convert/laion => notebooks}/__init__.py (73%) rename {examples => notebooks}/cifar10.ipynb (100%) rename {examples => notebooks}/facesynthetics.ipynb (100%) rename {examples => notebooks}/multiprocess_dataset_conversion.ipynb (98%) rename {examples => notebooks}/spark_dataframe_to_MDS.ipynb (99%) rename {examples => notebooks}/synthetic_nlp.ipynb (100%) rename streaming/{base => }/array.py (100%) delete mode 100644 streaming/base/converters/README.md delete mode 100644 streaming/base/format/base/__init__.py delete mode 100644 streaming/base/shared/__init__.py delete mode 100644 streaming/base/storage/__init__.py rename streaming/{base => }/batching/__init__.py (79%) rename streaming/{base => }/batching/per_stream.py (97%) rename streaming/{base => }/batching/random.py (94%) rename streaming/{base => }/batching/stratified.py (98%) rename streaming/{base => }/compression.py (100%) rename streaming/{base => }/constant.py (100%) create mode 100644 streaming/converters/README.md rename streaming/{base => }/converters/__init__.py (57%) rename streaming/{base => }/converters/dataframe_to_mds.py (97%) rename streaming/{base => }/dataloader.py (96%) rename streaming/{base => }/dataset.py (98%) rename streaming/{base => }/distributed.py (100%) rename streaming/{base => }/format/__init__.py (71%) rename streaming/{base => }/format/index.py (100%) rename streaming/{base => }/format/json/README.md (100%) rename streaming/{base => }/format/json/__init__.py (61%) rename streaming/{base => }/format/json/encodings.py (100%) rename streaming/{base => }/format/json/reader.py (98%) rename streaming/{base => }/format/json/writer.py (97%) rename streaming/{base => }/format/mds/README.md (100%) rename streaming/{base => }/format/mds/__init__.py (62%) rename streaming/{base => }/format/mds/encodings.py (100%) rename streaming/{base => }/format/mds/reader.py (97%) rename streaming/{base => }/format/mds/writer.py (96%) rename streaming/{base/format/base => format}/reader.py (99%) rename streaming/{base/format/base => format}/writer.py (98%) rename streaming/{base => }/format/xsv/README.md (100%) rename streaming/{base => }/format/xsv/__init__.py (60%) rename streaming/{base => }/format/xsv/encodings.py (100%) rename streaming/{base => }/format/xsv/reader.py (98%) rename streaming/{base => }/format/xsv/writer.py (98%) rename streaming/{base => }/hashing.py (100%) rename streaming/{base => }/local.py (93%) delete mode 100644 streaming/multimodal/__init__.py delete mode 100644 streaming/multimodal/convert/__init__.py rename streaming/{base => }/partition/__init__.py (94%) rename streaming/{base => }/partition/orig.py (100%) rename streaming/{base => }/partition/relaxed.py (98%) rename streaming/{base => }/sampling.py (100%) create mode 100644 streaming/shared/__init__.py rename streaming/{base => }/shared/array.py (97%) rename streaming/{base => }/shared/barrier.py (97%) rename streaming/{base => }/shared/memory.py (99%) rename streaming/{base => }/shared/prefix.py (97%) rename streaming/{base => }/shared/scalar.py (94%) rename streaming/{base => }/shuffle/__init__.py (82%) rename streaming/{base => }/shuffle/naive.py (100%) rename streaming/{base => }/shuffle/py1b.py (98%) rename streaming/{base => }/shuffle/py1br.py (98%) rename streaming/{base => }/shuffle/py1e.py (99%) rename streaming/{base => }/shuffle/py1s.py (100%) rename streaming/{base => }/shuffle/py2s.py (100%) rename streaming/{base => }/spanner.py (100%) create mode 100644 streaming/storage/__init__.py rename streaming/{base => }/storage/download.py (99%) rename streaming/{base => }/storage/upload.py (99%) rename streaming/{base => }/stream.py (98%) delete mode 100644 streaming/text/__init__.py delete mode 100644 streaming/text/convert/README.md delete mode 100644 streaming/text/convert/enwiki/tfrecord/__init__.py create mode 100644 streaming/util.py create mode 100644 streaming/vision.py delete mode 100644 streaming/vision/__init__.py delete mode 100644 streaming/vision/base.py delete mode 100644 streaming/vision/convert/README.md delete mode 100644 streaming/vision/convert/base.py rename streaming/{base => }/world.py (97%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1b2a8116b..7f7fb3c53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_language_version: python: python3 # Skip the pre-commit check for below directories to have # a consistency with the official tfrecord preprocessing scripts -exclude: "^(streaming/text/convert/enwiki/)" +exclude: "^(examples/text/enwiki_tok/)" repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. diff --git a/README.md b/README.md index 374d0f5ef..ad7cf7946 100644 --- a/README.md +++ b/README.md @@ -143,44 +143,45 @@ dataloader = DataLoader(dataset) ### 📚 What next? -Getting started guides, examples, API references, and other useful information can be found in our [docs](https://streaming.docs.mosaicml.com/). +Getting started guides, example notebooks, API references, and other useful information can be found in our [docs](https://streaming.docs.mosaicml.com/). We have end-to-end tutorials for training a model on: -- [CIFAR-10](https://streaming.docs.mosaicml.com/en/stable/examples/cifar10.html) -- [FaceSynthetics](https://streaming.docs.mosaicml.com/en/stable/examples/facesynthetics.html) -- [SyntheticNLP](https://streaming.docs.mosaicml.com/en/stable/examples/synthetic_nlp.html) +- [CIFAR-10](https://streaming.docs.mosaicml.com/en/stable/notebooks/cifar10.html) +- [FaceSynthetics](https://streaming.docs.mosaicml.com/en/stable/notebooks/facesynthetics.html) +- [SyntheticNLP](https://streaming.docs.mosaicml.com/en/stable/notebooks/synthetic_nlp.html) -We also have starter code for the following popular datasets, which can be found in the `streaming` [directory](https://github.com/mosaicml/streaming/tree/main/streaming): +We also have starter code for the following popular datasets, which can be found under [`examples/`](https://github.com/mosaicml/streaming/tree/main/examples) organized by modality: | Dataset | Task | Read | Write | | --- | --- | --- | --- | -| LAION-400M | Text and image | [Read](https://github.com/mosaicml/diffusion-benchmark/blob/main/data.py) | [Write](https://github.com/mosaicml/streaming/tree/main/streaming/multimodal/convert/laion/laion400m) | -| WebVid | Text and video | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/webvid.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid.py) | -| C4 | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/c4.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) | -| EnWiki | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/enwiki.py) | [Write](https://github.com/mosaicml/streaming/tree/main/streaming/text/convert/enwiki) | -| Pile | Text | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/text/pile.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) -| ADE20K | Image segmentation | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/ade20k.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) -| CIFAR10 | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/cifar10.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) | -| COCO | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/coco.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) | -| ImageNet | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/streaming/vision/imagenet.py) | [Write](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) | +| LAION-400M | Text and image | [Read](https://github.com/mosaicml/diffusion-benchmark/blob/main/data.py) | [Write](https://github.com/mosaicml/streaming/tree/main/examples/multimodal/laion400m) | +| WebVid | Text and video | [Read](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/webvid/write/) | +| C4 | Text | [Read](https://github.com/mosaicml/streaming/blob/main/examples/text/c4/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/text/c4/write.py) | +| EnWiki | Text | [Read](https://github.com/mosaicml/streaming/blob/main/examples/text/enwiki_text/read.py) | [Write](https://github.com/mosaicml/streaming/tree/main/examples/text/enwiki/write.py) | +| Pile | Text | [Read](https://github.com/mosaicml/streaming/blob/main/examples/text/pile/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/text/pile/write.py) +| ADE20K | Image segmentation | [Read](https://github.com/mosaicml/streaming/blob/main/examples/vision/ade20k/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/vision/ade20k/write.py) +| CIFAR10 | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/examples/vision/cifar10/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/vision/cifar10/write.py) | +| COCO | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/examples/vision/coco/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/vision/coco/write.py) | +| ImageNet | Image classification | [Read](https://github.com/mosaicml/streaming/blob/main/examples/vision/imagenet/read.py) | [Write](https://github.com/mosaicml/streaming/blob/main/examples/vision/imagenet/write.py) | **To start training on these datasets:** -1. Convert raw data into .mds format using the corresponding script from the `convert` directory. +1. Convert raw data into .mds format using the corresponding `write.py` script. For example: ```bash -$ python -m streaming.multimodal.convert.webvid --in --out +$ python -m examples.multimodal.webvid.write.craw_webvid --in --out_root ``` 2. Import dataset class to start training the model. ```python -from streaming.multimodal import StreamingInsideWebVid +from examples.multimodal.webvid.read import StreamingInsideWebVid + dataset = StreamingInsideWebVid(local=local, remote=remote, shuffle=True) ``` diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 265ea7d24..67156e2a0 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -142,10 +142,10 @@ so other contributors will know why this error was silenced. A public API, generally speaking, can be invoked by a user without a leading underscore in any portion of the path. The following are examples of public APIs: -* Standalone functions in public modules (e.g. `streaming.base.distributed.get_world_size`) -* Classes in public modules (e.g. `streaming.base.format.MDSWriter`) -* Public methods in public classes (e.g. `streaming.base.format.MDSWriter.write`) -* Public modules (e.g. `streaming.base.dataset`) +* Standalone functions in public modules (e.g. `streaming.distributed.get_world_size`) +* Classes in public modules (e.g. `streaming.format.MDSWriter`) +* Public methods in public classes (e.g. `streaming.format.MDSWriter.write`) +* Public modules (e.g. `streaming.dataset`) The following rules apply to public APIs: 1. All public APIs must have a docstring (see the Documentation section below) @@ -201,14 +201,14 @@ All public modules must define `__all__` to be the list of members that should b The variable is necessary to 1) limit what `from XXX import *` imports, and 2) ensure that the documentation only includes exported members, not unrelated re-imports. -For example, from [streaming/base/dataset.py](streaming/base/dataset.py) +For example, from [streaming/dataset.py](streaming/dataset.py) ```python """The :class:`Dataset` class, used for building streaming iterable datasets.""" from torch.utils.data import IterableDataset -from streaming.base.format import reader_from_json -from streaming.base.spanner import Spanner +from streaming.format import reader_from_json +from streaming.spanner import Spanner __all__ = ["Dataset"] # export only the Dataset, not other imports like `Spanner` or `reader_from_json` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..bf78c3635 --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming benchmarking.""" + +from benchmarks import compression as compression +from benchmarks import epoch as epoch +from benchmarks import hashing as hashing +from benchmarks import partition as partition +from benchmarks import samples as samples +from benchmarks import serialization as serialization +from benchmarks import shuffle as shuffle + +__all__ = ['compression', 'epoch', 'hashing', 'partition', 'samples', 'serialization', 'shuffle'] diff --git a/scripts/compression/bench.py b/benchmarks/compression/bench.py similarity index 96% rename from scripts/compression/bench.py rename to benchmarks/compression/bench.py index 7fff5149b..d3740e335 100644 --- a/scripts/compression/bench.py +++ b/benchmarks/compression/bench.py @@ -9,7 +9,7 @@ import numpy as np -from streaming.base.compression import compress, decompress, get_compressions +from streaming.compression import compress, decompress, get_compressions def parse_args() -> Namespace: diff --git a/scripts/compression/plot.py b/benchmarks/compression/plot.py similarity index 100% rename from scripts/compression/plot.py rename to benchmarks/compression/plot.py diff --git a/scripts/epoch/bench.py b/benchmarks/epoch/bench.py similarity index 97% rename from scripts/epoch/bench.py rename to benchmarks/epoch/bench.py index 393ea66af..a1c8b73e0 100644 --- a/scripts/epoch/bench.py +++ b/benchmarks/epoch/bench.py @@ -9,8 +9,8 @@ import numpy as np -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle def parse_args() -> Namespace: diff --git a/scripts/hashing/bench.py b/benchmarks/hashing/bench.py similarity index 97% rename from scripts/hashing/bench.py rename to benchmarks/hashing/bench.py index 6be145006..45e4d4855 100644 --- a/scripts/hashing/bench.py +++ b/benchmarks/hashing/bench.py @@ -9,7 +9,7 @@ import numpy as np -from streaming.base.hashing import get_hash, get_hashes +from streaming.hashing import get_hash, get_hashes def parse_args() -> Namespace: diff --git a/scripts/hashing/plot.py b/benchmarks/hashing/plot.py similarity index 100% rename from scripts/hashing/plot.py rename to benchmarks/hashing/plot.py diff --git a/scripts/partition/bench.py b/benchmarks/partition/bench.py similarity index 98% rename from scripts/partition/bench.py rename to benchmarks/partition/bench.py index 3d83d3b63..d52629d25 100644 --- a/scripts/partition/bench.py +++ b/benchmarks/partition/bench.py @@ -6,7 +6,7 @@ from argparse import ArgumentParser, Namespace from time import time -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/scripts/partition/diff.py b/benchmarks/partition/diff.py similarity index 98% rename from scripts/partition/diff.py rename to benchmarks/partition/diff.py index 43c10224b..0c6f68171 100644 --- a/scripts/partition/diff.py +++ b/benchmarks/partition/diff.py @@ -10,7 +10,7 @@ import numpy as np from tqdm import tqdm -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/scripts/partition/plot.py b/benchmarks/partition/plot.py similarity index 100% rename from scripts/partition/plot.py rename to benchmarks/partition/plot.py diff --git a/scripts/partition/txt.py b/benchmarks/partition/txt.py similarity index 98% rename from scripts/partition/txt.py rename to benchmarks/partition/txt.py index 4f6793825..8d71f6294 100644 --- a/scripts/partition/txt.py +++ b/benchmarks/partition/txt.py @@ -6,7 +6,7 @@ import math from argparse import ArgumentParser, Namespace -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions def parse_args() -> Namespace: diff --git a/scripts/partition/web.py b/benchmarks/partition/web.py similarity index 99% rename from scripts/partition/web.py rename to benchmarks/partition/web.py index c37a849f2..f961b06ba 100644 --- a/scripts/partition/web.py +++ b/benchmarks/partition/web.py @@ -16,7 +16,7 @@ from fastapi.responses import HTMLResponse from pydantic import BaseModel -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions INDEX = ''' diff --git a/scripts/samples/bench_and_plot.py b/benchmarks/samples/bench_and_plot.py similarity index 100% rename from scripts/samples/bench_and_plot.py rename to benchmarks/samples/bench_and_plot.py diff --git a/scripts/serialization/compare.py b/benchmarks/serialization/compare.py similarity index 100% rename from scripts/serialization/compare.py rename to benchmarks/serialization/compare.py diff --git a/scripts/serialization/survey_fixed_decimals.py b/benchmarks/serialization/survey_fixed_decimals.py similarity index 100% rename from scripts/serialization/survey_fixed_decimals.py rename to benchmarks/serialization/survey_fixed_decimals.py diff --git a/scripts/shuffle/bench.py b/benchmarks/shuffle/bench.py similarity index 97% rename from scripts/shuffle/bench.py rename to benchmarks/shuffle/bench.py index 74ec02021..ac15f641a 100644 --- a/scripts/shuffle/bench.py +++ b/benchmarks/shuffle/bench.py @@ -11,8 +11,8 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle import (get_shuffle_naive, get_shuffle_py1b, get_shuffle_py1s, - get_shuffle_py2s) +from streaming.shuffle import (get_shuffle_naive, get_shuffle_py1b, get_shuffle_py1s, + get_shuffle_py2s) def parse_args() -> Namespace: diff --git a/scripts/shuffle/plot.py b/benchmarks/shuffle/plot.py similarity index 100% rename from scripts/shuffle/plot.py rename to benchmarks/shuffle/plot.py diff --git a/scripts/shuffle/vis.py b/benchmarks/shuffle/vis.py similarity index 98% rename from scripts/shuffle/vis.py rename to benchmarks/shuffle/vis.py index 7819e6b2a..1b7f387d5 100644 --- a/scripts/shuffle/vis.py +++ b/benchmarks/shuffle/vis.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.base.shuffle import algos, get_shuffle +from streaming.shuffle import algos, get_shuffle def parse_args() -> Namespace: diff --git a/docs/source/benchmarks b/docs/source/benchmarks new file mode 120000 index 000000000..8fea9e7de --- /dev/null +++ b/docs/source/benchmarks @@ -0,0 +1 @@ +../../benchmarks \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index e25dc24ba..851808798 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -35,6 +35,9 @@ from sphinx.ext.autodoc import ClassDocumenter, _ from sphinx.writers.html5 import HTML5Translator +import benchmarks +import examples +import notebooks import streaming if not shutil.which('pandoc'): @@ -363,18 +366,41 @@ def _auto_rst_for_module(module: types.ModuleType, exclude_members: List[Any]) - def _modules_to_rst() -> List[types.ModuleType]: """Return the list of modules for which to generate API reference rst files.""" document_modules: List[types.Module] = [ + benchmarks, + benchmarks.compression, + benchmarks.epoch, + benchmarks.hashing, + benchmarks.partition, + benchmarks.samples, + benchmarks.serialization, + benchmarks.shuffle, + examples, + examples.multimodal, + examples.multimodal.laion400m, + examples.multimodal.webvid, + examples.text, + examples.text.c4, + examples.text.enwiki_tok, + examples.text.enwiki_txt, + examples.text.pile, + examples.vision, + examples.vision.ade20k, + examples.vision.cifar10, + examples.vision.coco, + examples.vision.imagenet, + notebooks, streaming, - streaming.base.compression, - streaming.base.format, - streaming.base.hashing, - streaming.base.partition, - streaming.base.shared, - streaming.base.shuffle, - streaming.base.storage, - streaming.base.util, - streaming.base.world, + streaming.compression, + streaming.format, + streaming.hashing, + streaming.partition, + streaming.shared, + streaming.shuffle, + streaming.storage, + streaming.util, + streaming.world, ] - exclude_modules: List[types.Module] = [streaming.base, streaming._version] + exclude_modules: List[types.Module] = [streaming, streaming._version] for name in streaming.__dict__: obj = streaming.__dict__[name] if isinstance(obj, types.ModuleType) and obj not in exclude_modules: diff --git a/docs/source/fundamentals/dataset_conversion_guide.md b/docs/source/fundamentals/dataset_conversion_guide.md index c2f750ac6..e480f2724 100644 --- a/docs/source/fundamentals/dataset_conversion_guide.md +++ b/docs/source/fundamentals/dataset_conversion_guide.md @@ -42,7 +42,7 @@ column = { import numpy as np from typing import Any -from streaming.base.format.mds.encodings import Encoding, _encodings +from streaming.format.mds.encodings import Encoding, _encodings class Int32(Encoding): def encode(self, obj: Any) -> bytes: diff --git a/docs/source/getting_started/quick_start.md b/docs/source/getting_started/quick_start.md index 28b90e742..cf7b38b4d 100644 --- a/docs/source/getting_started/quick_start.md +++ b/docs/source/getting_started/quick_start.md @@ -65,6 +65,6 @@ Start training your model with the Streaming dataset in a few steps! dataloader = DataLoader(dataset) ``` -That's it! For additional details on using {mod}`streaming`, please check out our [User Guide](user_guide.md) and [Examples](../examples/cifar10.ipynb). +That's it! For additional details on using {mod}`streaming`, please check out our [User Guide](user_guide.md) and [Examples](../notebooks/cifar10.ipynb). Happy training! diff --git a/docs/source/getting_started/user_guide.md b/docs/source/getting_started/user_guide.md index 0226b94a6..4b0d88bfd 100644 --- a/docs/source/getting_started/user_guide.md +++ b/docs/source/getting_started/user_guide.md @@ -106,7 +106,7 @@ def each(samples): It's time to call the {class}`streaming.MDSWriter` with the above initialized parameters and write the samples by iterating over a dataset. ```python -from streaming.base import MDSWriter +from streaming import MDSWriter dataset = RandomClassificationDataset() with MDSWriter(out=output_dir, columns=columns, compression=compression, hashes=hashes, size_limit=limit) as out: @@ -169,7 +169,7 @@ from torch.utils.data import DataLoader dataloader = DataLoader(dataset=dataset) ``` -You've now seen an in-depth look at how to prepare and use streaming datasets with PyTorch. To continue learning about Streaming, please continue to explore our [examples](../examples/cifar10.ipynb/)! +You've now seen an in-depth look at how to prepare and use streaming datasets with PyTorch. To continue learning about Streaming, please continue to explore our [examples](../notebooks/cifar10.ipynb/)! ## Other options diff --git a/docs/source/how_to_guides/dataset_conversion_to_mds_format.md b/docs/source/how_to_guides/dataset_conversion_to_mds_format.md index 10d045c7c..48ed67aec 100644 --- a/docs/source/how_to_guides/dataset_conversion_to_mds_format.md +++ b/docs/source/how_to_guides/dataset_conversion_to_mds_format.md @@ -13,34 +13,30 @@ Let's look at the steps one needs to perform to convert their raw data into an M 3. Convert the raw sample in the form of `column` field. 4. Instantiate MDSWriter and call the `write` method to write a raw sample one at a time. -Checkout the [user guide](../getting_started/user_guide.md) section which contains a simplistic example for the data conversion using single process. For multiprocess dataset conversion example, checkout [this](../examples/multiprocess_dataset_conversion.ipynb) tutorial. +Checkout the [user guide](../getting_started/user_guide.md) section which contains a simplistic example for the data conversion using single process. For multiprocess dataset conversion example, checkout [this](../notebooks/multiprocess_dataset_conversion.ipynb) tutorial. We've already created conversion scripts that can be used to convert popular public datasets to MDS format. Please see below for usage instructions. ## Spark Dataframe Conversion Examples -```{include} ../../../streaming/base/converters/README.md +```{include} ../../../streaming/converters/README.md :start-line: 2 ``` ## NLP Dataset Conversion Examples -```{include} ../../../streaming/text/convert/README.md -:start-line: 8 -``` +[examples/text/](../examples/text/) ## Vision Dataset Conversion Examples -```{include} ../../../streaming/vision/convert/README.md -:start-line: 8 -``` +[examples/vision/](../examples/vision/) ## Multimodal Dataset Conversion Examples ### [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/) -```{include} ../../../streaming/multimodal/convert/laion/laion400m/README.md +```{include} ../../../examples/multimodal/laion400m/README.md :start-line: 8 ``` ### [WebVid](https://m-bain.github.io/webvid-dataset/) -```{include} ../../../streaming/multimodal/convert/webvid/README.md +```{include} ../../../examples/multimodal/webvid/write/README.md :start-line: 12 ``` diff --git a/docs/source/index.md b/docs/source/index.md index 0bc2cda55..f04d25a93 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -75,11 +75,11 @@ If you have any questions, please feel free to reach out to us on [Twitter](htt :maxdepth: 1 :caption: Examples - examples/cifar10.ipynb - examples/facesynthetics.ipynb - examples/synthetic_nlp.ipynb - examples/multiprocess_dataset_conversion.ipynb - examples/spark_dataframe_to_MDS.ipynb + notebooks/cifar10.ipynb + notebooks/facesynthetics.ipynb + notebooks/synthetic_nlp.ipynb + notebooks/multiprocess_dataset_conversion.ipynb + notebooks/spark_dataframe_to_MDS.ipynb .. toctree:: :hidden: diff --git a/docs/source/notebooks b/docs/source/notebooks new file mode 120000 index 000000000..d4082256d --- /dev/null +++ b/docs/source/notebooks @@ -0,0 +1 @@ +../../notebooks/ \ No newline at end of file diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..96c5c674c --- /dev/null +++ b/examples/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Example streaming datasets.""" + +from examples import multimodal as multimodal +from examples import text as text +from examples import vision as vision + +__all__ = ['multimodal', 'text', 'vision'] diff --git a/examples/multimodal/__init__.py b/examples/multimodal/__init__.py new file mode 100644 index 000000000..53ac41f89 --- /dev/null +++ b/examples/multimodal/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Example multimodal streaming datasets.""" + +from examples.multimodal import laion400m as laion400m +from examples.multimodal import webvid as webvid + +__all__ = ['laion400m', 'webvid'] diff --git a/streaming/multimodal/convert/laion/laion400m/README.md b/examples/multimodal/laion400m/README.md similarity index 90% rename from streaming/multimodal/convert/laion/laion400m/README.md rename to examples/multimodal/laion400m/README.md index abfe45211..2869dc2a5 100644 --- a/streaming/multimodal/convert/laion/laion400m/README.md +++ b/examples/multimodal/laion400m/README.md @@ -27,7 +27,7 @@ cd streaming/ **3. Download metadata from the-eye.eu (parquet format)** ``` -./streaming/multimodal/convert/laion/laion400m/download_meta.sh +./examples/multimodal/laion400m/download_meta.sh ``` **4. Download data from the web (into parquet format, converting to mds format)** @@ -35,13 +35,13 @@ cd streaming/ The img2dataset download script saves samples in parquet files. ``` -./streaming/multimodal/convert/laion/laion400m/download_data.sh +./examples/multimodal/laion400m/download_data.sh ``` At the same time, do our conversion and uploading which uses MDS (you will want to run them at the same time, or disk usage can get excessive): ``` -./streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh +./examples/multimodal/laion400m/convert_and_upload.sh ``` **Optional** diff --git a/streaming/vision/convert/__init__.py b/examples/multimodal/laion400m/__init__.py similarity index 61% rename from streaming/vision/convert/__init__.py rename to examples/multimodal/laion400m/__init__.py index fcea5a2a2..044e80e3e 100644 --- a/streaming/vision/convert/__init__.py +++ b/examples/multimodal/laion400m/__init__.py @@ -1,4 +1,4 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Data conversion scripts for Computer Vision.""" +"""LAION-400M streaming dataset example.""" diff --git a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.py b/examples/multimodal/laion400m/convert_and_upload.py similarity index 99% rename from streaming/multimodal/convert/laion/laion400m/convert_and_upload.py rename to examples/multimodal/laion400m/convert_and_upload.py index 8af84a3d1..ddc1e4cb5 100644 --- a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.py +++ b/examples/multimodal/laion400m/convert_and_upload.py @@ -13,7 +13,7 @@ from pyarrow import parquet as pq from streaming import MDSWriter -from streaming.base.storage import CloudUploader +from streaming.storage import CloudUploader def parse_args() -> Namespace: diff --git a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh b/examples/multimodal/laion400m/convert_and_upload.sh similarity index 59% rename from streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh rename to examples/multimodal/laion400m/convert_and_upload.sh index ef74612b1..b17b9f78b 100755 --- a/streaming/multimodal/convert/laion/laion400m/convert_and_upload.sh +++ b/examples/multimodal/laion400m/convert_and_upload.sh @@ -2,7 +2,7 @@ REMOTE=$1 -python3 -m streaming.multimodal.convert.laion.laion400m.convert_and_upload \ +python3 -m examples.multimodal.laion400m.convert_and_upload \ --local laion400m-data \ --remote $REMOTE \ --keep_parquet 0 \ diff --git a/streaming/multimodal/convert/laion/laion400m/download_data.sh b/examples/multimodal/laion400m/download_data.sh similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/download_data.sh rename to examples/multimodal/laion400m/download_data.sh diff --git a/streaming/multimodal/convert/laion/laion400m/download_meta.sh b/examples/multimodal/laion400m/download_meta.sh similarity index 100% rename from streaming/multimodal/convert/laion/laion400m/download_meta.sh rename to examples/multimodal/laion400m/download_meta.sh diff --git a/streaming/text/convert/__init__.py b/examples/multimodal/webvid/__init__.py similarity index 56% rename from streaming/text/convert/__init__.py rename to examples/multimodal/webvid/__init__.py index a807b9660..fa78101db 100644 --- a/streaming/text/convert/__init__.py +++ b/examples/multimodal/webvid/__init__.py @@ -1,4 +1,4 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Data conversion scripts for Natural Language Processing.""" +"""WebVid streaming dataset example.""" diff --git a/streaming/multimodal/webvid.py b/examples/multimodal/webvid/read.py similarity index 99% rename from streaming/multimodal/webvid.py rename to examples/multimodal/webvid/read.py index 260ae7bc3..d3f74c2d9 100644 --- a/streaming/multimodal/webvid.py +++ b/examples/multimodal/webvid/read.py @@ -7,9 +7,9 @@ from time import sleep from typing import Any, Optional -from streaming.base import StreamingDataset -from streaming.base.dataset import TICK, _Iterator -from streaming.base.storage import download_file +from streaming import StreamingDataset +from streaming.dataset import TICK, _Iterator +from streaming.storage import download_file class StreamingInsideWebVid(StreamingDataset): diff --git a/scripts/webvid/bench_inside.py b/examples/multimodal/webvid/scripts/bench_inside.py similarity index 94% rename from scripts/webvid/bench_inside.py rename to examples/multimodal/webvid/scripts/bench_inside.py index 2560522a6..73fd4e8a9 100644 --- a/scripts/webvid/bench_inside.py +++ b/examples/multimodal/webvid/scripts/bench_inside.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.multimodal.webvid import StreamingInsideWebVid +from examples.multimodal.webvid.read import StreamingInsideWebVid def parse_args() -> Namespace: diff --git a/scripts/webvid/bench_outside_dt.py b/examples/multimodal/webvid/scripts/bench_outside_dt.py similarity index 81% rename from scripts/webvid/bench_outside_dt.py rename to examples/multimodal/webvid/scripts/bench_outside_dt.py index 7acda873c..83fe905eb 100644 --- a/scripts/webvid/bench_outside_dt.py +++ b/examples/multimodal/webvid/scripts/bench_outside_dt.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.multimodal.webvid import StreamingOutsideDTWebVid +from examples.multimodal.webvid.read import StreamingOutsideDTWebVid def parse_args() -> Namespace: @@ -18,12 +18,22 @@ def parse_args() -> Namespace: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--local', type=str, required=True, help='Streaming dataset local') + args.add_argument( + '--local', + type=str, + required=True, + help='Streaming dataset local', + ) args.add_argument('--extra_local', type=str, required=True, help='Streaming dataset extra local') - args.add_argument('--remote', type=str, required=True, help='Streaming dataset remote') + args.add_argument( + '--remote', + type=str, + required=True, + help='Streaming dataset remote', + ) args.add_argument('--extra_remote', type=str, required=True, diff --git a/scripts/webvid/bench_outside_gi.py b/examples/multimodal/webvid/scripts/bench_outside_gi.py similarity index 81% rename from scripts/webvid/bench_outside_gi.py rename to examples/multimodal/webvid/scripts/bench_outside_gi.py index b95efa71d..12072be8e 100644 --- a/scripts/webvid/bench_outside_gi.py +++ b/examples/multimodal/webvid/scripts/bench_outside_gi.py @@ -8,7 +8,7 @@ import numpy as np -from streaming.multimodal.webvid import StreamingOutsideGIWebVid +from examples.multimodal.webvid.read import StreamingOutsideGIWebVid def parse_args() -> Namespace: @@ -18,12 +18,22 @@ def parse_args() -> Namespace: Namespace: Command-line arguments. """ args = ArgumentParser() - args.add_argument('--local', type=str, required=True, help='Streaming dataset local') + args.add_argument( + '--local', + type=str, + required=True, + help='Streaming dataset local', + ) args.add_argument('--extra_local', type=str, required=True, help='Streaming dataset extra local') - args.add_argument('--remote', type=str, required=True, help='Streaming dataset remote') + args.add_argument( + '--remote', + type=str, + required=True, + help='Streaming dataset remote', + ) args.add_argument('--extra_remote', type=str, required=True, diff --git a/scripts/webvid/plot.py b/examples/multimodal/webvid/scripts/plot.py similarity index 100% rename from scripts/webvid/plot.py rename to examples/multimodal/webvid/scripts/plot.py diff --git a/streaming/multimodal/convert/webvid/README.md b/examples/multimodal/webvid/write/README.md similarity index 83% rename from streaming/multimodal/convert/webvid/README.md rename to examples/multimodal/webvid/write/README.md index 50b2dbf60..6c5a640de 100644 --- a/streaming/multimodal/convert/webvid/README.md +++ b/examples/multimodal/webvid/write/README.md @@ -15,7 +15,7 @@ Check out the steps below for information on converting WebVid datasets to MDS f Create an MDS dataset from a CSV file containing video URLs (downloads the videos). 1. Navigate to the [WebVid download section](https://m-bain.github.io/webvid-dataset/), where you will find 2.5M and 10M dataset splits. Download each CSV split you want to process. -2. Run the [crawl_webvid.py](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/crawl_webvid.py) script with minimum required arguments as shown below   +2. Run the [crawl_webvid.py](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/crawl_webvid.py) script with minimum required arguments as shown below     ``` python crawl_webvid.py --in --out_root   @@ -27,7 +27,7 @@ Create multiple MDS sub-datasets from a CSV file containing video URLs and a lis 1. Navigate to the [WebVid download section](https://m-bain.github.io/webvid-dataset/), where you will find 2.5M and 10M dataset splits. Download each CSV split you want to process. -2. Run the [crawl_webvid_subsets.py](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/crawl_webvid_subsets.py) script with minimum required arguments as shown below. The script also supports an optional arg `filter`, which takes a comma-separated list of keywords to filter into sub-datasets. +2. Run the [crawl_webvid_subsets.py](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/crawl_webvid_subsets.py) script with minimum required arguments as shown below. The script also supports an optional arg `filter`, which takes a comma-separated list of keywords to filter into sub-datasets. ``` python crawl_webvid_subsets.py --in --out_root @@ -36,7 +36,7 @@ Create multiple MDS sub-datasets from a CSV file containing video URLs and a lis #### Split out MDS datasets column Iterate an existing MDS dataset containing videos, creating a new MDS dataset without video contents embedded in it, instead, add a video filepath in a new MDS dataset where the video files (MP4) are stored separately. -1. Run the [extract_webvid_videos.py](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/extract_webvid_videos.py) script with minimum required arguments as shown below +1. Run the [extract_webvid_videos.py](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/extract_webvid_videos.py) script with minimum required arguments as shown below ``` python extract_webvid_videos.py --in --out_mds --out_mp4 diff --git a/streaming/multimodal/convert/webvid/__init__.py b/examples/multimodal/webvid/write/__init__.py similarity index 100% rename from streaming/multimodal/convert/webvid/__init__.py rename to examples/multimodal/webvid/write/__init__.py diff --git a/streaming/multimodal/convert/webvid/crawl_webvid.py b/examples/multimodal/webvid/write/crawl_webvid.py similarity index 100% rename from streaming/multimodal/convert/webvid/crawl_webvid.py rename to examples/multimodal/webvid/write/crawl_webvid.py diff --git a/streaming/multimodal/convert/webvid/crawl_webvid_subsets.py b/examples/multimodal/webvid/write/crawl_webvid_subsets.py similarity index 100% rename from streaming/multimodal/convert/webvid/crawl_webvid_subsets.py rename to examples/multimodal/webvid/write/crawl_webvid_subsets.py diff --git a/streaming/multimodal/convert/webvid/extract_webvid_videos.py b/examples/multimodal/webvid/write/extract_webvid_videos.py similarity index 100% rename from streaming/multimodal/convert/webvid/extract_webvid_videos.py rename to examples/multimodal/webvid/write/extract_webvid_videos.py diff --git a/examples/text/__init__.py b/examples/text/__init__.py new file mode 100644 index 000000000..59ba7324f --- /dev/null +++ b/examples/text/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Example text streaming datasets.""" + +from examples.text import c4 as c4 +from examples.text import enwiki_tok as enwiki_tok +from examples.text import enwiki_txt as enwiki_txt +from examples.text import pile as pile + +__all__ = ['c4', 'enwiki_tok', 'enwiki_txt', 'pile'] diff --git a/examples/text/c4/README.md b/examples/text/c4/README.md new file mode 100644 index 000000000..38576d19f --- /dev/null +++ b/examples/text/c4/README.md @@ -0,0 +1,7 @@ +### [C4: Colossal, Cleaned, Common Crawl dataset](https://huggingface.co/datasets/c4) + +1. Run the [write.py](https://github.com/mosaicml/streaming/blob/main/examples/text/c4/write.py) script as shown below. The script downloads the raw format with `train` and `val` splits from HuggingFace hub and converts to StreamingDataset MDS format into their own split directories. For more advanced use cases, please see the supported arguments for [c4.py](https://github.com/mosaicml/streaming/blob/main/examples/text/c4/write.py) and modify as necessary. + + ``` + python write.py --out_root + ``` diff --git a/streaming/multimodal/convert/laion/laion400m/__init__.py b/examples/text/c4/__init__.py similarity index 69% rename from streaming/multimodal/convert/laion/laion400m/__init__.py rename to examples/text/c4/__init__.py index 22968ec92..cb109ce66 100644 --- a/streaming/multimodal/convert/laion/laion400m/__init__.py +++ b/examples/text/c4/__init__.py @@ -1,4 +1,4 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""LAION-400M dataset creation.""" +"""C4 streaming dataset example.""" diff --git a/streaming/text/c4.py b/examples/text/c4/read.py similarity index 99% rename from streaming/text/c4.py rename to examples/text/c4/read.py index 82a24a255..d30340f97 100644 --- a/streaming/text/c4.py +++ b/examples/text/c4/read.py @@ -11,7 +11,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingC4'] diff --git a/streaming/text/convert/c4.py b/examples/text/c4/write.py similarity index 98% rename from streaming/text/convert/c4.py rename to examples/text/c4/write.py index 5dc186c52..941e77cb6 100644 --- a/streaming/text/convert/c4.py +++ b/examples/text/c4/write.py @@ -12,8 +12,8 @@ from torch.utils.data import DataLoader, IterableDataset, get_worker_info from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/examples/text/enwiki_tok/__init__.py b/examples/text/enwiki_tok/__init__.py new file mode 100644 index 000000000..a1809c220 --- /dev/null +++ b/examples/text/enwiki_tok/__init__.py @@ -0,0 +1 @@ +"""English Wikipedia (tokenized) streaming dataset example.""" diff --git a/streaming/text/convert/enwiki/mds/README.md b/examples/text/enwiki_tok/mds/README.md similarity index 100% rename from streaming/text/convert/enwiki/mds/README.md rename to examples/text/enwiki_tok/mds/README.md diff --git a/streaming/text/convert/enwiki/__init__.py b/examples/text/enwiki_tok/mds/__init__.py similarity index 100% rename from streaming/text/convert/enwiki/__init__.py rename to examples/text/enwiki_tok/mds/__init__.py diff --git a/streaming/text/convert/enwiki/mds/create_pretraining_data.py b/examples/text/enwiki_tok/mds/create_pretraining_data.py similarity index 100% rename from streaming/text/convert/enwiki/mds/create_pretraining_data.py rename to examples/text/enwiki_tok/mds/create_pretraining_data.py diff --git a/streaming/text/convert/enwiki/mds/make_eval.sh b/examples/text/enwiki_tok/mds/make_eval.sh similarity index 100% rename from streaming/text/convert/enwiki/mds/make_eval.sh rename to examples/text/enwiki_tok/mds/make_eval.sh diff --git a/streaming/text/convert/enwiki/mds/make_train_parallel.py b/examples/text/enwiki_tok/mds/make_train_parallel.py similarity index 100% rename from streaming/text/convert/enwiki/mds/make_train_parallel.py rename to examples/text/enwiki_tok/mds/make_train_parallel.py diff --git a/streaming/text/convert/enwiki/mds/merge_shard_groups.py b/examples/text/enwiki_tok/mds/merge_shard_groups.py similarity index 100% rename from streaming/text/convert/enwiki/mds/merge_shard_groups.py rename to examples/text/enwiki_tok/mds/merge_shard_groups.py diff --git a/streaming/text/convert/enwiki/mds/pick_eval_samples.py b/examples/text/enwiki_tok/mds/pick_eval_samples.py similarity index 100% rename from streaming/text/convert/enwiki/mds/pick_eval_samples.py rename to examples/text/enwiki_tok/mds/pick_eval_samples.py diff --git a/streaming/text/convert/enwiki/mds/tokenization.py b/examples/text/enwiki_tok/mds/tokenization.py similarity index 100% rename from streaming/text/convert/enwiki/mds/tokenization.py rename to examples/text/enwiki_tok/mds/tokenization.py diff --git a/streaming/text/convert/enwiki/mds/vocab.txt b/examples/text/enwiki_tok/mds/vocab.txt similarity index 100% rename from streaming/text/convert/enwiki/mds/vocab.txt rename to examples/text/enwiki_tok/mds/vocab.txt diff --git a/streaming/text/convert/enwiki/mds/__init__.py b/examples/text/enwiki_tok/tfrecord/__init__.py similarity index 100% rename from streaming/text/convert/enwiki/mds/__init__.py rename to examples/text/enwiki_tok/tfrecord/__init__.py diff --git a/streaming/text/convert/enwiki/tfrecord/count_samples.py b/examples/text/enwiki_tok/tfrecord/count_samples.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/count_samples.py rename to examples/text/enwiki_tok/tfrecord/count_samples.py diff --git a/streaming/text/convert/enwiki/tfrecord/create_pretraining_data.py b/examples/text/enwiki_tok/tfrecord/create_pretraining_data.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/create_pretraining_data.py rename to examples/text/enwiki_tok/tfrecord/create_pretraining_data.py diff --git a/streaming/text/convert/enwiki/tfrecord/make_eval.sh b/examples/text/enwiki_tok/tfrecord/make_eval.sh similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_eval.sh rename to examples/text/enwiki_tok/tfrecord/make_eval.sh diff --git a/streaming/text/convert/enwiki/tfrecord/make_train.sh b/examples/text/enwiki_tok/tfrecord/make_train.sh similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_train.sh rename to examples/text/enwiki_tok/tfrecord/make_train.sh diff --git a/streaming/text/convert/enwiki/tfrecord/make_train_parallel.py b/examples/text/enwiki_tok/tfrecord/make_train_parallel.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/make_train_parallel.py rename to examples/text/enwiki_tok/tfrecord/make_train_parallel.py diff --git a/streaming/text/convert/enwiki/tfrecord/pick_eval_samples.py b/examples/text/enwiki_tok/tfrecord/pick_eval_samples.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/pick_eval_samples.py rename to examples/text/enwiki_tok/tfrecord/pick_eval_samples.py diff --git a/streaming/text/convert/enwiki/tfrecord/tokenization.py b/examples/text/enwiki_tok/tfrecord/tokenization.py similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/tokenization.py rename to examples/text/enwiki_tok/tfrecord/tokenization.py diff --git a/streaming/text/convert/enwiki/tfrecord/vocab.txt b/examples/text/enwiki_tok/tfrecord/vocab.txt similarity index 100% rename from streaming/text/convert/enwiki/tfrecord/vocab.txt rename to examples/text/enwiki_tok/tfrecord/vocab.txt diff --git a/examples/text/enwiki_txt/README.txt b/examples/text/enwiki_txt/README.txt new file mode 100644 index 000000000..8ae0c36bf --- /dev/null +++ b/examples/text/enwiki_txt/README.txt @@ -0,0 +1,26 @@ +### [Wikipedia](https://huggingface.co/datasets/wikipedia) + +1. Download English Wikipedia 2020-01-01 from [here](https://drive.google.com/drive/folders/1cywmDnAsrP5-2vsr8GDc6QUc7VWe-M3v). +2. Unzip the file `results_text.zip` as shown below. + + ```bash + unzip results_text.zip + ``` + + Listing the output should show the following directory structure: + + ```bash + ├── eval.txt + ├── part-00000-of-00500 + ├── part-00001-of-00500 + ├── part-00002-of-00500 + ├── ..... + ├── part-00498-of-00500 + └── part-00499-of-00500 + ``` + +3. Run the [write.py](https://github.com/mosaicml/streaming/blob/main/examples/text/enwiki_txt/write.py) script. The script converts the `train` and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [write.py](https://github.com/mosaicml/streaming/blob/main/examples/text/enwiki_txt/write.py) and modify as necessary. + + ``` + python write.py --in_root --out_root + ``` diff --git a/examples/text/enwiki_txt/__init__.py b/examples/text/enwiki_txt/__init__.py new file mode 100644 index 000000000..01d883a01 --- /dev/null +++ b/examples/text/enwiki_txt/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""English Wikipedia (text) streaming dataset example.""" diff --git a/streaming/text/enwiki.py b/examples/text/enwiki_txt/read.py similarity index 99% rename from streaming/text/enwiki.py rename to examples/text/enwiki_txt/read.py index 63c24a5a3..4385e7394 100644 --- a/streaming/text/enwiki.py +++ b/examples/text/enwiki_txt/read.py @@ -7,7 +7,7 @@ import numpy as np -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingEnWiki'] diff --git a/streaming/text/convert/enwiki_text.py b/examples/text/enwiki_txt/write.py similarity index 97% rename from streaming/text/convert/enwiki_text.py rename to examples/text/enwiki_txt/write.py index 97f428d11..b31da8a50 100644 --- a/streaming/text/convert/enwiki_text.py +++ b/examples/text/enwiki_txt/write.py @@ -9,8 +9,8 @@ from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/examples/text/pile/README.md b/examples/text/pile/README.md new file mode 100644 index 000000000..14a895b0b --- /dev/null +++ b/examples/text/pile/README.md @@ -0,0 +1,26 @@ +### [Pile](https://pile.eleuther.ai/) + +1. Download the Pile dataset from [here](https://the-eye.eu/public/AI/pile/). + + Listing the output should show the following directory structure: + + ```bash + ├── SHA256SUMS.txt + ├── test.jsonl.zst + ├── train + │   ├── 00.jsonl.zst + │   ├── 01.jsonl.zst + │   ├── 02.jsonl.zst + │   ├── 03.jsonl.zst + │   ├── ..... + │   ├── 28.jsonl.zst + │   └── 29.jsonl.zst + └── val.jsonl.zst + ``` + +2. Run the [write.py](https://github.com/mosaicml/streaming/blob/main/examples/text/pile/write.py) script. The script converts the `train`, `test`, and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [write.py](https://github.com/mosaicml/streaming/blob/main/examples/text/pile/write.py) and modify as necessary. + + + ```bash + python write.py --in_root --out_root + ``` diff --git a/examples/text/pile/__init__.py b/examples/text/pile/__init__.py new file mode 100644 index 000000000..0126cd8db --- /dev/null +++ b/examples/text/pile/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Pile streaming dataset example.""" diff --git a/streaming/text/pile.py b/examples/text/pile/read.py similarity index 99% rename from streaming/text/pile.py rename to examples/text/pile/read.py index f2f06113b..58c4afc68 100644 --- a/streaming/text/pile.py +++ b/examples/text/pile/read.py @@ -11,7 +11,7 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingPile'] diff --git a/streaming/text/convert/pile.py b/examples/text/pile/write.py similarity index 98% rename from streaming/text/convert/pile.py rename to examples/text/pile/write.py index b01fb8027..99f78ab91 100644 --- a/streaming/text/convert/pile.py +++ b/examples/text/pile/write.py @@ -11,8 +11,8 @@ from multiprocessing import Pool from typing import Dict, Iterator, List, Tuple -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/examples/vision/__init__.py b/examples/vision/__init__.py new file mode 100644 index 000000000..4b89ee3b3 --- /dev/null +++ b/examples/vision/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Example computer vision streaming datasets.""" + +from examples.vision import ade20k as ade20k +from examples.vision import cifar10 as cifar10 +from examples.vision import coco as coco +from examples.vision import imagenet as imagenet + +__all__ = ['ade20k', 'cifar10', 'coco', 'imagenet'] diff --git a/examples/vision/ade20k/__init__.py b/examples/vision/ade20k/__init__.py new file mode 100644 index 000000000..ed70f2235 --- /dev/null +++ b/examples/vision/ade20k/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""ADE20K streaming dataset example.""" diff --git a/streaming/vision/ade20k.py b/examples/vision/ade20k/read.py similarity index 99% rename from streaming/vision/ade20k.py rename to examples/vision/ade20k/read.py index bba847115..f04fc423f 100644 --- a/streaming/vision/ade20k.py +++ b/examples/vision/ade20k/read.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional, Tuple -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingADE20K'] diff --git a/streaming/vision/convert/ade20k.py b/examples/vision/ade20k/write.py similarity index 98% rename from streaming/vision/convert/ade20k.py rename to examples/vision/ade20k/write.py index 8d0598666..ed02f1f70 100644 --- a/streaming/vision/convert/ade20k.py +++ b/examples/vision/ade20k/write.py @@ -11,8 +11,8 @@ from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/examples/vision/cifar10/__init__.py b/examples/vision/cifar10/__init__.py new file mode 100644 index 000000000..058b2b30d --- /dev/null +++ b/examples/vision/cifar10/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""CIFAR10 streaming dataset example.""" diff --git a/streaming/vision/cifar10.py b/examples/vision/cifar10/read.py similarity index 98% rename from streaming/vision/cifar10.py rename to examples/vision/cifar10/read.py index 75c2c36ac..c2f97d8ee 100644 --- a/streaming/vision/cifar10.py +++ b/examples/vision/cifar10/read.py @@ -7,7 +7,7 @@ `CIFAR-10 Dataset `_ for more details. """ -from streaming.vision.base import StreamingVisionDataset +from streaming.vision import StreamingVisionDataset __all__ = ['StreamingCIFAR10'] diff --git a/streaming/vision/convert/cifar10.py b/examples/vision/cifar10/write.py similarity index 95% rename from streaming/vision/convert/cifar10.py rename to examples/vision/cifar10/write.py index 1251338b1..0935da8a4 100644 --- a/streaming/vision/convert/cifar10.py +++ b/examples/vision/cifar10/write.py @@ -7,8 +7,8 @@ from torchvision.datasets import CIFAR10 -from streaming.base.util import get_list_arg -from streaming.vision.convert.base import convert_image_class_dataset +from streaming.util import get_list_arg +from streaming.vision import convert_image_class_dataset def parse_args() -> Namespace: diff --git a/streaming/vision/convert/fake_cifar10.py b/examples/vision/cifar10/write_fake.py similarity index 95% rename from streaming/vision/convert/fake_cifar10.py rename to examples/vision/cifar10/write_fake.py index 53ee00c03..a83e8b1ef 100644 --- a/streaming/vision/convert/fake_cifar10.py +++ b/examples/vision/cifar10/write_fake.py @@ -7,7 +7,7 @@ import numpy as np from PIL import Image -from streaming.vision.convert.base import convert_image_class_dataset +from streaming.vision import convert_image_class_dataset def parse_args() -> Namespace: diff --git a/examples/vision/coco/__init__.py b/examples/vision/coco/__init__.py new file mode 100644 index 000000000..f3533853f --- /dev/null +++ b/examples/vision/coco/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""COCO streaming dataset example.""" diff --git a/streaming/vision/coco.py b/examples/vision/coco/read.py similarity index 99% rename from streaming/vision/coco.py rename to examples/vision/coco/read.py index 162b17581..a9622eab3 100644 --- a/streaming/vision/coco.py +++ b/examples/vision/coco/read.py @@ -9,7 +9,7 @@ from typing import Any, Callable, Optional -from streaming.base import StreamingDataset +from streaming import StreamingDataset __all__ = ['StreamingCOCO'] diff --git a/streaming/vision/convert/coco.py b/examples/vision/coco/write.py similarity index 98% rename from streaming/vision/convert/coco.py rename to examples/vision/coco/write.py index 2456fc953..eb85cc61a 100644 --- a/streaming/vision/convert/coco.py +++ b/examples/vision/coco/write.py @@ -14,8 +14,8 @@ from torch.utils.data import Dataset from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/examples/vision/imagenet/__init__.py b/examples/vision/imagenet/__init__.py new file mode 100644 index 000000000..68b8d41fa --- /dev/null +++ b/examples/vision/imagenet/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""ImageNet streaming dataset example.""" diff --git a/streaming/vision/imagenet.py b/examples/vision/imagenet/read.py similarity index 98% rename from streaming/vision/imagenet.py rename to examples/vision/imagenet/read.py index 8ed47af54..ab993fe36 100644 --- a/streaming/vision/imagenet.py +++ b/examples/vision/imagenet/read.py @@ -7,7 +7,7 @@ 2012 Classification Dataset `_ for more details. """ -from streaming.vision.base import StreamingVisionDataset +from streaming.vision import StreamingVisionDataset __all__ = ['StreamingImageNet'] diff --git a/streaming/vision/convert/imagenet.py b/examples/vision/imagenet/write.py similarity index 98% rename from streaming/vision/convert/imagenet.py rename to examples/vision/imagenet/write.py index d350e9029..a69609167 100644 --- a/streaming/vision/convert/imagenet.py +++ b/examples/vision/imagenet/write.py @@ -12,8 +12,8 @@ from PIL import Image from tqdm import tqdm -from streaming.base import MDSWriter -from streaming.base.util import get_list_arg +from streaming import MDSWriter +from streaming.util import get_list_arg def parse_args() -> Namespace: diff --git a/streaming/multimodal/convert/laion/__init__.py b/notebooks/__init__.py similarity index 73% rename from streaming/multimodal/convert/laion/__init__.py rename to notebooks/__init__.py index dc40547ef..c7158a756 100644 --- a/streaming/multimodal/convert/laion/__init__.py +++ b/notebooks/__init__.py @@ -1,4 +1,4 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""LAION dataset creation.""" +"""Streaming notebooks.""" diff --git a/examples/cifar10.ipynb b/notebooks/cifar10.ipynb similarity index 100% rename from examples/cifar10.ipynb rename to notebooks/cifar10.ipynb diff --git a/examples/facesynthetics.ipynb b/notebooks/facesynthetics.ipynb similarity index 100% rename from examples/facesynthetics.ipynb rename to notebooks/facesynthetics.ipynb diff --git a/examples/multiprocess_dataset_conversion.ipynb b/notebooks/multiprocess_dataset_conversion.ipynb similarity index 98% rename from examples/multiprocess_dataset_conversion.ipynb rename to notebooks/multiprocess_dataset_conversion.ipynb index d0ce9f134..9ddf23dc1 100644 --- a/examples/multiprocess_dataset_conversion.ipynb +++ b/notebooks/multiprocess_dataset_conversion.ipynb @@ -424,7 +424,7 @@ }, "outputs": [], "source": [ - "from streaming.base.util import merge_index\n", + "from streaming.util import merge_index\n", "merge_index(out_root, keep_local=True)" ] }, @@ -508,7 +508,7 @@ "\n", "## What next?\n", "\n", - "You've now seen an in-depth tutorial on converting a dataset into MDS format using multiple process. If you are interested in the real world example, then, checkout the [WebVid](https://github.com/mosaicml/streaming/blob/main/streaming/multimodal/convert/webvid/crawl_webvid.py) and [Pile](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) dataset conversion scripts which converts the dataset into MDS format via multiprocessing." + "You've now seen an in-depth tutorial on converting a dataset into MDS format using multiple process. If you are interested in the real world example, then, checkout the [WebVid](https://github.com/mosaicml/streaming/blob/main/examples/multimodal/webvid/write/crawl_webvid.py) and [Pile](https://github.com/mosaicml/streaming/blob/main/examples/text/pile/write.py) dataset conversion scripts which converts the dataset into MDS format via multiprocessing." ] }, { diff --git a/examples/spark_dataframe_to_MDS.ipynb b/notebooks/spark_dataframe_to_MDS.ipynb similarity index 99% rename from examples/spark_dataframe_to_MDS.ipynb rename to notebooks/spark_dataframe_to_MDS.ipynb index c5617d464..72c72961b 100644 --- a/examples/spark_dataframe_to_MDS.ipynb +++ b/notebooks/spark_dataframe_to_MDS.ipynb @@ -137,7 +137,7 @@ { "cell_type": "code", "source": [ - "from streaming.base.converters import dataframeToMDS" + "from streaming.converters import dataframeToMDS" ], "metadata": { "id": "uzYHe6yYRzyV" @@ -500,7 +500,7 @@ "from streaming import StreamingDataset\n", "\n", "# clean stale shared memory if any\n", - "streaming.base.util.clean_stale_shared_memory()\n", + "streaming.util.clean_stale_shared_memory()\n", "\n", "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n", "\n", @@ -773,7 +773,7 @@ "from streaming import StreamingDataset\n", "\n", "# clean stale shared memory if any\n", - "streaming.base.util.clean_stale_shared_memory()\n", + "streaming.util.clean_stale_shared_memory()\n", "\n", "dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)\n", "\n", diff --git a/examples/synthetic_nlp.ipynb b/notebooks/synthetic_nlp.ipynb similarity index 100% rename from examples/synthetic_nlp.ipynb rename to notebooks/synthetic_nlp.ipynb diff --git a/pyproject.toml b/pyproject.toml index 742d541ec..a1cde9d45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,10 @@ include = [ ] exclude = [ "build/**", + "docs/source/conf.py", + "docs/source/examples/text/enwiki_tok/**", + "examples/text/enwiki_tok/**", "node_modules/**", - "streaming/text/convert/enwiki/**", - "docs/source/conf.py" ] # Disable checks for missing imports, as a conditional install of streaming will not include them diff --git a/regression/iterate_data.py b/regression/iterate_data.py index eab9131b5..bdbc77a10 100644 --- a/regression/iterate_data.py +++ b/regression/iterate_data.py @@ -17,8 +17,7 @@ get_streaming_dataset_params) from streaming import StreamingDataset -from streaming.base.distributed import (all_gather, barrier, get_rank, get_world_size, - maybe_init_dist) +from streaming.distributed import all_gather, barrier, get_rank, get_world_size, maybe_init_dist logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/regression/synthetic_dataset.py b/regression/synthetic_dataset.py index c90cbb888..3e8f44d78 100644 --- a/regression/synthetic_dataset.py +++ b/regression/synthetic_dataset.py @@ -14,7 +14,7 @@ import torch from utils import delete_gcs, delete_oci, delete_s3, get_kwargs, get_writer_params -from streaming.base import MDSWriter +from streaming import MDSWriter _DATASET_MAP = { 'sequencedataset': 'SequenceDataset', diff --git a/simulation/core/create_index.py b/simulation/core/create_index.py index 41e356c71..1809f1527 100644 --- a/simulation/core/create_index.py +++ b/simulation/core/create_index.py @@ -10,7 +10,7 @@ import string from typing import Optional -from streaming.base.format import get_index_basename +from streaming.format import get_index_basename logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/simulation/core/node_tracker.py b/simulation/core/node_tracker.py index 05933227c..f6ce55947 100644 --- a/simulation/core/node_tracker.py +++ b/simulation/core/node_tracker.py @@ -11,7 +11,7 @@ from numpy.typing import NDArray from sortedcollections import OrderedSet -from streaming.base.spanner import Spanner +from streaming.spanner import Spanner class NodeTracker(): diff --git a/simulation/core/shuffle_quality.py b/simulation/core/shuffle_quality.py index 38cb16dab..2e593d75d 100644 --- a/simulation/core/shuffle_quality.py +++ b/simulation/core/shuffle_quality.py @@ -10,8 +10,8 @@ from core.utils import remove_padded_samples from numpy.typing import NDArray -from streaming.base.partition.orig import get_partitions_orig -from streaming.base.shuffle import get_shuffle +from streaming.partition.orig import get_partitions_orig +from streaming.shuffle import get_shuffle logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index 3be35fd1a..57ad2f5d0 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -16,11 +16,11 @@ from core.sim_world import SimulationWorld from numpy.typing import NDArray -from streaming.base import Stream, StreamingDataset -from streaming.base.batching import generate_work -from streaming.base.format import get_index_basename -from streaming.base.spanner import Spanner -from streaming.base.util import bytes_to_int, number_abbrev_to_int +from streaming import Stream, StreamingDataset +from streaming.batching import generate_work +from streaming.format import get_index_basename +from streaming.spanner import Spanner +from streaming.util import bytes_to_int, number_abbrev_to_int logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/simulation/core/sim_spanner.py b/simulation/core/sim_spanner.py index 7a1858267..dd5996734 100644 --- a/simulation/core/sim_spanner.py +++ b/simulation/core/sim_spanner.py @@ -5,7 +5,7 @@ from typing import Tuple -from streaming.base.spanner import Spanner +from streaming.spanner import Spanner class SimulationSpanner(Spanner): diff --git a/simulation/core/sim_world.py b/simulation/core/sim_world.py index 6c607b8ad..3449f04df 100644 --- a/simulation/core/sim_world.py +++ b/simulation/core/sim_world.py @@ -3,7 +3,7 @@ """Contains info about the nodes, ranks, and workers of the run for simulation purposes.""" -from streaming.base.world import World +from streaming.world import World class SimulationWorld(World): diff --git a/simulation/core/yaml_processing.py b/simulation/core/yaml_processing.py index b16ec4a09..e1ddefab2 100644 --- a/simulation/core/yaml_processing.py +++ b/simulation/core/yaml_processing.py @@ -10,7 +10,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from streaming.base import Stream +from streaming import Stream def ingest_yaml(yaml_dict: Optional[dict] = None, diff --git a/simulation/interfaces/interface_utils.py b/simulation/interfaces/interface_utils.py index 863ced6d6..1588d6ab7 100644 --- a/simulation/interfaces/interface_utils.py +++ b/simulation/interfaces/interface_utils.py @@ -14,7 +14,7 @@ from core.utils import get_rolling_avg_throughput from numpy.typing import NDArray -from streaming.base.util import number_abbrev_to_int +from streaming.util import number_abbrev_to_int def plot_simulation(step_times: NDArray, step_downloads: NDArray, window: int = 10): diff --git a/simulation/interfaces/sim_cli.py b/simulation/interfaces/sim_cli.py index c7606b1af..521053604 100644 --- a/simulation/interfaces/sim_cli.py +++ b/simulation/interfaces/sim_cli.py @@ -16,7 +16,7 @@ from core.yaml_processing import create_simulation_dataset, ingest_yaml from interfaces.interface_utils import plot_simulation -from streaming.base.util import bytes_to_int +from streaming.util import bytes_to_int if __name__ == '__main__': parser = argparse.ArgumentParser(description='Simulate your training yaml from the command \ diff --git a/simulation/interfaces/sim_script.py b/simulation/interfaces/sim_script.py index 4d7b0596b..18291df5e 100644 --- a/simulation/interfaces/sim_script.py +++ b/simulation/interfaces/sim_script.py @@ -16,7 +16,7 @@ from core.utils import get_simulation_stats from interfaces.interface_utils import plot_simulation -from streaming.base import Stream +from streaming import Stream # Input Parameters diff --git a/simulation/interfaces/sim_ui.py b/simulation/interfaces/sim_ui.py index 77848dd79..5da4e42b7 100644 --- a/simulation/interfaces/sim_ui.py +++ b/simulation/interfaces/sim_ui.py @@ -28,7 +28,7 @@ from interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats, get_line_chart, param_inputs) -from streaming.base.util import bytes_to_int, number_abbrev_to_int +from streaming.util import bytes_to_int, number_abbrev_to_int # set up page st.set_page_config(layout='wide') diff --git a/simulation/interfaces/widgets.py b/simulation/interfaces/widgets.py index d9959befe..cc600f3bd 100644 --- a/simulation/interfaces/widgets.py +++ b/simulation/interfaces/widgets.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray from streamlit.delta_generator import DeltaGenerator -from streaming.base.util import bytes_to_int +from streaming.util import bytes_to_int def get_line_chart(data: pd.DataFrame, diff --git a/simulation/testing/wandb_testing.py b/simulation/testing/wandb_testing.py index 3944be3f9..4d1d1ec46 100644 --- a/simulation/testing/wandb_testing.py +++ b/simulation/testing/wandb_testing.py @@ -21,7 +21,7 @@ from core.sim_time import TimeUnit, ensure_time from numpy.typing import NDArray -from streaming.base import Stream +from streaming import Stream logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/streaming/__init__.py b/streaming/__init__.py index 7023580ce..45ca3f1cf 100644 --- a/streaming/__init__.py +++ b/streaming/__init__.py @@ -3,24 +3,15 @@ """MosaicML Streaming Datasets for cloud-native model training.""" -import streaming.multimodal as multimodal -import streaming.text as text -import streaming.vision as vision from streaming._version import __version__ -from streaming.base import (CSVWriter, JSONWriter, LocalDataset, MDSWriter, Stream, - StreamingDataLoader, StreamingDataset, TSVWriter, XSVWriter) +from streaming.dataloader import StreamingDataLoader +from streaming.dataset import StreamingDataset +from streaming.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter +from streaming.local import LocalDataset +from streaming.stream import Stream +from streaming.util import clean_stale_shared_memory __all__ = [ - 'StreamingDataLoader', - 'Stream', - 'StreamingDataset', - 'CSVWriter', - 'JSONWriter', - 'MDSWriter', - 'TSVWriter', - 'XSVWriter', - 'LocalDataset', - 'multimodal', - 'vision', - 'text', + 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', + 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory' ] diff --git a/streaming/base/array.py b/streaming/array.py similarity index 100% rename from streaming/base/array.py rename to streaming/array.py diff --git a/streaming/base/__init__.py b/streaming/base/__init__.py index 8834b9bea..ec0fc8af5 100644 --- a/streaming/base/__init__.py +++ b/streaming/base/__init__.py @@ -1,15 +1,11 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""MosaicML Streaming Datasets for cloud-native model training.""" - -from streaming.base.dataloader import StreamingDataLoader -from streaming.base.dataset import StreamingDataset -from streaming.base.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter -from streaming.base.local import LocalDataset -from streaming.base.stream import Stream - -__all__ = [ - 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', - 'MDSWriter', 'TSVWriter', 'XSVWriter' -] +"""This module has moved. + +Please update your imports to ``streaming``. +""" + +from streaming.util import redirect_imports + +redirect_imports('streaming') diff --git a/streaming/base/converters/README.md b/streaming/base/converters/README.md deleted file mode 100644 index cf275bfc0..000000000 --- a/streaming/base/converters/README.md +++ /dev/null @@ -1,7 +0,0 @@ -### Spark Dataframe Conversion - -Users can read datasets of any formats that Spark supports and convert the Spark dataframe to a Mosaic Streaming dataset. More specifically, - -1. We enable converting a Spark DataFrame into an MDS format via the utility function [dataframeToMDS](https://github.com/mosaicml/streaming/blob/main/streaming/base/converters/dataframe_to_mds.py). This utility function is flexible and supports a callable function, allowing modifications to the original data format. The function iterates over the callable, processes the modified data, and writes it in MDS format. For instance, it can be used with a tokenizer callable function that yields tokens as output. - -2. Users are recommended to refer to the starting example [Jupyter notebook](https://github.com/mosaicml/streaming/blob/main/examples/spark_dataframe_to_MDS.ipynb) which demonstrates a complete workflow. It illustrates how to use Spark to read raw data into a Spark DataFrame and then convert it into the MDS format via the `dataframeToMDS` function. In that tutorial, we also demonstrate the option to pass in a preprocessing tokenization job to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development. diff --git a/streaming/base/format/base/__init__.py b/streaming/base/format/base/__init__.py deleted file mode 100644 index 46bf9f730..000000000 --- a/streaming/base/format/base/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Base module for dataset reader and writer.""" - -from streaming.base.format.base.reader import FileInfo, Reader - -__all__ = ['FileInfo', 'Reader'] diff --git a/streaming/base/shared/__init__.py b/streaming/base/shared/__init__.py deleted file mode 100644 index cf507c4fe..000000000 --- a/streaming/base/shared/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Objects that live in shared memory. - -For when using `threading` or `multiprocessing` from the python standard library won't do, because -we are coordinating separately instantiated pytorch worker processes. -""" - -from streaming.base.shared.array import SharedArray as SharedArray -from streaming.base.shared.barrier import SharedBarrier as SharedBarrier -from streaming.base.shared.memory import SharedMemory as SharedMemory -from streaming.base.shared.prefix import _get_path as _get_path -from streaming.base.shared.prefix import get_shm_prefix as get_shm_prefix -from streaming.base.shared.scalar import SharedScalar as SharedScalar - -__all__ = ['SharedArray', 'SharedBarrier', 'SharedMemory', 'get_shm_prefix', 'SharedScalar'] diff --git a/streaming/base/storage/__init__.py b/streaming/base/storage/__init__.py deleted file mode 100644 index d3658656b..000000000 --- a/streaming/base/storage/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Base module for downloading/uploading files from/to cloud storage.""" - -from streaming.base.storage.download import (download_file, download_from_azure, - download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_oci, - download_from_s3, download_from_sftp) -from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - GCSUploader, LocalUploader, OCIUploader, S3Uploader) - -__all__ = [ - 'download_file', - 'CloudUploader', - 'S3Uploader', - 'GCSUploader', - 'OCIUploader', - 'LocalUploader', - 'AzureUploader', - 'AzureDataLakeUploader', - 'download_from_s3', - 'download_from_sftp', - 'download_from_gcs', - 'download_from_oci', - 'download_from_azure', - 'download_from_azure_datalake', - 'download_from_databricks_unity_catalog', - 'download_from_dbfs', - 'download_from_local', -] diff --git a/streaming/base/util.py b/streaming/base/util.py index e86876ee1..2c290770a 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -1,525 +1,11 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Utility and helper functions for datasets.""" +"""This module has moved. -import collections.abc -import functools -import json -import logging -import os -import random -import shutil -import tempfile -import urllib.parse -from collections import OrderedDict -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from pathlib import Path -from time import sleep, time -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +Please update your imports to ``streaming.util``. +""" -import torch.distributed as dist +from streaming.util import redirect_imports -from streaming.base.constant import SHM_TO_CLEAN -from streaming.base.distributed import get_local_rank, maybe_init_dist -from streaming.base.format.index import get_index_basename -from streaming.base.shared.prefix import _get_path - -logger = logging.getLogger(__name__) - -TCallable = TypeVar('TCallable', bound=Callable) - -__all__ = [ - 'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int', - 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry' -] - - -def get_list_arg(text: str) -> List[str]: - """Pass a list as a command-line flag. - - Args: - text (str): Text to split. - - Returns: - List[str]: Splits, if any. - """ - return text.split(',') if text else [] - - -def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, - err_msg: str) -> None: - """Wait for the file to exist till timeout seconds. Raise an Exception after that. - - Args: - filename (str): A file name - poll_interval (float): Number of seconds to wait before next polling - timeout (float): Number of seconds to wait for a file to exist before raising an exception - err_msg (str): Error message description for an exception - - Raises: - RuntimeError: Raise an Exception if file does not exist after timeout - """ - start_time = time() - while True: - sleep(poll_interval) - if os.path.exists(filename): - sleep(poll_interval) - break - dt = time() - start_time - if dt > timeout: - raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') - - -def bytes_to_int(bytes_str: Union[int, str]) -> int: - """Convert human readable byte format to an integer. - - Args: - bytes_str (Union[int, str]): Value to convert. - - Raises: - ValueError: Invalid byte suffix. - - Returns: - int: Integer value of bytes. - """ - #input is already an int - if isinstance(bytes_str, int) or isinstance(bytes_str, float): - return int(bytes_str) - - units = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, - 'pb': 1024**5, - 'eb': 1024**6, - 'zb': 1024**7, - 'yb': 1024**8, - } - # Convert a various byte types to an integer - for suffix in units: - bytes_str = bytes_str.lower().strip() - if bytes_str.lower().endswith(suffix): - try: - return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - else: - # Convert bytes to an integer - if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): - return int(bytes_str[0:-1]) - # Convert string representation of a number to an integer - elif bytes_str.isdigit(): - return int(bytes_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - - -def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: - """Convert human readable number abbreviations to an integer. - - Args: - abbrev_str (Union[int, str]): Value to convert. - - Raises: - ValueError: Invalid number suffix. - - Returns: - int: Integer value of number abbreviation. - """ - #input is already an int - if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): - return int(abbrev_str) - - units = { - 'k': 10**3, - 'm': 10**6, - 'b': 10**9, - 't': 10**12, - } - # Convert a various abbreviation types to an integer - for suffix in units: - abbrev_str = abbrev_str.lower().strip() - if abbrev_str.lower().endswith(suffix): - try: - return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - else: - # Convert string representation of a number to an integer - if abbrev_str.isdigit(): - return int(abbrev_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - - -def clean_stale_shared_memory() -> None: - """Clean up all the leaked shared memory. - - In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for - the local rank 0 to finish. - """ - # Initialize torch.distributed ourselves, if necessary. - destroy_dist = maybe_init_dist() - - # Perform clean up on local rank 0 - if get_local_rank() == 0: - for prefix_int in range(1000000): - leaked_shm = False - for shm_name in SHM_TO_CLEAN: - name = _get_path(prefix_int, shm_name) - try: - shm = BuiltinSharedMemory(name, True, 4) - except FileExistsError: - shm = BuiltinSharedMemory(name, False, 4) - leaked_shm = True - finally: - shm.close() # pyright: ignore - shm.unlink() - # Come out of loop if no leaked shared memory - if not leaked_shm: - break - - # Sync all ranks - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - # Delete the process group if Streaming initialized it. - if destroy_dist: - dist.destroy_process_group() - - -def get_import_exception_message(package_name: str, extra_deps: str) -> str: - """Get import exception message. - - Args: - package_name (str): Package name. - - Returns: - str: Exception message. - """ - return f'Streaming was installed without {package_name} support. ' + \ - f'To use {package_name} related packages with Streaming, run ' + \ - f'`pip install \'mosaicml-streaming[{package_name}]\'`.' - - -def merge_index(*args: Any, **kwargs: Any): - r"""Merge index.json from partitions to form a global index.json. - - This can be called as - - merge_index(index_file_urls, out, keep_local, download_timeout) - - merge_index(out, keep_local, download_timeout) - - The first signature takes in a list of index files URLs of MDS partitions. - The second takes the root of a MDS dataset and parse the partition folders from there. - - Args: - index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. - Each element can take the form of a single path string or a tuple string. - - 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. - 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. - 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. - - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file - - 1. A local directory, merge index happens locally. - 2. A remote directory, download all the sub-directories index.json, merge locally and upload. - 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. - - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: - return _merge_index_from_list(*args, **kwargs) - elif (isinstance(args[0], str) or - isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]: - return _merge_index_from_root(*args, **kwargs) - raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') - - -def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], - out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: - """Merge index.json from a list of index files of MDS directories to create joined index. - - Args: - index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions - each element can take the form of a single path string or a tuple string. - - The pattern of index_file_urls and corresponding reaction is one of: - 1. All URLS are str (local). All URLS are accessible locally -> no download - 2. All URLS are tuple (local, remote). All URLS are accessible locally -> no download - 3. All URLS are tuple (local, remote). Download URL that is not accessible locally - 4. All URLS are str (remote) -> download all - - out (Union[str, Tuple[str, str]]): path to put the merged index file - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - from streaming.base.storage.download import download_file - from streaming.base.storage.upload import CloudUploader - - if not index_file_urls or not out: - logger.warning('Either index_file_urls or out are None. ' + - 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') - return - - # This is the index json file name, e.g., it is index.json as of 0.6.0 - index_basename = get_index_basename() - - cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - - # Remove duplicates, and strip '/' from right if any - index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) - urls = [] - for url in index_file_urls: - if isinstance(url, str): - urls.append(url.rstrip('/').strip()) - else: - urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) - - # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. - with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f'A temporary folder {temp_root} is created to store index files') - - # Copy files to a temporary directory. Download if necessary - partitions = [] - for url in urls: - if isinstance(url, tuple): - src = url[0] if os.path.exists(url[0]) else url[1] - else: - src = url - - obj = urllib.parse.urlparse(src) - scheme, bucket, path = obj.scheme, obj.netloc, obj.path - if scheme == '' and bucket == '' and path == '': - raise FileNotFoundError( - f'Check data availability! local index {url[0]} is not accessible.' + - f'remote index {url[1]} does not have a valid url format') - dest = os.path.join(temp_root, path.lstrip('/')) - - try: - download_file(src, dest, download_timeout) - except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex - - if not os.path.exists(dest): - raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') - - partitions.append(dest) - - # merge shards from all index files - shards = [] - for partition_index in partitions: - p = Path(partition_index) - obj = json.load(open(partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join( - os.path.basename(p.parent), basename) - shards += obj['shards'] - - # Save merged index locally - obj = { - 'version': 2, - 'shards': shards, - } - merged_index_path = os.path.join(temp_root, index_basename) - with open(merged_index_path, 'w') as outfile: - json.dump(obj, outfile) - - # Move merged index from temp path to local part in out - # Upload merged index to remote if out has remote part - shutil.move(merged_index_path, cu.local) - if cu.remote is not None: - cu.upload_file(index_basename) - - # Clean up - if not keep_local: - shutil.rmtree(cu.local, ignore_errors=True) - - -def _merge_index_from_root(out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: - """Merge index.json given the root of MDS dataset. Write merged index to the root folder. - - Args: - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. - :A local directory, merge index happens locally - :A remote directory, download all the sub-directories index.json in a temporary - sub-directories, merge locally, and then upload it to out location - :A (local_dir, remote_dir), check if sub-directories index.json file present locally - If yes, then merge locally and upload to remote_dir . - If not, download all the sub-directories index.json from remote to local, - merge locally, and upload to remote_dir . - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - from streaming.base.storage.upload import CloudUploader - - def not_merged_index(index_file_path: str, out: str): - """Check if index_file_path is the merged index at folder out. - - Args: - index_file_path (str): the path to index.json file - out (str): remote or local url of a folder - Return: - (bool): no if index.json sits in out instead of in the subfolders of out - """ - prefix = str(urllib.parse.urlparse(out).path) - return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - - if not out: - logger.warning('No MDS dataset folder specified, no index merged') - return - - cu = CloudUploader.get(out, exist_ok=True, keep_local=True) - - local_index_files = [] - cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) - for file in cl.list_objects(): - if file.endswith('.json') and not_merged_index(file, cu.local): - local_index_files.append(file) - - if cu.remote: - obj = urllib.parse.urlparse(cu.remote) - remote_index_files = [] - for file in cu.list_objects(): - if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): - join_char = '//' - if obj.scheme == 'dbfs': - path = Path(cu.remote) - prefix = os.path.join(path.parts[0], path.parts[1]) - if prefix == 'dbfs:/Volumes': - join_char = '/' - remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file)) - if len(local_index_files) == len(remote_index_files): - _merge_index_from_list(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=download_timeout) - else: - _merge_index_from_list(remote_index_files, - out, - keep_local=keep_local, - download_timeout=download_timeout) - return - - _merge_index_from_list(local_index_files, - out, - keep_local=keep_local, - download_timeout=download_timeout) - - -@overload -def retry( - exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., - num_attempts: int = ..., - initial_backoff: float = ..., - max_jitter: float = ..., -) -> Callable[[TCallable], TCallable]: - ... - - -@overload -def retry(exc_class: TCallable) -> TCallable: - # Use the decorator without parenthesis - ... - - -# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type -# "(func: Never) -> Never" -def retry( # type: ignore - exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, - num_attempts: int = 3, - initial_backoff: float = 1.0, - max_jitter: float = 0.5, -): - """Decorator to retry a function with backoff and jitter. - - Attempts are spaced out with - ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. - - Example: - .. testcode:: - - from streaming.base.util import retry - - num_tries = 0 - - @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) - def flaky_function(): - global num_tries - if num_tries < 2: - num_tries += 1 - raise RuntimeError("Called too soon!") - return "Third time's a charm." - - print(flaky_function()) - - .. testoutput:: - - Third time's a charm. - - Args: - exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or - classes to retry. Defaults to Exception. - num_attempts (int, optional): The total number of attempts to make. Defaults to 3. - initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. - max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. - - Increasing the ``max_jitter`` can help prevent overloading a resource when multiple - processes in parallel are calling the same underlying function. - """ - if num_attempts < 1: - raise ValueError('num_attempts must be at-least 1') - - def wrapped_func(func: TCallable) -> TCallable: - - @functools.wraps(func) - def new_func(*args: Any, **kwargs: Any): - i = 0 - while True: - try: - return func(*args, **kwargs) - except exc_class as e: - if i + 1 == num_attempts: - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - raise e - else: - sleep(initial_backoff * 2**i + random.random() * max_jitter) - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - i += 1 - - return cast(TCallable, new_func) - - if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( - exc_class, type) and issubclass(exc_class, Exception)): - # Using the decorator without (), like @retry_with_backoff - func = cast(TCallable, exc_class) - exc_class = Exception - - return wrapped_func(func) - - return wrapped_func +redirect_imports('streaming.util') diff --git a/streaming/base/batching/__init__.py b/streaming/batching/__init__.py similarity index 79% rename from streaming/base/batching/__init__.py rename to streaming/batching/__init__.py index f4fd7f788..a95de0147 100644 --- a/streaming/base/batching/__init__.py +++ b/streaming/batching/__init__.py @@ -9,13 +9,13 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.batching.per_stream import generate_work_per_stream_batching -from streaming.base.batching.random import generate_work_random_batching -from streaming.base.batching.stratified import generate_work_stratified_batching -from streaming.base.world import World +from streaming.batching.per_stream import generate_work_per_stream_batching +from streaming.batching.random import generate_work_random_batching +from streaming.batching.stratified import generate_work_stratified_batching +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset batching_methods = { 'random': generate_work_random_batching, diff --git a/streaming/base/batching/per_stream.py b/streaming/batching/per_stream.py similarity index 97% rename from streaming/base/batching/per_stream.py rename to streaming/batching/per_stream.py index 1686720b9..e955b0114 100644 --- a/streaming/base/batching/per_stream.py +++ b/streaming/batching/per_stream.py @@ -10,12 +10,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/base/batching/random.py b/streaming/batching/random.py similarity index 94% rename from streaming/base/batching/random.py rename to streaming/batching/random.py index 48e803acb..a716e0515 100644 --- a/streaming/base/batching/random.py +++ b/streaming/batching/random.py @@ -10,12 +10,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/base/batching/stratified.py b/streaming/batching/stratified.py similarity index 98% rename from streaming/base/batching/stratified.py rename to streaming/batching/stratified.py index 2eef06fd5..aff18eba6 100644 --- a/streaming/base/batching/stratified.py +++ b/streaming/batching/stratified.py @@ -11,12 +11,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition import get_partitions -from streaming.base.shuffle import get_shuffle -from streaming.base.world import World +from streaming.partition import get_partitions +from streaming.shuffle import get_shuffle +from streaming.world import World if TYPE_CHECKING: - from streaming.base.dataset import StreamingDataset + from streaming.dataset import StreamingDataset logger = logging.getLogger(__name__) diff --git a/streaming/base/compression.py b/streaming/compression.py similarity index 100% rename from streaming/base/compression.py rename to streaming/compression.py diff --git a/streaming/base/constant.py b/streaming/constant.py similarity index 100% rename from streaming/base/constant.py rename to streaming/constant.py diff --git a/streaming/converters/README.md b/streaming/converters/README.md new file mode 100644 index 000000000..64a7c0659 --- /dev/null +++ b/streaming/converters/README.md @@ -0,0 +1,7 @@ +### Spark Dataframe Conversion + +Users can read datasets of any formats that Spark supports and convert the Spark dataframe to a Mosaic Streaming dataset. More specifically, + +1. We enable converting a Spark DataFrame into an MDS format via the utility function [dataframeToMDS](https://github.com/mosaicml/streaming/blob/main/streaming/converters/dataframe_to_mds.py). This utility function is flexible and supports a callable function, allowing modifications to the original data format. The function iterates over the callable, processes the modified data, and writes it in MDS format. For instance, it can be used with a tokenizer callable function that yields tokens as output. + +2. Users are recommended to refer to the starting example [Jupyter notebook](https://github.com/mosaicml/streaming/blob/main/notebooks/spark_dataframe_to_MDS.ipynb) which demonstrates a complete workflow. It illustrates how to use Spark to read raw data into a Spark DataFrame and then convert it into the MDS format via the `dataframeToMDS` function. In that tutorial, we also demonstrate the option to pass in a preprocessing tokenization job to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development. diff --git a/streaming/base/converters/__init__.py b/streaming/converters/__init__.py similarity index 57% rename from streaming/base/converters/__init__.py rename to streaming/converters/__init__.py index 8fbbed094..d12602e57 100644 --- a/streaming/base/converters/__init__.py +++ b/streaming/converters/__init__.py @@ -3,7 +3,7 @@ """Utility function for converting spark dataframe to MDS dataset.""" -from streaming.base.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds, - dataframeToMDS) +from streaming.converters.dataframe_to_mds import (MAPPING_SPARK_TO_MDS, dataframe_to_mds, + dataframeToMDS) __all__ = ['dataframeToMDS', 'dataframe_to_mds', 'MAPPING_SPARK_TO_MDS'] diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/converters/dataframe_to_mds.py similarity index 97% rename from streaming/base/converters/dataframe_to_mds.py rename to streaming/converters/dataframe_to_mds.py index c74460b3f..5093a9d7f 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/converters/dataframe_to_mds.py @@ -11,8 +11,8 @@ import pandas as pd -from streaming.base.util import get_import_exception_message -from streaming.base.util import merge_index as do_merge_index +from streaming.util import get_import_exception_message +from streaming.util import merge_index as do_merge_index try: from pyspark import TaskContext @@ -26,9 +26,9 @@ raise e from streaming import MDSWriter -from streaming.base.format.index import get_index_basename -from streaming.base.format.mds.encodings import _encodings -from streaming.base.storage.upload import CloudUploader +from streaming.format.index import get_index_basename +from streaming.format.mds.encodings import _encodings +from streaming.storage.upload import CloudUploader logger = logging.getLogger(__name__) diff --git a/streaming/base/dataloader.py b/streaming/dataloader.py similarity index 96% rename from streaming/base/dataloader.py rename to streaming/dataloader.py index 89cdb0026..a0c881f34 100644 --- a/streaming/base/dataloader.py +++ b/streaming/dataloader.py @@ -9,8 +9,8 @@ from torch.utils.data import DataLoader from transformers import BatchEncoding, BatchFeature -from streaming.base.dataset import StreamingDataset -from streaming.base.world import World +from streaming.dataset import StreamingDataset +from streaming.world import World class StreamingDataLoader(DataLoader): diff --git a/streaming/base/dataset.py b/streaming/dataset.py similarity index 98% rename from streaming/base/dataset.py rename to streaming/dataset.py index 6ef4b7a80..844f88d10 100644 --- a/streaming/base/dataset.py +++ b/streaming/dataset.py @@ -22,20 +22,20 @@ from torch import distributed as dist from torch.utils.data import IterableDataset -from streaming.base.array import Array -from streaming.base.batching import generate_work -from streaming.base.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, - EPOCH_DATA, EPOCH_SHAPE, NEXT_EPOCH, RESUME, - SHARD_ACCESS_TIMES, SHARD_STATES, TICK) -from streaming.base.distributed import maybe_init_dist -from streaming.base.format import get_index_basename -from streaming.base.sampling import get_sampling -from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, - _get_path, get_shm_prefix) -from streaming.base.spanner import Spanner -from streaming.base.stream import Stream -from streaming.base.util import bytes_to_int, number_abbrev_to_int -from streaming.base.world import World +from streaming.array import Array +from streaming.batching import generate_work +from streaming.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA, + EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES, + TICK) +from streaming.distributed import maybe_init_dist +from streaming.format import get_index_basename +from streaming.sampling import get_sampling +from streaming.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, + get_shm_prefix) +from streaming.spanner import Spanner +from streaming.stream import Stream +from streaming.util import bytes_to_int, number_abbrev_to_int +from streaming.world import World # An arbitrary time in the future, used for cold shard eviction. NEVER = np.iinfo(np.uint64).max diff --git a/streaming/base/distributed.py b/streaming/distributed.py similarity index 100% rename from streaming/base/distributed.py rename to streaming/distributed.py diff --git a/streaming/base/format/__init__.py b/streaming/format/__init__.py similarity index 71% rename from streaming/base/format/__init__.py rename to streaming/format/__init__.py index 962828ae2..bbec4927e 100644 --- a/streaming/base/format/__init__.py +++ b/streaming/format/__init__.py @@ -5,12 +5,11 @@ from typing import Any, Dict, Optional -from streaming.base.format.base import FileInfo, Reader -from streaming.base.format.index import get_index_basename -from streaming.base.format.json import JSONReader, JSONWriter -from streaming.base.format.mds import MDSReader, MDSWriter -from streaming.base.format.xsv import (CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, - XSVWriter) +from streaming.format.index import get_index_basename +from streaming.format.json import JSONReader, JSONWriter +from streaming.format.mds import MDSReader, MDSWriter +from streaming.format.reader import FileInfo, Reader +from streaming.format.xsv import CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter __all__ = [ 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Reader', diff --git a/streaming/base/format/index.py b/streaming/format/index.py similarity index 100% rename from streaming/base/format/index.py rename to streaming/format/index.py diff --git a/streaming/base/format/json/README.md b/streaming/format/json/README.md similarity index 100% rename from streaming/base/format/json/README.md rename to streaming/format/json/README.md diff --git a/streaming/base/format/json/__init__.py b/streaming/format/json/__init__.py similarity index 61% rename from streaming/base/format/json/__init__.py rename to streaming/format/json/__init__.py index fe37c8570..47e8be8f6 100644 --- a/streaming/base/format/json/__init__.py +++ b/streaming/format/json/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in JSON format.""" -from streaming.base.format.json.reader import JSONReader -from streaming.base.format.json.writer import JSONWriter +from streaming.format.json.reader import JSONReader +from streaming.format.json.writer import JSONWriter __all__ = ['JSONReader', 'JSONWriter'] diff --git a/streaming/base/format/json/encodings.py b/streaming/format/json/encodings.py similarity index 100% rename from streaming/base/format/json/encodings.py rename to streaming/format/json/encodings.py diff --git a/streaming/base/format/json/reader.py b/streaming/format/json/reader.py similarity index 98% rename from streaming/base/format/json/reader.py rename to streaming/format/json/reader.py index 4aaeb91cc..698783d71 100644 --- a/streaming/base/format/json/reader.py +++ b/streaming/format/json/reader.py @@ -11,7 +11,7 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, SplitReader +from streaming.format.reader import FileInfo, SplitReader __all__ = ['JSONReader'] diff --git a/streaming/base/format/json/writer.py b/streaming/format/json/writer.py similarity index 97% rename from streaming/base/format/json/writer.py rename to streaming/format/json/writer.py index aae9d1d28..b0117a47f 100644 --- a/streaming/base/format/json/writer.py +++ b/streaming/format/json/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.base.format.base.writer import SplitWriter -from streaming.base.format.json.encodings import is_json_encoded, is_json_encoding +from streaming.format.json.encodings import is_json_encoded, is_json_encoding +from streaming.format.writer import SplitWriter __all__ = ['JSONWriter'] diff --git a/streaming/base/format/mds/README.md b/streaming/format/mds/README.md similarity index 100% rename from streaming/base/format/mds/README.md rename to streaming/format/mds/README.md diff --git a/streaming/base/format/mds/__init__.py b/streaming/format/mds/__init__.py similarity index 62% rename from streaming/base/format/mds/__init__.py rename to streaming/format/mds/__init__.py index 2c18ca0e7..67a5be56f 100644 --- a/streaming/base/format/mds/__init__.py +++ b/streaming/format/mds/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in MDS format.""" -from streaming.base.format.mds.reader import MDSReader -from streaming.base.format.mds.writer import MDSWriter +from streaming.format.mds.reader import MDSReader +from streaming.format.mds.writer import MDSWriter __all__ = ['MDSReader', 'MDSWriter'] diff --git a/streaming/base/format/mds/encodings.py b/streaming/format/mds/encodings.py similarity index 100% rename from streaming/base/format/mds/encodings.py rename to streaming/format/mds/encodings.py diff --git a/streaming/base/format/mds/reader.py b/streaming/format/mds/reader.py similarity index 97% rename from streaming/base/format/mds/reader.py rename to streaming/format/mds/reader.py index 275f01192..245458bf4 100644 --- a/streaming/base/format/mds/reader.py +++ b/streaming/format/mds/reader.py @@ -10,8 +10,8 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, JointReader -from streaming.base.format.mds.encodings import mds_decode +from streaming.format.mds.encodings import mds_decode +from streaming.format.reader import FileInfo, JointReader __all__ = ['MDSReader'] diff --git a/streaming/base/format/mds/writer.py b/streaming/format/mds/writer.py similarity index 96% rename from streaming/base/format/mds/writer.py rename to streaming/format/mds/writer.py index e82fc02a8..950c60f20 100644 --- a/streaming/base/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -8,9 +8,9 @@ import numpy as np -from streaming.base.format.base.writer import JointWriter -from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, - is_mds_encoding, mds_encode) +from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, + is_mds_encoding, mds_encode) +from streaming.format.writer import JointWriter __all__ = ['MDSWriter'] diff --git a/streaming/base/format/base/reader.py b/streaming/format/reader.py similarity index 99% rename from streaming/base/format/base/reader.py rename to streaming/format/reader.py index 80ec45231..5d1401d55 100644 --- a/streaming/base/format/base/reader.py +++ b/streaming/format/reader.py @@ -8,8 +8,8 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Set, Union -from streaming.base.array import Array -from streaming.base.util import bytes_to_int +from streaming.array import Array +from streaming.util import bytes_to_int __all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] diff --git a/streaming/base/format/base/writer.py b/streaming/format/writer.py similarity index 98% rename from streaming/base/format/base/writer.py rename to streaming/format/writer.py index 7cc3add3d..8be4cb33f 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/format/writer.py @@ -18,11 +18,11 @@ from typing_extensions import Self -from streaming.base.compression import compress, get_compression_extension, is_compression -from streaming.base.format.index import get_index_basename -from streaming.base.hashing import get_hash, is_hash -from streaming.base.storage.upload import CloudUploader -from streaming.base.util import bytes_to_int +from streaming.compression import compress, get_compression_extension, is_compression +from streaming.format.index import get_index_basename +from streaming.hashing import get_hash, is_hash +from streaming.storage.upload import CloudUploader +from streaming.util import bytes_to_int __all__ = ['JointWriter', 'SplitWriter'] diff --git a/streaming/base/format/xsv/README.md b/streaming/format/xsv/README.md similarity index 100% rename from streaming/base/format/xsv/README.md rename to streaming/format/xsv/README.md diff --git a/streaming/base/format/xsv/__init__.py b/streaming/format/xsv/__init__.py similarity index 60% rename from streaming/base/format/xsv/__init__.py rename to streaming/format/xsv/__init__.py index 6d5ca2489..985010a42 100644 --- a/streaming/base/format/xsv/__init__.py +++ b/streaming/format/xsv/__init__.py @@ -3,7 +3,7 @@ """Module to write and read the dataset in Tabular format.""" -from streaming.base.format.xsv.reader import CSVReader, TSVReader, XSVReader -from streaming.base.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter +from streaming.format.xsv.reader import CSVReader, TSVReader, XSVReader +from streaming.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter __all__ = ['CSVReader', 'CSVWriter', 'TSVReader', 'TSVWriter', 'XSVReader', 'XSVWriter'] diff --git a/streaming/base/format/xsv/encodings.py b/streaming/format/xsv/encodings.py similarity index 100% rename from streaming/base/format/xsv/encodings.py rename to streaming/format/xsv/encodings.py diff --git a/streaming/base/format/xsv/reader.py b/streaming/format/xsv/reader.py similarity index 98% rename from streaming/base/format/xsv/reader.py rename to streaming/format/xsv/reader.py index 896d9cda9..f43ee6f5d 100644 --- a/streaming/base/format/xsv/reader.py +++ b/streaming/format/xsv/reader.py @@ -10,8 +10,8 @@ import numpy as np from typing_extensions import Self -from streaming.base.format.base.reader import FileInfo, SplitReader -from streaming.base.format.xsv.encodings import xsv_decode +from streaming.format.reader import FileInfo, SplitReader +from streaming.format.xsv.encodings import xsv_decode __all__ = ['XSVReader', 'CSVReader', 'TSVReader'] diff --git a/streaming/base/format/xsv/writer.py b/streaming/format/xsv/writer.py similarity index 98% rename from streaming/base/format/xsv/writer.py rename to streaming/format/xsv/writer.py index 2888597b2..b1ab720d3 100644 --- a/streaming/base/format/xsv/writer.py +++ b/streaming/format/xsv/writer.py @@ -8,8 +8,8 @@ import numpy as np -from streaming.base.format.base.writer import SplitWriter -from streaming.base.format.xsv.encodings import is_xsv_encoding, xsv_encode +from streaming.format.writer import SplitWriter +from streaming.format.xsv.encodings import is_xsv_encoding, xsv_encode __all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter'] diff --git a/streaming/base/hashing.py b/streaming/hashing.py similarity index 100% rename from streaming/base/hashing.py rename to streaming/hashing.py diff --git a/streaming/base/local.py b/streaming/local.py similarity index 93% rename from streaming/base/local.py rename to streaming/local.py index 48eea91a5..47dd8134f 100644 --- a/streaming/base/local.py +++ b/streaming/local.py @@ -10,9 +10,9 @@ import numpy as np from torch.utils.data import Dataset -from streaming.base.array import Array -from streaming.base.format import get_index_basename, reader_from_json -from streaming.base.spanner import Spanner +from streaming.array import Array +from streaming.format import get_index_basename, reader_from_json +from streaming.spanner import Spanner __all__ = ['LocalDataset'] diff --git a/streaming/multimodal/__init__.py b/streaming/multimodal/__init__.py deleted file mode 100644 index cac23533f..000000000 --- a/streaming/multimodal/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported multimodal datasets.""" - -from streaming.multimodal.webvid import StreamingInsideWebVid as StreamingInsideWebVid -from streaming.multimodal.webvid import StreamingOutsideDTWebVid as StreamingOutsideDTWebVid -from streaming.multimodal.webvid import StreamingOutsideGIWebVid as StreamingOutsideGIWebVid diff --git a/streaming/multimodal/convert/__init__.py b/streaming/multimodal/convert/__init__.py deleted file mode 100644 index 36f008387..000000000 --- a/streaming/multimodal/convert/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Dataset conversion for natively supported multimodal datasets.""" diff --git a/streaming/base/partition/__init__.py b/streaming/partition/__init__.py similarity index 94% rename from streaming/base/partition/__init__.py rename to streaming/partition/__init__.py index ad1edefa2..5e67e485c 100644 --- a/streaming/base/partition/__init__.py +++ b/streaming/partition/__init__.py @@ -8,8 +8,8 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition.orig import get_partitions_orig -from streaming.base.partition.relaxed import get_partitions_relaxed +from streaming.partition.orig import get_partitions_orig +from streaming.partition.relaxed import get_partitions_relaxed algos = { 'orig': get_partitions_orig, diff --git a/streaming/base/partition/orig.py b/streaming/partition/orig.py similarity index 100% rename from streaming/base/partition/orig.py rename to streaming/partition/orig.py diff --git a/streaming/base/partition/relaxed.py b/streaming/partition/relaxed.py similarity index 98% rename from streaming/base/partition/relaxed.py rename to streaming/partition/relaxed.py index c2f0d83a8..f57529874 100644 --- a/streaming/base/partition/relaxed.py +++ b/streaming/partition/relaxed.py @@ -9,7 +9,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.partition.orig import get_partitions_orig +from streaming.partition.orig import get_partitions_orig logger = logging.getLogger(__name__) diff --git a/streaming/base/sampling.py b/streaming/sampling.py similarity index 100% rename from streaming/base/sampling.py rename to streaming/sampling.py diff --git a/streaming/shared/__init__.py b/streaming/shared/__init__.py new file mode 100644 index 000000000..8d599d4fe --- /dev/null +++ b/streaming/shared/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Objects that live in shared memory. + +For when using `threading` or `multiprocessing` from the python standard library won't do, because +we are coordinating separately instantiated pytorch worker processes. +""" + +from streaming.shared.array import SharedArray as SharedArray +from streaming.shared.barrier import SharedBarrier as SharedBarrier +from streaming.shared.memory import SharedMemory as SharedMemory +from streaming.shared.prefix import _get_path as _get_path +from streaming.shared.prefix import get_shm_prefix as get_shm_prefix +from streaming.shared.scalar import SharedScalar as SharedScalar + +__all__ = ['SharedArray', 'SharedBarrier', 'SharedMemory', 'get_shm_prefix', 'SharedScalar'] diff --git a/streaming/base/shared/array.py b/streaming/shared/array.py similarity index 97% rename from streaming/base/shared/array.py rename to streaming/shared/array.py index 20689d125..cd69db85f 100644 --- a/streaming/base/shared/array.py +++ b/streaming/shared/array.py @@ -8,7 +8,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shared.memory import SharedMemory +from streaming.shared.memory import SharedMemory class SharedArray: diff --git a/streaming/base/shared/barrier.py b/streaming/shared/barrier.py similarity index 97% rename from streaming/base/shared/barrier.py rename to streaming/shared/barrier.py index ceeb3ec43..b4adda46e 100644 --- a/streaming/base/shared/barrier.py +++ b/streaming/shared/barrier.py @@ -11,8 +11,8 @@ import numpy as np from filelock import FileLock -from streaming.base.constant import TICK -from streaming.base.shared.array import SharedArray +from streaming.constant import TICK +from streaming.shared.array import SharedArray # Time out to wait before raising exception TIMEOUT = 60 diff --git a/streaming/base/shared/memory.py b/streaming/shared/memory.py similarity index 99% rename from streaming/base/shared/memory.py rename to streaming/shared/memory.py index b5b70f55e..b235b7e32 100644 --- a/streaming/base/shared/memory.py +++ b/streaming/shared/memory.py @@ -9,7 +9,7 @@ from time import sleep from typing import Any, Optional -from streaming.base.constant import TICK +from streaming.constant import TICK class SharedMemory: diff --git a/streaming/base/shared/prefix.py b/streaming/shared/prefix.py similarity index 97% rename from streaming/base/shared/prefix.py rename to streaming/shared/prefix.py index 48d2aaa6c..f51f9f1a6 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/shared/prefix.py @@ -14,9 +14,9 @@ import numpy as np from torch import distributed as dist -from streaming.base.constant import LOCALS, TICK -from streaming.base.shared import SharedMemory -from streaming.base.world import World +from streaming.constant import LOCALS, TICK +from streaming.shared import SharedMemory +from streaming.world import World def _each_prefix_int() -> Iterator[int]: @@ -128,7 +128,7 @@ def _check_and_find(streams_local: List[str], streams_remote: List[Union[str, No f'Reused local directory: {streams_local} vs ' + f'{their_locals}. Provide a different one. If using ' + f'a unique local directory, try deleting the local directory and ' + - f'call `streaming.base.util.clean_stale_shared_memory()` only once ' + + f'call `streaming.util.clean_stale_shared_memory()` only once ' + f'in your script to clean up the stale shared memory before ' + f'instantiation of `StreamingDataset`.') return prefix_int diff --git a/streaming/base/shared/scalar.py b/streaming/shared/scalar.py similarity index 94% rename from streaming/base/shared/scalar.py rename to streaming/shared/scalar.py index 14cd5e7fa..c9714befc 100644 --- a/streaming/base/shared/scalar.py +++ b/streaming/shared/scalar.py @@ -5,7 +5,7 @@ from typing import Any -from streaming.base.shared.array import SharedArray +from streaming.shared.array import SharedArray class SharedScalar: diff --git a/streaming/base/shuffle/__init__.py b/streaming/shuffle/__init__.py similarity index 82% rename from streaming/base/shuffle/__init__.py rename to streaming/shuffle/__init__.py index e5e529c42..d34eb71fd 100644 --- a/streaming/base/shuffle/__init__.py +++ b/streaming/shuffle/__init__.py @@ -6,12 +6,12 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.naive import get_shuffle_naive -from streaming.base.shuffle.py1b import get_shuffle_py1b -from streaming.base.shuffle.py1br import get_shuffle_py1br -from streaming.base.shuffle.py1e import get_shuffle_py1e -from streaming.base.shuffle.py1s import get_shuffle_py1s -from streaming.base.shuffle.py2s import get_shuffle_py2s +from streaming.shuffle.naive import get_shuffle_naive +from streaming.shuffle.py1b import get_shuffle_py1b +from streaming.shuffle.py1br import get_shuffle_py1br +from streaming.shuffle.py1e import get_shuffle_py1e +from streaming.shuffle.py1s import get_shuffle_py1s +from streaming.shuffle.py2s import get_shuffle_py2s algos = { 'py1b': get_shuffle_py1b, diff --git a/streaming/base/shuffle/naive.py b/streaming/shuffle/naive.py similarity index 100% rename from streaming/base/shuffle/naive.py rename to streaming/shuffle/naive.py diff --git a/streaming/base/shuffle/py1b.py b/streaming/shuffle/py1b.py similarity index 98% rename from streaming/base/shuffle/py1b.py rename to streaming/shuffle/py1b.py index bb59f0c73..fdfaf9dd0 100644 --- a/streaming/base/shuffle/py1b.py +++ b/streaming/shuffle/py1b.py @@ -10,7 +10,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1b(shard_sizes: NDArray[np.int64], diff --git a/streaming/base/shuffle/py1br.py b/streaming/shuffle/py1br.py similarity index 98% rename from streaming/base/shuffle/py1br.py rename to streaming/shuffle/py1br.py index eff32210c..bc4c5053a 100644 --- a/streaming/base/shuffle/py1br.py +++ b/streaming/shuffle/py1br.py @@ -10,7 +10,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1br(shard_sizes: NDArray[np.int64], diff --git a/streaming/base/shuffle/py1e.py b/streaming/shuffle/py1e.py similarity index 99% rename from streaming/base/shuffle/py1e.py rename to streaming/shuffle/py1e.py index 3583caa22..e5dfc6291 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/shuffle/py1e.py @@ -13,7 +13,7 @@ import numpy as np from numpy.typing import NDArray -from streaming.base.shuffle.py1s import divide_spans +from streaming.shuffle.py1s import divide_spans def get_shuffle_py1e(shard_sizes: NDArray[np.int64], diff --git a/streaming/base/shuffle/py1s.py b/streaming/shuffle/py1s.py similarity index 100% rename from streaming/base/shuffle/py1s.py rename to streaming/shuffle/py1s.py diff --git a/streaming/base/shuffle/py2s.py b/streaming/shuffle/py2s.py similarity index 100% rename from streaming/base/shuffle/py2s.py rename to streaming/shuffle/py2s.py diff --git a/streaming/base/spanner.py b/streaming/spanner.py similarity index 100% rename from streaming/base/spanner.py rename to streaming/spanner.py diff --git a/streaming/storage/__init__.py b/streaming/storage/__init__.py new file mode 100644 index 000000000..674d4fbad --- /dev/null +++ b/streaming/storage/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Base module for downloading/uploading files from/to cloud storage.""" + +from streaming.storage.download import (download_file, download_from_azure, + download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, + download_from_gcs, download_from_local, download_from_oci, + download_from_s3, download_from_sftp) +from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, + GCSUploader, LocalUploader, OCIUploader, S3Uploader) + +__all__ = [ + 'download_file', + 'CloudUploader', + 'S3Uploader', + 'GCSUploader', + 'OCIUploader', + 'LocalUploader', + 'AzureUploader', + 'AzureDataLakeUploader', + 'download_from_s3', + 'download_from_sftp', + 'download_from_gcs', + 'download_from_oci', + 'download_from_azure', + 'download_from_azure_datalake', + 'download_from_databricks_unity_catalog', + 'download_from_dbfs', + 'download_from_local', +] diff --git a/streaming/base/storage/download.py b/streaming/storage/download.py similarity index 99% rename from streaming/base/storage/download.py rename to streaming/storage/download.py index 9db4af328..edb88943c 100644 --- a/streaming/base/storage/download.py +++ b/streaming/storage/download.py @@ -10,7 +10,7 @@ from time import sleep, time from typing import Any, Dict, Optional -from streaming.base.util import get_import_exception_message +from streaming.util import get_import_exception_message __all__ = [ 'download_from_s3', diff --git a/streaming/base/storage/upload.py b/streaming/storage/upload.py similarity index 99% rename from streaming/base/storage/upload.py rename to streaming/storage/upload.py index dab805bf5..9723315ef 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/storage/upload.py @@ -15,9 +15,8 @@ import tqdm -from streaming.base.storage.download import (BOTOCORE_CLIENT_ERROR_CODES, - GCS_ERROR_NO_AUTHENTICATION) -from streaming.base.util import get_import_exception_message, retry +from streaming.storage.download import BOTOCORE_CLIENT_ERROR_CODES, GCS_ERROR_NO_AUTHENTICATION +from streaming.util import get_import_exception_message, retry __all__ = [ 'CloudUploader', diff --git a/streaming/base/stream.py b/streaming/stream.py similarity index 98% rename from streaming/base/stream.py rename to streaming/stream.py index d707f9a6b..200ba83e5 100644 --- a/streaming/base/stream.py +++ b/streaming/stream.py @@ -13,14 +13,14 @@ from numpy.typing import NDArray from typing_extensions import Self -from streaming.base.compression import decompress -from streaming.base.constant import TICK -from streaming.base.distributed import barrier, get_local_rank -from streaming.base.format import FileInfo, Reader, get_index_basename, reader_from_json -from streaming.base.hashing import get_hash -from streaming.base.storage import download_file -from streaming.base.util import retry, wait_for_file_to_exist -from streaming.base.world import World +from streaming.compression import decompress +from streaming.constant import TICK +from streaming.distributed import barrier, get_local_rank +from streaming.format import FileInfo, Reader, get_index_basename, reader_from_json +from streaming.hashing import get_hash +from streaming.storage import download_file +from streaming.util import retry, wait_for_file_to_exist +from streaming.world import World class Stream: diff --git a/streaming/text/__init__.py b/streaming/text/__init__.py deleted file mode 100644 index 0452f4430..000000000 --- a/streaming/text/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported NLP datasets.""" - -from streaming.text.c4 import StreamingC4 as StreamingC4 -from streaming.text.enwiki import StreamingEnWiki as StreamingEnWiki -from streaming.text.pile import StreamingPile as StreamingPile - -__all__ = ['StreamingPile', 'StreamingC4', 'StreamingEnWiki'] diff --git a/streaming/text/convert/README.md b/streaming/text/convert/README.md deleted file mode 100644 index 029ddae09..000000000 --- a/streaming/text/convert/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# Dataset preparation - -To use Streaming Dataset we must first convert the dataset from its native format to MosaicML's Streaming Dataset format called Mosaic Dataset Shard (MDS). Once in MDS format, we can access the dataset from the local file system (disk network attached storage, etc.) or object store (GCS, OCS, S3, etc.). From object store, data can be streamed to train deep learning models and it all just works efficiently. - -Check out steps below for information on converting common NLP datasets to MDS format. Please see [MDSWriter()](https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html) parameters for details on advanced usage. - -## NLP Dataset Conversion Examples - -### [C4: Colossal, Cleaned, Common Crawl dataset](https://huggingface.co/datasets/c4) - -1. Run the [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) script as shown below. The script downloads the raw format with `train` and `val` splits from HuggingFace hub and converts to StreamingDataset MDS format into their own split directories. For more advanced use cases, please see the supported arguments for [c4.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/c4.py) and modify as necessary. - - ``` - python c4.py --out_root - ``` - -### [Wikipedia](https://huggingface.co/datasets/wikipedia) - -1. Download English Wikipedia 2020-01-01 from [here](https://drive.google.com/drive/folders/1cywmDnAsrP5-2vsr8GDc6QUc7VWe-M3v). -2. Unzip the file `results_text.zip` as shown below. - - ```bash - unzip results_text.zip - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── eval.txt - ├── part-00000-of-00500 - ├── part-00001-of-00500 - ├── part-00002-of-00500 - ├── ..... - ├── part-00498-of-00500 - └── part-00499-of-00500 - ``` - -3. Run the [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) script. The script converts the `train` and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [enwiki_text.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/enwiki_text.py) and modify as necessary. - - ``` - python enwiki_text.py --in_root --out_root - ``` - -### [Pile](https://pile.eleuther.ai/) - -1. Download the Pile dataset from [here](https://the-eye.eu/public/AI/pile/). - - Listing the output should show the following directory structure: - - ```bash - ├── SHA256SUMS.txt - ├── test.jsonl.zst - ├── train - │   ├── 00.jsonl.zst - │   ├── 01.jsonl.zst - │   ├── 02.jsonl.zst - │   ├── 03.jsonl.zst - │   ├── ..... - │   ├── 28.jsonl.zst - │   └── 29.jsonl.zst - └── val.jsonl.zst - ``` - -2. Run the [pile.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) script. The script converts the `train`, `test`, and `val` dataset splits into their own split directories. For more advanced use cases, please see the supported arguments for [pile.py](https://github.com/mosaicml/streaming/blob/main/streaming/text/convert/pile.py) and modify as necessary. - - - ```bash - python pile.py --in_root --out_root - ``` diff --git a/streaming/text/convert/enwiki/tfrecord/__init__.py b/streaming/text/convert/enwiki/tfrecord/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/streaming/util.py b/streaming/util.py new file mode 100644 index 000000000..a9c1a0bab --- /dev/null +++ b/streaming/util.py @@ -0,0 +1,551 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility and helper functions for datasets.""" + +import collections.abc +import functools +import inspect +import json +import logging +import os +import random +import shutil +import sys +import tempfile +import urllib.parse +from collections import OrderedDict +from importlib import import_module +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory +from pathlib import Path +from time import sleep, time +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload +from warnings import warn + +import torch.distributed as dist + +from streaming.constant import SHM_TO_CLEAN +from streaming.distributed import get_local_rank, maybe_init_dist +from streaming.format.index import get_index_basename +from streaming.shared.prefix import _get_path + +logger = logging.getLogger(__name__) + +TCallable = TypeVar('TCallable', bound=Callable) + +__all__ = [ + 'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int', + 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry', + 'redirect_imports' +] + + +def get_list_arg(text: str) -> List[str]: + """Pass a list as a command-line flag. + + Args: + text (str): Text to split. + + Returns: + List[str]: Splits, if any. + """ + return text.split(',') if text else [] + + +def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for the file to exist till timeout seconds. Raise an Exception after that. + + Args: + filename (str): A file name + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') + + +def bytes_to_int(bytes_str: Union[int, str]) -> int: + """Convert human readable byte format to an integer. + + Args: + bytes_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid byte suffix. + + Returns: + int: Integer value of bytes. + """ + #input is already an int + if isinstance(bytes_str, int) or isinstance(bytes_str, float): + return int(bytes_str) + + units = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + 'pb': 1024**5, + 'eb': 1024**6, + 'zb': 1024**7, + 'yb': 1024**8, + } + # Convert a various byte types to an integer + for suffix in units: + bytes_str = bytes_str.lower().strip() + if bytes_str.lower().endswith(suffix): + try: + return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + else: + # Convert bytes to an integer + if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): + return int(bytes_str[0:-1]) + # Convert string representation of a number to an integer + elif bytes_str.isdigit(): + return int(bytes_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + + +def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: + """Convert human readable number abbreviations to an integer. + + Args: + abbrev_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid number suffix. + + Returns: + int: Integer value of number abbreviation. + """ + #input is already an int + if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): + return int(abbrev_str) + + units = { + 'k': 10**3, + 'm': 10**6, + 'b': 10**9, + 't': 10**12, + } + # Convert a various abbreviation types to an integer + for suffix in units: + abbrev_str = abbrev_str.lower().strip() + if abbrev_str.lower().endswith(suffix): + try: + return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) + else: + # Convert string representation of a number to an integer + if abbrev_str.isdigit(): + return int(abbrev_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) + + +def clean_stale_shared_memory() -> None: + """Clean up all the leaked shared memory. + + In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for + the local rank 0 to finish. + """ + # Initialize torch.distributed ourselves, if necessary. + destroy_dist = maybe_init_dist() + + # Perform clean up on local rank 0 + if get_local_rank() == 0: + for prefix_int in range(1000000): + leaked_shm = False + for shm_name in SHM_TO_CLEAN: + name = _get_path(prefix_int, shm_name) + try: + shm = BuiltinSharedMemory(name, True, 4) + except FileExistsError: + shm = BuiltinSharedMemory(name, False, 4) + leaked_shm = True + finally: + shm.close() # pyright: ignore + shm.unlink() + # Come out of loop if no leaked shared memory + if not leaked_shm: + break + + # Sync all ranks + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + # Delete the process group if Streaming initialized it. + if destroy_dist: + dist.destroy_process_group() + + +def get_import_exception_message(package_name: str, extra_deps: str) -> str: + """Get import exception message. + + Args: + package_name (str): Package name. + + Returns: + str: Exception message. + """ + return f'Streaming was installed without {package_name} support. ' + \ + f'To use {package_name} related packages with Streaming, run ' + \ + f'`pip install \'mosaicml-streaming[{package_name}]\'`.' + + +def merge_index(*args: Any, **kwargs: Any): + r"""Merge index.json from partitions to form a global index.json. + + This can be called as + + merge_index(index_file_urls, out, keep_local, download_timeout) + + merge_index(out, keep_local, download_timeout) + + The first signature takes in a list of index files URLs of MDS partitions. + The second takes the root of a MDS dataset and parse the partition folders from there. + + Args: + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. + Each element can take the form of a single path string or a tuple string. + + 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. + 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. + 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. + + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file + + 1. A local directory, merge index happens locally. + 2. A remote directory, download all the sub-directories index.json, merge locally and upload. + 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. + + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: + return _merge_index_from_list(*args, **kwargs) + elif (isinstance(args[0], str) or + isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]: + return _merge_index_from_root(*args, **kwargs) + raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') + + +def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: + """Merge index.json from a list of index files of MDS directories to create joined index. + + Args: + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions + each element can take the form of a single path string or a tuple string. + + The pattern of index_file_urls and corresponding reaction is one of: + 1. All URLS are str (local). All URLS are accessible locally -> no download + 2. All URLS are tuple (local, remote). All URLS are accessible locally -> no download + 3. All URLS are tuple (local, remote). Download URL that is not accessible locally + 4. All URLS are str (remote) -> download all + + out (Union[str, Tuple[str, str]]): path to put the merged index file + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + from streaming.storage.download import download_file + from streaming.storage.upload import CloudUploader + + if not index_file_urls or not out: + logger.warning('Either index_file_urls or out are None. ' + + 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') + return + + # This is the index json file name, e.g., it is index.json as of 0.6.0 + index_basename = get_index_basename() + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + # Remove duplicates, and strip '/' from right if any + index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) + urls = [] + for url in index_file_urls: + if isinstance(url, str): + urls.append(url.rstrip('/').strip()) + else: + urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. + with tempfile.TemporaryDirectory() as temp_root: + logging.warning(f'A temporary folder {temp_root} is created to store index files') + + # Copy files to a temporary directory. Download if necessary + partitions = [] + for url in urls: + if isinstance(url, tuple): + src = url[0] if os.path.exists(url[0]) else url[1] + else: + src = url + + obj = urllib.parse.urlparse(src) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible.' + + f'remote index {url[1]} does not have a valid url format') + dest = os.path.join(temp_root, path.lstrip('/')) + + try: + download_file(src, dest, download_timeout) + except Exception as ex: + raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex + + if not os.path.exists(dest): + raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') + + partitions.append(dest) + + # merge shards from all index files + shards = [] + for partition_index in partitions: + p = Path(partition_index) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join( + os.path.basename(p.parent), basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Move merged index from temp path to local part in out + # Upload merged index to remote if out has remote part + shutil.move(merged_index_path, cu.local) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True) + + +def _merge_index_from_root(out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: + """Merge index.json given the root of MDS dataset. Write merged index to the root folder. + + Args: + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + :A local directory, merge index happens locally + :A remote directory, download all the sub-directories index.json in a temporary + sub-directories, merge locally, and then upload it to out location + :A (local_dir, remote_dir), check if sub-directories index.json file present locally + If yes, then merge locally and upload to remote_dir . + If not, download all the sub-directories index.json from remote to local, + merge locally, and upload to remote_dir . + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + from streaming.storage.upload import CloudUploader + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out. + + Args: + index_file_path (str): the path to index.json file + out (str): remote or local url of a folder + Return: + (bool): no if index.json sits in out instead of in the subfolders of out + """ + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + if not out: + logger.warning('No MDS dataset folder specified, no index merged') + return + + cu = CloudUploader.get(out, exist_ok=True, keep_local=True) + + local_index_files = [] + cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + for file in cl.list_objects(): + if file.endswith('.json') and not_merged_index(file, cu.local): + local_index_files.append(file) + + if cu.remote: + obj = urllib.parse.urlparse(cu.remote) + remote_index_files = [] + for file in cu.list_objects(): + if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): + join_char = '//' + if obj.scheme == 'dbfs': + path = Path(cu.remote) + prefix = os.path.join(path.parts[0], path.parts[1]) + if prefix == 'dbfs:/Volumes': + join_char = '/' + remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file)) + if len(local_index_files) == len(remote_index_files): + _merge_index_from_list(list(zip(local_index_files, remote_index_files)), + out, + keep_local=keep_local, + download_timeout=download_timeout) + else: + _merge_index_from_list(remote_index_files, + out, + keep_local=keep_local, + download_timeout=download_timeout) + return + + _merge_index_from_list(local_index_files, + out, + keep_local=keep_local, + download_timeout=download_timeout) + + +@overload +def retry( + exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., + num_attempts: int = ..., + initial_backoff: float = ..., + max_jitter: float = ..., +) -> Callable[[TCallable], TCallable]: + ... + + +@overload +def retry(exc_class: TCallable) -> TCallable: + # Use the decorator without parenthesis + ... + + +# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type +# "(func: Never) -> Never" +def retry( # type: ignore + exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, + num_attempts: int = 3, + initial_backoff: float = 1.0, + max_jitter: float = 0.5, +): + """Decorator to retry a function with backoff and jitter. + + Attempts are spaced out with + ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. + + Example: + .. testcode:: + + from streaming.util import retry + + num_tries = 0 + + @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) + def flaky_function(): + global num_tries + if num_tries < 2: + num_tries += 1 + raise RuntimeError("Called too soon!") + return "Third time's a charm." + + print(flaky_function()) + + .. testoutput:: + + Third time's a charm. + + Args: + exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or + classes to retry. Defaults to Exception. + num_attempts (int, optional): The total number of attempts to make. Defaults to 3. + initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. + max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. + + Increasing the ``max_jitter`` can help prevent overloading a resource when multiple + processes in parallel are calling the same underlying function. + """ + if num_attempts < 1: + raise ValueError('num_attempts must be at-least 1') + + def wrapped_func(func: TCallable) -> TCallable: + + @functools.wraps(func) + def new_func(*args: Any, **kwargs: Any): + i = 0 + while True: + try: + return func(*args, **kwargs) + except exc_class as e: + if i + 1 == num_attempts: + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + raise e + else: + sleep(initial_backoff * 2**i + random.random() * max_jitter) + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + i += 1 + + return cast(TCallable, new_func) + + if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( + exc_class, type) and issubclass(exc_class, Exception)): + # Using the decorator without (), like @retry_with_backoff + func = cast(TCallable, exc_class) + exc_class = Exception + + return wrapped_func(func) + + return wrapped_func + + +def redirect_imports(new_fqdn: str) -> None: + """Overlay the members of the target module onto the module of the caller. + + Args: + new_fqdn (str): Fully-qualified dot-separated target module path. + """ + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + if module is None: + raise RuntimeError('Module was None.') + old_fqdn = module.__name__ + + # old = sys.modules[old_fqdn] + new = import_module(new_fqdn) + sys.modules[old_fqdn].__dict__.update(new.__dict__) + + warn(f'Please update your imports: {old_fqdn} has moved to {new_fqdn}.', + DeprecationWarning, + stacklevel=2) diff --git a/streaming/vision.py b/streaming/vision.py new file mode 100644 index 000000000..bfc8a2800 --- /dev/null +++ b/streaming/vision.py @@ -0,0 +1,154 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Base classes for computer vision :class:`StreamingDataset`s.""" + +import os +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +from torch.utils.data import Dataset +from torchvision.datasets import VisionDataset +from torchvision.transforms.functional import to_tensor +from tqdm import tqdm + +from streaming import MDSWriter +from streaming.dataset import StreamingDataset + +__all__ = ['StreamingVisionDataset'] + + +class StandardTransform: + """Individual input and output transforms called jointly, following torchvision. + + Args: + transform (Callable, optional): Input transform. Defaults to ``None``. + target_transform (Callable, optional): Output transform. Defaults to ``None``. + """ + + def __init__(self, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None) -> None: + self.transform = transform + self.target_transform = target_transform + + def __call__(self, x: Any, y: Any) -> Tuple[Any, Any]: + """Apply the transforms to input and output. + + Args: + x (Any): Input. + y (Any): Output. + + Returns: + Tuple[Any, Any]: Transformed input and output. + """ + if self.transform: + x = self.transform(x) + else: + x = to_tensor(x) + if self.target_transform: + y = self.target_transform(y) + return x, y + + +class StreamingVisionDataset(StreamingDataset, VisionDataset): + """A streaming, iterable, torchvision VisionDataset. + + Args: + transforms (callable, optional): A function/transforms that takes in an image and a label + and returns the transformed versions of both. Defaults to ``None``. + transform (callable, optional): A function/transform that takes in an image and returns a + transformed version. Defaults to ``None``. + target_transform (callable, optional): A function/transform that takes in a target and + returns a transformed version. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. + """ + + def __init__(self, + *, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + **kwargs: Dict[str, Any]) -> None: + StreamingDataset.__init__(self, **kwargs) + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError( + 'Only transforms or transform/target_transform can be passed as an argument') + + self.transform = transform + self.target_transform = target_transform + if not has_transforms: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def get_item(self, idx: int) -> Any: + """Get sample by global index, blocking to load its shard if missing. + + Args: + idx (int): Sample index. + + Returns: + Any: Sample data. + """ + obj = super().get_item(idx) + x = obj['x'] + y = obj['y'] + return self.transforms(x, y) + + +def convert_image_class_dataset(dataset: Dataset, + out_root: str, + split: Optional[str] = None, + compression: Optional[str] = None, + hashes: Optional[List[str]] = None, + size_limit: int = 1 << 24, + progress_bar: bool = True, + leave: bool = False, + encoding: str = 'pil') -> None: + """Convert an image classification Dataset. + + Args: + dataset (Dataset): The dataset object to convert. + out_root (str): Output directory where shards are cached by split. + remote (str, optional): Remote dataset directory where shards are uploaded by split. + split (str, optional): Which dataset split to use, if any. Defaults to ``None``. + compression (str, optional): Optional compression. Defaults to ``None``. + hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. + Defaults to ``None``. + size_limit (int): Uncompressed shard size limit, at which point it flushes the shard and + starts a new one. Defaults to ``1 << 26``. + progress_bar (bool): Whether to display a progress bar while converting. + Defaults to ``True``. + leave (bool): Whether to leave the progress bar in the console when done. Defaults to + ``False``. + encoding (str): MDS encoding to use for the image data. Defaults to ``pil``. + """ + split = split or '' + columns = { + 'i': 'int', + 'x': encoding, + 'y': 'int', + } + hashes = hashes or [] + indices = np.random.permutation(len(dataset)).tolist() # pyright: ignore + if progress_bar: + indices = tqdm(indices, leave=leave) + + out_split_dir = os.path.join(out_root, split) + + with MDSWriter(out=out_split_dir, + columns=columns, + compression=compression, + hashes=hashes, + size_limit=size_limit, + progress_bar=progress_bar) as out: + for i in indices: + x, y = dataset[i] + out.write({ + 'i': i, + 'x': x, + 'y': y, + }) diff --git a/streaming/vision/__init__.py b/streaming/vision/__init__.py deleted file mode 100644 index f7ceab7b7..000000000 --- a/streaming/vision/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Natively supported CV datasets.""" - -from streaming.vision.ade20k import StreamingADE20K as StreamingADE20K -from streaming.vision.cifar10 import StreamingCIFAR10 as StreamingCIFAR10 -from streaming.vision.coco import StreamingCOCO as StreamingCOCO -from streaming.vision.imagenet import StreamingImageNet as StreamingImageNet - -__all__ = ['StreamingADE20K', 'StreamingCIFAR10', 'StreamingCOCO', 'StreamingImageNet'] diff --git a/streaming/vision/base.py b/streaming/vision/base.py deleted file mode 100644 index 564305849..000000000 --- a/streaming/vision/base.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Base classes for computer vision :class:`StreamingDataset`s.""" - -from typing import Any, Callable, Optional, Tuple - -from torchvision.datasets import VisionDataset -from torchvision.transforms.functional import to_tensor - -from streaming.base import StreamingDataset - -__all__ = ['StreamingVisionDataset'] - - -class StandardTransform: - """Individual input and output transforms called jointly, following torchvision. - - Args: - transform (Callable, optional): Input transform. Defaults to ``None``. - target_transform (Callable, optional): Output transform. Defaults to ``None``. - """ - - def __init__(self, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - self.transform = transform - self.target_transform = target_transform - - def __call__(self, x: Any, y: Any) -> Tuple[Any, Any]: - """Apply the transforms to input and output. - - Args: - x (Any): Input. - y (Any): Output. - - Returns: - Tuple[Any, Any]: Transformed input and output. - """ - if self.transform: - x = self.transform(x) - else: - x = to_tensor(x) - if self.target_transform: - y = self.target_transform(y) - return x, y - - -class StreamingVisionDataset(StreamingDataset, VisionDataset): - """A streaming, iterable, torchvision VisionDataset. - - Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transforms (callable, optional): A function/transforms that takes in an image and a label - and returns the transformed versions of both. Defaults to ``None``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. - """ - - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - StreamingDataset.__init__(self, - remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) - - has_transforms = transforms is not None - has_separate_transform = transform is not None or target_transform is not None - if has_transforms and has_separate_transform: - raise ValueError( - 'Only transforms or transform/target_transform can be passed as an argument') - - self.transform = transform - self.target_transform = target_transform - if not has_transforms: - transforms = StandardTransform(transform, target_transform) - self.transforms = transforms - - def get_item(self, idx: int) -> Any: - """Get sample by global index, blocking to load its shard if missing. - - Args: - idx (int): Sample index. - - Returns: - Any: Sample data. - """ - obj = super().get_item(idx) - x = obj['x'] - y = obj['y'] - return self.transforms(x, y) diff --git a/streaming/vision/convert/README.md b/streaming/vision/convert/README.md deleted file mode 100644 index 58eda5148..000000000 --- a/streaming/vision/convert/README.md +++ /dev/null @@ -1,113 +0,0 @@ -# Dataset preparation - -To use Streaming Dataset we must first convert the dataset from its native format to MosaicML's Streaming Dataset format called Mosaic Dataset Shard (MDS). Once in MDS format, we can access the dataset from the local file system (disk network attached storage, etc.) or object store (GCS, OCS, S3, etc.). From object store, data can be streamed to train deep learning models and it all just works efficiently. - -Check out steps below for information on converting common Computer Vision datasets to MDS format. Please see [MDSWriter()](https://streaming.docs.mosaicml.com/en/latest/api_reference/generated/streaming.MDSWriter.html) parameters for details on advanced usage. - -## Vision Datasets Conversion Examples - -### [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) - -1. Download the ADE20K dataset from [here](https://groups.csail.mit.edu/vision/datasets/ADE20K/). -2. Listing the output should show the following directory structure: - - ```bash - ├── annotations - │ ├── training - │ └── validation - └── images - ├── training - └── validation - ``` - -3. Run the [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please see the supported arguments for [ade20k.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/ade20k.py) and modify according as necessary. - - ``` - python ade20k.py --in_root --out_root - ``` - -### [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) - -1. Run the [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) script as shown below. The CIFAR10 dataset will be automatically downloaded if it doesn't exist locally. For advanced use cases, please see the supported arguments for [cifar10.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/cifar10.py) and modify as necessary. - - ``` - python cifar10.py --in_root --out_root - ``` - -### [MS-COCO](https://cocodataset.org/#home) - -1. Download the COCO 2017 dataset from [here](https://cocodataset.org/#download). Please download both the COCO images and annotations and unzip the files as shown below. - - ```bash - mkdir coco - wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip - wget -c http://images.cocodataset.org/zips/train2017.zip - wget -c http://images.cocodataset.org/zips/val2017.zip - - unzip annotations_trainval2017.zip - unzip train2017.zip - unzip val2017.zip - - rm annotations_trainval2017.zip - rm train2017.zip - rm val2017.zip - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── annotations - │ ├── instances_train2017.json - │ └── instances_val2017.json - ├── train2017 - │ ├── 000000391895.jpg - | |── ... - └── val2017 - │ ├── 000000000139.jpg - | |── ... - ``` - -2. Run the [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced use cases, please seet the supported arguments for [coco.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/coco.py) and modify as necessary. - - ``` - python coco.py --in_root --out_root - ``` - -### [ImageNet](https://www.image-net.org/) - -1. Download the ImageNet dataset from [here](https://image-net.org/download.php). Two files are needed, `ILSVRC2012_img_train.tar` for training and `ILSVRC2012_img_val.tar` for validation. Next untar both the files as shown below. - - ```bash - mkdir val - mv ILSVRC2012_img_val.tar val/ - tar -xvf ILSVRC2012_img_val.tar -C val/ - rm ILSVRC2012_img_val.tar - - mkdir train - mv ILSVRC2012_img_train.tar train/ - tar -xvf ILSVRC2012_img_train.tar -C train/ - rm ILSVRC2012_img_train.tar - ``` - - Listing the output should show the following directory structure: - - ```bash - ├── train/ - ├── n01440764 - │ ├── n01440764_10026.JPEG - │ ├── n01440764_10027.JPEG - │ ├── ...... - ├── ...... - ├── val/ - ├── n01440764 - │ ├── ILSVRC2012_val_00000293.JPEG - │ ├── ILSVRC2012_val_00002138.JPEG - │ ├── ...... - ├── ...... - ``` - -2. Run the [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) script as shown below. The script converts the `train` and `val` dataset splits into their own directories. For advanced uses cases, please see the supported arguments for [imagenet.py](https://github.com/mosaicml/streaming/blob/main/streaming/vision/convert/imagenet.py) and modify as needed. - - ``` - python imagenet.py --in_root --out_root - ``` diff --git a/streaming/vision/convert/base.py b/streaming/vision/convert/base.py deleted file mode 100644 index 5194816fd..000000000 --- a/streaming/vision/convert/base.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Utility and helper functions to convert CV datasets.""" - -import os -from typing import List, Optional - -import numpy as np -from torch.utils.data import Dataset -from tqdm import tqdm - -from streaming.base import MDSWriter - - -def convert_image_class_dataset(dataset: Dataset, - out_root: str, - split: Optional[str] = None, - compression: Optional[str] = None, - hashes: Optional[List[str]] = None, - size_limit: int = 1 << 24, - progress_bar: bool = True, - leave: bool = False, - encoding: str = 'pil') -> None: - """Convert an image classification Dataset. - - Args: - dataset (Dataset): The dataset object to convert. - out_root (str): Output directory where shards are cached by split. - remote (str, optional): Remote dataset directory where shards are uploaded by split. - split (str, optional): Which dataset split to use, if any. Defaults to ``None``. - compression (str, optional): Optional compression. Defaults to ``None``. - hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. - Defaults to ``None``. - size_limit (int): Uncompressed shard size limit, at which point it flushes the shard and - starts a new one. Defaults to ``1 << 26``. - progress_bar (bool): Whether to display a progress bar while converting. - Defaults to ``True``. - leave (bool): Whether to leave the progress bar in the console when done. Defaults to - ``False``. - encoding (str): MDS encoding to use for the image data. Defaults to ``pil``. - """ - split = split or '' - columns = { - 'i': 'int', - 'x': encoding, - 'y': 'int', - } - hashes = hashes or [] - indices = np.random.permutation(len(dataset)).tolist() # pyright: ignore - if progress_bar: - indices = tqdm(indices, leave=leave) - - out_split_dir = os.path.join(out_root, split) - - with MDSWriter(out=out_split_dir, - columns=columns, - compression=compression, - hashes=hashes, - size_limit=size_limit, - progress_bar=progress_bar) as out: - for i in indices: - x, y = dataset[i] - out.write({ - 'i': i, - 'x': x, - 'y': y, - }) diff --git a/streaming/base/world.py b/streaming/world.py similarity index 97% rename from streaming/base/world.py rename to streaming/world.py index c787c2f97..b512b4132 100644 --- a/streaming/base/world.py +++ b/streaming/world.py @@ -5,7 +5,7 @@ from torch.utils.data import get_worker_info -from streaming.base import distributed as dist +from streaming import distributed as dist class World: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a99ea973a..e46f1cff5 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -11,7 +11,7 @@ from pyspark.sql.functions import col from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType -from streaming.base.converters import dataframe_to_mds +from streaming.converters import dataframe_to_mds os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls diff --git a/tests/common/datasets.py b/tests/common/datasets.py index dbaefcd38..ce76010c3 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -5,7 +5,7 @@ import numpy as np -from streaming.base import MDSWriter +from streaming import MDSWriter class SequenceDataset: diff --git a/tests/test_array.py b/tests/test_array.py index 30816665f..7cfeb3f42 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -7,7 +7,7 @@ import pytest from numpy.typing import NDArray -from streaming.base.array import Array +from streaming.array import Array class Range(Array): diff --git a/tests/test_barrier.py b/tests/test_barrier.py index fdc5eb87d..72fcb6d13 100644 --- a/tests/test_barrier.py +++ b/tests/test_barrier.py @@ -11,7 +11,7 @@ import pytest -from streaming.base.shared import SharedArray, SharedBarrier +from streaming.shared import SharedArray, SharedBarrier class TestSharedBarrier: diff --git a/tests/test_compression.py b/tests/test_compression.py index 24c72a158..ff62ecb9d 100644 --- a/tests/test_compression.py +++ b/tests/test_compression.py @@ -7,9 +7,9 @@ import numpy as np import pytest -from streaming.base import StreamingDataset -from streaming.base.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, - decompress, get_compression_extension, is_compression) +from streaming import StreamingDataset +from streaming.compression import (Brotli, Bzip2, Gzip, Snappy, Zstandard, compress, decompress, + get_compression_extension, is_compression) from tests.common.datasets import SequenceDataset, write_mds_dataset diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 73fc1726b..8da2b8673 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -12,8 +12,8 @@ import torch.distributed as dist from torch.utils.data import DataLoader -import streaming.base.distributed as ms_dist -from streaming.base import StreamingDataset +import streaming.distributed as ms_dist +from streaming import StreamingDataset from tests.common.datasets import SequenceDataset, write_mds_dataset from tests.common.distributed import DistributedTest diff --git a/tests/test_download.py b/tests/test_download.py index 2d33cefb3..26814641d 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -10,11 +10,10 @@ import pytest from botocore.exceptions import ClientError -from streaming.base.storage.download import (download_file, download_from_azure, - download_from_azure_datalake, - download_from_databricks_unity_catalog, - download_from_dbfs, download_from_gcs, - download_from_local, download_from_s3) +from streaming.storage.download import (download_file, download_from_azure, + download_from_azure_datalake, + download_from_databricks_unity_catalog, download_from_dbfs, + download_from_gcs, download_from_local, download_from_s3) from tests.conftest import GCS_URL, MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -167,7 +166,7 @@ def test_download_from_local(): class TestDownload: - @patch('streaming.base.storage.download.download_from_s3') + @patch('streaming.storage.download.download_from_s3') @pytest.mark.usefixtures('remote_local_file') def test_download_from_s3_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') @@ -175,7 +174,7 @@ def test_download_from_s3_gets_called(self, mocked_requests: Mock, remote_local_ mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath, 60) - @patch('streaming.base.storage.download.download_from_gcs') + @patch('streaming.storage.download.download_from_gcs') @pytest.mark.usefixtures('remote_local_file') def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') @@ -183,7 +182,7 @@ def test_download_from_gcs_gets_called(self, mocked_requests: Mock, remote_local mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_azure') + @patch('streaming.storage.download.download_from_azure') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='azure://') @@ -191,7 +190,7 @@ def test_download_from_azure_gets_called(self, mocked_requests: Mock, remote_loc mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_azure_datalake') + @patch('streaming.storage.download.download_from_azure_datalake') @pytest.mark.usefixtures('remote_local_file') def test_download_from_azure_datalake_gets_called(self, mocked_requests: Mock, remote_local_file: Any): @@ -200,7 +199,7 @@ def test_download_from_azure_datalake_gets_called(self, mocked_requests: Mock, mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_sftp') + @patch('streaming.storage.download.download_from_sftp') @pytest.mark.usefixtures('remote_local_file') def test_download_from_sftp_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='sftp://') @@ -208,7 +207,7 @@ def test_download_from_sftp_gets_called(self, mocked_requests: Mock, remote_loca mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_databricks_unity_catalog') + @patch('streaming.storage.download.download_from_databricks_unity_catalog') @pytest.mark.usefixtures('remote_local_file') def test_download_from_databricks_unity_catalog_gets_called(self, mocked_requests: Mock, remote_local_file: Any): @@ -217,7 +216,7 @@ def test_download_from_databricks_unity_catalog_gets_called(self, mocked_request mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_dbfs') + @patch('streaming.storage.download.download_from_dbfs') @pytest.mark.usefixtures('remote_local_file') def test_download_from_dbfs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='dbfs:/') @@ -225,7 +224,7 @@ def test_download_from_dbfs_gets_called(self, mocked_requests: Mock, remote_loca mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath, mock_local_filepath) - @patch('streaming.base.storage.download.download_from_local') + @patch('streaming.storage.download.download_from_local') @pytest.mark.usefixtures('remote_local_file') def test_download_from_local_gets_called(self, mocked_requests: Mock, remote_local_file: Any): mock_remote_filepath, mock_local_filepath = remote_local_file() diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 88e6ba203..70d048647 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -10,9 +10,9 @@ import pytest from PIL import Image -import streaming.base.format.json.encodings as jsonEnc -import streaming.base.format.mds.encodings as mdsEnc -import streaming.base.format.xsv.encodings as xsvEnc +import streaming.format.json.encodings as jsonEnc +import streaming.format.mds.encodings as mdsEnc +import streaming.format.xsv.encodings as xsvEnc class TestMDSEncodings: diff --git a/tests/test_hashing.py b/tests/test_hashing.py index 225ce7458..a9558f493 100644 --- a/tests/test_hashing.py +++ b/tests/test_hashing.py @@ -6,8 +6,8 @@ import pytest -import streaming.base.hashing as shash -from streaming.base import StreamingDataset +import streaming.hashing as shash +from streaming import StreamingDataset from tests.common.utils import convert_to_mds logger = logging.getLogger(__name__) diff --git a/tests/test_local.py b/tests/test_local.py index df6fb5f05..4292f6c92 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from streaming import MDSWriter -from streaming.base.local import LocalDataset +from streaming.local import LocalDataset def test_local_dataset(): diff --git a/tests/test_partition.py b/tests/test_partition.py index 37da79ce1..2c4af319e 100644 --- a/tests/test_partition.py +++ b/tests/test_partition.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from streaming.base.partition import get_partitions +from streaming.partition import get_partitions @pytest.mark.parametrize('partition_algo', ['orig', 'relaxed']) diff --git a/tests/test_reader.py b/tests/test_reader.py index fbe7ff723..3d43e7aad 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -12,7 +12,7 @@ import pytest from numpy.typing import NDArray -from streaming.base import StreamingDataset +from streaming import StreamingDataset from tests.common.utils import convert_to_mds, copy_all_files logger = logging.getLogger(__name__) diff --git a/tests/test_sampling.py b/tests/test_sampling.py index b8b661d8a..e2be7484c 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -3,7 +3,7 @@ import numpy as np -from streaming.base.sampling import get_sampling +from streaming.sampling import get_sampling def test_choose_per_shard_adds_up(): diff --git a/tests/test_shared.py b/tests/test_shared.py index c28229472..02a4c531b 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -5,8 +5,8 @@ import pytest -from streaming.base.shared import get_shm_prefix -from streaming.base.world import World +from streaming.shared import get_shm_prefix +from streaming.world import World @pytest.mark.usefixtures('local_remote_dir') diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index 76eeb7dd9..221a8793f 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -5,8 +5,8 @@ import numpy as np -from streaming.base.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, - get_shuffle_py1s, get_shuffle_py2s) +from streaming.shuffle import (get_shuffle_py1b, get_shuffle_py1br, get_shuffle_py1e, + get_shuffle_py1s, get_shuffle_py2s) def check(get_shuffle: Callable) -> None: diff --git a/tests/test_spanner.py b/tests/test_spanner.py index 340facd4d..46b802d5c 100644 --- a/tests/test_spanner.py +++ b/tests/test_spanner.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from streaming.base.spanner import Spanner +from streaming.spanner import Spanner def test_spanner_success(): diff --git a/tests/test_stream.py b/tests/test_stream.py index 9a7f64af4..818c19ae8 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -11,7 +11,7 @@ from _pytest.monkeypatch import MonkeyPatch from streaming import Stream, StreamingDataset -from streaming.base.distributed import barrier +from streaming.distributed import barrier from tests.common.utils import convert_to_mds diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9a32e303c..926dc67e0 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -10,8 +10,8 @@ import pytest from torch.utils.data import DataLoader -from streaming.base import Stream, StreamingDataLoader, StreamingDataset -from streaming.base.util import clean_stale_shared_memory +from streaming import Stream, StreamingDataLoader, StreamingDataset +from streaming.util import clean_stale_shared_memory from tests.common.utils import convert_to_mds diff --git a/tests/test_streaming_remote.py b/tests/test_streaming_remote.py index 7e3dd7fc9..b59498f8a 100644 --- a/tests/test_streaming_remote.py +++ b/tests/test_streaming_remote.py @@ -7,9 +7,7 @@ import pytest -from streaming.base import StreamingDataset -from streaming.text import StreamingC4 -from streaming.vision import StreamingADE20K, StreamingCIFAR10, StreamingCOCO, StreamingImageNet +from streaming import StreamingDataset def get_dataset(name: str, @@ -26,7 +24,7 @@ def get_dataset(name: str, 'train': 20206, 'val': 2000, }, - 'class': StreamingADE20K, + 'class': StreamingDataset, 'kwargs': {}, }, 'imagenet1k': { @@ -35,7 +33,7 @@ def get_dataset(name: str, 'train': 1281167, 'val': 50000, }, - 'class': StreamingImageNet, + 'class': StreamingDataset, 'kwargs': {}, }, 'coco': { @@ -44,7 +42,7 @@ def get_dataset(name: str, 'train': 117266, 'val': 4952, }, - 'class': StreamingCOCO, + 'class': StreamingDataset, 'kwargs': {}, }, 'c4': { @@ -53,11 +51,12 @@ def get_dataset(name: str, 'train': 364868892, 'val': 364608, }, - 'class': StreamingC4, + 'class': StreamingDataset, 'kwargs': { - 'tokenizer_name': 'bert-base-uncased', - 'max_seq_len': 512, - 'group_method': 'truncate' + # Use kwargs if creating a `StreamingC4`, but not needed for `StreamingDataset`. + # 'tokenizer_name': 'bert-base-uncased', + # 'max_seq_len': 512, + # 'group_method': 'truncate' }, }, 'cifar10': { @@ -66,7 +65,7 @@ def get_dataset(name: str, 'train': 50000, 'val': 10000, }, - 'class': StreamingCIFAR10, + 'class': StreamingDataset, 'kwargs': {}, }, 'test_streaming_upload': { diff --git a/tests/test_upload.py b/tests/test_upload.py index 57c0046f0..3e4056dbb 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -10,10 +10,9 @@ import boto3 import pytest -from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, - DatabricksUnityCatalogUploader, DBFSUploader, - GCSAuthentication, GCSUploader, LocalUploader, - S3Uploader) +from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, + DatabricksUnityCatalogUploader, DBFSUploader, + GCSAuthentication, GCSUploader, LocalUploader, S3Uploader) from tests.conftest import MY_BUCKET, R2_URL MY_PREFIX = 'train' @@ -37,8 +36,8 @@ def _method(cloud_prefix: str = '') -> Tuple[str, str]: class TestCloudUploader: - @patch('streaming.base.storage.upload.S3Uploader.check_bucket_exists') - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.S3Uploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.parametrize( 'mapping', [ @@ -111,7 +110,7 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = CloudUploader.get(out=out) - @patch('streaming.base.storage.LocalUploader.list_objects') + @patch('streaming.storage.LocalUploader.list_objects') @pytest.mark.usefixtures('remote_local_dir') def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, remote_local_dir: Any): @@ -123,7 +122,7 @@ def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, class TestS3Uploader: - @patch('streaming.base.storage.upload.S3Uploader.check_bucket_exists') + @patch('streaming.storage.upload.S3Uploader.check_bucket_exists') @pytest.mark.parametrize('out', ['s3://bucket/dir', ('./dir1', 's3://bucket/dir/')]) def test_instantiation(self, mocked_requests: Mock, out: Any): mocked_requests.side_effect = None @@ -215,7 +214,7 @@ def test_invalid_cloud_prefix(self, remote_local_dir: Any): class TestGCSUploader: - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.parametrize('out', ['gs://bucket/dir', ('./dir1', 'gs://bucket/dir/')]) @pytest.mark.usefixtures('gcs_hmac_credentials') def test_instantiation(self, mocked_requests: Mock, out: Any): @@ -268,7 +267,7 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = GCSUploader(out=out) - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @pytest.mark.usefixtures('gcs_hmac_credentials') @pytest.mark.parametrize('out', ['gs://bucket/dir']) def test_hmac_authentication(self, mocked_requests: Mock, out: str): @@ -284,7 +283,7 @@ def test_service_account_authentication(self, mock_client: Mock, mock_default: M uploader = GCSUploader(out=out) assert uploader.authentication == GCSAuthentication.SERVICE_ACCOUNT - @patch('streaming.base.storage.upload.GCSUploader.check_bucket_exists') + @patch('streaming.storage.upload.GCSUploader.check_bucket_exists') @patch('google.auth.default') @patch('google.cloud.storage.Client') @pytest.mark.usefixtures('gcs_service_account_credentials', 'gcs_hmac_credentials') @@ -324,7 +323,7 @@ def test_no_credentials_error(self, remote_local_dir: Any): class TestAzureUploader: - @patch('streaming.base.storage.upload.AzureUploader.check_bucket_exists') + @patch('streaming.storage.upload.AzureUploader.check_bucket_exists') @pytest.mark.usefixtures('azure_credentials') @pytest.mark.parametrize('out', ['azure://bucket/dir', ('./dir1', 'azure://bucket/dir/')]) def test_instantiation(self, mocked_requests: Mock, out: Any): @@ -356,7 +355,7 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): class TestAzureDataLakeUploader: - @patch('streaming.base.storage.upload.AzureDataLakeUploader.check_container_exists') + @patch('streaming.storage.upload.AzureDataLakeUploader.check_container_exists') @pytest.mark.usefixtures('azure_credentials') @pytest.mark.parametrize('out', ['azure://container/dir', ('./dir1', 'azure://container/dir/')]) @@ -389,7 +388,7 @@ def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): class TestDatabricksUnityCatalogUploader: - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize( 'out', ['dbfs:/Volumes/container/dir', ('./dir1', 'dbfs:/Volumes/container/dir/')]) def test_instantiation(self, mock_create_client: Mock, out: Any): @@ -398,14 +397,14 @@ def test_instantiation(self, mock_create_client: Mock, out: Any): if not isinstance(out, str): shutil.rmtree(out[0], ignore_errors=True) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['ss4://bucket/dir', ('./dir1', 'gcs://bucket/dir/')]) def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DatabricksUnityCatalogUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') def test_local_directory_is_empty(self, mock_create_client: Mock, local_remote_dir: Tuple[str, str]): mock_create_client.side_effect = None @@ -421,7 +420,7 @@ def test_local_directory_is_empty(self, mock_create_client: Mock, class TestDBFSUploader: - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['dbfs:/container/dir', ('./dir1', 'dbfs:/container/dir/')]) def test_instantiation(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None @@ -429,14 +428,14 @@ def test_instantiation(self, mock_create_client: Mock, out: Any): if not isinstance(out, str): shutil.rmtree(out[0], ignore_errors=True) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') @pytest.mark.parametrize('out', ['ss4://bucket/dir', ('./dir1', 'gcs://bucket/dir/')]) def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): mock_create_client.side_effect = None with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DBFSUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') + @patch('streaming.storage.upload.DatabricksUploader._create_workspace_client') def test_local_directory_is_empty(self, mock_create_client: Mock, local_remote_dir: Tuple[str, str]): with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): diff --git a/tests/test_util.py b/tests/test_util.py index e59f75911..eecbe4138 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -11,12 +11,12 @@ import pytest -from streaming.base.constant import RESUME -from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file -from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - merge_index, number_abbrev_to_int, retry) +from streaming.constant import RESUME +from streaming.shared.prefix import _get_path +from streaming.storage.download import download_file +from streaming.storage.upload import CloudUploader +from streaming.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, merge_index, + number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -187,7 +187,7 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - from streaming.base.converters import dataframeToMDS + from streaming.converters import dataframeToMDS def not_merged_index(index_file_path: str, out: str): """Check if index_file_path is the merged index at folder out.""" @@ -254,7 +254,7 @@ def test_merge_index_from_root_local(local_remote_dir: Tuple[str, str], n_partit from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - from streaming.base.converters import dataframeToMDS + from streaming.converters import dataframeToMDS out, _ = local_remote_dir From 3cd8a22673cdfb0c3e9a85c94677e0bc676cd314 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Thu, 7 Dec 2023 14:32:14 -0800 Subject: [PATCH 02/12] Dataset kwargs switchover. (#523) * Dataset kwargs switchover. * Docstrings: **kwargs not kwargs. * Docstrings: Callable not callable. * Add dev to workflows. --- .github/workflows/install.yaml | 2 + .github/workflows/linting.yaml | 2 + .github/workflows/pytest.yaml | 2 + examples/multimodal/webvid/read.py | 242 ++--------------------------- examples/text/c4/read.py | 95 +---------- examples/text/enwiki_txt/read.py | 91 +---------- examples/text/pile/read.py | 105 ++----------- examples/vision/ade20k/read.py | 97 +----------- examples/vision/cifar10/read.py | 60 +------ examples/vision/coco/read.py | 96 +----------- examples/vision/imagenet/read.py | 60 +------ streaming/vision.py | 6 +- 12 files changed, 68 insertions(+), 790 deletions(-) diff --git a/.github/workflows/install.yaml b/.github/workflows/install.yaml index 167156917..58e867fa8 100644 --- a/.github/workflows/install.yaml +++ b/.github/workflows/install.yaml @@ -3,10 +3,12 @@ name: Installation on: push: branches: + - dev - main - release/* pull_request: branches: + - dev - main - release/* workflow_dispatch: {} diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 4a85fa0de..213fe6ef9 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -3,10 +3,12 @@ name: Linting on: push: branches: + - dev - main - release/* pull_request: branches: + - dev - main - release/* workflow_call: diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 8cb994a89..cab9c9448 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -3,10 +3,12 @@ name: Test on: push: branches: + - dev - main - release/* pull_request: branches: + - dev - main - release/* workflow_call: diff --git a/examples/multimodal/webvid/read.py b/examples/multimodal/webvid/read.py index d3f74c2d9..89f7525b2 100644 --- a/examples/multimodal/webvid/read.py +++ b/examples/multimodal/webvid/read.py @@ -5,7 +5,7 @@ import os from time import sleep -from typing import Any, Optional +from typing import Any, Dict, Optional from streaming import StreamingDataset from streaming.dataset import TICK, _Iterator @@ -18,58 +18,12 @@ class StreamingInsideWebVid(StreamingDataset): Videos are stored "inside" the shards, as is typically done. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + **kwargs (Dict[str, Any]): Keyword arguments. """ + def __init__(self, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) + def get_item(self, idx: int) -> Any: """Get the sample at the index. @@ -91,101 +45,18 @@ class StreamingOutsideGIWebVid(StreamingDataset): get_item ("GI"), when samples are requested by the dataloader. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. extra_local (str, optional): Base destination of extra local sample downloads. extra_remote (str, optional): Base source of extra remote sample downloads. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, extra_local: Optional[str] = None, - extra_remote: Optional[str] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) - + extra_remote: Optional[str] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) # Videos are stored outside of their shards here. - self.download_timeout = download_timeout + self.download_timeout = self.streams[0].download_timeout self.extra_local = extra_local self.extra_remote = extra_remote @@ -222,101 +93,18 @@ class StreamingOutsideDTWebVid(StreamingDataset): _download_thread ("DT"), when the download thread prefetches the sample. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. extra_local (str, optional): Base destination of extra local sample downloads. extra_remote (str, optional): Base source of extra remote sample downloads. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, extra_local: Optional[str] = None, - extra_remote: Optional[str] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) - + extra_remote: Optional[str] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) # Videos are stored outside of their shards here. - self.download_timeout = download_timeout + self.download_timeout = self.streams[0].download_timeout self.extra_local = extra_local self.extra_remote = extra_remote diff --git a/examples/text/c4/read.py b/examples/text/c4/read.py index d30340f97..1cb407307 100644 --- a/examples/text/c4/read.py +++ b/examples/text/c4/read.py @@ -7,7 +7,7 @@ the `Common Crawl `_ dataset. """ -from typing import Any, Dict, Optional +from typing import Any, Dict from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -20,104 +20,19 @@ class StreamingC4(StreamingDataset): """Implementation of the C4 (Colossal Cleaned Common Crawl) dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - tokenizer_name: str, - max_seq_len: int, - group_method: str) -> None: + def __init__(self, *, tokenizer_name: str, max_seq_len: int, group_method: str, + **kwargs: Dict[str, Any]) -> None: if group_method not in {'truncate'}: raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.") - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + super().__init__(**kwargs) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len diff --git a/examples/text/enwiki_txt/read.py b/examples/text/enwiki_txt/read.py index 4385e7394..9e1c11b09 100644 --- a/examples/text/enwiki_txt/read.py +++ b/examples/text/enwiki_txt/read.py @@ -3,7 +3,7 @@ """English Wikipedia 2020-01-01 streaming dataset.""" -from typing import Any, Optional +from typing import Any, Dict import numpy as np @@ -16,94 +16,11 @@ class StreamingEnWiki(StreamingDataset): """Implementation of the English Wikipedia 2020-01-01 streaming dataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + def __init__(self, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) self.field_dtypes = { 'input_ids': np.int32, 'input_mask': np.int32, diff --git a/examples/text/pile/read.py b/examples/text/pile/read.py index 58c4afc68..440517c3f 100644 --- a/examples/text/pile/read.py +++ b/examples/text/pile/read.py @@ -7,7 +7,7 @@ high-quality datasets combined together. """ -from typing import Any, Dict, Optional +from typing import Any, Dict from transformers.models.auto.tokenization_auto import AutoTokenizer @@ -20,104 +20,19 @@ class StreamingPile(StreamingDataset): """Implementation of the the Pile using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - cache_limit: Optional[int] = None, - partition_algo: str = 'orig', - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - tokenizer_name: str, - max_seq_len: int, - group_method: str) -> None: - if group_method not in ['truncate']: - raise ValueError(f'Only group_method="truncate" is supported at this time.') - - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + def __init__(self, *, tokenizer_name: str, max_seq_len: int, group_method: str, + **kwargs: Dict[str, Any]) -> None: + if group_method not in {'truncate'}: + raise ValueError(f"group_method='{group_method}' must be one of {'truncate'}.") + + super().__init__(**kwargs) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len @@ -140,9 +55,7 @@ def _tokenize(self, text_sample: Dict[str, Any]): padding = 'max_length' max_length = self.max_seq_len else: - truncation = False - padding = False - max_length = None + raise ValueError(f'Got unknown group_method={self.group_method}.') return self.tokenizer(text_sample['text'], truncation=truncation, padding=padding, diff --git a/examples/vision/ade20k/read.py b/examples/vision/ade20k/read.py index f04fc423f..64fcdcf4e 100644 --- a/examples/vision/ade20k/read.py +++ b/examples/vision/ade20k/read.py @@ -7,7 +7,7 @@ more details about this dataset. """ -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple from streaming import StreamingDataset @@ -18,103 +18,22 @@ class StreamingADE20K(StreamingDataset): """Implementation of the ADE20K dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - joint_transform (callable, optional): A function/transforms that takes in an image and a + joint_transform (Callable, optional): A function/transforms that takes in an image and a target and returns the transformed versions of both. Defaults to ``None``. - transform (callable, optional): A function/transform that takes in an image and returns a + transform (Callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in the target and + target_transform (Callable, optional): A function/transform that takes in the target and transforms it. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ def __init__(self, *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - partition_algo: str = 'orig', - cache_limit: Optional[int] = None, - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, joint_transform: Optional[Callable] = None, transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + target_transform: Optional[Callable] = None, + **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) self.joint_transform = joint_transform self.transform = transform self.target_transform = target_transform diff --git a/examples/vision/cifar10/read.py b/examples/vision/cifar10/read.py index c2f97d8ee..b81c29514 100644 --- a/examples/vision/cifar10/read.py +++ b/examples/vision/cifar10/read.py @@ -7,6 +7,8 @@ `CIFAR-10 Dataset `_ for more details. """ +from typing import Any, Dict + from streaming.vision import StreamingVisionDataset __all__ = ['StreamingCIFAR10'] @@ -16,58 +18,8 @@ class StreamingCIFAR10(StreamingVisionDataset): """Implementation of the CIFAR-10 dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ + + def __init__(self, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) diff --git a/examples/vision/coco/read.py b/examples/vision/coco/read.py index a9622eab3..b18ebea79 100644 --- a/examples/vision/coco/read.py +++ b/examples/vision/coco/read.py @@ -7,7 +7,7 @@ `COCO dataset `_ for more details. """ -from typing import Any, Callable, Optional +from typing import Any, Callable, Dict, Optional from streaming import StreamingDataset @@ -18,97 +18,13 @@ class StreamingCOCO(StreamingDataset): """Implementation of the COCO dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transform (callable, optional): A function/transform that takes in an image and bboxes and - returns a transformed version. Defaults to ``None``. + transform (Callable, optional): A function/transform that takes in an image and returns a + transformed version. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ - def __init__(self, - *, - remote: Optional[str] = None, - local: Optional[str] = None, - split: Optional[str] = None, - download_retry: int = 2, - download_timeout: float = 60, - validate_hash: Optional[str] = None, - keep_zip: bool = False, - epoch_size: Optional[int] = None, - predownload: Optional[int] = None, - partition_algo: str = 'orig', - cache_limit: Optional[int] = None, - num_canonical_nodes: Optional[int] = None, - batch_size: Optional[int] = None, - shuffle: bool = False, - shuffle_algo: str = 'py1s', - shuffle_seed: int = 9176, - shuffle_block_size: int = 1 << 18, - transform: Optional[Callable] = None) -> None: - super().__init__(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip, - epoch_size=epoch_size, - predownload=predownload, - cache_limit=cache_limit, - partition_algo=partition_algo, - num_canonical_nodes=num_canonical_nodes, - batch_size=batch_size, - shuffle=shuffle, - shuffle_algo=shuffle_algo, - shuffle_seed=shuffle_seed, - shuffle_block_size=shuffle_block_size) + def __init__(self, *, transform: Optional[Callable] = None, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) self.transform = transform def get_item(self, idx: int) -> Any: diff --git a/examples/vision/imagenet/read.py b/examples/vision/imagenet/read.py index ab993fe36..998e35052 100644 --- a/examples/vision/imagenet/read.py +++ b/examples/vision/imagenet/read.py @@ -7,6 +7,8 @@ 2012 Classification Dataset `_ for more details. """ +from typing import Any, Dict + from streaming.vision import StreamingVisionDataset __all__ = ['StreamingImageNet'] @@ -16,58 +18,8 @@ class StreamingImageNet(StreamingVisionDataset): """Implementation of the ImageNet dataset using StreamingDataset. Args: - remote (str, optional): Remote path or directory to download the dataset from. If ``None``, - its data must exist locally. StreamingDataset uses either ``streams`` or - ``remote``/``local``. Defaults to ``None``. - local (str, optional): Local working directory to download shards to. This is where shards - are cached while they are being used. Uses a temp directory if not set. - StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``. - split (str, optional): Which dataset split to use, if any. If provided, we stream from/to - the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``. - download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. - download_timeout (float): Number of seconds to wait for a shard to download before raising - an exception. Defaults to ``60``. - validate_hash (str, optional): Optional hash or checksum algorithm to use to validate - shards. Defaults to ``None``. - keep_zip (bool): Whether to keep or delete the compressed form when decompressing - downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to - ``False``. - epoch_size (int, optional): Number of samples to draw per epoch balanced across all - streams. If ``None``, takes its value from the total number of underlying samples. - Provide this field if you are weighting streams relatively to target a larger or - smaller epoch size. Defaults to ``None``. - predownload (int, optional): Target number of samples to download per worker in advance - of current sample. Workers will attempt to download ahead by this many samples during, - but not before, training. Recommendation is to provide a value greater than per device - batch size to ensure at-least per device batch size number of samples cached locally. - If ``None``, its value gets derived using per device batch size and number of - canonical nodes ``max(batch_size, 256 * batch_size // num_canonical_nodes)``. - Defaults to ``None``. - cache_limit (int, optional): Maximum size in bytes of this StreamingDataset's shard cache. - Before downloading a shard, the least recently used resident shard(s) may be evicted - (deleted from the local cache) in order to stay under the limit. Set to ``None`` to - disable shard eviction. Defaults to ``None``. - partition_algo (str): Which partitioning algorithm to use. Defaults to ``orig``. - num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with - resumption. The sample space is divided evenly according to the number of canonical - nodes. The higher the value, the more independent non-overlapping paths the - StreamingDataset replicas take through the shards per model replica (increasing data - source diversity). Defaults to ``None``, which is interpreted as 64 times the number - of nodes of the initial run. - - .. note:: - - For sequential sample ordering, set ``shuffle`` to ``False`` and - ``num_canonical_nodes`` to the number of physical nodes of the initial run. - batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is - partitioned over the workers. Defaults to ``None``. - shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to - ``False``. - shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1s``. - shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. - shuffle_block_size (int): Unit of shuffle. Defaults to ``1 << 18``. - transform (callable, optional): A function/transform that takes in an image and returns a - transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and - returns a transformed version. Defaults to ``None``. + **kwargs (Dict[str, Any]): Keyword arguments. """ + + def __init__(self, **kwargs: Dict[str, Any]) -> None: + super().__init__(**kwargs) diff --git a/streaming/vision.py b/streaming/vision.py index bfc8a2800..5a1bb1e04 100644 --- a/streaming/vision.py +++ b/streaming/vision.py @@ -55,11 +55,11 @@ class StreamingVisionDataset(StreamingDataset, VisionDataset): """A streaming, iterable, torchvision VisionDataset. Args: - transforms (callable, optional): A function/transforms that takes in an image and a label + transforms (Callable, optional): A function/transforms that takes in an image and a label and returns the transformed versions of both. Defaults to ``None``. - transform (callable, optional): A function/transform that takes in an image and returns a + transform (Callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to ``None``. - target_transform (callable, optional): A function/transform that takes in a target and + target_transform (Callable, optional): A function/transform that takes in a target and returns a transformed version. Defaults to ``None``. **kwargs (Dict[str, Any]): Keyword arguments. """ From bf81f6b21122a1f4ccb127f527b6280de8fa4285 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 11 Dec 2023 17:13:04 -0800 Subject: [PATCH 03/12] Organize utils. (#524) * Break up util.py. * Update streaming/util/importing.py Co-authored-by: Karan Jariwala * Update streaming/util/importing.py Co-authored-by: Karan Jariwala * Add basic import redirect test. --------- Co-authored-by: Karan Jariwala --- streaming/storage/__init__.py | 4 +- streaming/storage/download.py | 34 ++- streaming/stream.py | 4 +- streaming/util.py | 551 ---------------------------------- streaming/util/__init__.py | 15 + streaming/util/importing.py | 45 +++ streaming/util/merging.py | 236 +++++++++++++++ streaming/util/retrying.py | 109 +++++++ streaming/util/shared.py | 50 +++ streaming/util/shorthand.py | 115 +++++++ tests/test_importing.py | 8 + 11 files changed, 605 insertions(+), 566 deletions(-) delete mode 100644 streaming/util.py create mode 100644 streaming/util/__init__.py create mode 100644 streaming/util/importing.py create mode 100644 streaming/util/merging.py create mode 100644 streaming/util/retrying.py create mode 100644 streaming/util/shared.py create mode 100644 streaming/util/shorthand.py create mode 100644 tests/test_importing.py diff --git a/streaming/storage/__init__.py b/streaming/storage/__init__.py index 674d4fbad..5d6d599e0 100644 --- a/streaming/storage/__init__.py +++ b/streaming/storage/__init__.py @@ -7,7 +7,8 @@ download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, download_from_local, download_from_oci, - download_from_s3, download_from_sftp) + download_from_s3, download_from_sftp, + wait_for_file_to_exist) from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, GCSUploader, LocalUploader, OCIUploader, S3Uploader) @@ -29,4 +30,5 @@ 'download_from_databricks_unity_catalog', 'download_from_dbfs', 'download_from_local', + 'wait_for_file_to_exist', ] diff --git a/streaming/storage/download.py b/streaming/storage/download.py index edb88943c..51a1b4e16 100644 --- a/streaming/storage/download.py +++ b/streaming/storage/download.py @@ -22,6 +22,7 @@ 'download_from_databricks_unity_catalog', 'download_from_dbfs', 'download_from_local', + 'wait_for_file_to_exist', ] BOTOCORE_CLIENT_ERROR_CODES = {'403', '404', 'NoSuchKey'} @@ -474,18 +475,27 @@ def download_file(remote: Optional[str], local: str, timeout: float): download_from_local(remote, local) -def wait_for_download(local: str, timeout: float = 60) -> None: - """Wait for a download by another thread/process to complete. +def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for the file to exist till timeout seconds. Raise an Exception after that. Args: - local (str): Local path (local filesystem). - timeout (float): How long to wait for file to download before raising an exception. - Defaults to ``60``. + filename (str): A file name + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout """ - t0 = time() - while not os.path.exists(local): - elapsed = time() - t0 - if timeout < elapsed: - raise TimeoutError( - f'Waited longer than {timeout}s for other worker to download {local}.') - sleep(0.25) + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError( + f'{err_msg} due to timeout. Waited {dt:.3f} sec, which is longer than the ' + + f'timeout limit of {timeout:.3f} sec.') diff --git a/streaming/stream.py b/streaming/stream.py index 200ba83e5..7948e9b65 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -18,8 +18,8 @@ from streaming.distributed import barrier, get_local_rank from streaming.format import FileInfo, Reader, get_index_basename, reader_from_json from streaming.hashing import get_hash -from streaming.storage import download_file -from streaming.util import retry, wait_for_file_to_exist +from streaming.storage import download_file, wait_for_file_to_exist +from streaming.util import retry from streaming.world import World diff --git a/streaming/util.py b/streaming/util.py deleted file mode 100644 index a9c1a0bab..000000000 --- a/streaming/util.py +++ /dev/null @@ -1,551 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Utility and helper functions for datasets.""" - -import collections.abc -import functools -import inspect -import json -import logging -import os -import random -import shutil -import sys -import tempfile -import urllib.parse -from collections import OrderedDict -from importlib import import_module -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from pathlib import Path -from time import sleep, time -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload -from warnings import warn - -import torch.distributed as dist - -from streaming.constant import SHM_TO_CLEAN -from streaming.distributed import get_local_rank, maybe_init_dist -from streaming.format.index import get_index_basename -from streaming.shared.prefix import _get_path - -logger = logging.getLogger(__name__) - -TCallable = TypeVar('TCallable', bound=Callable) - -__all__ = [ - 'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int', - 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry', - 'redirect_imports' -] - - -def get_list_arg(text: str) -> List[str]: - """Pass a list as a command-line flag. - - Args: - text (str): Text to split. - - Returns: - List[str]: Splits, if any. - """ - return text.split(',') if text else [] - - -def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, - err_msg: str) -> None: - """Wait for the file to exist till timeout seconds. Raise an Exception after that. - - Args: - filename (str): A file name - poll_interval (float): Number of seconds to wait before next polling - timeout (float): Number of seconds to wait for a file to exist before raising an exception - err_msg (str): Error message description for an exception - - Raises: - RuntimeError: Raise an Exception if file does not exist after timeout - """ - start_time = time() - while True: - sleep(poll_interval) - if os.path.exists(filename): - sleep(poll_interval) - break - dt = time() - start_time - if dt > timeout: - raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') - - -def bytes_to_int(bytes_str: Union[int, str]) -> int: - """Convert human readable byte format to an integer. - - Args: - bytes_str (Union[int, str]): Value to convert. - - Raises: - ValueError: Invalid byte suffix. - - Returns: - int: Integer value of bytes. - """ - #input is already an int - if isinstance(bytes_str, int) or isinstance(bytes_str, float): - return int(bytes_str) - - units = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, - 'pb': 1024**5, - 'eb': 1024**6, - 'zb': 1024**7, - 'yb': 1024**8, - } - # Convert a various byte types to an integer - for suffix in units: - bytes_str = bytes_str.lower().strip() - if bytes_str.lower().endswith(suffix): - try: - return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - else: - # Convert bytes to an integer - if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): - return int(bytes_str[0:-1]) - # Convert string representation of a number to an integer - elif bytes_str.isdigit(): - return int(bytes_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) - - -def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: - """Convert human readable number abbreviations to an integer. - - Args: - abbrev_str (Union[int, str]): Value to convert. - - Raises: - ValueError: Invalid number suffix. - - Returns: - int: Integer value of number abbreviation. - """ - #input is already an int - if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): - return int(abbrev_str) - - units = { - 'k': 10**3, - 'm': 10**6, - 'b': 10**9, - 't': 10**12, - } - # Convert a various abbreviation types to an integer - for suffix in units: - abbrev_str = abbrev_str.lower().strip() - if abbrev_str.lower().endswith(suffix): - try: - return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - else: - # Convert string representation of a number to an integer - if abbrev_str.isdigit(): - return int(abbrev_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) - - -def clean_stale_shared_memory() -> None: - """Clean up all the leaked shared memory. - - In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for - the local rank 0 to finish. - """ - # Initialize torch.distributed ourselves, if necessary. - destroy_dist = maybe_init_dist() - - # Perform clean up on local rank 0 - if get_local_rank() == 0: - for prefix_int in range(1000000): - leaked_shm = False - for shm_name in SHM_TO_CLEAN: - name = _get_path(prefix_int, shm_name) - try: - shm = BuiltinSharedMemory(name, True, 4) - except FileExistsError: - shm = BuiltinSharedMemory(name, False, 4) - leaked_shm = True - finally: - shm.close() # pyright: ignore - shm.unlink() - # Come out of loop if no leaked shared memory - if not leaked_shm: - break - - # Sync all ranks - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - # Delete the process group if Streaming initialized it. - if destroy_dist: - dist.destroy_process_group() - - -def get_import_exception_message(package_name: str, extra_deps: str) -> str: - """Get import exception message. - - Args: - package_name (str): Package name. - - Returns: - str: Exception message. - """ - return f'Streaming was installed without {package_name} support. ' + \ - f'To use {package_name} related packages with Streaming, run ' + \ - f'`pip install \'mosaicml-streaming[{package_name}]\'`.' - - -def merge_index(*args: Any, **kwargs: Any): - r"""Merge index.json from partitions to form a global index.json. - - This can be called as - - merge_index(index_file_urls, out, keep_local, download_timeout) - - merge_index(out, keep_local, download_timeout) - - The first signature takes in a list of index files URLs of MDS partitions. - The second takes the root of a MDS dataset and parse the partition folders from there. - - Args: - index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. - Each element can take the form of a single path string or a tuple string. - - 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. - 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. - 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. - - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file - - 1. A local directory, merge index happens locally. - 2. A remote directory, download all the sub-directories index.json, merge locally and upload. - 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. - - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: - return _merge_index_from_list(*args, **kwargs) - elif (isinstance(args[0], str) or - isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]: - return _merge_index_from_root(*args, **kwargs) - raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') - - -def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], - out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: - """Merge index.json from a list of index files of MDS directories to create joined index. - - Args: - index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions - each element can take the form of a single path string or a tuple string. - - The pattern of index_file_urls and corresponding reaction is one of: - 1. All URLS are str (local). All URLS are accessible locally -> no download - 2. All URLS are tuple (local, remote). All URLS are accessible locally -> no download - 3. All URLS are tuple (local, remote). Download URL that is not accessible locally - 4. All URLS are str (remote) -> download all - - out (Union[str, Tuple[str, str]]): path to put the merged index file - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - from streaming.storage.download import download_file - from streaming.storage.upload import CloudUploader - - if not index_file_urls or not out: - logger.warning('Either index_file_urls or out are None. ' + - 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') - return - - # This is the index json file name, e.g., it is index.json as of 0.6.0 - index_basename = get_index_basename() - - cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - - # Remove duplicates, and strip '/' from right if any - index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) - urls = [] - for url in index_file_urls: - if isinstance(url, str): - urls.append(url.rstrip('/').strip()) - else: - urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) - - # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. - with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f'A temporary folder {temp_root} is created to store index files') - - # Copy files to a temporary directory. Download if necessary - partitions = [] - for url in urls: - if isinstance(url, tuple): - src = url[0] if os.path.exists(url[0]) else url[1] - else: - src = url - - obj = urllib.parse.urlparse(src) - scheme, bucket, path = obj.scheme, obj.netloc, obj.path - if scheme == '' and bucket == '' and path == '': - raise FileNotFoundError( - f'Check data availability! local index {url[0]} is not accessible.' + - f'remote index {url[1]} does not have a valid url format') - dest = os.path.join(temp_root, path.lstrip('/')) - - try: - download_file(src, dest, download_timeout) - except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex - - if not os.path.exists(dest): - raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') - - partitions.append(dest) - - # merge shards from all index files - shards = [] - for partition_index in partitions: - p = Path(partition_index) - obj = json.load(open(partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join( - os.path.basename(p.parent), basename) - shards += obj['shards'] - - # Save merged index locally - obj = { - 'version': 2, - 'shards': shards, - } - merged_index_path = os.path.join(temp_root, index_basename) - with open(merged_index_path, 'w') as outfile: - json.dump(obj, outfile) - - # Move merged index from temp path to local part in out - # Upload merged index to remote if out has remote part - shutil.move(merged_index_path, cu.local) - if cu.remote is not None: - cu.upload_file(index_basename) - - # Clean up - if not keep_local: - shutil.rmtree(cu.local, ignore_errors=True) - - -def _merge_index_from_root(out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: - """Merge index.json given the root of MDS dataset. Write merged index to the root folder. - - Args: - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. - :A local directory, merge index happens locally - :A remote directory, download all the sub-directories index.json in a temporary - sub-directories, merge locally, and then upload it to out location - :A (local_dir, remote_dir), check if sub-directories index.json file present locally - If yes, then merge locally and upload to remote_dir . - If not, download all the sub-directories index.json from remote to local, - merge locally, and upload to remote_dir . - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60. - """ - from streaming.storage.upload import CloudUploader - - def not_merged_index(index_file_path: str, out: str): - """Check if index_file_path is the merged index at folder out. - - Args: - index_file_path (str): the path to index.json file - out (str): remote or local url of a folder - Return: - (bool): no if index.json sits in out instead of in the subfolders of out - """ - prefix = str(urllib.parse.urlparse(out).path) - return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - - if not out: - logger.warning('No MDS dataset folder specified, no index merged') - return - - cu = CloudUploader.get(out, exist_ok=True, keep_local=True) - - local_index_files = [] - cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) - for file in cl.list_objects(): - if file.endswith('.json') and not_merged_index(file, cu.local): - local_index_files.append(file) - - if cu.remote: - obj = urllib.parse.urlparse(cu.remote) - remote_index_files = [] - for file in cu.list_objects(): - if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): - join_char = '//' - if obj.scheme == 'dbfs': - path = Path(cu.remote) - prefix = os.path.join(path.parts[0], path.parts[1]) - if prefix == 'dbfs:/Volumes': - join_char = '/' - remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file)) - if len(local_index_files) == len(remote_index_files): - _merge_index_from_list(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=download_timeout) - else: - _merge_index_from_list(remote_index_files, - out, - keep_local=keep_local, - download_timeout=download_timeout) - return - - _merge_index_from_list(local_index_files, - out, - keep_local=keep_local, - download_timeout=download_timeout) - - -@overload -def retry( - exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., - num_attempts: int = ..., - initial_backoff: float = ..., - max_jitter: float = ..., -) -> Callable[[TCallable], TCallable]: - ... - - -@overload -def retry(exc_class: TCallable) -> TCallable: - # Use the decorator without parenthesis - ... - - -# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type -# "(func: Never) -> Never" -def retry( # type: ignore - exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, - num_attempts: int = 3, - initial_backoff: float = 1.0, - max_jitter: float = 0.5, -): - """Decorator to retry a function with backoff and jitter. - - Attempts are spaced out with - ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. - - Example: - .. testcode:: - - from streaming.util import retry - - num_tries = 0 - - @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) - def flaky_function(): - global num_tries - if num_tries < 2: - num_tries += 1 - raise RuntimeError("Called too soon!") - return "Third time's a charm." - - print(flaky_function()) - - .. testoutput:: - - Third time's a charm. - - Args: - exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or - classes to retry. Defaults to Exception. - num_attempts (int, optional): The total number of attempts to make. Defaults to 3. - initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. - max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. - - Increasing the ``max_jitter`` can help prevent overloading a resource when multiple - processes in parallel are calling the same underlying function. - """ - if num_attempts < 1: - raise ValueError('num_attempts must be at-least 1') - - def wrapped_func(func: TCallable) -> TCallable: - - @functools.wraps(func) - def new_func(*args: Any, **kwargs: Any): - i = 0 - while True: - try: - return func(*args, **kwargs) - except exc_class as e: - if i + 1 == num_attempts: - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - raise e - else: - sleep(initial_backoff * 2**i + random.random() * max_jitter) - logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') - i += 1 - - return cast(TCallable, new_func) - - if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( - exc_class, type) and issubclass(exc_class, Exception)): - # Using the decorator without (), like @retry_with_backoff - func = cast(TCallable, exc_class) - exc_class = Exception - - return wrapped_func(func) - - return wrapped_func - - -def redirect_imports(new_fqdn: str) -> None: - """Overlay the members of the target module onto the module of the caller. - - Args: - new_fqdn (str): Fully-qualified dot-separated target module path. - """ - frame = inspect.stack()[1] - module = inspect.getmodule(frame[0]) - if module is None: - raise RuntimeError('Module was None.') - old_fqdn = module.__name__ - - # old = sys.modules[old_fqdn] - new = import_module(new_fqdn) - sys.modules[old_fqdn].__dict__.update(new.__dict__) - - warn(f'Please update your imports: {old_fqdn} has moved to {new_fqdn}.', - DeprecationWarning, - stacklevel=2) diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py new file mode 100644 index 000000000..55a54e10e --- /dev/null +++ b/streaming/util/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for streaming.""" + +from streaming.util.importing import get_import_exception_message, redirect_imports +from streaming.util.merging import merge_index +from streaming.util.retrying import retry +from streaming.util.shared import clean_stale_shared_memory +from streaming.util.shorthand import bytes_to_int, get_list_arg, number_abbrev_to_int + +__all__ = [ + 'get_import_exception_message', 'redirect_imports', 'merge_index', 'retry', + 'clean_stale_shared_memory', 'get_list_arg', 'bytes_to_int', 'number_abbrev_to_int' +] diff --git a/streaming/util/importing.py b/streaming/util/importing.py new file mode 100644 index 000000000..56935d990 --- /dev/null +++ b/streaming/util/importing.py @@ -0,0 +1,45 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for importing.""" + +import inspect +import sys +from importlib import import_module +from warnings import warn + +__all__ = ['get_import_exception_message', 'redirect_imports'] + + +def get_import_exception_message(package_name: str, extra_deps: str) -> str: + """Get import exception message. + + Args: + package_name (str): Package name. + + Returns: + str: Exception message. + """ + return f'Streaming was installed without {package_name} support. ' + \ + f'To use {package_name} related packages with Streaming, run ' + \ + f'`pip install \'mosaicml-streaming[{package_name}]\'`.' + + +def redirect_imports(new_fqdn: str) -> None: + """Overlay the members of the target module onto the module of the caller. + + Args: + new_fqdn (str): Fully-qualified dot-separated target module path. + """ + frame = inspect.stack()[1] + module = inspect.getmodule(frame[0]) + if module is None: + raise RuntimeError('Module was None.') + old_fqdn = module.__name__ + + new = import_module(new_fqdn) + sys.modules[old_fqdn].__dict__.update(new.__dict__) + + warn(f'Please update your imports: {old_fqdn} has moved to {new_fqdn}.', + DeprecationWarning, + stacklevel=2) diff --git a/streaming/util/merging.py b/streaming/util/merging.py new file mode 100644 index 000000000..944eef56b --- /dev/null +++ b/streaming/util/merging.py @@ -0,0 +1,236 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for merging datasets.""" + +import json +import logging +import os +import shutil +import tempfile +import urllib.parse +from collections import OrderedDict +from pathlib import Path +from typing import Any, List, Tuple, Union + +from streaming.format.index import get_index_basename + +logger = logging.getLogger(__name__) + +__all__ = ['merge_index'] + + +def merge_index(*args: Any, **kwargs: Any): + r"""Merge index.json from partitions to form a global index.json. + + This can be called as + + merge_index(index_file_urls, out, keep_local, download_timeout) + + merge_index(out, keep_local, download_timeout) + + The first signature takes in a list of index files URLs of MDS partitions. + The second takes the root of a MDS dataset and parse the partition folders from there. + + Args: + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. + Each element can take the form of a single path string or a tuple string. + + 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. + 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. + 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. + + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file + + 1. A local directory, merge index happens locally. + 2. A remote directory, download all the sub-directories index.json, merge locally and upload. + 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. + + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: + return _merge_index_from_list(*args, **kwargs) + elif (isinstance(args[0], str) or + isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]: + return _merge_index_from_root(*args, **kwargs) + raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') + + +def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: + """Merge index.json from a list of index files of MDS directories to create joined index. + + Args: + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions + each element can take the form of a single path string or a tuple string. + + The pattern of index_file_urls and corresponding reaction is one of: + 1. All URLS are str (local). All URLS are accessible locally -> no download + 2. All URLS are tuple (local, remote). All URLS are accessible locally -> no download + 3. All URLS are tuple (local, remote). Download URL that is not accessible locally + 4. All URLS are str (remote) -> download all + + out (Union[str, Tuple[str, str]]): path to put the merged index file + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + from streaming.storage.download import download_file + from streaming.storage.upload import CloudUploader + + if not index_file_urls or not out: + logger.warning('Either index_file_urls or out are None. ' + + 'Need to specify both `index_file_urls` and `out`. ' + 'No index merged') + return + + # This is the index json file name, e.g., it is index.json as of 0.6.0 + index_basename = get_index_basename() + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + # Remove duplicates, and strip '/' from right if any + index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) + urls = [] + for url in index_file_urls: + if isinstance(url, str): + urls.append(url.rstrip('/').strip()) + else: + urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. + with tempfile.TemporaryDirectory() as temp_root: + logging.warning(f'A temporary folder {temp_root} is created to store index files') + + # Copy files to a temporary directory. Download if necessary + partitions = [] + for url in urls: + if isinstance(url, tuple): + src = url[0] if os.path.exists(url[0]) else url[1] + else: + src = url + + obj = urllib.parse.urlparse(src) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible.' + + f'remote index {url[1]} does not have a valid url format') + dest = os.path.join(temp_root, path.lstrip('/')) + + try: + download_file(src, dest, download_timeout) + except Exception as ex: + raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex + + if not os.path.exists(dest): + raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') + + partitions.append(dest) + + # merge shards from all index files + shards = [] + for partition_index in partitions: + p = Path(partition_index) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join( + os.path.basename(p.parent), basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Move merged index from temp path to local part in out + # Upload merged index to remote if out has remote part + shutil.move(merged_index_path, cu.local) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True) + + +def _merge_index_from_root(out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: + """Merge index.json given the root of MDS dataset. Write merged index to the root folder. + + Args: + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + :A local directory, merge index happens locally + :A remote directory, download all the sub-directories index.json in a temporary + sub-directories, merge locally, and then upload it to out location + :A (local_dir, remote_dir), check if sub-directories index.json file present locally + If yes, then merge locally and upload to remote_dir . + If not, download all the sub-directories index.json from remote to local, + merge locally, and upload to remote_dir . + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. + """ + from streaming.storage.upload import CloudUploader + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out. + + Args: + index_file_path (str): the path to index.json file + out (str): remote or local url of a folder + Return: + (bool): no if index.json sits in out instead of in the subfolders of out + """ + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + if not out: + logger.warning('No MDS dataset folder specified, no index merged') + return + + cu = CloudUploader.get(out, exist_ok=True, keep_local=True) + + local_index_files = [] + cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + for file in cl.list_objects(): + if file.endswith('.json') and not_merged_index(file, cu.local): + local_index_files.append(file) + + if cu.remote: + obj = urllib.parse.urlparse(cu.remote) + remote_index_files = [] + for file in cu.list_objects(): + if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): + join_char = '//' + if obj.scheme == 'dbfs': + path = Path(cu.remote) + prefix = os.path.join(path.parts[0], path.parts[1]) + if prefix == 'dbfs:/Volumes': + join_char = '/' + remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file)) + if len(local_index_files) == len(remote_index_files): + _merge_index_from_list(list(zip(local_index_files, remote_index_files)), + out, + keep_local=keep_local, + download_timeout=download_timeout) + else: + _merge_index_from_list(remote_index_files, + out, + keep_local=keep_local, + download_timeout=download_timeout) + return + + _merge_index_from_list(local_index_files, + out, + keep_local=keep_local, + download_timeout=download_timeout) diff --git a/streaming/util/retrying.py b/streaming/util/retrying.py new file mode 100644 index 000000000..e2c78a8c6 --- /dev/null +++ b/streaming/util/retrying.py @@ -0,0 +1,109 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Decorator that retries the wrapped function with backoff.""" + +import collections.abc +import functools +import logging +import random +from time import sleep +from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast, overload + +__all__ = ['retry'] + +logger = logging.getLogger(__name__) +TCallable = TypeVar('TCallable', bound=Callable) + + +@overload +def retry( + exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., + num_attempts: int = ..., + initial_backoff: float = ..., + max_jitter: float = ..., +) -> Callable[[TCallable], TCallable]: + ... + + +@overload +def retry(exc_class: TCallable) -> TCallable: + # Use the decorator without parenthesis + ... + + +# error: Type "(TCallable@retry) -> TCallable@retry" cannot be assigned to type +# "(func: Never) -> Never" +def retry( # type: ignore + exc_class: Union[TCallable, Type[Exception], Sequence[Type[Exception]]] = Exception, + num_attempts: int = 3, + initial_backoff: float = 1.0, + max_jitter: float = 0.5, +): + """Decorator to retry a function with backoff and jitter. + + Attempts are spaced out with + ``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds. + + Example: + .. testcode:: + + from streaming.util import retry + + num_tries = 0 + + @retry(RuntimeError, num_attempts=3, initial_backoff=0.1) + def flaky_function(): + global num_tries + if num_tries < 2: + num_tries += 1 + raise RuntimeError("Called too soon!") + return "Third time's a charm." + + print(flaky_function()) + + .. testoutput:: + + Third time's a charm. + + Args: + exc_class (Type[Exception] | Sequence[Type[Exception]]], optional): The exception class or + classes to retry. Defaults to Exception. + num_attempts (int, optional): The total number of attempts to make. Defaults to 3. + initial_backoff (float, optional): The initial backoff, in seconds. Defaults to 1.0. + max_jitter (float, optional): The maximum amount of random jitter to add. Defaults to 0.5. + + Increasing the ``max_jitter`` can help prevent overloading a resource when multiple + processes in parallel are calling the same underlying function. + """ + if num_attempts < 1: + raise ValueError('num_attempts must be at-least 1') + + def wrapped_func(func: TCallable) -> TCallable: + + @functools.wraps(func) + def new_func(*args: Any, **kwargs: Any): + i = 0 + while True: + try: + return func(*args, **kwargs) + except exc_class as e: + if i + 1 == num_attempts: + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + raise e + else: + sleep(initial_backoff * 2**i + random.random() * max_jitter) + logger.debug(f'Attempt {i + 1}/{num_attempts} failed with: {e}') + i += 1 + + return cast(TCallable, new_func) + + if not isinstance(exc_class, collections.abc.Sequence) and not (isinstance( + exc_class, type) and issubclass(exc_class, Exception)): + # Using the decorator without (), like @retry_with_backoff + func = cast(TCallable, exc_class) + exc_class = Exception + + return wrapped_func(func) + + return wrapped_func diff --git a/streaming/util/shared.py b/streaming/util/shared.py new file mode 100644 index 000000000..813e485b9 --- /dev/null +++ b/streaming/util/shared.py @@ -0,0 +1,50 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for shared memory.""" + +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory + +import torch.distributed as dist + +from streaming.constant import SHM_TO_CLEAN +from streaming.distributed import get_local_rank, maybe_init_dist +from streaming.shared.prefix import _get_path + +__all__ = ['clean_stale_shared_memory'] + + +def clean_stale_shared_memory() -> None: + """Clean up all the leaked shared memory. + + In case of a distributed run, clean up happens on local rank 0 while other local ranks wait for + the local rank 0 to finish. + """ + # Initialize torch.distributed ourselves, if necessary. + destroy_dist = maybe_init_dist() + + # Perform clean up on local rank 0 + if get_local_rank() == 0: + for prefix_int in range(1000000): + leaked_shm = False + for shm_name in SHM_TO_CLEAN: + name = _get_path(prefix_int, shm_name) + try: + shm = BuiltinSharedMemory(name, True, 4) + except FileExistsError: + shm = BuiltinSharedMemory(name, False, 4) + leaked_shm = True + finally: + shm.close() # pyright: ignore + shm.unlink() + # Come out of loop if no leaked shared memory + if not leaked_shm: + break + + # Sync all ranks + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + # Delete the process group if Streaming initialized it. + if destroy_dist: + dist.destroy_process_group() diff --git a/streaming/util/shorthand.py b/streaming/util/shorthand.py new file mode 100644 index 000000000..c96d9dca4 --- /dev/null +++ b/streaming/util/shorthand.py @@ -0,0 +1,115 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for human-friendly argument shorthand.""" + +from typing import List, Union + +__all__ = ['get_list_arg', 'bytes_to_int', 'number_abbrev_to_int'] + + +def get_list_arg(text: str) -> List[str]: + """Pass a list as a command-line flag. + + Args: + text (str): Text to split. + + Returns: + List[str]: Splits, if any. + """ + return text.split(',') if text else [] + + +def bytes_to_int(bytes_str: Union[int, str]) -> int: + """Convert human readable byte format to an integer. + + Args: + bytes_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid byte suffix. + + Returns: + int: Integer value of bytes. + """ + #input is already an int + if isinstance(bytes_str, int) or isinstance(bytes_str, float): + return int(bytes_str) + + units = { + 'kb': 1024, + 'mb': 1024**2, + 'gb': 1024**3, + 'tb': 1024**4, + 'pb': 1024**5, + 'eb': 1024**6, + 'zb': 1024**7, + 'yb': 1024**8, + } + # Convert a various byte types to an integer + for suffix in units: + bytes_str = bytes_str.lower().strip() + if bytes_str.lower().endswith(suffix): + try: + return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + else: + # Convert bytes to an integer + if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): + return int(bytes_str[0:-1]) + # Convert string representation of a number to an integer + elif bytes_str.isdigit(): + return int(bytes_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {bytes_str}. Supported suffix are ', + f'{["b"] + list(units.keys())}.' + ])) + + +def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: + """Convert human readable number abbreviations to an integer. + + Args: + abbrev_str (Union[int, str]): Value to convert. + + Raises: + ValueError: Invalid number suffix. + + Returns: + int: Integer value of number abbreviation. + """ + #input is already an int + if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): + return int(abbrev_str) + + units = { + 'k': 10**3, + 'm': 10**6, + 'b': 10**9, + 't': 10**12, + } + # Convert a various abbreviation types to an integer + for suffix in units: + abbrev_str = abbrev_str.lower().strip() + if abbrev_str.lower().endswith(suffix): + try: + return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) + except ValueError: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) + else: + # Convert string representation of a number to an integer + if abbrev_str.isdigit(): + return int(abbrev_str) + else: + raise ValueError(''.join([ + f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', + f'{list(units.keys())}.' + ])) diff --git a/tests/test_importing.py b/tests/test_importing.py new file mode 100644 index 000000000..45283b383 --- /dev/null +++ b/tests/test_importing.py @@ -0,0 +1,8 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + + +def test_redirect_imports(): + from streaming.base.util import get_import_exception_message # pyright: ignore + + # Import successful. From 0ecf06f5381d089f4473bce4954bf45e2c9d05ce Mon Sep 17 00:00:00 2001 From: James Knighton Date: Mon, 11 Dec 2023 18:03:12 -0800 Subject: [PATCH 04/12] Fancy overly long lines command. (#529) --- Makefile | 2 +- scripts/long_lines.py | 188 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 scripts/long_lines.py diff --git a/Makefile b/Makefile index 1015ec416..6ad87b9ce 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ style: $(PYTHON) -m docformatter -ri $(dirs) longlines: - find streaming tests -type f -name "*.py" | xargs grep -x '.\{100,\}' + python3 scripts/long_lines.py test: $(PYTHON) -m $(PYTEST) $(EXTRA_ARGS) diff --git a/scripts/long_lines.py b/scripts/long_lines.py new file mode 100644 index 000000000..60f70bd2a --- /dev/null +++ b/scripts/long_lines.py @@ -0,0 +1,188 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Note long lines.""" + +import os +import re +import sys +from argparse import ArgumentParser, Namespace +from re import Pattern +from typing import IO, Iterator, Optional + + +def parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument( + '--root', + type=str, + default='.', + help='Start with all the files under this directory.', + ) + args.add_argument( + '--include', + type=str, + default=r'.*\.py$', + help='Drop all files whose paths fail to match this pattern, if given.', + ) + args.add_argument( + '--exclude', + type=str, + default='', + help='Among the remaining files, drop any whose paths match this pattern, if given.', + ) + args.add_argument( + '--max_len', + type=int, + default=100, + help='Maximum line length, excluding any trailing newline.', + ) + args.add_argument( + '--non_text', + type=str, + default='error', + help='What to do if we encounter a binary (non-text) file while processing the files. ' + + 'Generally, this would suggest ``include`` was set too loose, or data corruption. ' + + 'Options: ``error``, ``warn``, ``ignore``.', + ) + args.add_argument( + '--color', + type=str, + default='light', + help='Whether to output in color. Supported options: none, light.', + ) + return args.parse_args() + + +non_text_behaviors = {'error', 'warn', 'ignore'} + + +def each_path(root: str, + include: Optional[Pattern] = None, + exclude: Optional[Pattern] = None) -> Iterator[str]: + """Get each file path under root, in order, possibly included and excluded. + + Args: + root (str): Evaluate for inclusion every file under the given root dir. + include (Pattern, optional): First, check if the include pattern matches against ecah file + path. If no include pattern was provided, we match all files. Defaults to ``None``. + exclude (Pattern, optional): Second, for each of the included file paths, check if the + exclude pattern matches it. If no exclude pattern, we do nothing. Defaults to ``None``. + + Returns: + Iterator[str]: Each file path, in order. + """ + for parent, _, file_basenames in os.walk(root): + for basename in file_basenames: + path = os.path.join(parent, basename) + + if include: + if not include.match(path): + continue + + if exclude: + if exclude.match(path): + continue + + yield path + + +def handle_non_text(behavior: str, path: str) -> None: + """Handle having received a binary file instead of a text file. + + Args: + behavior (str): Which non-text behavior to employ. + path (str): Path to file. + """ + if behavior == 'error': + raise ValueError(f'Encountered non-text file: {path}.') + elif behavior == 'warn': + print(f'{path}:binary') + elif behavior == 'ignore': + pass + else: + txt = ', '.join(sorted(non_text_behaviors)) + raise ValueError(f'Unknown non-text behavior (must be one of: {txt}): {behavior}.') + + +def open_text(path: str, non_text_behavior: str = 'warn') -> Optional[IO[str]]: + """Open the file as text (for reading line by line), with handling for binary files. + + Args: + path (str): Path to text file. + non_text_behavior (str): What to do when we got a binary file instead. + + Returns: + IO[str], optional: On success, IO in mode 'r'. + """ + try: + return open(path) + except: + handle_non_text(non_text_behavior, path) + + +def drop_newline(line: str) -> str: + """Remove the line's optional trailing newline. + + Args: + line (str): Original line. + + Returns: + str: Normalized line. + """ + if line.endswith('\n'): + return line[:-1] + elif line.endswith('\r\n'): + return line[:-2] + else: + return line + + +def main(args: Namespace) -> int: + """Note long lines. + + Args: + args (Namespace): Command-line arguments. + """ + colors = ['none', 'light'] + if args.color not in colors: + raise ValueError('Color option must be one of {colors}, but got: {args.color}.') + + include = re.compile(args.include) if args.include else None + exclude = re.compile(args.exclude) if args.exclude else None + + if args.max_len < 0: + raise ValueError(f'max_len must be non-negative, but got: {args.max_len}') + + if args.non_text not in non_text_behaviors: + txt = ', '.join(sorted(non_text_behaviors)) + raise ValueError(f'Unknown non-text behavior (must be one of: {txt}): {args.non_text}.') + + count = 0 + for path in sorted(each_path(args.root, include, exclude)): + if not (file := open_text(path, args.non_text)): + continue + + lines = map(drop_newline, file) + for line_no, line in enumerate(lines): + if args.max_len < len(line): + good_line = line[:args.max_len] + bad_line = line[args.max_len:] + if args.color == 'light': + path = f'\033[0;97m{path}\033[0;0m' + line_no = f'\033[0;92m{line_no}\033[0;0m' + good_line = f'\033[0;94m{good_line}\033[0;0m' + bad_line = f'\033[0;91m{bad_line}\033[0;0m' + print(f'{path}:{line_no}:{good_line}{bad_line}') + count += 1 + + return 1 if count else 0 + + +if __name__ == '__main__': + sys.exit(main(parse_args())) From 7c3fa058f614b990d32bab3c15612041c07761f9 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Wed, 13 Dec 2023 20:55:53 -0800 Subject: [PATCH 05/12] Redo/generalize/tighten args shorthand (#530) * Redo/generalize/tighten args shorthand, clean up usage, update tests. * Fix (cruft). * Fix (typo). * Fix (reference to member). * Tweak. * Divide tests/test_util.py into tests/util/....py. * Fix. * Error messages. * Lowercase, no space. --- simulation/core/sim_dataset.py | 18 +- simulation/interfaces/interface_utils.py | 4 +- simulation/interfaces/sim_cli.py | 4 +- simulation/interfaces/sim_ui.py | 8 +- simulation/interfaces/widgets.py | 13 +- streaming/dataset.py | 16 +- streaming/format/reader.py | 12 +- streaming/format/writer.py | 11 +- streaming/util/__init__.py | 7 +- streaming/util/shorthand.py | 425 +++++++++++++++---- tests/util/__init__.py | 2 + tests/{test_util.py => util/test_merging.py} | 127 +----- tests/util/test_retrying.py | 30 ++ tests/util/test_shared.py | 23 + tests/util/test_shorthand.py | 106 +++++ 15 files changed, 539 insertions(+), 267 deletions(-) create mode 100644 tests/util/__init__.py rename tests/{test_util.py => util/test_merging.py} (65%) create mode 100644 tests/util/test_retrying.py create mode 100644 tests/util/test_shared.py create mode 100644 tests/util/test_shorthand.py diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index 57ad2f5d0..cd1b7340f 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -20,7 +20,7 @@ from streaming.batching import generate_work from streaming.format import get_index_basename from streaming.spanner import Spanner -from streaming.util import bytes_to_int, number_abbrev_to_int +from streaming.util.shorthand import normalize_bytes, normalize_count logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -134,7 +134,6 @@ def __init__(self, self.nodes = nodes self.devices = devices self.workers = workers - self.cache_limit = cache_limit self.partition_algo = partition_algo self.predownload = predownload self.batch_size = batch_size @@ -189,11 +188,7 @@ def __init__(self, self.batch_size = batch_size or 1 # Convert epoch size from string to int, if needed. Cannot be negative. - epoch_size_value = None - if epoch_size: - epoch_size_value = number_abbrev_to_int(epoch_size) - if epoch_size_value < 0: - raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.') + epoch_size_value = normalize_count(epoch_size) if epoch_size else None # Initialize the Stream defaults and normalize to a list of Streams. if streams: @@ -284,14 +279,15 @@ def __init__(self, self.num_shards = len(self.shards) # Check that cache limit is possible. - if self.cache_limit: - if isinstance(self.cache_limit, str): - self.cache_limit = bytes_to_int(self.cache_limit) + if cache_limit: + self.cache_limit = normalize_bytes(cache_limit) min_cache_usage = sum((stream.get_index_size() for stream in streams)) if self.cache_limit <= min_cache_usage: raise ValueError(f'Minimum cache usage ({min_cache_usage} bytes) is larger than ' + f'the cache limit ({self.cache_limit} bytes). Please raise ' + f'`cache_limit`.') + else: + self.cache_limit = None for stream_idx, index_filename in enumerate(index_filenames): if indices_created[stream_idx] == 0: @@ -465,8 +461,6 @@ def get_cache_limit(self) -> Optional[int]: Returns: Optional[int]: The dataset's cache limit. """ - if isinstance(self.cache_limit, str): - self.cache_limit = bytes_to_int(self.cache_limit) return self.cache_limit def get_instantiation_time(self) -> float: diff --git a/simulation/interfaces/interface_utils.py b/simulation/interfaces/interface_utils.py index 1588d6ab7..e7e04eb58 100644 --- a/simulation/interfaces/interface_utils.py +++ b/simulation/interfaces/interface_utils.py @@ -14,7 +14,7 @@ from core.utils import get_rolling_avg_throughput from numpy.typing import NDArray -from streaming.util import number_abbrev_to_int +from streaming.util.shorthand import normalize_count def plot_simulation(step_times: NDArray, step_downloads: NDArray, window: int = 10): @@ -88,7 +88,7 @@ def get_train_dataset_params(input_params: dict, old_params: Optional[dict] = No train_dataset_params['cache_limit'] = input_params['cache_limit'] train_dataset_params['shuffle'] = input_params['shuffle'] train_dataset_params['shuffle_algo'] = input_params['shuffle_algo'] - train_dataset_params['shuffle_block_size'] = number_abbrev_to_int( + train_dataset_params['shuffle_block_size'] = normalize_count( input_params['shuffle_block_size']) if input_params['shuffle_block_size'] is not None \ else None train_dataset_params['shuffle_seed'] = input_params['seed'] diff --git a/simulation/interfaces/sim_cli.py b/simulation/interfaces/sim_cli.py index 521053604..cf6d1be8f 100644 --- a/simulation/interfaces/sim_cli.py +++ b/simulation/interfaces/sim_cli.py @@ -16,7 +16,7 @@ from core.yaml_processing import create_simulation_dataset, ingest_yaml from interfaces.interface_utils import plot_simulation -from streaming.util import bytes_to_int +from streaming.util.shorthand import normalize_bytes if __name__ == '__main__': parser = argparse.ArgumentParser(description='Simulate your training yaml from the command \ @@ -74,7 +74,7 @@ node_network_bandwidth = str(bandwidth_input) # Convert strings into numbers for applicable args - node_network_bandwidth = bytes_to_int(node_network_bandwidth) + node_network_bandwidth = normalize_bytes(node_network_bandwidth) # Create SimulationDataset print('Constructing SimulationDataset...') diff --git a/simulation/interfaces/sim_ui.py b/simulation/interfaces/sim_ui.py index 5da4e42b7..2567395a8 100644 --- a/simulation/interfaces/sim_ui.py +++ b/simulation/interfaces/sim_ui.py @@ -28,7 +28,7 @@ from interfaces.widgets import (display_shuffle_quality_graph, display_simulation_stats, get_line_chart, param_inputs) -from streaming.util import bytes_to_int, number_abbrev_to_int +from streaming.util.shorthand import normalize_bytes, normalize_count # set up page st.set_page_config(layout='wide') @@ -60,7 +60,7 @@ def submit_jobs(shuffle_quality: bool, dataset: SimulationDataset, time_per_samp max_duration (Time): Maximum duration of simulation. """ total_batches = get_total_batches(dataset=dataset, max_duration=max_duration) - node_internet_bandwidth = bytes_to_int(node_internet_bandwidth) + node_internet_bandwidth = normalize_bytes(node_internet_bandwidth) cache_limit = dataset.get_cache_limit() gen_sim = simulate(dataset, time_per_sample, @@ -92,7 +92,7 @@ def submit_jobs(shuffle_quality: bool, dataset: SimulationDataset, time_per_samp devices = input_params['devices'] workers = input_params['workers'] device_batch_size = input_params['device_batch_size'] - shuffle_block_size = number_abbrev_to_int(input_params['shuffle_block_size']) \ + shuffle_block_size = normalize_count(input_params['shuffle_block_size']) \ if input_params['shuffle_block_size'] is not None \ else dataset.get_shuffle_block_size() samples_per_shard = dataset.get_avg_samples_per_shard() @@ -279,7 +279,7 @@ def get_input_params_initial(physical_nodes: int, devices: int, workers: int, help='time for one device to process one sample from your dataset.') time_per_sample = float(time_per_sample) node_internet_bandwidth = col1.text_input('network bandwidth per node (bytes/s)', - value='500MB', + value='500mb', help='network bandwidth available to each \ node. in practice, network bandwidth is \ variable and is affected by many factors, \ diff --git a/simulation/interfaces/widgets.py b/simulation/interfaces/widgets.py index cc600f3bd..2f08c6a89 100644 --- a/simulation/interfaces/widgets.py +++ b/simulation/interfaces/widgets.py @@ -20,7 +20,7 @@ from numpy.typing import NDArray from streamlit.delta_generator import DeltaGenerator -from streaming.util import bytes_to_int +from streaming.util.shorthand import normalize_bytes def get_line_chart(data: pd.DataFrame, @@ -121,18 +121,18 @@ def stream_entry(component: DeltaGenerator, in each shard.', key=str(key) + 'samples') avg_raw_shard_size = component.text_input('avg raw shard size (bytes)', - value='60MB', + value='60mb', help='average raw size, in bytes, \ of a single shard.', key=str(key) + 'rawsize') - avg_raw_shard_size = bytes_to_int(avg_raw_shard_size) + avg_raw_shard_size = normalize_bytes(avg_raw_shard_size) avg_zip_shard_size = component.text_input('avg compressed shard size (bytes)', value='None', help='average compressed size, \ in bytes, of a single shard.', key=str(key) + 'zipsize') avg_zip_shard_size = None if avg_zip_shard_size == 'None' \ - else bytes_to_int(avg_zip_shard_size) + else normalize_bytes(avg_zip_shard_size) stream_entries['shards'] = shards stream_entries['samples_per_shard'] = samples_per_shard stream_entries['avg_raw_shard_size'] = avg_raw_shard_size @@ -256,7 +256,7 @@ def param_inputs(component: DeltaGenerator, input_params: dict, defaults: dict = sample from your dataset.') node_network_bandwidth = col_m.text_input( 'network bandwidth per node (bytes/s)', - value='500MB' + value='500mb' if 'node_network_bandwidth' not in defaults else defaults['node_network_bandwidth'], help='network bandwidth available to \ each node. in practice, network bandwidth is \ @@ -326,7 +326,8 @@ def param_inputs(component: DeltaGenerator, input_params: dict, defaults: dict = value='None' if 'cache_limit' not in defaults else defaults['cache_limit'], help='cache limit per node for this run. \ setting cache limit too low will impact throughput.') - cache_limit = None if cache_limit == '' or cache_limit == 'None' else bytes_to_int(cache_limit) + cache_limit = None if cache_limit == '' or cache_limit == 'None' else \ + normalize_bytes(cache_limit) sampling_methods = ['balanced', 'fixed'] sampling_method = col_r.selectbox('sampling method', sampling_methods, diff --git a/streaming/dataset.py b/streaming/dataset.py index 844f88d10..a9500163a 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -34,7 +34,7 @@ get_shm_prefix) from streaming.spanner import Spanner from streaming.stream import Stream -from streaming.util import bytes_to_int, number_abbrev_to_int +from streaming.util.shorthand import normalize_bytes, normalize_count from streaming.world import World # An arbitrary time in the future, used for cold shard eviction. @@ -330,7 +330,6 @@ def __init__(self, batching_method: str = 'random') -> None: # Global arguments (which do not live in Streams). self.predownload = predownload - self.cache_limit = cache_limit self.sampling_method = sampling_method self.sampling_granularity = sampling_granularity self.partition_algo = partition_algo @@ -396,11 +395,7 @@ def __init__(self, self.predownload = 8 * self.batch_size if self.batch_size is not None else 64 # Convert epoch size from string to int, if needed. Cannot be negative. - epoch_size_value = None - if epoch_size: - epoch_size_value = number_abbrev_to_int(epoch_size) - if epoch_size_value < 0: - raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.') + epoch_size_value = normalize_count(epoch_size) if epoch_size else None # Initialize torch dist ourselves, if necessary. destroy_dist = maybe_init_dist() @@ -468,9 +463,8 @@ def __init__(self, self.num_shards = len(self.shards) # Check that cache limit is possible. - if self.cache_limit: - if isinstance(self.cache_limit, str): - self.cache_limit = bytes_to_int(self.cache_limit) + if cache_limit: + self.cache_limit = normalize_bytes(cache_limit) min_cache_usage = sum((stream.get_index_size() for stream in streams)) if self.cache_limit <= min_cache_usage: raise ValueError(f'Minimum cache usage ({min_cache_usage} bytes) is larger than ' + @@ -486,6 +480,8 @@ def __init__(self, f'bytes) which includes raw (decompressed) and zip ' + f'(compressed) file size. Recommendation is to provide a ' + f'`cache_limit` as high as possible to avoid thrashing.') + else: + self.cache_limit = None # Build the shard index (for partitioning and mapping samples to shards). self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64) diff --git a/streaming/format/reader.py b/streaming/format/reader.py index 5d1401d55..cc55f205a 100644 --- a/streaming/format/reader.py +++ b/streaming/format/reader.py @@ -9,7 +9,7 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Union from streaming.array import Array -from streaming.util import bytes_to_int +from streaming.util.shorthand import normalize_bytes __all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] @@ -52,20 +52,12 @@ def __init__( samples: int, size_limit: Optional[Union[int, str]], ) -> None: - - if size_limit: - if (isinstance(size_limit, str)): - size_limit = bytes_to_int(size_limit) - if size_limit < 0: - raise ValueError(f'`size_limit` must be greater than zero, instead, ' + - f'found as {size_limit}.') - self.dirname = dirname self.split = split or '' self.compression = compression self.hashes = hashes self.samples = samples - self.size_limit = size_limit + self.size_limit = normalize_bytes(size_limit) if size_limit else None self.file_pairs = [] diff --git a/streaming/format/writer.py b/streaming/format/writer.py index 8be4cb33f..4b98b93d4 100644 --- a/streaming/format/writer.py +++ b/streaming/format/writer.py @@ -22,7 +22,7 @@ from streaming.format.index import get_index_basename from streaming.hashing import get_hash, is_hash from streaming.storage.upload import CloudUploader -from streaming.util import bytes_to_int +from streaming.util.shorthand import normalize_bytes __all__ = ['JointWriter', 'SplitWriter'] @@ -91,13 +91,6 @@ def __init__(self, if not is_hash(algo): raise ValueError(f'Invalid hash: {algo}.') - size_limit_value = None - if size_limit: - size_limit_value = bytes_to_int(size_limit) - if size_limit_value < 0: - raise ValueError(f'`size_limit` must be greater than zero, instead, ' + - f'found as {size_limit_value}.') - # Validate keyword arguments invalid_kwargs = [ arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry') @@ -108,7 +101,7 @@ def __init__(self, self.keep_local = keep_local self.compression = compression self.hashes = hashes - self.size_limit = size_limit_value + self.size_limit = normalize_bytes(size_limit) if size_limit else None self.extra_bytes_per_shard = extra_bytes_per_shard self.extra_bytes_per_sample = extra_bytes_per_sample self.new_samples: List[bytes] diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py index 55a54e10e..209bf5662 100644 --- a/streaming/util/__init__.py +++ b/streaming/util/__init__.py @@ -7,9 +7,12 @@ from streaming.util.merging import merge_index from streaming.util.retrying import retry from streaming.util.shared import clean_stale_shared_memory -from streaming.util.shorthand import bytes_to_int, get_list_arg, number_abbrev_to_int +from streaming.util.shorthand import (get_list_arg, get_str2str_arg, normalize_bin_bytes, + normalize_bytes, normalize_count, normalize_dec_bytes, + normalize_duration) __all__ = [ 'get_import_exception_message', 'redirect_imports', 'merge_index', 'retry', - 'clean_stale_shared_memory', 'get_list_arg', 'bytes_to_int', 'number_abbrev_to_int' + 'clean_stale_shared_memory', 'get_list_arg', 'get_str2str_arg', 'normalize_dec_bytes', + 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration' ] diff --git a/streaming/util/shorthand.py b/streaming/util/shorthand.py index c96d9dca4..11eb856fd 100644 --- a/streaming/util/shorthand.py +++ b/streaming/util/shorthand.py @@ -1,115 +1,370 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Utilities for human-friendly argument shorthand.""" +"""Conversions between human-friendly string forms and int/float.""" -from typing import List, Union +from collections import defaultdict +from typing import Dict, List, Union -__all__ = ['get_list_arg', 'bytes_to_int', 'number_abbrev_to_int'] +__all__ = [ + 'get_list_arg', 'get_str2str_arg', 'normalize_dec_bytes', 'normalize_bin_bytes', + 'normalize_bytes', 'normalize_count', 'normalize_duration' +] -def get_list_arg(text: str) -> List[str]: - """Pass a list as a command-line flag. +def get_list_arg(text: str, sep: str = ',') -> List[str]: + """Pass a list as a comma-delimited string. Args: - text (str): Text to split. + text (str): Text to parse. Returns: - List[str]: Splits, if any. + List[str]: List of items. """ - return text.split(',') if text else [] + if not text: + return [] + return text.split(sep) -def bytes_to_int(bytes_str: Union[int, str]) -> int: - """Convert human readable byte format to an integer. + +def get_str2str_arg(text: str, sep: str = ',', eq: str = '=') -> Dict[str, str]: + """Pass a dict as a comma- and equals-delimited string. Args: - bytes_str (Union[int, str]): Value to convert. + text (str): Text to parse. + sep (str): Separator text. Defaults to ``,``. + eq (str): Assignment text. Deffaults to ``=``. + + Returns: + Dict[str, str]: Mapping of str to str. + """ + if not text: + return {} - Raises: - ValueError: Invalid byte suffix. + ret = {} + parts = text.split(sep) + for part in parts: + key, val = part.split(eq) + if key in ret: + raise ValueError(f'Repeated key: {key} (text: {text}).') + ret[key] = val + return ret + + +def _normalize_arg(text: str, units: Dict[str, int], to_type: type) -> Union[int, float]: + """Normalize a human-friendly unit string to number. + + Args: + text (str): Human-friendly string. + units (Dict[str, Any]): Mapping of unit name to value. + to_type (Union[int, float]): The return type. Returns: - int: Integer value of bytes. + type: Computer-friendly number. """ - #input is already an int - if isinstance(bytes_str, int) or isinstance(bytes_str, float): - return int(bytes_str) - - units = { - 'kb': 1024, - 'mb': 1024**2, - 'gb': 1024**3, - 'tb': 1024**4, - 'pb': 1024**5, - 'eb': 1024**6, - 'zb': 1024**7, - 'yb': 1024**8, - } - # Convert a various byte types to an integer - for suffix in units: - bytes_str = bytes_str.lower().strip() - if bytes_str.lower().endswith(suffix): - try: - return int(float(bytes_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) + # Must be non-empty. + if not text: + raise ValueError(f'Attempted to normalize an empty string to some value.') + + # Drop commas and underscores (useful to demarcate thousands '1,337' or '1_337'). + text = text.replace(',', '') + text = text.replace('_', '') + + # Must start with a digit. + char = text[0] + if not char.isdigit(): + raise ValueError(f'Text must start with a digit, but got {text[0]} instead (input: ' + + f'{text}).') + + # Must alternative between numbers and units, starting with a number. + in_num = True + part = [] + parts = [] + for char in text: + is_digit = char.isdigit() or char == '.' + if in_num: + if is_digit: + part.append(char) + else: + part = ''.join(part) + parts.append(part) + part = [char] + in_num = False + else: + if is_digit: + part = ''.join(part) + parts.append(part) + part = [char] + in_num = True + else: + part.append(char) + part = ''.join(part) + parts.append(part) + + # If just a number, that's it. + if len(parts) == 1: + part, = parts + try: + return to_type(part) + except: + raise ValueError(f'Input must be numeric, but got {part} instead (input: {text}).') + + # Pair up numbers and units. + if len(parts) % 2: + if '' in units: + # Special case where the implied unit is the empty string, i.e. the smallest unit. + parts.append('') + else: + # If not just a number, each number must be paired with a corresponding unit. + raise ValueError(f'Text must contain pairs of number and unit, but got an odd ' + + f'number of parts instead: {parts} (input: {text}).') + + # Assign parts as numbers and units. + part_nums = [] + part_units = [] + for i, part in enumerate(parts): + if i % 2: + part_units.append(part) + else: + part_nums.append(part) + + # Each number before the last one must be integral. + for i, num in enumerate(part_nums[:-0]): + try: + part_nums[i] = int(num) + except: + raise ValueError(f'Non-final numbers must be integral, but got part {i} as {num} ' + + f'instead (input: {text}).') + + # Parse out the digits of the final number, which may be fractional. + txt = part_nums[-1] + num_dots = txt.count('.') + if not num_dots: + last_num_dec_shift = 0 + elif num_dots == 1: + idx = txt.rindex('.') + last_num_dec_shift = len(txt) - idx - 1 + txt = txt.replace('.', '') + elif 1 < num_dots: + raise ValueError(f'Final number must not contain multiple decimal points, but got ' + + f'{part_nums[-1]} instead (input: {text}).') else: - # Convert bytes to an integer - if bytes_str.endswith('b') and bytes_str[0:-1].isdigit(): - return int(bytes_str[0:-1]) - # Convert string representation of a number to an integer - elif bytes_str.isdigit(): - return int(bytes_str) + raise ValueError(f'Handling for str.count(".") returning negative required by lint.') + + # Parse the digits as an integer for exact precision, no float nonsense. + try: + part_nums[-1] = int(txt) + except: + raise ValueError(f'Final number must be numeric, but got {part_nums[-1]} instead ' + + f'(input: {text}).') + + # Each unit must be known to us. + part_muls = [] + for i, unit in enumerate(part_units): + mul = units.get(unit) + if mul is None: + raise ValueError(f'Unit is unknown: {unit} in part {i} (input: {text}).') + part_muls.append(mul) + + # Each unit must be used at most once. + unit2count = defaultdict(int) + for i, unit in enumerate(part_units): + unit2count[unit] += 1 + for unit in sorted(unit2count): + count = unit2count[unit] + if count != 1: + raise ValueError(f'Unit is reused: {unit} is used {count} times (input: {text}).') + + # Units must be listed in descending order of size. + prev_mul = part_muls[0] + for i in range(1, len(part_muls)): + mul = part_muls[i] + if mul < prev_mul: + prev_mul = mul else: - raise ValueError(''.join([ - f'Unsupported value/suffix {bytes_str}. Supported suffix are ', - f'{["b"] + list(units.keys())}.' - ])) + unit = part_units[i] + raise ValueError(f'Units are out of order: {unit} in part {i} (input: {text}).') + # The number of any given part must not exceed the size of the next biggest part's unit. + # + # (Otherwise you would just roll its overage into the next biggest part.) + for i in range(1, len(part_muls)): + parent_mul = part_muls[i - 1] + mul = part_muls[i] + num = part_nums[i] + if parent_mul < mul * num: + parent_unit = part_units[i - 1] + unit = part_units[i] + raise ValueError(f'The number of any non-initial part must not exceed the ratio of ' + + f'the unit of the next biggest part to its own unit (otherwise it ' + + f'should have been rolled into the bigger part): part {i} having ' + + f'{num} of {unit} ({mul}x) vs parent part {i - 1} in units of ' + + f'{parent_unit} ({parent_mul}x) (input: {text}).') -def number_abbrev_to_int(abbrev_str: Union[int, str]) -> int: - """Convert human readable number abbreviations to an integer. + # Collect parts, with last part being possibly scaled down to account for a decimal point. + ret = 0 + for num, mul in zip(part_nums[:-1], part_muls[:-1]): + ret += num * mul + ret += part_nums[-1] * part_muls[-1] // 10**last_num_dec_shift + return ret - Args: - abbrev_str (Union[int, str]): Value to convert. - Raises: - ValueError: Invalid number suffix. +def _normalize_num(arg: Union[int, float, str], units: Dict[str, int], + to_type: type) -> Union[int, float]: + """Normalize from human-friendly argument to number. + + Args: + arg (Union[int, float, str]): Human-friendly argument. + units (Dict[str, Any]): Mapping of unit name to value. + to_type (type): The return type. Returns: - int: Integer value of number abbreviation. + Union[int, float]: Numeric argument. """ - #input is already an int - if isinstance(abbrev_str, int) or isinstance(abbrev_str, float): - return int(abbrev_str) - - units = { - 'k': 10**3, - 'm': 10**6, - 'b': 10**9, - 't': 10**12, - } - # Convert a various abbreviation types to an integer - for suffix in units: - abbrev_str = abbrev_str.lower().strip() - if abbrev_str.lower().endswith(suffix): - try: - return int(float(abbrev_str[0:-len(suffix)]) * units[suffix]) - except ValueError: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) + if isinstance(arg, (int, float)): + return to_type(arg) else: - # Convert string representation of a number to an integer - if abbrev_str.isdigit(): - return int(abbrev_str) - else: - raise ValueError(''.join([ - f'Unsupported value/suffix {abbrev_str}. Supported suffix are ', - f'{list(units.keys())}.' - ])) + return _normalize_arg(arg, units, to_type) + + +def _normalize_int(arg: Union[int, str], units: Dict[str, int]) -> int: + """Normalize from human-friendly argument to int. + + Args: + arg (Union[int, str]): Human-friendly argument. + units (Dict[str, int]): Mapping of unit name to value. + + Returns: + int: Integral argument. + """ + return _normalize_num(arg, units, int) # pyright: ignore + + +def _normalize_nonneg_int(arg: Union[int, str], units: Dict[str, int]) -> int: + """Normalize from human-friendly argument to non-negative int. + + Args: + arg (Union[int, str]): Human-friendly argument. + units (Dict[str, int]): Mapping of unit name to value. + + Returns: + int: Non-negative integral argument. + """ + ret = _normalize_int(arg, units) # pyright: ignore + if ret < 0: + raise ValueError(f'Value cannot be negative, but got {ret} (input: {arg}).') + return ret + + +def _normalize_float(arg: Union[int, float, str], units: Dict[str, int]) -> int: + """Normalize from human-friendly argument to float. + + Args: + arg (Union[int, float, str]): Human-friendly argument. + units (Dict[str, int]): Mapping of unit name to value. + + Returns: + float: Floating argument. + """ + return _normalize_num(arg, units, float) # pyright: ignore + + +def _get_units(base: int, names: List[str]) -> Dict[str, int]: + """Generate units mapping given a base and names of powers of that base. + + Args: + base (int): Base to exponentiate. + names (List[str]): Name of each power of base. + + Returns: + Dic[str, int]: Mapping of unit name to value. + """ + units = {} + for i, name in enumerate(names): + if name in units: + raise ValueError(f'Reused unit name: {name}.') + units[name] = base**i + return units + + +_dec_bytes_units = _get_units(1000, 'b kb mb gb tb pb eb zb yb rb qb'.split()) + + +def normalize_dec_bytes(size: Union[int, str]) -> int: + """Normalize from human-friendly base-1000 bytes to int. + + Args: + size (Union[int, str]): Human-friendly base-1000 bytes. + + Returns: + int: Integral bytes. + """ + return _normalize_nonneg_int(size, _dec_bytes_units) + + +_bin_bytes_units = _get_units(1024, 'ib kib mib gib tib pib eib zib yib rib qib'.split()) + + +def normalize_bin_bytes(size: Union[int, str]) -> int: + """Normalize from human-friendly base-1024 bytes to int. + + Args: + size (Union[int, str]): Human-friendly base-1024 bytes. + + Returns: + int: Integral bytes. + """ + return _normalize_nonneg_int(size, _bin_bytes_units) + + +def normalize_bytes(size: Union[int, str]) -> int: + """Normalize from human-friendly base-1000 or base-1024 bytes to int. + + Args: + size (Union[int, str]): Human-friendly base-1000 or base-1024 bytes. + + Returns: + int: Integral bytes. + """ + errors = [] + for norm in [normalize_dec_bytes, normalize_bin_bytes]: + try: + return norm(size) + except Exception as e: + errors.append(e) + raise ValueError(f'Invalid bytes: {size}. Reasons: {[str(e) for e in errors]}.') + + +_count_units = _get_units(1000, ' k m b t'.split(' ')) + + +def normalize_count(count: Union[int, str]) -> int: + """Normalize from human-friendly count to int. + + Args: + count (Union[int, str]): Human-friendly count. + + Returns: + int: Integral count. + """ + return _normalize_nonneg_int(count, _count_units) + + +_duration_units = { + 's': 1, + 'm': 60, + 'h': 60 * 60, + 'd': 24 * 60 * 60, +} + + +def normalize_duration(duration: Union[int, float, str]) -> float: + """Normalize from human-friendly duration to float. + + Args: + duration (Union[int, float, str]): Human-friendly duration. + + Returns: + float: Float duration. + """ + return _normalize_float(duration, _duration_units) diff --git a/tests/util/__init__.py b/tests/util/__init__.py new file mode 100644 index 000000000..dd3d19b22 --- /dev/null +++ b/tests/util/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/test_util.py b/tests/util/test_merging.py similarity index 65% rename from tests/test_util.py rename to tests/util/test_merging.py index eecbe4138..861a42d69 100644 --- a/tests/test_util.py +++ b/tests/util/test_merging.py @@ -6,17 +6,13 @@ import tempfile import time import urllib.parse -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Union +from typing import Tuple, Union import pytest -from streaming.constant import RESUME -from streaming.shared.prefix import _get_path from streaming.storage.download import download_file from streaming.storage.upload import CloudUploader -from streaming.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, merge_index, - number_abbrev_to_int, retry) +from streaming.util import merge_index MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -28,101 +24,6 @@ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls -@pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), - ('hello', ['hello']), ('', [])]) -def test_get_list_arg(text: str, expected_output: List[Optional[str]]): - output = get_list_arg(text) - assert output == expected_output - - -def test_bytes_to_int(): - input_to_expected = [ - ('1234', 1234), - ('1b', 1), - ('50b', 50), - ('50B', 50), - ('100kb', 102400), - (' 100 kb', 102400), - ('75mb', 78643200), - ('75MB', 78643200), - ('75 mb ', 78643200), - ('1.39gb', 1492501135), - ('1.39Gb', 1492501135), - ('2tb', 2199023255552), - ('3pb', 3377699720527872), - ('1.11eb', 1279742870113600256), - ('1.09zb', 1286844866581978415104), - ('2.0yb', 2417851639229258349412352), - (1234, 1234), - (1, 1), - (0.5 * 1024, 512), - (100 * 1024, 102400), - (75 * 1024**2, 78643200), - (75 * 1024 * 1024, 78643200), - (35.78, 35), - (325388903.203984, 325388903), - ] - for size_pair in input_to_expected: - output = bytes_to_int(size_pair[0]) - assert output == size_pair[1] - - -def test_bytes_to_int_Exception(): - input_data = ['', '12kbb', '27mxb', '79kkb'] - for value in input_data: - with pytest.raises(ValueError, match=f'Unsupported value/suffix.*'): - _ = bytes_to_int(value) - - -def test_number_abbrev_to_int(): - input_to_expected = [ - ('1234', 1234), - ('1k', 1000), - ('50k', 50000), - ('50K', 50000), - ('100k', 100000), - (' 100 k', 100000), - ('75m', 75000000), - ('75M', 75000000), - ('75 m ', 75000000), - ('1.39b', 1390000000), - ('1.39B', 1390000000), - ('2t', 2000000000000), - ('3 T', 3000000000000), - (1234, 1234), - (1, 1), - (0.5 * 1000, 500), - (100 * 1000, 100000), - (75 * 1000**2, 75000000), - (75 * 1000 * 1000, 75000000), - (35.78, 35), - (325388903.203984, 325388903), - ] - for size_pair in input_to_expected: - output = number_abbrev_to_int(size_pair[0]) - assert output == size_pair[1] - - -def test_number_abbrev_to_int_Exception(): - input_data = ['', '12kbb', '27mxb', '79bk', '79bb', '79 b m', 'p 64', '64p'] - for value in input_data: - with pytest.raises(ValueError, match=f'Unsupported value/suffix.*'): - _ = number_abbrev_to_int(value) - - -def test_clean_stale_shared_memory(): - # Create a leaked shared memory - name = _get_path(0, RESUME) - _ = BuiltinSharedMemory(name, True, 64) - - # Clean up the stale shared memory - clean_stale_shared_memory() - - # If clean up is successful, it should raise FileNotFoundError Exception - with pytest.raises(FileNotFoundError): - _ = BuiltinSharedMemory(name, False, 64) - - def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool, expected_n_shard_files: int = -1): @@ -275,27 +176,3 @@ def test_merge_index_from_root_local(local_remote_dir: Tuple[str, str], n_partit mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) merge_index(mds_path, keep_local=keep_local) integrity_check(mds_path, keep_local=keep_local) - - -@pytest.mark.parametrize('with_args', [True, False]) -def test_retry(with_args: bool): - num_tries = 0 - return_after = 2 - - if with_args: - decorator = retry(RuntimeError, num_attempts=3, initial_backoff=0.01, max_jitter=0.01) - return_after = 2 - else: - decorator = retry - # Need to return immediately to avoid timeouts - return_after = 0 - - @decorator - def flaky_function(): - nonlocal num_tries - if num_tries < return_after: - num_tries += 1 - raise RuntimeError('Called too soon!') - return "Third time's a charm" - - assert flaky_function() == "Third time's a charm" diff --git a/tests/util/test_retrying.py b/tests/util/test_retrying.py new file mode 100644 index 000000000..77740dd6a --- /dev/null +++ b/tests/util/test_retrying.py @@ -0,0 +1,30 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from streaming.util.retrying import retry + + +@pytest.mark.parametrize('with_args', [True, False]) +def test_retry(with_args: bool): + num_tries = 0 + return_after = 2 + + if with_args: + decorator = retry(RuntimeError, num_attempts=3, initial_backoff=0.01, max_jitter=0.01) + return_after = 2 + else: + decorator = retry + # Need to return immediately to avoid timeouts + return_after = 0 + + @decorator + def flaky_function(): + nonlocal num_tries + if num_tries < return_after: + num_tries += 1 + raise RuntimeError('Called too soon!') + return "Third time's a charm" + + assert flaky_function() == "Third time's a charm" diff --git a/tests/util/test_shared.py b/tests/util/test_shared.py new file mode 100644 index 000000000..a66f8e047 --- /dev/null +++ b/tests/util/test_shared.py @@ -0,0 +1,23 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory + +import pytest + +from streaming.constant import RESUME +from streaming.shared.prefix import _get_path +from streaming.util.shared import clean_stale_shared_memory + + +def test_clean_stale_shared_memory(): + # Create a leaked shared memory + name = _get_path(0, RESUME) + _ = BuiltinSharedMemory(name, True, 64) + + # Clean up the stale shared memory + clean_stale_shared_memory() + + # If clean up is successful, it should raise FileNotFoundError Exception + with pytest.raises(FileNotFoundError): + _ = BuiltinSharedMemory(name, False, 64) diff --git a/tests/util/test_shorthand.py b/tests/util/test_shorthand.py new file mode 100644 index 000000000..0f6b36254 --- /dev/null +++ b/tests/util/test_shorthand.py @@ -0,0 +1,106 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import pytest + +from streaming.util.shorthand import get_list_arg, normalize_bytes, normalize_count + + +@pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), + ('hello', ['hello']), ('', [])]) +def test_get_list_arg(text: str, expected_output: List[Optional[str]]): + output = get_list_arg(text) + assert output == expected_output + + +def test_normalize_bytes(): + input_to_expected = [ + ('1234', 1234), + ('1b', 1), + ('50b', 50), + ('100kib', 102400), + ('75mb', 75000000), + ('75mib', 78643200), + ('1.39gib', 1492501135), + ('2tib', 2199023255552), + ('3pib', 3377699720527872), + ('1.11eib', 1279742870113600143), + ('1.09zib', 1286844866581978320732), + ('2.0yib', 2417851639229258349412352), + ('7yb', 7000000000000000000000000), + (1234, 1234), + (1, 1), + (0.5 * 1024, 512), + (100 * 1024, 102400), + (75 * 1024**2, 78643200), + (75 * 1024 * 1024, 78643200), + (35.78, 35), + (325388903.203984, 325388903), + ] + for size_pair in input_to_expected: + output = normalize_bytes(size_pair[0]) + assert output == size_pair[1] + + +def test_normalize_bytes_except(): + input_data = [ + '', + '12kbb', + '27mxb', + '79kkb', + '50B', + ' 100 kb', + '75MB', + '75 mb', + '1.39Gb', + ] + for value in input_data: + with pytest.raises(ValueError): + _ = normalize_bytes(value) + + +def test_normalize_count(): + input_to_expected = [ + ('1234', 1234), + ('1k', 1000), + ('50k', 50000), + ('100k', 100000), + ('75m', 75000000), + ('1.39b', 1390000000), + ('2t', 2000000000000), + (1234, 1234), + (1, 1), + (0.5 * 1000, 500), + (100 * 1000, 100000), + (75 * 1000**2, 75000000), + (75 * 1000 * 1000, 75000000), + (35.78, 35), + (325388903.203984, 325388903), + ] + for size_pair in input_to_expected: + output = normalize_count(size_pair[0]) + assert output == size_pair[1] + + +def test_normalize_count_except(): + input_data = [ + '', + '12kbb', + '27mxb', + '79bk', + '79bb', + '79 b m', + 'p 64', + '64p', + '50K', + ' 100 k', + '75M', + '75 m', + '1.39B', + '3 T', + ] + for value in input_data: + with pytest.raises(ValueError): + _ = normalize_count(value) From d969cd62cdf3713ddc4c22fee284356664a7defd Mon Sep 17 00:00:00 2001 From: James Knighton Date: Thu, 14 Dec 2023 20:03:15 -0800 Subject: [PATCH 06/12] Add benchmarking suite for all backends and formats (#533) * Benchmarking all backends and formats. * Fix (missing docstrings). --- benchmarks/backends/datagen.py | 204 +++++++++++++++ benchmarks/backends/plot.py | 90 +++++++ benchmarks/backends/read.py | 440 +++++++++++++++++++++++++++++++++ benchmarks/backends/write.py | 387 +++++++++++++++++++++++++++++ streaming/util/__init__.py | 3 +- streaming/util/tabulation.py | 125 ++++++++++ 6 files changed, 1248 insertions(+), 1 deletion(-) create mode 100644 benchmarks/backends/datagen.py create mode 100644 benchmarks/backends/plot.py create mode 100644 benchmarks/backends/read.py create mode 100644 benchmarks/backends/write.py create mode 100644 streaming/util/tabulation.py diff --git a/benchmarks/backends/datagen.py b/benchmarks/backends/datagen.py new file mode 100644 index 000000000..3a3e8d32e --- /dev/null +++ b/benchmarks/backends/datagen.py @@ -0,0 +1,204 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a synthetic dataset.""" + +from typing import Dict, List, Tuple, TypeVar + +import numpy as np +from numpy.random import Generator +from tqdm import tqdm + +__all__ = ['generate'] + + +def _generate_int(rng: Generator, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000) -> int: + """Pick a random integer to say in words. + + This is a synthetic dataset whose random numbers need to be distinct, deterministic given a + seed, and little else. We choose a distribution that seems the most pleasing to us. + + Properties: + * About 80% positive and 20% negative. + * Magnitude of up to a billion on either side of zero. + * Strongly skewed toward the origin, i.e. chosen uniformly across base-10 digit lengths (at + least until running out of integers of that length anyway). + + Args: + rng (Generator): NumPy random number generator. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + """ + if not 0 <= pos_prob <= 1: + raise ValueError(f'Invalid positive probability ``pos_prob``: 0 <= {pos_prob} <= 1.') + + if not low < 0 < high: + raise ValueError(f'Invalid sampling range ``low`` and/or ``high``: {low} < 0 < {high}.') + + is_pos = rng.uniform() < pos_prob + max_digits = np.log10(high) if is_pos else np.log10(-low) + exponent = rng.uniform(0, max_digits) + magnitude = int(10**exponent) + sign = is_pos * 2 - 1 + return sign * magnitude + + +def _generate_ints(count: int, + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> List[int]: + """Sample until we have the given number of distinct integers. + + Args: + count (int): How many samples to draw. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to display a progress bar. Defaults to ``True``. + + Returns: + List[int]: The integers that were drawn. + """ + rng = np.random.default_rng(seed) + nums = set() + progress_bar = tqdm(total=count, leave=False) if show_progress else None + while len(nums) < count: + num = _generate_int(rng) + if num in nums: + continue + + nums.add(num) + if progress_bar: + progress_bar.update(1) + if progress_bar: + progress_bar.close() + + nums = sorted(nums) + rng.shuffle(nums) + return nums + + +_ones = ('zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' + 'fifteen sixteen seventeen eighteen nineteen').split() + +_tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split() + + +def _int_to_words(num: int) -> List[str]: + """Say an integer as a list of words. + + Args: + num (int): The integer. + + Returns: + List[str]: The integer as a list of words. + """ + if num < 0: + return ['negative'] + _int_to_words(-num) + elif num <= 19: + return [_ones[num]] + elif num < 100: + tens = [_tens[num // 10 - 2]] + ones = [_ones[num % 10]] if num % 10 else [] + return tens + ones + elif num < 1_000: + hundreds = [_ones[num // 100], 'hundred'] + etc = _int_to_words(num % 100) if num % 100 else [] + return hundreds + etc + elif num < 1_000_000: + thousands = _int_to_words(num // 1_000) + ['thousand'] + etc = _int_to_words(num % 1_000) if num % 1_000 else [] + return thousands + etc + elif num < 1_000_000_000: + millions = _int_to_words(num // 1_000_000) + ['million'] + etc = _int_to_words(num % 1_000_000) if num % 1_000_000 else [] + return millions + etc + else: + raise ValueError('Integer out of range: -1,000,000,000 < {num} < +1,000,000,000.') + + +def _int_to_text(num: int) -> str: + """Say an integer as text. + + Args: + num (int): The integer. + + Returns: + str: The integer as text. + """ + words = _int_to_words(num) + return ' '.join(words) + + +T = TypeVar('T') + + +def _split(items: List[T], sizes: List[int]) -> List[List[T]]: + """Divide the given items across the splits given by their sizes. + + Args: + items (List[Any]): The items to divide across the spans. + sizes (List[int]): Number of items per split. + + Returns: + List[List[Any]]: Each split of items. + """ + total = sum(sizes) + if len(items) != total: + raise ValueError(f'Number of items must match the combined size of the splits: ' + + f'{len(items)} items vs splits of size {sizes} = {total}.') + + splits = [] + begin = 0 + for size in sizes: + split = items[begin:begin + size] + splits.append(split) + begin += size + + return splits + + +def generate(split2size: Dict[str, int], + seed: int = 0x1337, + pos_prob: float = 0.75, + low: int = -1_000_000_000, + high: int = 1_000_000_000, + show_progress: bool = True) -> Dict[str, Tuple[List[int], List[str]]]: + """Generate a dataset, made of splits, to be saved in different forms for comparison. + + Args: + split2size (Dict[str, int]): Mapping of split name to size in samples. + seed (int): Seed for the random number generator. Defaults to ``0x1337``. + pos_prob (float): Probability of output being positive. Defaults to ``0.75``. + low (int): Minimum of output range. Must be negative. Defaults to ``-1_000_000_000``. + high (int): Maximum of output range. Must be positive. Defaults to ``1_000_000_000``. + show_progress (bool): Whether to show a progress bar. Defaults to ``True``. + + Returns: + Dict[str, Tuple[List[int], List[str]]]: Mapping of split name to nums and texts. + """ + split_sizes = [] + total = 0 + for split in sorted(split2size): + size = split2size[split] + split_sizes.append(size) + total += size + + nums = _generate_ints(total, seed, low, high, show_progress) + nums_per_split = _split(nums, split_sizes) + + texts = list(map(_int_to_text, nums)) + texts_per_split = _split(texts, split_sizes) + + dataset = {} + for index, split in enumerate(sorted(split2size)): + dataset[split] = nums_per_split[index], texts_per_split[index] + + return dataset diff --git a/benchmarks/backends/plot.py b/benchmarks/backends/plot.py new file mode 100644 index 000000000..74a933437 --- /dev/null +++ b/benchmarks/backends/plot.py @@ -0,0 +1,90 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Plot dataset iteration time.""" + +import json +from argparse import ArgumentParser, Namespace + +import numpy as np +from matplotlib import pyplot as plt + + +def _parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--stats', type=str, default='data/backends/stats.json') + args.add_argument('--plot', type=str, default='data/backends/plot.png') + return args.parse_args() + + +def main(args: Namespace) -> None: + """Randomly iterate over a Parquet dataset with Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + streaming_colors = { + 'csv': '#c00', + 'jsonl': '#a00', + 'mds': '#800', + } + + parquet_colors = { + 'native': 'green', + 'cold': 'blue', + 'warm': 'red', + } + + lance_take_counts = 2**np.arange(11) + lance_colors = '#730', '#840', '#950', '#a60', '#b70', '#c80', '#d90', '#ea0', '#fb1', \ + '#fc4', '#fd7' + lance_colors = dict(zip(map(str, lance_take_counts), lance_colors)) + + colors = { + 'streaming': streaming_colors, + 'parquet': parquet_colors, + 'lance': lance_colors, + } + + stats = json.load(open(args.stats)) + + plt.rc('legend', fontsize=5) + plt.title('Throughput') + plt.xlabel('Seconds') + plt.ylabel('Samples') + line_width = 0.75 + + for backend in sorted(colors): + keys = sorted(colors[backend]) + if backend == 'lance': + keys = sorted(map(int, keys)) + keys = list(map(str, keys)) + for key in keys: + for ordering in ['seq', 'rand']: + color = colors[backend][key] + try: + obj = stats[backend][key][ordering] + except: + continue + times = np.array(obj['times']) / 1e9 + line_style = '-' if ordering == 'seq' else ':' + label = obj['label'] + plt.plot(times, + np.arange(len(times)), + c=color, + ls=line_style, + lw=line_width, + label=label) + + plt.legend() + plt.grid(which='major', ls='--', c='#ddd') + plt.savefig(args.plot, dpi=600) + + +if __name__ == '__main__': + main(_parse_args()) diff --git a/benchmarks/backends/read.py b/benchmarks/backends/read.py new file mode 100644 index 000000000..7bb86a0b0 --- /dev/null +++ b/benchmarks/backends/read.py @@ -0,0 +1,440 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark dataset iteration time.""" + +import json +import os +from argparse import ArgumentParser, Namespace +from time import time +from typing import Any, Dict, Iterator, List, Tuple + +import lance +import numpy as np +from lance import LanceDataset +from numpy.typing import NDArray +from pyarrow import parquet as pq +from pyarrow.parquet import ParquetFile +from tqdm import tqdm, trange + +from streaming import StreamingDataset + + +def _parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + args.add_argument('--data_root', type=str, default='data/backends/gold/') + args.add_argument('--split', type=str, default='medium') + args.add_argument('--lance_pow_interval', type=int, default=4) + args.add_argument('--parquet_suffix', type=str, default='.parquet') + args.add_argument('--progress_bar', type=int, default=1) + args.add_argument('--time_limit', type=float, default=180) + args.add_argument('--out', type=str, default='data/backends/stats.json') + return args.parse_args() + + +def _bench_lance_seq(dataset: LanceDataset, take_count: int, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a Lance dataset in sequential order. + + Args: + dataset (LanceDataset): The Lance dataset to iterate. + take_count (int): How many samples to take per sequential access. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + num_samples = dataset.count_rows() + if num_samples % take_count: + raise ValueError(f'`num_samples` ({num_samples}) must be divisible by `take_count` ' + + f'({take_count}).') + num_batches = num_samples // take_count + shape = num_batches, take_count + times = np.zeros(shape, np.float64) + sample, = dataset.head(1).to_pylist() + columns = sorted(sample) + each_batch = enumerate(dataset.to_batches(columns=columns, batch_size=take_count)) + if show_progress: + each_batch = tqdm(each_batch, total=num_batches, leave=False) + t0 = time() + for i, samples in each_batch: + samples.to_pylist() + assert len(samples) == take_count + if num_batches < i: # ??? + break + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times.flatten() + + +def _bench_lance_rand(dataset: LanceDataset, take_count: int, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a Lance dataset in random order. + + Args: + dataset (LanceDataset): The Lance dataset to iterate. + take_count (int): How many samples to take per random access. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + num_samples = dataset.count_rows() + if num_samples % take_count: + raise ValueError(f'`num_samples` ({num_samples}) must be divisible by `take_count` ' + + f'({take_count}).') + shape = num_samples // take_count, take_count + times = np.zeros(shape, np.float64) + batches = np.random.permutation(num_samples).reshape(shape) + if show_progress: + batches = tqdm(batches, leave=False) + t0 = time() + for i, sample_ids in enumerate(batches): + dataset.take(sample_ids).to_pylist() + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times.flatten() + + +def _each_parquet(dataset_root: str, parquet_suffix: str) -> Iterator[str]: + """Iteracte over each Parquet shard file of the dataset in order. + + Args: + dataset_root (str): Dataset root directory. + parquet_suffix (str): Parquet shard file suffix. + + Returns: + Iterator[str]: Each Parquet shard file. + """ + ret = [] + for parent, _, file_basenames in os.walk(dataset_root): + file_basenames = filter(lambda basename: basename.endswith(parquet_suffix), file_basenames) + ret += [os.path.join(parent, basename) for basename in file_basenames] + yield from sorted(ret) + + +def _get_parquet_mapping(dataset_dir: str, parquet_suffix: str) -> \ + Tuple[List[str], NDArray[np.int64]]: + """Get a mapping of sample ID to (shard ID, relative sample ID within that shard). + + Args: + dataset_dir (str): Parquet dataset directory. + parquet_suffix (str): Parquet shard file suffix. + + Returns: + Tuple[List[str], NDArray[np.int64]]: Filenames and mapping. + """ + filenames = list(_each_parquet(dataset_dir, parquet_suffix)) + mapping = [] + for file_id, filename in enumerate(filenames): + file = ParquetFile(filename) + mapping += list(zip([file_id] * file.metadata.num_rows, range(file.metadata.num_rows))) + mapping = np.array(mapping) + return filenames, mapping + + +def _bench_parquet_seq(dataset_dir: str, parquet_suffix: str, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in sequential order. + + Args: + dataset_dir (str): Parquet dataset directory. + parquet_suffix (str): Parquet shard file suffix. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + filenames, mapping = _get_parquet_mapping(dataset_dir, parquet_suffix) + num_samples = len(mapping) + times = np.zeros(num_samples, np.float64) + progress_bar = tqdm(total=num_samples, leave=False) if show_progress else None + i = 0 + t0 = time() + for filename in filenames: + table = pq.read_table(filename) + for _ in table.to_pylist(): + times[i] = t = time() - t0 + if time_limit <= t: + return times[:i] + i += 1 + if progress_bar: + progress_bar.update(1) + return times + + +def _bench_parquet_rand(dataset_dir: str, parquet_suffix: str, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in random order. + + Args: + dataset_dir (str): Parquet dataset directory. + parquet_suffix (str): Parquet shard file suffix. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + filenames, mapping = _get_parquet_mapping(dataset_dir, parquet_suffix) + num_samples = len(mapping) + indices = np.random.permutation(num_samples) + times = np.zeros(num_samples, np.float64) + progress_bar = tqdm(total=num_samples, leave=False) if show_progress else None + t0 = time() + for i, sample_id in enumerate(indices): + file_id, shard_sample_id = mapping[sample_id] + filename = filenames[file_id] + table = pq.read_table(filename) + shard_samples = table.to_pylist() + shard_samples[shard_sample_id] + times[i] = t = time() - t0 + if progress_bar: + progress_bar.update(1) + if time_limit <= t: + times = times[:i] + break + return times + + +def _clear_mds_files(dataset_root: str) -> None: # pyright: ignore + """Clear the intermediate MDS shard files. + + Args: + dataset_root (str): Dataset root directoyr. + """ + for parent, _, file_basenames in os.walk(dataset_root): + for basename in file_basenames: + if basename.endswith('.mds'): + filename = os.path.join(parent, basename) + os.remove(filename) + + +def _bench_streaming_seq(dataset: StreamingDataset, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in sequential order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + times = np.zeros(dataset.num_samples, np.float64) + xrange = trange(dataset.num_samples, leave=False) if show_progress else \ + range(dataset.num_samples) + t0 = time() + for i in xrange: + dataset[i] + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times + + +def _bench_streaming_rand(dataset: StreamingDataset, show_progress: bool, + time_limit: float) -> NDArray[np.float64]: + """Benchmark iterating a StreamingDataset in random order. + + Args: + dataset (StreamingDataset): The streaming dataset to iterate. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + NDArray[np.float64]: Time taken to process that many dataset samples. + """ + indices = np.random.permutation(dataset.num_samples) + times = np.zeros(dataset.num_samples) + if show_progress: + indices = tqdm(indices, leave=False) + t0 = time() + for i, sample_id in enumerate(indices): + dataset[sample_id] + times[i] = t = time() - t0 + if time_limit <= t: + times = times[:i] + break + return times + + +def _to_dict(label: str, times: NDArray[np.float64]) -> Dict[str, Any]: + """Convert a label and sample latencies ndarray into an interpretable JSON dict. + + Args: + label (str): Name of this run. + times (NDArray[np.float64]): Sample access times ndarray. + + Returns: + Dict[str, Any]: JSON dict of interpretable metadata. + """ + rate = int(len(times) / times[-1]) + label = f'{label}: {rate:,}/s' + print(label) + return { + 'label': label, + 'rate': rate, + 'times': (times * 1e9).astype(np.int64).tolist(), + } + + +def _bench_streaming_format(data_root: str, shard_format: str, split: str, show_progress: bool, + time_limit: float) -> Dict[str, Any]: + """Benchmark the performance of a native Stremaing format (e.g., MDS, JSONL, CSV). + + Args: + data_root (str): Data root directory. + shard_format (str): Streaming format name. + split (str): Split name. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + Dict[str, Any]: Mapping of ordering name to benchmark metadata JSON dict. + """ + dataset_dir = os.path.join(data_root, shard_format, split) + dataset = StreamingDataset(local=dataset_dir) + + times = _bench_streaming_seq(dataset, show_progress, time_limit) + seq = _to_dict(f'Streaming {shard_format.upper()} seq', times) + + times = _bench_streaming_rand(dataset, show_progress, time_limit) + rand = _to_dict(f'Streaming {shard_format.upper()} rand', times) + + return {'seq': seq, 'rand': rand} + + +def _bench_streaming(data_root: str, split: str, show_progress: bool, + time_limit: float) -> Dict[str, Any]: + """Benchmark the performance of all native Streaming formats. + + Args: + data_root (str): Data root directory. + split (str): Split name. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + Dict[str, Any]: Mapping of format to ordering to benchmark metadata JSON dict. + """ + mds = _bench_streaming_format(data_root, 'mds', split, show_progress, time_limit) + csv = _bench_streaming_format(data_root, 'csv', split, show_progress, time_limit) + jsonl = _bench_streaming_format(data_root, 'jsonl', split, show_progress, time_limit) + return {'mds': mds, 'csv': csv, 'jsonl': jsonl} + + +def _bench_parquet(data_root: str, split: str, parquet_suffix: str, show_progress: bool, + time_limit: float) -> Dict[str, Any]: + """Benchmark the performance of Parquet and Streaming Parquet. + + Args: + data_root (str): Data root directory. + split (str): Split name. + parquet_suffix (str): Parquet filename suffix. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + + Returns: + Dict[str, Any]: Mapping of benchmark name to ordering to benchmark metadata JSON dict. + """ + dataset_dir = os.path.join(data_root, 'parquet', split) + + times = _bench_parquet_seq(dataset_dir, parquet_suffix, show_progress, time_limit) + seq = _to_dict('Parquet seq (in mem)', times) + times = _bench_parquet_rand(dataset_dir, parquet_suffix, show_progress, time_limit) + rand = _to_dict('Parquet rand (in mem)', times) + native = {'seq': seq, 'rand': rand} + """ + streaming_dataset = StreamingDataset(local=dataset_dir) + + times = _bench_streaming_seq(streaming_dataset, show_progress, time_limit) + seq = _to_dict('Streaming Parquet seq (cold)', times) + _clear_mds_files(dataset_dir) + times = _bench_streaming_rand(streaming_dataset, show_progress, time_limit) + rand = _to_dict('Streaming Parquet rand (cold)', times) + cold = {'seq': seq, 'rand': rand} + + times = _bench_streaming_seq(streaming_dataset, show_progress, time_limit) + seq = _to_dict('Streaming Parquet seq (cached)', times) + times = _bench_streaming_rand(streaming_dataset, show_progress, time_limit) + rand = _to_dict('Streaming Parquet rand (cached)', times) + warm = {'seq': seq, 'rand': rand} + """ + + return {'native': native} + + +def _bench_lance(data_root: str, split: str, show_progress: bool, time_limit: float, + pow_interval: int) -> Dict[str, Any]: + """Benchmark the performance of Lance and, someday, Streaming Lance. + + Args: + data_root (str): Data root directory. + split (str): Split name. + show_progress (bool): Whether to show a progress bar. + time_limit (float): Benchmarking cutoff time. + pow_interval (int): Take count exponent interval. Must be either ``2`` or ``4``. + + Returns: + Dict[str, Any]: Mapping of take count to ordering to benchmark metadata JSON dict. + """ + if pow_interval == 4: + take_counts = 1, 4, 16, 64, 256, 1024 + elif pow_interval == 2: + take_counts = 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024 + else: + raise ValueError(f'Unsupported --lance_pow_interval: {pow_interval} (must be 2 or 4).') + + dataset_dir = os.path.join(data_root, 'lance', split) + lance_dataset = lance.dataset(dataset_dir) + + ret = {} + + for take_count in take_counts: + times = _bench_lance_seq(lance_dataset, take_count, show_progress, time_limit) + ret[take_count] = {} + ret[take_count]['seq'] = _to_dict(f'Lance seq x{take_count:04}', times) + + for take_count in take_counts: + times = _bench_lance_rand(lance_dataset, take_count, show_progress, time_limit) + ret[take_count]['rand'] = _to_dict(f'Lance rand x{take_count:04}', times) + + return ret + + +def main(args: Namespace) -> None: + """Randomly iterate over a Parquet dataset with Streaming. + + Args: + args (Namespace): Command-line arguments. + """ + show_progress = bool(args.progress_bar) + + streaming_info = _bench_streaming(args.data_root, args.split, show_progress, args.time_limit) + parquet_info = _bench_parquet(args.data_root, args.split, args.parquet_suffix, show_progress, + args.time_limit) + lance_info = _bench_lance(args.data_root, args.split, show_progress, args.time_limit, + args.lance_pow_interval) + info = {'streaming': streaming_info, 'parquet': parquet_info, 'lance': lance_info} + + with open(args.out, 'w') as out: + json.dump(info, out) + + +if __name__ == '__main__': + main(_parse_args()) diff --git a/benchmarks/backends/write.py b/benchmarks/backends/write.py new file mode 100644 index 000000000..9404e5d0a --- /dev/null +++ b/benchmarks/backends/write.py @@ -0,0 +1,387 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Generate a synthetic dataset and serialize it using each Streaming format/backend.""" + +import os +from argparse import ArgumentParser, Namespace +from collections import defaultdict +from functools import partial +from shutil import rmtree +from time import time +from typing import Dict, Iterable, List, Optional, Tuple + +import lance +import pyarrow as pa +import pyspark +import pyspark.sql +from delta import configure_spark_with_delta_pip +from pyarrow import parquet as pq +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from tqdm import tqdm +from wurlitzer import pipes + +from benchmarks.backends.datagen import generate +from streaming import CSVWriter, JSONWriter, MDSWriter +from streaming.util.tabulation import Tabulator + + +def _parse_args() -> Namespace: + """Parse command-line arguments. + + Returns: + Namespace: Command-line arguments. + """ + args = ArgumentParser() + + # Reproducibility. + args.add_argument('--seed', type=int, default=1337) + + # Dataset data distribution. + args.add_argument('--data_pos_prob', type=float, default=0.75) + args.add_argument('--data_low', type=int, default=-1_000_000_000) + args.add_argument('--data_high', type=int, default=1_000_000_000) + + # Sizes of datasets splits and shards. + args.add_argument('--small', type=int, default=1 << 15) + args.add_argument('--medium', type=int, default=1 << 20) + args.add_argument('--large', type=int, default=1 << 25) + args.add_argument('--size_limit', type=int, default=1 << 23) + args.add_argument('--samples_per_shard', type=int, default=1 << 18) + + # Outputs. + args.add_argument('--data_root', type=str, default='data/backends/') + args.add_argument('--formats', type=str, default='csv,delta,jsonl,lance,mds,parquet') + + # Introspection. + args.add_argument('--show_progress', type=int, default=1) + args.add_argument('--quiet_delta', type=int, default=1) + + return args.parse_args() + + +def _write_csv(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset in Streaming CSV form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + with CSVWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + + +def _write_jsonl(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset Streaming JSONL form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + + +def _write_mds(nums: List[int], + txts: List[str], + root: str, + size_limit: Optional[int], + show_progress: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + size_limit (int, optional): Maximum shard size in bytes, or no limit. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + columns = { + 'num': 'int', + 'txt': 'str', + } + with MDSWriter(out=root, columns=columns, size_limit=size_limit) as out: + each_sample = zip(nums, txts) + if show_progress: + each_sample = tqdm(each_sample, total=len(nums), leave=False) + for num, txt in each_sample: + sample = { + 'num': num, + 'txt': txt, + } + out.write(sample) + + +def _write_parquet(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + show_progress: bool = True) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + show_progress (bool): Whether to show a progress bar while saving. Defaults to ``True``. + """ + if not os.path.exists(root): + os.makedirs(root) + num_samples = len(nums) + num_shards = (num_samples + samples_per_shard - 1) // samples_per_shard + each_shard = range(num_shards) + if show_progress: + each_shard = tqdm(each_shard, total=num_shards, leave=False) + for i in each_shard: + begin = i * samples_per_shard + end = min(begin + samples_per_shard, num_samples) + shard_nums = nums[begin:end] + shard_txts = txts[begin:end] + path = os.path.join(root, f'{i:05}.parquet') + obj = { + 'num': shard_nums, + 'txt': shard_txts, + } + table = pa.Table.from_pydict(obj) + pq.write_table(table, path) + + +def _write_delta(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: + """Save the dataset in Streaming MDS form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + builder = pyspark.sql.SparkSession.builder.appName('prolix') # pyright: ignore + builder = builder.config('spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension') + builder = builder.config('spark.sql.catalog.spark_catalog', + 'org.apache.spark.sql.delta.catalog.DeltaCatalog') + spark = configure_spark_with_delta_pip(builder).getOrCreate() + schema = StructType([ + StructField('num', IntegerType(), False), + StructField('txt', StringType(), False), + ]) + samples = list(zip(nums, txts)) + df = spark.createDataFrame(samples, schema) + df.write.format('delta').option('maxRecordsPerFile', samples_per_shard).save(root) + + +def _do_write_delta(nums: List[int], + txts: List[str], + root: str, + samples_per_shard: int, + quietly: bool = True) -> None: + """Save the dataset in Streaming MDS form, possibly capturing stdout/stderr. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + quietly (bool): Whether to capture the Delta logging. Defaults to ``True``. + """ + write = lambda: _write_delta(nums, txts, root, samples_per_shard) + if quietly: + with pipes(): + write() + else: + write() + + +def _write_lance(nums: List[int], txts: List[str], root: str, samples_per_shard: int) -> None: + """Save the dataset in Lance form. + + Args: + nums (List[int]): The sample numbers. + txts (List[str]): The sample texts. + root (str): Root directory. + samples_per_shard (int): Maximum numbero of samples per shard. + """ + column_names = 'num', 'txt' + column_values = nums, txts + table = pa.Table.from_arrays(column_values, column_names) + lance.write_dataset(table, root, mode='create', max_rows_per_file=samples_per_shard) + + +def _get_file_sizes(root: str) -> List[int]: + """Inventory what was written, collecting total files and total bytes. + + Args: + root (str): Dataset root. + + Returns: + Tuple[int, int]: Total files and total bytes written. + """ + sizes = [] + for parent, _, file_basenames in sorted(os.walk(root)): + for basename in sorted(file_basenames): + path = os.path.join(parent, basename) + size = os.stat(path).st_size + sizes.append(size) + return sizes + + +def _splits_by_size(dataset: Dict[str, Tuple[List[int], List[str]]]) -> Iterable[str]: + """Order a dataset's splits by their size in samples, then by name. + + Argxs: + dataset (Dict[str, Tuple[List[int], List[str]]]): Mapping of split name to split data. + + Returns: + Iterable[str]: Ordered split names. + """ + size2splits = defaultdict(list) + for split, (nums, _) in dataset.items(): + size2splits[len(nums)].append(split) + + splits_by_size = [] + for size in sorted(size2splits): + for split in sorted(size2splits[size]): + splits_by_size.append(split) + + return splits_by_size + + +def main(args: Namespace) -> None: + """Generate identical datasets in various formats for performance comparison. + + Args: + args (Namespace): Command-line arguments. + """ + # Confgure the dataset writing statistics table printer. + table_columns = ''' + < format 8 + > sec 7 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + table_indent = 4 + table = Tabulator.from_conf(table_columns, table_indent * ' ') + + # Normalize arguments. + format_names = args.formats.split(',') if args.formats else [] + show_progress = bool(args.show_progress) + quiet_delta = bool(args.quiet_delta) + + # Given args, now we know how to configure saving the dataset in each format. + format2write = { + 'csv': + partial(_write_csv, size_limit=args.size_limit, show_progress=show_progress), + 'delta': + partial(_do_write_delta, quietly=quiet_delta, + samples_per_shard=args.samples_per_shard), + 'jsonl': + partial(_write_jsonl, size_limit=args.size_limit, show_progress=show_progress), + 'lance': + partial(_write_lance, samples_per_shard=args.samples_per_shard), + 'mds': + partial(_write_mds, size_limit=args.size_limit, show_progress=show_progress), + 'parquet': + partial(_write_parquet, + samples_per_shard=args.samples_per_shard, + show_progress=show_progress), + } + + # Collect sizes of the splits to generate. + split2size = { + 'small': args.small, + 'medium': args.medium, + 'large': args.large, + } + + # Generate the dataset samples. + t0 = time() + dataset = generate(split2size, args.seed, args.data_pos_prob, args.data_low, args.data_high, + show_progress) + elapsed = time() - t0 + print(f'Generate: {elapsed:.3f} sec.') + + # Wipe output directory if exists. + if os.path.exists(args.data_root): + print(f'Found directory at {args.data_root}, wiping it for reuse') + rmtree(args.data_root) + + # Write each split in each desired formats, in order of size. + pretty_int = lambda num: f'{num:,}' + for split in _splits_by_size(dataset): + print() + print(f'Write split: {split}') + print(table.draw_line()) + print(table.draw_header()) + print(table.draw_line()) + + nums, txts = dataset[split] + for format_name in format_names: + split_root = os.path.join(args.data_root, 'gold', format_name, split) + write = format2write[format_name] + + t0 = time() + try: + write(nums, txts, split_root) + except: + continue # Getting Delta Java OOMs at gigabyte size. + elapsed = time() - t0 + + file_sizes = _get_file_sizes(split_root) + row = { + 'format': format_name, + 'sec': f'{elapsed:.3f}', + 'samples': pretty_int(len(nums)), + 'usec/sp': f'{1e6 * elapsed / len(nums):.3f}', + 'bytes': pretty_int(sum(file_sizes)), + 'files': pretty_int(len(file_sizes)), + 'bytes/file': pretty_int(sum(file_sizes) // len(file_sizes)), + 'max bytes/file': pretty_int(max(file_sizes)), + } + print(table.draw_row(row)) + print(table.draw_line()) + + +if __name__ == '__main__': + main(_parse_args()) diff --git a/streaming/util/__init__.py b/streaming/util/__init__.py index 209bf5662..d17cf214d 100644 --- a/streaming/util/__init__.py +++ b/streaming/util/__init__.py @@ -10,9 +10,10 @@ from streaming.util.shorthand import (get_list_arg, get_str2str_arg, normalize_bin_bytes, normalize_bytes, normalize_count, normalize_dec_bytes, normalize_duration) +from streaming.util.tabulation import Tabulator __all__ = [ 'get_import_exception_message', 'redirect_imports', 'merge_index', 'retry', 'clean_stale_shared_memory', 'get_list_arg', 'get_str2str_arg', 'normalize_dec_bytes', - 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration' + 'normalize_bin_bytes', 'normalize_bytes', 'normalize_count', 'normalize_duration', 'Tabulator' ] diff --git a/streaming/util/tabulation.py b/streaming/util/tabulation.py new file mode 100644 index 000000000..179194d6c --- /dev/null +++ b/streaming/util/tabulation.py @@ -0,0 +1,125 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Line by line text table printer.""" + +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Self + +__all__ = ['Tabulator'] + + +class Tabulator: + """Line by line text table printer. + + Example: + conf = ''' + < format 8 + > sec 6 + > samples 12 + > usec/sp 8 + > bytes 14 + > files 6 + > bytes/file 12 + > max bytes/file 14 + ''' + left = 4 * ' ' + tab = Tabulator.from_conf(conf, left) + + Args: + cols (List[Tuple[str, str, int]]: Each column config (i.e., just, name, width). + left (str, optional): Print this before each line (e.g., indenting). Defaults to ``None``. + """ + + def __init__(self, cols: List[Tuple[str, str, int]], left: Optional[str] = None) -> None: + self.cols = cols + self.col_justs = [] + self.col_names = [] + self.col_widths = [] + for just, name, width in cols: + if just not in {'<', '>'}: + raise ValueError(f'Invalid justify (must be one of "<" or ">"): {just}.') + + if not name: + raise ValueError('Name must be non-empty.') + elif width < len(name): + raise ValueError(f'Name is too wide for its column width: {width} vs {name}.') + + if width <= 0: + raise ValueError(f'Width must be positive, but got: {width}.') + + self.col_justs.append(just) + self.col_names.append(name) + self.col_widths.append(width) + + self.left = left + + self.box_chr_horiz = chr(0x2500) + self.box_chr_vert = chr(0x2502) + + @classmethod + def from_conf(cls, conf: str, left: Optional[str] = None) -> Self: + """Initialize a Tabulator from a text table defining its columns. + + Args: + conf (str): The table config. + left (str, optional): Optional string that is printed before each line (e.g., indents). + """ + cols = [] + for line in conf.strip().split('\n'): + words = line.split() + + if len(words) < 3: + raise ValueError(f'Invalid col config (must be "just name width"): {line}.') + + just = words[0] + name = ' '.join(words[1:-1]) + width = int(words[-1]) + cols.append((just, name, width)) + return cls(cols, left) + + def draw_row(self, row: Dict[str, Any]) -> str: + """Draw a row, given a mapping of column name to field value. + + Args: + row (Dict[str, Any]): Mapping of column name to field value. + + Returns: + str: Text line. + """ + fields = [] + for just, name, width in self.cols: + val = row[name] + + txt = val if isinstance(val, str) else str(val) + if width < len(txt): + raise ValueError(f'Field is too wide for its column: column (just: {just}, ' + + f'name: {name}, width: {width}) vs field {txt}.') + + txt = txt.ljust(width) if just == '<' else txt.rjust(width) + fields.append(txt) + + left_txt = self.left or '' + fields_txt = f' {self.box_chr_vert} '.join(fields) + return f'{left_txt}{self.box_chr_vert} {fields_txt} {self.box_chr_vert}' + + def draw_header(self) -> str: + """Draw a header row. + + Returns: + str: Text line. + """ + row = dict(zip(self.col_names, self.col_names)) + return self.draw_row(row) + + def draw_line(self) -> str: + """Draw a divider row. + + Returns: + str: Text line. + """ + seps = (self.box_chr_horiz * width for width in self.col_widths) + row = dict(zip(self.col_names, seps)) + line = self.draw_row(row) + return line.replace(self.box_chr_vert, self.box_chr_horiz) From 78c150ec926fb938994f4d153a4c464097a6e8ec Mon Sep 17 00:00:00 2001 From: James Knighton Date: Thu, 14 Dec 2023 20:30:30 -0800 Subject: [PATCH 07/12] Replicate allow_unsafe_types in dev, where it will be used more. (#534) --- simulation/core/sim_dataset.py | 9 ++++-- streaming/dataset.py | 9 ++++-- streaming/format/mds/encodings.py | 21 +++++++++++- streaming/format/mds/reader.py | 17 +++++++++- streaming/format/reader.py | 10 ++++++ streaming/stream.py | 6 +++- tests/test_unsafe_types.py | 53 +++++++++++++++++++++++++++++++ 7 files changed, 118 insertions(+), 7 deletions(-) create mode 100644 tests/test_unsafe_types.py diff --git a/simulation/core/sim_dataset.py b/simulation/core/sim_dataset.py index cd1b7340f..c3859fb6d 100644 --- a/simulation/core/sim_dataset.py +++ b/simulation/core/sim_dataset.py @@ -99,6 +99,9 @@ class SimulationDataset(StreamingDataset): Defaults to ``1``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. """ def __init__(self, @@ -125,7 +128,8 @@ def __init__(self, shuffle_block_size: Optional[int] = None, sampling_method: str = 'balanced', sampling_granularity: int = 1, - batching_method: str = 'random') -> None: + batching_method: str = 'random', + allow_unsafe_types: bool = False) -> None: # Time how long it takes for StreamingDataset instantiation t0 = time.time() @@ -145,6 +149,7 @@ def __init__(self, self.sampling_granularity = sampling_granularity self.batching_method = batching_method self.num_canonical_nodes = num_canonical_nodes + self.allow_unsafe_types = allow_unsafe_types self.initial_physical_nodes = nodes @@ -260,7 +265,7 @@ def __init__(self, local_foldernames = [] for stream_id, stream in enumerate(self.streams): logger.info(f' Processing index file for stream {stream_id + 1}') - stream_shards = stream.get_shards(self.world) + stream_shards = stream.get_shards(self.world, self.allow_unsafe_types) num_stream_samples = sum(map(len, stream_shards)) index_filename = os.path.join(stream.local, stream.split, get_index_basename()) index_filenames.append(index_filename) diff --git a/streaming/dataset.py b/streaming/dataset.py index a9500163a..e320d7426 100644 --- a/streaming/dataset.py +++ b/streaming/dataset.py @@ -303,6 +303,9 @@ class StreamingDataset(Array, IterableDataset): ``None``. batching_method (str): Which batching method to use, either ``random``, ``stratified``, or ``per_stream``. Defaults to ``random``. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an error + if ``False``. Defaults to ``False``. """ def __init__(self, @@ -327,7 +330,8 @@ def __init__(self, shuffle_algo: str = 'py1e', shuffle_seed: int = 9176, shuffle_block_size: Optional[int] = None, - batching_method: str = 'random') -> None: + batching_method: str = 'random', + allow_unsafe_types: bool = False) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.sampling_method = sampling_method @@ -340,6 +344,7 @@ def __init__(self, self.shuffle_seed = shuffle_seed self.shuffle_block_size = shuffle_block_size self.batching_method = batching_method + self.allow_unsafe_types = allow_unsafe_types # Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the # number of physical nodes of the initial run in the _resume function. @@ -447,7 +452,7 @@ def __init__(self, self.sample_offset_per_stream = np.zeros(self.num_streams, np.int64) self.samples_per_stream = np.zeros(self.num_streams, np.int64) for stream_id, stream in enumerate(self.streams): - stream_shards = stream.get_shards(world) + stream_shards = stream.get_shards(world, self.allow_unsafe_types) num_stream_samples = sum(map(len, stream_shards)) if not num_stream_samples: index_filename = os.path.join(stream.local, stream.split, get_index_basename()) diff --git a/streaming/format/mds/encodings.py b/streaming/format/mds/encodings.py index eede8cd47..11df25f1c 100644 --- a/streaming/format/mds/encodings.py +++ b/streaming/format/mds/encodings.py @@ -17,7 +17,12 @@ from typing_extensions import Self __all__ = [ - 'get_mds_encoded_size', 'get_mds_encodings', 'is_mds_encoding', 'mds_decode', 'mds_encode' + 'get_mds_encoded_size', + 'get_mds_encodings', + 'is_mds_encoding', + 'mds_decode', + 'mds_encode', + 'is_mds_encoding_safe', ] @@ -543,6 +548,8 @@ def _is_valid(self, original: Any, converted: Any) -> None: 'json': JSON, } +_unsafe_encodings = {'pkl'} + def get_mds_encodings() -> Set[str]: """List supported encodings. @@ -586,6 +593,18 @@ def is_mds_encoding(encoding: str) -> bool: return coder is not None +def is_mds_encoding_safe(encoding: str) -> bool: + """Get whether the given encoding is safe (does not allow arbitrary code execution). + + Args: + encoding (str): Encoding. + + Returns: + bool: Whether the encoding is safe. + """ + return encoding not in _unsafe_encodings + + def mds_encode(encoding: str, obj: Any) -> bytes: """Encode the given data from the original object to bytes. diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/reader.py index 245458bf4..7ec93c98a 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/reader.py @@ -10,7 +10,7 @@ import numpy as np from typing_extensions import Self -from streaming.format.mds.encodings import mds_decode +from streaming.format.mds.encodings import is_mds_encoding_safe, mds_decode from streaming.format.reader import FileInfo, JointReader __all__ = ['MDSReader'] @@ -84,6 +84,21 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args[key] = FileInfo(**arg) if arg else None return cls(**args) + def validate(self, allow_unsafe_types: bool) -> None: + """Check whether this shard is acceptable to be part of some Stream. + + Args: + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error if ``False``. + """ + if not allow_unsafe_types: + for column_id, encoding in enumerate(self.column_encodings): + if not is_mds_encoding_safe(encoding): + name = self.column_names[column_id] + raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' + + f'proceed anyway, set ``allow_unsafe_types=True``.') + def decode_sample(self, data: bytes) -> Dict[str, Any]: """Decode a sample dict from bytes. diff --git a/streaming/format/reader.py b/streaming/format/reader.py index cc55f205a..e2e5271fc 100644 --- a/streaming/format/reader.py +++ b/streaming/format/reader.py @@ -61,6 +61,16 @@ def __init__( self.file_pairs = [] + def validate(self, allow_unsafe_types: bool) -> None: + """Check whether this shard is acceptable to be part of some Stream. + + Args: + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error if ``False``. + """ + pass + @property def size(self): """Get the number of samples in this shard. diff --git a/streaming/stream.py b/streaming/stream.py index 7948e9b65..974ceaac7 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -421,11 +421,14 @@ def prepare_shard(self, shard: Reader) -> int: delta += self._prepare_shard_part(raw_info, zip_info, shard.compression) return delta - def get_shards(self, world: World) -> List[Reader]: + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: """Load this Stream's index, retrieving its shard readers. Args: world (World): Distributed context. + allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code + execution during deserialization, whether to keep going if ``True`` or raise an + error. Returns: `List[Reader]: Shard readers. @@ -469,6 +472,7 @@ def get_shards(self, world: World) -> List[Reader]: shards = [] for info in obj['shards']: shard = reader_from_json(self.local, self.split, info) + shard.validate(allow_unsafe_types) shards.append(shard) return shards diff --git a/tests/test_unsafe_types.py b/tests/test_unsafe_types.py new file mode 100644 index 000000000..40b3a5e1a --- /dev/null +++ b/tests/test_unsafe_types.py @@ -0,0 +1,53 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Tuple + +import pytest + +from streaming import MDSWriter, StreamingDataset + + +@pytest.mark.usefixtures('local_remote_dir') +def test_do_allow_unsafe_types_safe(local_remote_dir: Tuple[str, str]): + local, _ = local_remote_dir + columns = {'num': 'int'} + with MDSWriter(out=local, columns=columns) as out: + for num in range(100): + sample = {'num': num} + out.write(sample) + StreamingDataset(local=local, allow_unsafe_types=True) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_do_allow_unsafe_types_unsafe(local_remote_dir: Tuple[str, str]): + local, _ = local_remote_dir + columns = {'num': 'pkl'} + with MDSWriter(out=local, columns=columns) as out: + for num in range(100): + sample = {'num': num} + out.write(sample) + StreamingDataset(local=local, allow_unsafe_types=True) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_dont_allow_unsafe_types_safe(local_remote_dir: Tuple[str, str]): + local, _ = local_remote_dir + columns = {'num': 'int'} + with MDSWriter(out=local, columns=columns) as out: + for num in range(100): + sample = {'num': num} + out.write(sample) + StreamingDataset(local=local) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_dont_allow_unsafe_types_unsafe(local_remote_dir: Tuple[str, str]): + local, _ = local_remote_dir + columns = {'num': 'pkl'} + with MDSWriter(out=local, columns=columns) as out: + for num in range(100): + sample = {'num': num} + out.write(sample) + with pytest.raises(ValueError, match='.*contains an unsafe type.*'): + StreamingDataset(local=local) From 02bd910c052fefa46d0b4b7f879b75bc8723b472 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 02:42:25 -0800 Subject: [PATCH 08/12] New storage APIs (#536) * New storage APIs. * Potentially fix import issue. * Fix (path). * Fix (paths). * Fix (paths). --- streaming/storage/__init__.py | 10 +- streaming/storage/download.py | 28 --- streaming/storage/extra.py | 365 ++++++++++++++++++++++++++++++++++ streaming/storage/upload.py | 17 +- 4 files changed, 382 insertions(+), 38 deletions(-) create mode 100644 streaming/storage/extra.py diff --git a/streaming/storage/__init__.py b/streaming/storage/__init__.py index 5d6d599e0..c5f2b9526 100644 --- a/streaming/storage/__init__.py +++ b/streaming/storage/__init__.py @@ -7,8 +7,9 @@ download_from_azure_datalake, download_from_databricks_unity_catalog, download_from_dbfs, download_from_gcs, download_from_local, download_from_oci, - download_from_s3, download_from_sftp, - wait_for_file_to_exist) + download_from_s3, download_from_sftp) +from streaming.storage.extra import (file_exists, list_dataset_files, smart_download_file, + wait_for_file_to_exist, walk_dir, walk_prefix) from streaming.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, GCSUploader, LocalUploader, OCIUploader, S3Uploader) @@ -31,4 +32,9 @@ 'download_from_dbfs', 'download_from_local', 'wait_for_file_to_exist', + 'walk_prefix', + 'walk_dir', + 'list_dataset_files', + 'smart_download_file', + 'file_exists', ] diff --git a/streaming/storage/download.py b/streaming/storage/download.py index 51a1b4e16..ee5e1bb90 100644 --- a/streaming/storage/download.py +++ b/streaming/storage/download.py @@ -7,7 +7,6 @@ import pathlib import shutil import urllib.parse -from time import sleep, time from typing import Any, Dict, Optional from streaming.util import get_import_exception_message @@ -22,7 +21,6 @@ 'download_from_databricks_unity_catalog', 'download_from_dbfs', 'download_from_local', - 'wait_for_file_to_exist', ] BOTOCORE_CLIENT_ERROR_CODES = {'403', '404', 'NoSuchKey'} @@ -473,29 +471,3 @@ def download_file(remote: Optional[str], local: str, timeout: float): download_from_dbfs(remote, local) else: download_from_local(remote, local) - - -def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, - err_msg: str) -> None: - """Wait for the file to exist till timeout seconds. Raise an Exception after that. - - Args: - filename (str): A file name - poll_interval (float): Number of seconds to wait before next polling - timeout (float): Number of seconds to wait for a file to exist before raising an exception - err_msg (str): Error message description for an exception - - Raises: - RuntimeError: Raise an Exception if file does not exist after timeout - """ - start_time = time() - while True: - sleep(poll_interval) - if os.path.exists(filename): - sleep(poll_interval) - break - dt = time() - start_time - if dt > timeout: - raise RuntimeError( - f'{err_msg} due to timeout. Waited {dt:.3f} sec, which is longer than the ' + - f'timeout limit of {timeout:.3f} sec.') diff --git a/streaming/storage/extra.py b/streaming/storage/extra.py new file mode 100644 index 000000000..3dc4210d4 --- /dev/null +++ b/streaming/storage/extra.py @@ -0,0 +1,365 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Some extras which wrap and/or complement the Streaming storage API.""" + +import os +import re +from re import Pattern +from time import sleep, time +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from urllib.parse import urlparse + +from streaming.hashing import get_hash +from streaming.storage.download import download_file +from streaming.storage.upload import CloudUploader +from streaming.util.shorthand import normalize_bytes, normalize_duration + +__all__ = [ + 'wait_for_file_to_exist', 'walk_prefix', 'walk_dir', 'list_dataset_files', + 'smart_download_file', 'file_exists' +] + + +def wait_for_file_to_exist(filename: str, poll_interval: float, timeout: float, + err_msg: str) -> None: + """Wait for the file to exist till timeout seconds. Raise an Exception after that. + + File must be local. + + Args: + filename (str): A file name + poll_interval (float): Number of seconds to wait before next polling + timeout (float): Number of seconds to wait for a file to exist before raising an exception + err_msg (str): Error message description for an exception + + Raises: + RuntimeError: Raise an Exception if file does not exist after timeout + """ + start_time = time() + while True: + sleep(poll_interval) + if os.path.exists(filename): + sleep(poll_interval) + break + dt = time() - start_time + if dt > timeout: + raise RuntimeError(f'{err_msg}' + f'{timeout:.3f} < {dt:.3f} secs.') + + +def _normalize_path(path: str) -> Tuple[str, bool]: + """Analyze the path, returning URI scheme-normalized form and whether is on the local fs. + + Args: + path (str): Path to analyze. + + Returns: + Tuple[str, bool]: Normalized path, and whether it is local. + """ + obj = urlparse(path) + if obj.scheme == '': + is_local = True + elif obj.scheme == 'file': + is_local = True + path = obj.path + else: + is_local = False + return path, is_local + + +def _normalize_dir(dirname: str) -> str: + """Normalize a dirname to contain one trailing slash. + + Args: + dirname (str): Directory path. + + Returns: + str: Normalized directory path. + """ + return dirname.rstrip(os.path.sep) + os.path.sep + + +def walk_prefix(prefix: str) -> List[str]: + """Recursively list all file paths matching a prefix in sorted order. + + Notes: + * If you choose a non-directory as a prefix, returned paths will indeed be relative to your + non-directory, which may seem funky. + * There is some special case handling so that if your path is a local directory with or + without a trailing slash, returned paths will nevertheless never start with a slash, lest + they assume "absolute" power. + + Args: + prefix (str): Path prefix. + + Returns: + List[str]: All file paths under the prefix, which are all relative to the given prefix. + """ + prefix, is_local = _normalize_path(prefix) + + if is_local: + # Prefix points to local filesystem. + prefix_rel_files = [] + if os.path.isdir(prefix): + # Prefix is a directory, so include everything under the directory. + root = _normalize_dir(prefix) + for abs_dir, _, file_bases in os.walk(root): + root_rel_dir = abs_dir.lstrip(root) + for base in file_bases: + root_rel_file = os.path.join(root_rel_dir, base) + prefix_rel_files.append(root_rel_file) + else: + # Prefix has other stuff tacked onto it after the directory, so include everything + # under the prefix's parent directory which also matches the prefix's basename. + root = os.path.dirname(prefix) + for abs_dir, _, file_bases in os.walk(root): + for base in file_bases: + abs_file = os.path.join(abs_dir, base) + if abs_file.startswith(prefix): + prefix_rel_file = abs_file.lstrip(prefix) + prefix_rel_files.append(prefix_rel_file) + else: + # Prefix points to some non-local storage. + neither = CloudUploader.get(prefix, exist_ok=True) + prefix_rel_files = neither.list_objects(prefix) + + # TODO: verify all implementations do a global sort on returned paths, then remove this line. + return sorted(prefix_rel_files) + + +def walk_dir(root: str) -> List[str]: + """Recursively list the given directory in sorted order. + + Notes: + * Supported across various storage backends, including local filesystem. + * Root must be a directory, not a generic path prefix, to make the local case nicer. + * There seems to be inconsistency in list_objects() about what the returned paths are + relative to: cwd, the given root, some local... let's just wrap it for our purposes. + + Args: + root (str): Root directory to walk. + + Returns: + List[str]: File paths, which are relative to the given root. + """ + obj = urlparse(root) + if obj.scheme == '': + is_local = True + elif obj.scheme == 'file': + is_local = True + root = obj.path + else: + is_local = False + + if is_local: + if not os.path.isdir(root): + raise ValueError(f'Path is not a directory: {root}.') + paths = [] + for sub_root, _, file_basenames in os.walk(root): + sub_path = sub_root.lstrip(root) + paths += [os.path.join(sub_path, name) for name in file_basenames] + else: + neither = CloudUploader.get(root, exist_ok=True) + paths = neither.list_objects(root) + + return sorted(paths) + + +def _filter(keep: Optional[Union[str, Pattern, Callable[[str], bool]]], + paths: Optional[Iterable[str]]) -> Iterable[str]: + """Filter the given paths according to the pattern or predicate. + + Args: + keep (Union[str, Pattern, Callable[[str], bool]], optional): A regex or Callable which is + applied to each path, keeping or dropping it. If not provided, do no filtering. + paths (Iterable[str], optional): Iterable of paths to filter. If empty, is the empty list. + """ + paths = paths or [] + if not keep: + pass + elif isinstance(keep, str): + keep_regex = re.compile(keep) + paths = filter(keep_regex.match, paths) + elif isinstance(keep, Pattern): + paths = filter(keep.match, paths) + elif isinstance(keep, Callable): + paths = filter(keep, paths) + else: + raise ValueError(f'Unsupported type of keep: {keep}.') + yield from paths + + +def _get_overlap(want: Set[str], have: Set[str]) -> Dict[str, Any]: + """Get the overlap between two sets for informational/debugging purposes. + + Args: + want (Set[str]): What we want. + have (Set[str]): What we have. + + Returns: + Dict[str, Any]: Information about overlaps. + """ + return { + 'present': len(want & have), + 'missing': len(want.difference(have)), + 'ignored': len(have.difference(want)), + } + + +def list_dataset_files( + *, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + paths: Optional[Iterable[str]] = None, + keep: Optional[Union[str, Pattern, Callable[[str], bool]]] = None) -> List[str]: + """Collect all/certain local/remote dataset files, which are then filtered. + + Args: + local (str): Local dataset root. + remote (str, optional): Remote dataset root, if we have a remote. + split (str, optional): Split subdir, if used. + paths (Iterable[str], optional): Iterable of paths relative to dataset root (i.e., + local/remote + split). These are then filtered by the keep predicate, if any. If not + provided, defaults to a sorted, recursive listing of all dataset files. Such a listing + treats remote as authoritative if provided, else uses local. Defaults to ``None``. + keep (Union[str, Pattern, Callable[[str], bool]], optional): A regex or Callable which is + applied to each path in order to keep or drop it from the listing. If not provided, no + filtering is performed to paths. Defaults to ``None``. + + Returns: + List[str]: List of paths, relative to dataset root, ordered by ``paths``. + """ + # Tack on the split dir, if any. + if split: + local = os.path.join(local, split) + if remote: + remote = os.path.join(remote, split) + + # If no paths Iterable was not provided, list all the files, filter, and we're done. + if paths is None: + root = remote if remote else local + paths = walk_dir(root) + return list(_filter(keep, paths)) + + # If we were indeed provided explicit paths, cross-check those against a listing of local + # before we start assuming everything is fine. + want_paths = list(_filter(keep, paths)) + want_paths_set = set(want_paths) + have_local_paths_set = set(walk_dir(local)) + if want_paths_set.issubset(have_local_paths_set): # All exist in local? + return want_paths + + # If local is incomplete, and there is no remote, give up. + if not remote: + obj = _get_overlap(want_paths_set, have_local_paths_set) + raise ValueError(f'Local does not contain all listed shards, and no remote was ' + + f'provided. Overlap of listed vs local: {obj["present"]} present, ' + + f'{obj["missing"]} missing, {obj["ignored"]} ignored.') + + # Explicit paths, incomplete local, but we do have a remote to fall back to. Let's cross-check + # against that. + have_remote_paths_set = set(walk_dir(remote)) + if want_paths_set.issubset(have_remote_paths_set): + return want_paths + + # Both local and remote do not contain all the needed files, so give up. + l_obj = _get_overlap(want_paths_set, have_local_paths_set) + r_obj = _get_overlap(want_paths_set, have_remote_paths_set) + raise ValueError(f'Neither local nor remote contains all shards listed. Overlap of listed ' + + f'vs local: {l_obj["present"]} present, {l_obj["missing"]} missing, ' + + f'{l_obj["ignored"]} ignored. Overlap of listed vs remote: ' + + f'{r_obj["present"]} present, {r_obj["missing"]} missing, ' + + f'{r_obj["ignored"]} ignored.') + + +def smart_download_file(*, + remote: str, + local: str, + timeout: Union[float, str] = 60, + size: Optional[Union[int, str]] = None, + max_size: Optional[Union[int, str]] = None, + hashes: Optional[Dict[str, str]] = None) -> None: + """Download a file from the remote path to the local path, with size/hash checks. + + Args: + remote (str): Remote path. + local (str): Local path. + timeout (Union[float, str]): Maximum time to download, in seconds. Defaults to ``60``. + size (Union[int, str], optional): Expected file size. This check is a weak but fast/cheap + way to detect overwrites, truncation, tampering, and corruption. Defaults to ``None``. + max_size (Union[int, str], optional): Maximum file size. This check is a fast/cheap way to + prevent the user from inadvertently using shards that are far too large for Streaming + purposes, which is non-obvious and would result in a terrible user experience. Defaults + to ``None``. + hashes (Dict[str, str], optional): Hashes to check, as a dict of hash algo name to expected + hex digest. These checks are a very strong but slow/expensive way to detect changes to + data. See our benchmarks for more details. Defaults to ``None``. + """ + # Download. + want_timeout = normalize_duration(timeout) + download_file(remote, local, want_timeout) + + # Size checks. + if size is not None or max_size is not None: + have_size = os.stat(local).st_size + + # Exact size check. + if size is not None: + want_size = normalize_bytes(size) + if want_size != have_size: + raise ValueError( + f'The file as downloaded does not match the expected size: remote path = ' + + f'{remote}, local path = {local}, expected size = {want_size}, got size = ' + + f'{have_size}.') + + # Size limit check. + if max_size is not None: + want_max_size = normalize_bytes(max_size) + if want_max_size < have_size: + raise ValueError( + f'The file is too large for efficient use by Streaming, please reduce shard ' + + f'size: remote path = {remote}, local path = {local}, maximum size = ' + + f'{want_max_size}, got size = {have_size}.') + + # Hash checks. + if hashes: + data = open(local, 'rb').read() + for hash_algo in sorted(hashes): + want_hex_digest = hashes[hash_algo] + have_hex_digest = get_hash(hash_algo, data) + if want_hex_digest != have_hex_digest: + raise ValueError( + f'The file as downloaded does not match the expected hash: remote path = ' + + f'{remote}, local path = {local}, hash algo = {hash_algo}, expected hex ' + + f'digest = {want_hex_digest}, got digest = {have_hex_digest}.') + + +def file_exists(*, + path: str, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None) -> bool: + """Determine whether the file path exists across local and/or remote. + + Args: + path (str): File path relative to local and/or remote. + local (str): Local root. + remote (str, optional): Remote root. + split (str, optional): Dataset split, if applicable. + + Returns: + bool: Whether file exists locally and/or remotely. + """ + local_filename = os.path.join(local, split or '', path) + filenames = walk_prefix(local_filename) + if filenames and filenames[0] == local_filename: + return True + + if remote: + remote_path = os.path.join(remote, split or '', path) + paths = walk_prefix(remote_path) + if paths and paths[0] == remote_path: + return True + + return False diff --git a/streaming/storage/upload.py b/streaming/storage/upload.py index 9723315ef..c2a2f40bd 100644 --- a/streaming/storage/upload.py +++ b/streaming/storage/upload.py @@ -74,14 +74,15 @@ def get(cls, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` + already exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. """ cls._validate(cls, out) - obj = urllib.parse.urlparse(out) if isinstance(out, str) else urllib.parse.urlparse(out[1]) + obj = urllib.parse.urlparse(out) if isinstance(out, str) else \ + urllib.parse.urlparse(out[1]) provider_prefix = obj.scheme if obj.scheme == 'dbfs': path = pathlib.Path(out) if isinstance(out, str) else pathlib.Path(out[1]) @@ -141,8 +142,8 @@ def __init__(self, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` + already exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -170,8 +171,8 @@ def __init__(self, raise FileExistsError(f'Directory is not empty: {self.local}') else: logger.warning( - f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.' - ) + f'Directory {self.local} exists and not empty. But continue to mkdir since ' + + f'exist_ok is set to be True.') os.makedirs(self.local, exist_ok=True) @@ -773,7 +774,7 @@ def check_container_exists(self, remote: str): error: Container does not exist. """ container_name = urllib.parse.urlparse(remote).netloc - if self.azure_service.get_file_system_client(file_system=container_name).exists() is False: + if not self.azure_service.get_file_system_client(file_system=container_name).exists(): raise FileNotFoundError( f'Either container `{container_name}` does not exist! ' + f'or check the container permission.',) From 3972c9d4f98b677d27465f9fac629c97efefc2d7 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 07:47:16 -0800 Subject: [PATCH 09/12] Improve naming: JSON shards are actually JSONL, etc. (#537) * Stdize docstrings, also fix ordering of get_sample_data, decode_sample. * Terminology: "joint" -> "mono". * "split" -> "dual" to stop confusing people (SplitWriter != dataaset splits) * "Reader" -> "Shard". They manage shards. They do more than read. * Fix filenames accordingly. * Finally, JSON -> JSONL. * Switch order of decorators... * Fix markdown code. --- STYLE_GUIDE.md | 4 +- benchmarks/backends/write.py | 4 +- benchmarks/samples/bench_and_plot.py | 4 +- streaming/__init__.py | 6 +- streaming/format/__init__.py | 47 ++++-- streaming/format/json/__init__.py | 9 -- streaming/format/{json => jsonl}/README.md | 10 +- streaming/format/jsonl/__init__.py | 9 ++ streaming/format/{json => jsonl}/encodings.py | 10 +- .../format/{json/reader.py => jsonl/shard.py} | 42 ++--- streaming/format/{json => jsonl}/writer.py | 22 +-- streaming/format/mds/__init__.py | 6 +- streaming/format/mds/{reader.py => shard.py} | 56 +++---- streaming/format/mds/writer.py | 10 +- streaming/format/{reader.py => shard.py} | 45 ++++-- streaming/format/writer.py | 24 +-- streaming/format/xsv/__init__.py | 6 +- streaming/format/xsv/{reader.py => shard.py} | 52 +++---- streaming/format/xsv/writer.py | 10 +- streaming/local.py | 4 +- streaming/stream.py | 21 ++- tests/test_encodings.py | 146 +++++++++--------- tests/test_writer.py | 26 ++-- 23 files changed, 301 insertions(+), 272 deletions(-) delete mode 100644 streaming/format/json/__init__.py rename streaming/format/{json => jsonl}/README.md (83%) create mode 100644 streaming/format/jsonl/__init__.py rename streaming/format/{json => jsonl}/encodings.py (86%) rename streaming/format/{json/reader.py => jsonl/shard.py} (88%) rename streaming/format/{json => jsonl}/writer.py (88%) rename streaming/format/mds/{reader.py => shard.py} (95%) rename streaming/format/{reader.py => shard.py} (94%) rename streaming/format/xsv/{reader.py => shard.py} (96%) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 67156e2a0..4b888accd 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -207,10 +207,10 @@ For example, from [streaming/dataset.py](streaming/dataset.py) """The :class:`Dataset` class, used for building streaming iterable datasets.""" from torch.utils.data import IterableDataset -from streaming.format import reader_from_json +from streaming.format import shard_from_json from streaming.spanner import Spanner -__all__ = ["Dataset"] # export only the Dataset, not other imports like `Spanner` or `reader_from_json` +__all__ = ["Dataset"] # Export `Dataset` only, not the others e.g. `Spanner` or `shard_from_json`. class Dataset(IterableDataset): diff --git a/benchmarks/backends/write.py b/benchmarks/backends/write.py index 9404e5d0a..78c85bbfb 100644 --- a/benchmarks/backends/write.py +++ b/benchmarks/backends/write.py @@ -22,7 +22,7 @@ from wurlitzer import pipes from benchmarks.backends.datagen import generate -from streaming import CSVWriter, JSONWriter, MDSWriter +from streaming import CSVWriter, JSONLWriter, MDSWriter from streaming.util.tabulation import Tabulator @@ -108,7 +108,7 @@ def _write_jsonl(nums: List[int], 'num': 'int', 'txt': 'str', } - with JSONWriter(out=root, columns=columns, size_limit=size_limit) as out: + with JSONLWriter(out=root, columns=columns, size_limit=size_limit) as out: each_sample = zip(nums, txts) if show_progress: each_sample = tqdm(each_sample, total=len(nums), leave=False) diff --git a/benchmarks/samples/bench_and_plot.py b/benchmarks/samples/bench_and_plot.py index 31307ff32..306049875 100644 --- a/benchmarks/samples/bench_and_plot.py +++ b/benchmarks/samples/bench_and_plot.py @@ -17,7 +17,7 @@ from numpy.typing import DTypeLike, NDArray from tqdm import trange -from streaming import CSVWriter, JSONWriter, MDSWriter, StreamingDataset +from streaming import CSVWriter, JSONLWriter, MDSWriter, StreamingDataset def parse_args() -> Namespace: @@ -244,7 +244,7 @@ def bench(args: Namespace, bench_name: str, desc: str, generate: Callable, format_infos = [ ('mds', MDSWriter, args.mds_color), - ('jsonl', JSONWriter, args.jsonl_color), + ('jsonl', JSONLWriter, args.jsonl_color), ('csv', CSVWriter, args.csv_color), ] format_infos = list(filter(lambda info: info[0] in formats, format_infos)) diff --git a/streaming/__init__.py b/streaming/__init__.py index 45ca3f1cf..c8efa5a36 100644 --- a/streaming/__init__.py +++ b/streaming/__init__.py @@ -6,12 +6,12 @@ from streaming._version import __version__ from streaming.dataloader import StreamingDataLoader from streaming.dataset import StreamingDataset -from streaming.format import CSVWriter, JSONWriter, MDSWriter, TSVWriter, XSVWriter +from streaming.format import CSVWriter, JSONLWriter, MDSWriter, TSVWriter, XSVWriter from streaming.local import LocalDataset from streaming.stream import Stream from streaming.util import clean_stale_shared_memory __all__ = [ - 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONWriter', 'LocalDataset', - 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory' + 'StreamingDataLoader', 'Stream', 'StreamingDataset', 'CSVWriter', 'JSONLWriter', + 'LocalDataset', 'MDSWriter', 'TSVWriter', 'XSVWriter', 'clean_stale_shared_memory' ] diff --git a/streaming/format/__init__.py b/streaming/format/__init__.py index bbec4927e..dec5ac15c 100644 --- a/streaming/format/__init__.py +++ b/streaming/format/__init__.py @@ -1,32 +1,45 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Individual dataset writer for every format.""" +"""Streaming serialization format, consisting of an index and multiple types of shards.""" from typing import Any, Dict, Optional from streaming.format.index import get_index_basename -from streaming.format.json import JSONReader, JSONWriter -from streaming.format.mds import MDSReader, MDSWriter -from streaming.format.reader import FileInfo, Reader -from streaming.format.xsv import CSVReader, CSVWriter, TSVReader, TSVWriter, XSVReader, XSVWriter +from streaming.format.jsonl import JSONLShard, JSONLWriter +from streaming.format.mds import MDSShard, MDSWriter +from streaming.format.shard import FileInfo, Shard +from streaming.format.xsv import CSVShard, CSVWriter, TSVShard, TSVWriter, XSVShard, XSVWriter __all__ = [ - 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONWriter', 'MDSWriter', 'Reader', - 'reader_from_json', 'TSVWriter', 'XSVWriter' + 'CSVWriter', 'FileInfo', 'get_index_basename', 'JSONLWriter', 'MDSWriter', 'Shard', + 'shard_from_json', 'TSVWriter', 'XSVWriter' ] -_readers = { - 'csv': CSVReader, - 'json': JSONReader, - 'mds': MDSReader, - 'tsv': TSVReader, - 'xsv': XSVReader +# Mapping of shard metadata dict "format" field to what type of Shard it is. +_shards = { + 'csv': CSVShard, + 'jsonl': JSONLShard, + 'mds': MDSShard, + 'tsv': TSVShard, + 'xsv': XSVShard, } -def reader_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Reader: - """Initialize the reader from JSON object. +def _get_shard_class(format_name: str) -> Shard: + """Get the associated Shard class given a Shard format name. + + Args: + format_name (str): Shard format name. + """ + # JSONL shards were originally called JSON shards (while containing JSONL). + if format_name == 'json': + format_name = 'jsonl' + return _shards[format_name] + + +def shard_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Shard: + """Create a shard from a JSON config. Args: dirname (str): Local directory containing shards. @@ -34,8 +47,8 @@ def reader_from_json(dirname: str, split: Optional[str], obj: Dict[str, Any]) -> obj (Dict[str, Any]): JSON object to load. Returns: - Reader: Loaded Reader of `format` type + Shard: The loaded Shard. """ assert obj['version'] == 2 - cls = _readers[obj['format']] + cls = _get_shard_class(obj['format']) return cls.from_json(dirname, split, obj) diff --git a/streaming/format/json/__init__.py b/streaming/format/json/__init__.py deleted file mode 100644 index 47e8be8f6..000000000 --- a/streaming/format/json/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -"""Module to write and read the dataset in JSON format.""" - -from streaming.format.json.reader import JSONReader -from streaming.format.json.writer import JSONWriter - -__all__ = ['JSONReader', 'JSONWriter'] diff --git a/streaming/format/json/README.md b/streaming/format/jsonl/README.md similarity index 83% rename from streaming/format/json/README.md rename to streaming/format/jsonl/README.md index 13cd1fd99..59def4e38 100644 --- a/streaming/format/json/README.md +++ b/streaming/format/jsonl/README.md @@ -7,14 +7,14 @@ Example: "words": "str" }, "compression": "zstd:7", - "format": "json", + "format": "jsonl", "hashes": [ "sha1", "xxh3_64" ], "newline": "\n", "raw_data": { - "basename": "shard.00000.json", + "basename": "shard.00000.jsonl", "bytes": 1048546, "hashes": { "sha1": "bfb6509ba6f041726943ce529b36a1cb74e33957", @@ -22,7 +22,7 @@ Example: } }, "raw_meta": { - "basename": "shard.00000.json.meta", + "basename": "shard.00000.jsonl.meta", "bytes": 53590, "hashes": { "sha1": "15ae80e002fe625b0b18f1a45058532ee867fa9b", @@ -33,7 +33,7 @@ Example: "size_limit": 1048576, "version": 2, "zip_data": { - "basename": "shard.00000.json.zstd", + "basename": "shard.00000.jsonl.zstd", "bytes": 149268, "hashes": { "sha1": "7d45c600a71066ca8d43dbbaa2ffce50a91b735e", @@ -41,7 +41,7 @@ Example: } }, "zip_meta": { - "basename": "shard.00000.json.meta.zstd", + "basename": "shard.00000.jsonl.meta.zstd", "bytes": 42180, "hashes": { "sha1": "f64477cca5d27fc3a0301eeb4452ef7310cbf670", diff --git a/streaming/format/jsonl/__init__.py b/streaming/format/jsonl/__init__.py new file mode 100644 index 000000000..53d630a3e --- /dev/null +++ b/streaming/format/jsonl/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming JSONL shards.""" + +from streaming.format.jsonl.shard import JSONLShard +from streaming.format.jsonl.writer import JSONLWriter + +__all__ = ['JSONLShard', 'JSONLWriter'] diff --git a/streaming/format/json/encodings.py b/streaming/format/jsonl/encodings.py similarity index 86% rename from streaming/format/json/encodings.py rename to streaming/format/jsonl/encodings.py index 215b8ee36..2f3048e8f 100644 --- a/streaming/format/json/encodings.py +++ b/streaming/format/jsonl/encodings.py @@ -1,16 +1,16 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Check whether sample encoding is of supported JSON types.""" +"""Check whether sample encoding is of supported JSONL types.""" from abc import ABC, abstractmethod from typing import Any -__all__ = ['is_json_encoded', 'is_json_encoding'] +__all__ = ['is_jsonl_encoded', 'is_jsonl_encoding'] class Encoding(ABC): - """Encoding of an object of JSON type.""" + """Encoding of an object of JSONL type.""" @classmethod @abstractmethod @@ -60,7 +60,7 @@ def is_encoded(cls, obj: Any) -> bool: _encodings = {'str': Str, 'int': Int, 'float': Float} -def is_json_encoded(encoding: str, value: Any) -> bool: +def is_jsonl_encoded(encoding: str, value: Any) -> bool: """Get whether the given object is of this encoding type. Args: @@ -74,7 +74,7 @@ def is_json_encoded(encoding: str, value: Any) -> bool: return cls.is_encoded(value) -def is_json_encoding(encoding: str) -> bool: +def is_jsonl_encoding(encoding: str) -> bool: """Get whether the given encoding is supported. Args: diff --git a/streaming/format/json/reader.py b/streaming/format/jsonl/shard.py similarity index 88% rename from streaming/format/json/reader.py rename to streaming/format/jsonl/shard.py index 698783d71..985b75684 100644 --- a/streaming/format/json/reader.py +++ b/streaming/format/jsonl/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`JSONReader` reads samples from `.json` files that were written by :class:`MDSWriter`.""" +"""Streaming JSONL shard reading.""" import json import os @@ -11,13 +11,13 @@ import numpy as np from typing_extensions import Self -from streaming.format.reader import FileInfo, SplitReader +from streaming.format.shard import DualShard, FileInfo -__all__ = ['JSONReader'] +__all__ = ['JSONLShard'] -class JSONReader(SplitReader): - """Provides random access to the samples of a JSON shard. +class JSONLShard(DualShard): + """Provides random access to the samples of a JSONL shard. Args: dirname (str): Local dataset directory. @@ -68,7 +68,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded JSONReader. + Self: Loaded JSONLShard. """ args = deepcopy(obj) # Version check. @@ -77,9 +77,9 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S f'Expected version 2.') del args['version'] # Check format. - if args['format'] != 'json': - raise ValueError(f'Unsupported data format: {args["format"]}. ' + - f'Expected to be `json`.') + if args['format'] not in {'json', 'jsonl'}: + raise ValueError(f'Unsupported data format: got {args["format"]}, but expected ' + + f'"jsonl" (or "json").') del args['format'] args['dirname'] = dirname args['split'] = split @@ -88,18 +88,6 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args[key] = FileInfo(**arg) if arg else None return cls(**args) - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. - - Args: - data (bytes): The sample encoded as bytes. - - Returns: - Dict[str, Any]: Sample dict. - """ - text = data.decode('utf-8') - return json.loads(text) - def get_sample_data(self, idx: int) -> bytes: """Get the raw sample data at the index. @@ -120,3 +108,15 @@ def get_sample_data(self, idx: int) -> bytes: fp.seek(begin) data = fp.read(end - begin) return data + + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. + + Args: + data (bytes): The sample encoded as bytes. + + Returns: + Dict[str, Any]: Sample dict. + """ + text = data.decode('utf-8') + return json.loads(text) diff --git a/streaming/format/json/writer.py b/streaming/format/jsonl/writer.py similarity index 88% rename from streaming/format/json/writer.py rename to streaming/format/jsonl/writer.py index b0117a47f..99a01d14f 100644 --- a/streaming/format/json/writer.py +++ b/streaming/format/jsonl/writer.py @@ -1,21 +1,21 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`JSONWriter` writes samples to `.json` files that can be read by :class:`JSONReader`.""" +"""Streaming JSONL shard writing.""" import json from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from streaming.format.json.encodings import is_json_encoded, is_json_encoding -from streaming.format.writer import SplitWriter +from streaming.format.jsonl.encodings import is_jsonl_encoded, is_jsonl_encoding +from streaming.format.writer import DualWriter -__all__ = ['JSONWriter'] +__all__ = ['JSONLWriter'] -class JSONWriter(SplitWriter): - r"""Writes a streaming JSON dataset. +class JSONLWriter(DualWriter): + r"""Writes a streaming JSONL dataset. Args: columns (Dict[str, str]): Sample columns. @@ -47,7 +47,7 @@ class JSONWriter(SplitWriter): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. """ - format = 'json' + format = 'jsonl' def __init__(self, *, @@ -66,7 +66,7 @@ def __init__(self, size_limit=size_limit, **kwargs) for encoding in columns.values(): - assert is_json_encoding(encoding) + assert is_jsonl_encoding(encoding) self.columns = columns self.newline = newline @@ -83,7 +83,7 @@ def encode_sample(self, sample: Dict[str, Any]) -> bytes: obj = {} for key, encoding in self.columns.items(): value = sample[key] - assert is_json_encoded(encoding, value) + assert is_jsonl_encoded(encoding, value) obj[key] = value text = json.dumps(obj, sort_keys=True) + self.newline return text.encode('utf-8') @@ -98,8 +98,8 @@ def get_config(self) -> Dict[str, Any]: obj.update({'columns': self.columns, 'newline': self.newline}) return obj - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. diff --git a/streaming/format/mds/__init__.py b/streaming/format/mds/__init__.py index 67a5be56f..5136f7efd 100644 --- a/streaming/format/mds/__init__.py +++ b/streaming/format/mds/__init__.py @@ -1,9 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Module to write and read the dataset in MDS format.""" +"""MDS shards.""" -from streaming.format.mds.reader import MDSReader +from streaming.format.mds.shard import MDSShard from streaming.format.mds.writer import MDSWriter -__all__ = ['MDSReader', 'MDSWriter'] +__all__ = ['MDSShard', 'MDSWriter'] diff --git a/streaming/format/mds/reader.py b/streaming/format/mds/shard.py similarity index 95% rename from streaming/format/mds/reader.py rename to streaming/format/mds/shard.py index 7ec93c98a..956bb069b 100644 --- a/streaming/format/mds/reader.py +++ b/streaming/format/mds/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`MDSReader` reads samples in `.mds` files written by :class:`StreamingDatasetWriter`.""" +"""MDS shard reading.""" import os from copy import deepcopy @@ -11,12 +11,12 @@ from typing_extensions import Self from streaming.format.mds.encodings import is_mds_encoding_safe, mds_decode -from streaming.format.reader import FileInfo, JointReader +from streaming.format.shard import FileInfo, MonoShard -__all__ = ['MDSReader'] +__all__ = ['MDSShard'] -class MDSReader(JointReader): +class MDSShard(MonoShard): """Provides random access to the samples of an MDS shard. Args: @@ -66,7 +66,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded MDSReader. + Self: Loaded MDSShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -99,6 +99,29 @@ def validate(self, allow_unsafe_types: bool) -> None: raise ValueError(f'Column {name} contains an unsafe type: {encoding}. To ' + f'proceed anyway, set ``allow_unsafe_types=True``.') + def get_sample_data(self, idx: int) -> bytes: + """Get the raw sample data at the index. + + Args: + idx (int): Sample index. + + Returns: + bytes: Sample data. + """ + filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + offset = (1 + idx) * 4 + with open(filename, 'rb', 0) as fp: + fp.seek(offset) + pair = fp.read(8) + begin, end = np.frombuffer(pair, np.uint32) + fp.seek(begin) + data = fp.read(end - begin) + if not data: + raise IndexError( + f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.' + ) + return data + def decode_sample(self, data: bytes) -> Dict[str, Any]: """Decode a sample dict from bytes. @@ -123,26 +146,3 @@ def decode_sample(self, data: bytes) -> Dict[str, Any]: sample[key] = mds_decode(encoding, value) idx += size return sample - - def get_sample_data(self, idx: int) -> bytes: - """Get the raw sample data at the index. - - Args: - idx (int): Sample index. - - Returns: - bytes: Sample data. - """ - filename = os.path.join(self.dirname, self.split, self.raw_data.basename) - offset = (1 + idx) * 4 - with open(filename, 'rb', 0) as fp: - fp.seek(offset) - pair = fp.read(8) - begin, end = np.frombuffer(pair, np.uint32) - fp.seek(begin) - data = fp.read(end - begin) - if not data: - raise IndexError( - f'Relative sample index {idx} is not present in the {self.raw_data.basename} file.' - ) - return data diff --git a/streaming/format/mds/writer.py b/streaming/format/mds/writer.py index 950c60f20..e7fc9ef4c 100644 --- a/streaming/format/mds/writer.py +++ b/streaming/format/mds/writer.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`.""" +"""MDS shard writing.""" import json from typing import Any, Dict, List, Optional, Tuple, Union @@ -10,12 +10,12 @@ from streaming.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings, is_mds_encoding, mds_encode) -from streaming.format.writer import JointWriter +from streaming.format.writer import MonoWriter __all__ = ['MDSWriter'] -class MDSWriter(JointWriter): +class MDSWriter(MonoWriter): """Writes a streaming MDS dataset. Args: @@ -127,8 +127,8 @@ def get_config(self) -> Dict[str, Any]: }) return obj - def encode_joint_shard(self) -> bytes: - """Encode a joint shard out of the cached samples (single file). + def encode_mono_shard(self) -> bytes: + """Encode a mono shard out of the cached samples (single file). Returns: bytes: File data. diff --git a/streaming/format/reader.py b/streaming/format/shard.py similarity index 94% rename from streaming/format/reader.py rename to streaming/format/shard.py index e2e5271fc..818fc036f 100644 --- a/streaming/format/reader.py +++ b/streaming/format/shard.py @@ -8,10 +8,12 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing_extensions import Self + from streaming.array import Array from streaming.util.shorthand import normalize_bytes -__all__ = ['FileInfo', 'Reader', 'JointReader', 'SplitReader'] +__all__ = ['FileInfo', 'Shard', 'MonoShard', 'DualShard'] @dataclass @@ -28,7 +30,7 @@ class FileInfo(object): hashes: Dict[str, str] -class Reader(Array, ABC): +class Shard(Array, ABC): """Provides random access to the samples of a shard. Args: @@ -61,6 +63,21 @@ def __init__( self.file_pairs = [] + @classmethod + @abstractmethod + def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Self: + """Initialize from JSON object. + + Args: + dirname (str): Local directory containing shards. + split (str, optional): Which dataset split to use, if any. + obj (Dict[str, Any]): JSON object to load. + + Returns: + Self: Loaded Shard. + """ + raise NotImplementedError + def validate(self, allow_unsafe_types: bool) -> None: """Check whether this shard is acceptable to be part of some Stream. @@ -276,26 +293,26 @@ def get_persistent_size(self, keep_zip: bool) -> int: return size @abstractmethod - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. + def get_sample_data(self, idx: int) -> bytes: + """Get the raw sample data at the index. Args: - data (bytes): The sample encoded as bytes. + idx (int): Sample index. Returns: - Dict[str, Any]: Sample dict. + bytes: Sample data. """ raise NotImplementedError @abstractmethod - def get_sample_data(self, idx: int) -> bytes: - """Get the raw sample data at the index. + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. Args: - idx (int): Sample index. + data (bytes): The sample encoded as bytes. Returns: - bytes: Sample data. + Dict[str, Any]: Sample dict. """ raise NotImplementedError @@ -321,8 +338,8 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: yield self[i] -class JointReader(Reader): - """Provides random access to the samples of a joint shard. +class MonoShard(Shard): + """Provides random access to the samples of a mono shard. Args: dirname (str): Local dataset directory. @@ -353,8 +370,8 @@ def __init__( self.file_pairs.append((raw_data, zip_data)) -class SplitReader(Reader): - """Provides random access to the samples of a split shard. +class DualShard(Shard): + """Provides random access to the samples of a dual shard. Args: dirname (str): Local dataset directory. diff --git a/streaming/format/writer.py b/streaming/format/writer.py index 4b98b93d4..7cc606034 100644 --- a/streaming/format/writer.py +++ b/streaming/format/writer.py @@ -24,7 +24,7 @@ from streaming.storage.upload import CloudUploader from streaming.util.shorthand import normalize_bytes -__all__ = ['JointWriter', 'SplitWriter'] +__all__ = ['MonoWriter', 'DualWriter'] logger = logging.getLogger(__name__) @@ -340,8 +340,8 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseEx self.finish() -class JointWriter(Writer): - """Writes a streaming dataset with joint shards. +class MonoWriter(Writer): + """Writes a streaming dataset with mono shards. Args: out (str | Tuple[str, str]): Output dataset directory to save shard files. @@ -395,8 +395,8 @@ def __init__(self, **kwargs) @abstractmethod - def encode_joint_shard(self) -> bytes: - """Encode a joint shard out of the cached samples (single file). + def encode_mono_shard(self) -> bytes: + """Encode a mono shard out of the cached samples (single file). Returns: bytes: File data. @@ -411,7 +411,7 @@ def flush_shard(self) -> None: return raw_data_basename, zip_data_basename = self._name_next_shard() - raw_data = self.encode_joint_shard() + raw_data = self.encode_mono_shard() raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename, zip_data_basename) obj = { @@ -428,10 +428,10 @@ def flush_shard(self) -> None: future.add_done_callback(self.exception_callback) -class SplitWriter(Writer): - """Writes a streaming dataset with split shards. +class DualWriter(Writer): + """Writes a streaming dataset with dual shards. - Split shards refer to raw data (csv, json, etc.) paired with an index into it. + Dual shards refer to raw data (csv, json, etc.) paired with an index into it. Args: out (str | Tuple[str, str]): Output dataset directory to save shard files. @@ -482,8 +482,8 @@ def __init__(self, **kwargs) @abstractmethod - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. @@ -499,7 +499,7 @@ def flush_shard(self) -> None: raw_data_basename, zip_data_basename = self._name_next_shard() raw_meta_basename, zip_meta_basename = self._name_next_shard('meta') - raw_data, raw_meta = self.encode_split_shard() + raw_data, raw_meta = self.encode_dual_shard() raw_data_info, zip_data_info = self._process_file(raw_data, raw_data_basename, zip_data_basename) raw_meta_info, zip_meta_info = self._process_file(raw_meta, raw_meta_basename, diff --git a/streaming/format/xsv/__init__.py b/streaming/format/xsv/__init__.py index 985010a42..8532c1013 100644 --- a/streaming/format/xsv/__init__.py +++ b/streaming/format/xsv/__init__.py @@ -1,9 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Module to write and read the dataset in Tabular format.""" +"""Streaming XSV shards, with specializations for CSV and TSV.""" -from streaming.format.xsv.reader import CSVReader, TSVReader, XSVReader +from streaming.format.xsv.shard import CSVShard, TSVShard, XSVShard from streaming.format.xsv.writer import CSVWriter, TSVWriter, XSVWriter -__all__ = ['CSVReader', 'CSVWriter', 'TSVReader', 'TSVWriter', 'XSVReader', 'XSVWriter'] +__all__ = ['CSVShard', 'CSVWriter', 'TSVShard', 'TSVWriter', 'XSVShard', 'XSVWriter'] diff --git a/streaming/format/xsv/reader.py b/streaming/format/xsv/shard.py similarity index 96% rename from streaming/format/xsv/reader.py rename to streaming/format/xsv/shard.py index f43ee6f5d..426954638 100644 --- a/streaming/format/xsv/reader.py +++ b/streaming/format/xsv/shard.py @@ -1,7 +1,7 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -"""Reads and decode samples from tabular formatted files such as XSV, CSV, and TSV.""" +"""Streaming XSV shard reading, with specializations for CSV and TSV.""" import os from copy import deepcopy @@ -10,13 +10,13 @@ import numpy as np from typing_extensions import Self -from streaming.format.reader import FileInfo, SplitReader +from streaming.format.shard import DualShard, FileInfo from streaming.format.xsv.encodings import xsv_decode -__all__ = ['XSVReader', 'CSVReader', 'TSVReader'] +__all__ = ['XSVShard', 'CSVShard', 'TSVShard'] -class XSVReader(SplitReader): +class XSVShard(DualShard): """Provides random access to the samples of an XSV shard. Args: @@ -73,7 +73,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded XSVReader. + Self: Loaded XSVShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -91,23 +91,6 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S args[key] = FileInfo(**arg) if arg else None return cls(**args) - def decode_sample(self, data: bytes) -> Dict[str, Any]: - """Decode a sample dict from bytes. - - Args: - data (bytes): The sample encoded as bytes. - - Returns: - Dict[str, Any]: Sample dict. - """ - text = data.decode('utf-8') - text = text[:-len(self.newline)] - parts = text.split(self.separator) - sample = {} - for name, encoding, part in zip(self.column_names, self.column_encodings, parts): - sample[name] = xsv_decode(encoding, part) - return sample - def get_sample_data(self, idx: int) -> bytes: """Get the raw sample data at the index. @@ -129,8 +112,25 @@ def get_sample_data(self, idx: int) -> bytes: data = fp.read(end - begin) return data + def decode_sample(self, data: bytes) -> Dict[str, Any]: + """Decode a sample dict from bytes. + + Args: + data (bytes): The sample encoded as bytes. + + Returns: + Dict[str, Any]: Sample dict. + """ + text = data.decode('utf-8') + text = text[:-len(self.newline)] + parts = text.split(self.separator) + sample = {} + for name, encoding, part in zip(self.column_names, self.column_encodings, parts): + sample[name] = xsv_decode(encoding, part) + return sample + -class CSVReader(XSVReader): +class CSVShard(XSVShard): """Provides random access to the samples of a CSV shard. Args: @@ -182,7 +182,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded CSVReader. + Self: Loaded CSVShard. """ args = deepcopy(obj) if args['version'] != 2: @@ -201,7 +201,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S return cls(**args) -class TSVReader(XSVReader): +class TSVShard(XSVShard): """Provides random access to the samples of an XSV shard. Args: @@ -253,7 +253,7 @@ def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> S obj (Dict[str, Any]): JSON object to load. Returns: - Self: Loaded TSVReader. + Self: Loaded TSVShard. """ args = deepcopy(obj) if args['version'] != 2: diff --git a/streaming/format/xsv/writer.py b/streaming/format/xsv/writer.py index b1ab720d3..519ec881b 100644 --- a/streaming/format/xsv/writer.py +++ b/streaming/format/xsv/writer.py @@ -1,20 +1,20 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -""":class:`XSVWriter` writes samples to `.xsv` files that can be read by :class:`XSVReader`.""" +"""Streaming XSV shard writing, with specializations for CSV and TSV.""" import json from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from streaming.format.writer import SplitWriter +from streaming.format.writer import DualWriter from streaming.format.xsv.encodings import is_xsv_encoding, xsv_encode __all__ = ['XSVWriter', 'CSVWriter', 'TSVWriter'] -class XSVWriter(SplitWriter): +class XSVWriter(DualWriter): r"""Writes a streaming XSV dataset. Args: @@ -114,8 +114,8 @@ def get_config(self) -> Dict[str, Any]: }) return obj - def encode_split_shard(self) -> Tuple[bytes, bytes]: - """Encode a split shard out of the cached samples (data, meta files). + def encode_dual_shard(self) -> Tuple[bytes, bytes]: + """Encode a dual shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. diff --git a/streaming/local.py b/streaming/local.py index 47dd8134f..ed6e99469 100644 --- a/streaming/local.py +++ b/streaming/local.py @@ -11,7 +11,7 @@ from torch.utils.data import Dataset from streaming.array import Array -from streaming.format import get_index_basename, reader_from_json +from streaming.format import get_index_basename, shard_from_json from streaming.spanner import Spanner __all__ = ['LocalDataset'] @@ -39,7 +39,7 @@ def __init__(self, local: str, split: Optional[str] = None): self.shards = [] for info in obj['shards']: - shard = reader_from_json(local, split, info) + shard = shard_from_json(local, split, info) self.shards.append(shard) self.num_samples = sum([shard.samples for shard in self.shards]) diff --git a/streaming/stream.py b/streaming/stream.py index 974ceaac7..3c3735e9f 100644 --- a/streaming/stream.py +++ b/streaming/stream.py @@ -16,7 +16,7 @@ from streaming.compression import decompress from streaming.constant import TICK from streaming.distributed import barrier, get_local_rank -from streaming.format import FileInfo, Reader, get_index_basename, reader_from_json +from streaming.format import FileInfo, Shard, get_index_basename, shard_from_json from streaming.hashing import get_hash from streaming.storage import download_file, wait_for_file_to_exist from streaming.util import retry @@ -352,9 +352,8 @@ def _prepare_shard_part(self, compression: Optional[str] = None) -> int: """Get shard data given metadata for the raw and compressed versions of it. - MDS format uses joint shards (ie, one file per shard). Other formats supported by streaming - use split shards (ie, shard data lives in two files per shard: the raw data itself and - metadata in a separate file). + Shards are either mono shards (one file per shard, like MDS) or dual shards (a pair of data + and meta files per shard, like the Streaming JSONL and XSV shard formats). Args: raw_info (FileInfo): Raw file info. @@ -407,11 +406,11 @@ def _prepare_shard_part(self, raise ValueError(f'Checksum failure: {raw_filename}') return delta - def prepare_shard(self, shard: Reader) -> int: + def prepare_shard(self, shard: Shard) -> int: """Ensure (download, validate, extract, etc.) that we have the given shard. Args: - shard (Reader): Which shard. + shard (Shard): Which shard. Returns: int: Change in cache usage. @@ -421,7 +420,7 @@ def prepare_shard(self, shard: Reader) -> int: delta += self._prepare_shard_part(raw_info, zip_info, shard.compression) return delta - def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: + def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Shard]: """Load this Stream's index, retrieving its shard readers. Args: @@ -431,7 +430,7 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: error. Returns: - `List[Reader]: Shard readers. + `List[Shard]: Shard readers. """ # Download the index file if it does not exist locally. basename = get_index_basename() @@ -471,17 +470,17 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: # Initialize shard readers according to the loaded info. shards = [] for info in obj['shards']: - shard = reader_from_json(self.local, self.split, info) + shard = shard_from_json(self.local, self.split, info) shard.validate(allow_unsafe_types) shards.append(shard) return shards - def set_up_local(self, shards: List[Reader], cache_usage_per_shard: NDArray[np.int64]) -> None: + def set_up_local(self, shards: List[Shard], cache_usage_per_shard: NDArray[np.int64]) -> None: """Bring a local directory into a consistent state, getting which shards are present. Args: - shards (List[Reader]): List of this stream's shards. + shards (List[Shard]): List of this stream's shards. cache_usage_per_shard (NDArray[np.int64]): Cache usage per shard of this stream. """ # List the cache directory (so that we hit the filesystem once). diff --git a/tests/test_encodings.py b/tests/test_encodings.py index 70d048647..dce91d8cd 100644 --- a/tests/test_encodings.py +++ b/tests/test_encodings.py @@ -10,16 +10,16 @@ import pytest from PIL import Image -import streaming.format.json.encodings as jsonEnc -import streaming.format.mds.encodings as mdsEnc -import streaming.format.xsv.encodings as xsvEnc +import streaming.format.jsonl.encodings as jsonl_enc +import streaming.format.mds.encodings as mds_enc +import streaming.format.xsv.encodings as xsv_enc class TestMDSEncodings: @pytest.mark.parametrize('data', [b'5', b'\x00\x00']) def test_byte_encode_decode(self, data: bytes): - byte_enc = mdsEnc.Bytes() + byte_enc = mds_enc.Bytes() assert byte_enc.size is None output = byte_enc.encode(data) assert output == data @@ -29,13 +29,13 @@ def test_byte_encode_decode(self, data: bytes): @pytest.mark.parametrize('data', ['9', 25]) def test_byte_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - byte_enc = mdsEnc.Bytes() + byte_enc = mds_enc.Bytes() _ = byte_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [('99', b'99'), ('streaming dataset', b'streaming dataset')]) def test_str_encode_decode(self, data: str, encode_data: bytes): - str_enc = mdsEnc.Str() + str_enc = mds_enc.Str() assert str_enc.size is None # Test encode @@ -51,13 +51,13 @@ def test_str_encode_decode(self, data: str, encode_data: bytes): @pytest.mark.parametrize('data', [b'9', 25]) def test_str_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - str_enc = mdsEnc.Str() + str_enc = mds_enc.Str() _ = str_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(99, b'c\x00\x00\x00\x00\x00\x00\x00'), (987654321, b'\xb1h\xde:\x00\x00\x00\x00')]) def test_int_encode_decode(self, data: int, encode_data: bytes): - int_enc = mdsEnc.Int() + int_enc = mds_enc.Int() assert int_enc.size == 8 # Test encode @@ -73,7 +73,7 @@ def test_int_encode_decode(self, data: int, encode_data: bytes): @pytest.mark.parametrize('data', [b'9', 25.9]) def test_int_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - int_enc = mdsEnc.Int() + int_enc = mds_enc.Int() _ = int_enc.encode(data) @pytest.mark.parametrize('dtype_str', [ @@ -103,28 +103,28 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): a = np.random.randint(0, 1000, shape).astype(dtype) encoding = 'ndarray' - assert mdsEnc.is_mds_encoding(encoding) - assert mdsEnc.get_mds_encoded_size(encoding) is None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + assert mds_enc.is_mds_encoding(encoding) + assert mds_enc.get_mds_encoded_size(encoding) is None + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() b1_len = len(b) encoding = f'ndarray:{dtype.__name__}' - assert mdsEnc.is_mds_encoding(encoding) - assert mdsEnc.get_mds_encoded_size(encoding) is None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + assert mds_enc.is_mds_encoding(encoding) + assert mds_enc.get_mds_encoded_size(encoding) is None + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() b2_len = len(b) shape_str = ','.join(map(str, shape)) encoding = f'ndarray:{dtype.__name__}:{shape_str}' - assert mdsEnc.is_mds_encoding(encoding) - b_size = mdsEnc.get_mds_encoded_size(encoding) + assert mds_enc.is_mds_encoding(encoding) + b_size = mds_enc.get_mds_encoded_size(encoding) assert b_size is not None - b = mdsEnc.mds_encode(encoding, a) - c = mdsEnc.mds_decode(encoding, b) + b = mds_enc.mds_encode(encoding, a) + c = mds_enc.mds_decode(encoding, b) assert (a == c).all() assert len(b) == b_size b3_len = len(b) @@ -134,7 +134,7 @@ def test_ndarray_encode_decode(self, dtype_str: str, shape: Tuple[int]): @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_pil_encode_decode(self, mode: str): - pil_enc = mdsEnc.PIL() + pil_enc = mds_enc.PIL() assert pil_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -158,12 +158,12 @@ def test_pil_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'9', 25.9]) def test_pil_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - pil_enc = mdsEnc.PIL() + pil_enc = mds_enc.PIL() _ = pil_enc.encode(data) @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpeg_encode_decode(self, mode: str): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() assert jpeg_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -182,7 +182,7 @@ def test_jpeg_encode_decode(self, mode: str): @pytest.mark.parametrize('mode', ['L', 'RGB']) def test_jpegfile_encode_decode(self, mode: str): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() assert jpeg_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -208,12 +208,12 @@ def test_jpegfile_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'99', 12.5]) def test_jpeg_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - jpeg_enc = mdsEnc.JPEG() + jpeg_enc = mds_enc.JPEG() _ = jpeg_enc.encode(data) @pytest.mark.parametrize('mode', ['I', 'L', 'RGB']) def test_png_encode_decode(self, mode: str): - png_enc = mdsEnc.PNG() + png_enc = mds_enc.PNG() assert png_enc.size is None # Creating the (32 x 32) NumPy Array with random values @@ -237,12 +237,12 @@ def test_png_encode_decode(self, mode: str): @pytest.mark.parametrize('data', [b'123', 77.7]) def test_png_encode_invalid_data(self, data: Any): with pytest.raises(AttributeError): - png_enc = mdsEnc.PNG() + png_enc = mds_enc.PNG() _ = png_enc.encode(data) @pytest.mark.parametrize('data', [25, 'streaming', np.array(7)]) def test_pickle_encode_decode(self, data: Any): - pkl_enc = mdsEnc.Pickle() + pkl_enc = mds_enc.Pickle() assert pkl_enc.size is None # Test encode @@ -258,7 +258,7 @@ def test_pickle_encode_decode(self, data: Any): @pytest.mark.parametrize('data', [25, 'streaming', {'alpha': 1, 'beta': 2}]) def test_json_encode_decode(self, data: Any): - json_enc = mdsEnc.JSON() + json_enc = mds_enc.JSON() assert json_enc.size is None # Test encode @@ -275,12 +275,12 @@ def test_json_encode_decode(self, data: Any): def test_json_invalid_data(self): wrong_json_with_single_quotes = "{'name': 'streaming'}" with pytest.raises(json.JSONDecodeError): - json_enc = mdsEnc.JSON() + json_enc = mds_enc.JSON() json_enc._is_valid(wrong_json_with_single_quotes, wrong_json_with_single_quotes) @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*')]) def test_mds_uint8(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt8() + coder = mds_enc.UInt8() assert coder.size == 1 enc = coder.encode(decoded) @@ -293,7 +293,7 @@ def test_mds_uint8(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0')]) def test_mds_uint16(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt16() + coder = mds_enc.UInt16() assert coder.size == 2 enc = coder.encode(decoded) @@ -306,7 +306,7 @@ def test_mds_uint16(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0')]) def test_mds_uint32(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt32() + coder = mds_enc.UInt32() assert coder.size == 4 enc = coder.encode(decoded) @@ -319,7 +319,7 @@ def test_mds_uint32(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0\0\0\0\0')]) def test_mds_uint64(self, decoded: int, encoded: bytes): - coder = mdsEnc.UInt64() + coder = mds_enc.UInt64() assert coder.size == 8 enc = coder.encode(decoded) @@ -332,7 +332,7 @@ def test_mds_uint64(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*')]) def test_mds_int8(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int8() + coder = mds_enc.Int8() assert coder.size == 1 enc = coder.encode(decoded) @@ -345,7 +345,7 @@ def test_mds_int8(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0')]) def test_mds_int16(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int16() + coder = mds_enc.Int16() assert coder.size == 2 enc = coder.encode(decoded) @@ -358,7 +358,7 @@ def test_mds_int16(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0')]) def test_mds_int32(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int32() + coder = mds_enc.Int32() assert coder.size == 4 enc = coder.encode(decoded) @@ -371,7 +371,7 @@ def test_mds_int32(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'*\0\0\0\0\0\0\0')]) def test_mds_int64(self, decoded: int, encoded: bytes): - coder = mdsEnc.Int64() + coder = mds_enc.Int64() assert coder.size == 8 enc = coder.encode(decoded) @@ -384,7 +384,7 @@ def test_mds_int64(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'@Q')]) def test_mds_float16(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float16() + coder = mds_enc.Float16() assert coder.size == 2 enc = coder.encode(decoded) @@ -397,7 +397,7 @@ def test_mds_float16(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'\0\0(B')]) def test_mds_float32(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float32() + coder = mds_enc.Float32() assert coder.size == 4 enc = coder.encode(decoded) @@ -410,7 +410,7 @@ def test_mds_float32(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'\0\0\0\0\0\0E@')]) def test_mds_float64(self, decoded: float, encoded: bytes): - coder = mdsEnc.Float64() + coder = mds_enc.Float64() assert coder.size == 8 enc = coder.encode(decoded) @@ -423,7 +423,7 @@ def test_mds_float64(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42, b'42'), (-42, b'-42')]) def test_mds_StrInt(self, decoded: int, encoded: bytes): - coder = mdsEnc.StrInt() + coder = mds_enc.StrInt() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -434,7 +434,7 @@ def test_mds_StrInt(self, decoded: int, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(42.0, b'42.0'), (-42.0, b'-42.0')]) def test_mds_StrFloat(self, decoded: float, encoded: bytes): - coder = mdsEnc.StrFloat() + coder = mds_enc.StrFloat() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -446,7 +446,7 @@ def test_mds_StrFloat(self, decoded: float, encoded: bytes): @pytest.mark.parametrize(('decoded', 'encoded'), [(Decimal('4E15'), b'4E+15'), (Decimal('-4E15'), b'-4E+15')]) def test_mds_StrDecimal(self, decoded: Decimal, encoded: bytes): - coder = mdsEnc.StrDecimal() + coder = mds_enc.StrDecimal() enc = coder.encode(decoded) assert isinstance(enc, bytes) assert enc == encoded @@ -463,14 +463,14 @@ def test_get_mds_encodings(self): expected_encodings = { 'int', 'bytes', 'json', 'ndarray', 'png', 'jpeg', 'str', 'pil', 'pkl' } | scalars - enc = mdsEnc.get_mds_encodings() + enc = mds_enc.get_mds_encodings() assert len(enc) == len(expected_encodings) assert enc == expected_encodings @pytest.mark.parametrize(('enc_name', 'expected_output'), [('jpeg', True), ('', False), ('pngg', False)]) def test_is_mds_encoding(self, enc_name: str, expected_output: bool): - is_supported = mdsEnc.is_mds_encoding(enc_name) + is_supported = mds_enc.is_mds_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('encoding', 'decoded', 'encoded'), @@ -480,35 +480,35 @@ def test_is_mds_encoding(self, enc_name: str, expected_output: bool): ('int64', 42, b'*\0\0\0\0\0\0\0'), ('float16', 42.0, b'@Q'), ('float32', 42.0, b'\0\0(B'), ('float64', 42.0, b'\0\0\0\0\0\0E@')]) def test_mds_scalar(self, encoding: str, decoded: Union[int, float], encoded: bytes): - enc = mdsEnc.mds_encode(encoding, decoded) + enc = mds_enc.mds_encode(encoding, decoded) assert isinstance(enc, bytes) assert enc == encoded - dec = mdsEnc.mds_decode(encoding, enc) + dec = mds_enc.mds_decode(encoding, enc) assert dec == decoded - dec = mdsEnc.mds_decode(encoding, encoded) + dec = mds_enc.mds_decode(encoding, encoded) assert dec == decoded @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', b'9'), ('int', 27), ('str', 'mosaicml')]) def test_mds_encode(self, enc_name: str, data: Any): - output = mdsEnc.mds_encode(enc_name, data) + output = mds_enc.mds_encode(enc_name, data) assert isinstance(output, bytes) @pytest.mark.parametrize(('enc_name', 'data'), [('bytes', 9), ('int', '27'), ('str', 12.5)]) def test_mds_encode_invalid_data(self, enc_name: str, data: Any): with pytest.raises(AttributeError): - _ = mdsEnc.mds_encode(enc_name, data) + _ = mds_enc.mds_encode(enc_name, data) @pytest.mark.parametrize(('enc_name', 'data', 'expected_data_type'), [('bytes', b'c\x00\x00\x00\x00\x00\x00\x00', bytes), ('str', b'mosaicml', str)]) def test_mds_decode(self, enc_name: str, data: Any, expected_data_type: Any): - output = mdsEnc.mds_decode(enc_name, data) + output = mds_enc.mds_decode(enc_name, data) assert isinstance(output, expected_data_type) @pytest.mark.parametrize(('enc_name', 'expected_size'), [('bytes', None), ('int', 8)]) def test_get_mds_encoded_size(self, enc_name: str, expected_size: Any): - output = mdsEnc.get_mds_encoded_size(enc_name) + output = mds_enc.get_mds_encoded_size(enc_name) assert output is expected_size @@ -517,7 +517,7 @@ class TestXSVEncodings: @pytest.mark.parametrize(('data', 'encode_data'), [('99', '99'), ('streaming dataset', 'streaming dataset')]) def test_str_encode_decode(self, data: str, encode_data: str): - str_enc = xsvEnc.Str() + str_enc = xsv_enc.Str() # Test encode enc_data = str_enc.encode(data) @@ -532,12 +532,12 @@ def test_str_encode_decode(self, data: str, encode_data: str): @pytest.mark.parametrize('data', [99, b'streaming dataset', 123.45]) def test_str_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - str_enc = xsvEnc.Str() + str_enc = xsv_enc.Str() _ = str_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(99, '99'), (987675432, '987675432')]) def test_int_encode_decode(self, data: int, encode_data: str): - int_enc = xsvEnc.Int() + int_enc = xsv_enc.Int() # Test encode enc_data = int_enc.encode(data) @@ -552,12 +552,12 @@ def test_int_encode_decode(self, data: int, encode_data: str): @pytest.mark.parametrize('data', ['99', b'streaming dataset', 123.45]) def test_int_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - int_enc = xsvEnc.Int() + int_enc = xsv_enc.Int() _ = int_enc.encode(data) @pytest.mark.parametrize(('data', 'encode_data'), [(1.24, '1.24'), (9.0, '9.0')]) def test_float_encode_decode(self, data: int, encode_data: str): - float_enc = xsvEnc.Float() + float_enc = xsv_enc.Float() # Test encode enc_data = float_enc.encode(data) @@ -572,7 +572,7 @@ def test_float_encode_decode(self, data: int, encode_data: str): @pytest.mark.parametrize('data', ['99', b'streaming dataset', 12]) def test_float_encode_invalid_data(self, data: Any): with pytest.raises(Exception): - float_enc = xsvEnc.Float() + float_enc = xsv_enc.Float() _ = float_enc.encode(data) @pytest.mark.parametrize(('enc_name', 'expected_output'), [ @@ -582,14 +582,14 @@ def test_float_encode_invalid_data(self, data: Any): ('', False), ]) def test_is_xsv_encoding(self, enc_name: str, expected_output: bool): - is_supported = xsvEnc.is_xsv_encoding(enc_name) + is_supported = xsv_enc.is_xsv_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('enc_name', 'data', 'expected_data'), [('str', 'mosaicml', 'mosaicml'), ('int', 27, '27'), ('float', 1.25, '1.25')]) def test_xsv_encode(self, enc_name: str, data: Any, expected_data: str): - output = xsvEnc.xsv_encode(enc_name, data) + output = xsv_enc.xsv_encode(enc_name, data) assert isinstance(output, str) assert output == expected_data @@ -597,7 +597,7 @@ def test_xsv_encode(self, enc_name: str, data: Any, expected_data: str): [('str', 'mosaicml', 'mosaicml'), ('int', '27', 27), ('float', '1.25', 1.25)]) def test_xsv_decode(self, enc_name: str, data: str, expected_data: Any): - output = xsvEnc.xsv_decode(enc_name, data) + output = xsv_enc.xsv_decode(enc_name, data) assert isinstance(output, type(expected_data)) assert output == expected_data @@ -606,7 +606,7 @@ class TestJSONEncodings: @pytest.mark.parametrize('data', ['99', 'mosaicml']) def test_str_is_encoded(self, data: str): - json_enc = jsonEnc.Str() + json_enc = jsonl_enc.Str() # Test encode enc_data = json_enc.is_encoded(data) @@ -615,12 +615,12 @@ def test_str_is_encoded(self, data: str): @pytest.mark.parametrize('data', [99, b'mosaicml']) def test_str_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - json_enc = jsonEnc.Str() + json_enc = jsonl_enc.Str() _ = json_enc.is_encoded(data) @pytest.mark.parametrize('data', [99, 987675432]) def test_int_is_encoded(self, data: int): - int_enc = jsonEnc.Int() + int_enc = jsonl_enc.Int() # Test encode enc_data = int_enc.is_encoded(data) @@ -629,12 +629,12 @@ def test_int_is_encoded(self, data: int): @pytest.mark.parametrize('data', ['99', b'mosaicml', 1.25]) def test_int_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - int_enc = jsonEnc.Int() + int_enc = jsonl_enc.Int() _ = int_enc.is_encoded(data) @pytest.mark.parametrize('data', [1.25]) def test_float_is_encoded(self, data: int): - float_enc = jsonEnc.Float() + float_enc = jsonl_enc.Float() # Test encode enc_data = float_enc.is_encoded(data) @@ -643,7 +643,7 @@ def test_float_is_encoded(self, data: int): @pytest.mark.parametrize('data', ['99', b'mosaicml', 25]) def test_float_is_encoded_invalid_data(self, data: Any): with pytest.raises(AttributeError): - float_enc = jsonEnc.Float() + float_enc = jsonl_enc.Float() _ = float_enc.is_encoded(data) @pytest.mark.parametrize(('enc_name', 'expected_output'), [ @@ -652,13 +652,13 @@ def test_float_is_encoded_invalid_data(self, data: Any): ('float', True), ('', False), ]) - def test_is_json_encoding(self, enc_name: str, expected_output: bool): - is_supported = jsonEnc.is_json_encoding(enc_name) + def test_is_jsonl_encoding(self, enc_name: str, expected_output: bool): + is_supported = jsonl_enc.is_jsonl_encoding(enc_name) assert is_supported is expected_output @pytest.mark.parametrize(('enc_name', 'data', 'expected_output'), [('str', 'hello', True), ('int', 10, True), ('float', 9.9, True)]) - def test_is_json_encoded(self, enc_name: str, data: Any, expected_output: bool): - is_supported = jsonEnc.is_json_encoded(enc_name, data) + def test_is_jsonl_encoded(self, enc_name: str, data: Any, expected_output: bool): + is_supported = jsonl_enc.is_jsonl_encoded(enc_name, data) assert is_supported is expected_output diff --git a/tests/test_writer.py b/tests/test_writer.py index 188a6b40b..d2aa691a1 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -9,7 +9,7 @@ import numpy as np import pytest -from streaming import CSVWriter, JSONWriter, MDSWriter, StreamingDataset, TSVWriter, XSVWriter +from streaming import CSVWriter, JSONLWriter, MDSWriter, StreamingDataset, TSVWriter, XSVWriter from tests.common.datasets import NumberAndSayDataset, SequenceDataset from tests.common.utils import get_config_in_bytes @@ -122,7 +122,7 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s assert before == after -class TestJSONWriter: +class TestJSONLWriter: @pytest.mark.parametrize('num_samples', [100]) @pytest.mark.parametrize('size_limit', [32]) @@ -133,18 +133,18 @@ def test_config(self, local_remote_dir: Tuple[str, str], num_samples: int, columns = dict(zip(dataset.column_names, dataset.column_encodings)) expected_config = { 'version': 2, - 'format': 'json', + 'format': 'jsonl', 'compression': None, 'hashes': [], 'size_limit': size_limit, 'columns': columns, 'newline': '\n' } - writer = JSONWriter(out=local, - columns=columns, - compression=None, - hashes=None, - size_limit=size_limit) + writer = JSONLWriter(out=local, + columns=columns, + compression=None, + hashes=None, + size_limit=size_limit) assert writer.get_config() == expected_config @pytest.mark.parametrize('num_samples', [50000]) @@ -158,11 +158,11 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s local, _ = local_remote_dir dataset = NumberAndSayDataset(num_samples, seed=seed) columns = dict(zip(dataset.column_names, dataset.column_encodings)) - with JSONWriter(out=local, - columns=columns, - compression=compression, - hashes=hashes, - size_limit=size_limit) as out: + with JSONLWriter(out=local, + columns=columns, + compression=compression, + hashes=hashes, + size_limit=size_limit) as out: for sample in dataset: out.write(sample) From 3dcf500a32d47d9db1b14e0890ead27997e6dee8 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 08:25:14 -0800 Subject: [PATCH 10/12] Parquet ingestion, Streaming Parquet shards (draft). --- streaming/format/index.py | 116 ++++++++++++++ streaming/format/parquet/__init__.py | 9 ++ streaming/format/parquet/indexing.py | 220 +++++++++++++++++++++++++++ streaming/format/parquet/shard.py | 203 ++++++++++++++++++++++++ streaming/format/shard.py | 36 +++-- 5 files changed, 575 insertions(+), 9 deletions(-) create mode 100644 streaming/format/parquet/__init__.py create mode 100644 streaming/format/parquet/indexing.py create mode 100644 streaming/format/parquet/shard.py diff --git a/streaming/format/index.py b/streaming/format/index.py index f1b0cdd2a..87a6cc8e1 100644 --- a/streaming/format/index.py +++ b/streaming/format/index.py @@ -3,6 +3,16 @@ """Methods having to do with streaming dataset indexes.""" +import json +import os +from re import Pattern +from typing import Callable, Dict, Iterable, Optional, Union +from warnings import warn + +from streaming.format.parquet.indexing import index_parquet +from streaming.storage import CloudUploader, download_file, file_exists +from streaming.util.shorthand import normalize_duration + __all__ = ['get_index_basename'] @@ -13,3 +23,109 @@ def get_index_basename() -> str: str: Index basename. """ return 'index.json' + + +Predicate = Union[str, Pattern, Callable[[str], bool]] + + +def materialize_index(*, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + backend: str = 'streaming', + files: Optional[Iterable[str]] = None, + keep: Optional[Predicate] = r'^.*\.parquet$', + num_procs: Optional[int] = None, + show_progress: bool = True, + columns: Optional[Dict[str, Dict[str, str]]] = None, + match_columns: bool = True, + download_timeout: Union[float, str] = '5m', + max_file_size: Optional[Union[int, str]] = '200mb', + save_index_to_remote: bool = True) -> None: + r"""Either download or generate the Streaming index for the given dataset. + + Args: + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + files (Iterable[str], optional): An Iterable of file paths relative to dataset root. These + paths filtered for the Parquets constituting this dataset by ``keep``. If not set, we + default to a sorted listing of all the files under dataset root. We list the remote if + provided, else we assume local is complete. Defaults to ``None``. + keep (Union[str, Pattern, Callable[[str], bool]], optional): Iterating ``files``, we keep + the ones this regex matches (if str) or predicate accepts (if Callable). Defaults to + ``^.*\.parquet$``, i.e. include every file that ends with ".parquet". + num_procs (int, optional): Number of processes for download/processing of potentially many + large Parquet files. ``0`` means single-process; ``None`` means + processes; positive int means that number of processes. Defaults to ``None``. + show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. + columns (Dict[str, str], optional): For field names and types specified here, override the + inferred columns to configure it manually. Defaults to ``None``. + match_columns (bool): Whether to require that all the dataset Parquets have exactly the same + column configuration. This is a correctness guard rail, preventing non-dataset Parquet + shards from sneaking into our dataset. Streaming for its part is fine with shards being + "incompatible"; assumes client will handle it. Defaults to ``True``. + download_timeout (Union[float, str]): For each Parquet file. Defaults to ``2m``. + max_file_size (Union[int, str], optional): File size limit, above which we raise an error. + This is a performance guard rail, as choppiness increases linearly with shard size. The + sweet spot is typically around 32mb. Defaults to ``200mb``. + save_index_to_remote (bool): If we are indexing a third-party dataset and have a remote, + whether to save the generated index to the remote in order to prevent having to index + the dataset again in the future. + """ + index_rel_path = get_index_basename() + if backend == 'streaming': + # First option: this is explicitly a Streaming dataset. + # + # Ensure the index.json is local and we're done. + local_filename = os.path.join(local, split or '', index_rel_path) + if not os.path.exists(local_filename): + if remote: + # Download the `index.json` to `index.json.tmp` then rename to `index.json`. This + # is because only one process performs the downloading, while otherse wait for it + # to complete. + remote_path = os.path.join(remote, split or '', index_rel_path) + temp_local_filename = local_filename + '.tmp' + norm_download_timeout = normalize_duration(download_timeout) + download_file(remote_path, temp_local_filename, norm_download_timeout) + os.rename(temp_local_filename, local_filename) + else: + raise ValueError(f'No `remote` provided, but local file {local_filename} does ' + + f'not exist either.') + elif file_exists(local=local, remote=remote, split=split, path=index_rel_path): + # Second option: this is a Streaming dataset, but the backend is set wrong. + # + # Note: Streaming datasets are datasets that Streaming can use -- they need a Streaming + # index.json, but their shards can be in other formats, e.g. Parquet files or Delta tables. + warn(f'Specified a non-Streaming backend ({backend}), but a Streaming index.json was ' + + f'found (which makes this technically a Streaming dataset). Will use this ' + + f'already-existing Streaming index instead of re-indexing the dataset.') + else: + # Third option: This is not a Streaming dataset. + # + # We call out to backend-specific assimilate() methods to index this third-party dataset, + # resulting in a perfectly normal and valid index.json. May want to save that to remote. + if backend == 'parquet': + obj = index_parquet(local=local, + remote=remote, + split=split, + files=files, + keep=keep, + num_procs=num_procs, + show_progress=show_progress, + columns=columns, + match_columns=match_columns, + download_timeout=download_timeout, + max_file_size=max_file_size) + else: + raise ValueError(f'Unsupported backend: {backend}.') + + # Save index to local. + index_filename = os.path.join(local, split or '', index_rel_path) + with open(index_filename, 'w') as out: + json.dump(obj, out) + + # Maybe save index to remote. + if save_index_to_remote and remote: + uploader = CloudUploader.get((local, remote), True) + uploader.upload_file(index_rel_path) diff --git a/streaming/format/parquet/__init__.py b/streaming/format/parquet/__init__.py new file mode 100644 index 000000000..170429408 --- /dev/null +++ b/streaming/format/parquet/__init__.py @@ -0,0 +1,9 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming Parquet shards.""" + +from streaming.format.parquet.indexing import index_parquet +from streaming.format.parquet.reader import ParquetShard + +__all__ = ['index_parquet', 'ParquetShard'] diff --git a/streaming/format/parquet/indexing.py b/streaming/format/parquet/indexing.py new file mode 100644 index 000000000..16efa8380 --- /dev/null +++ b/streaming/format/parquet/indexing.py @@ -0,0 +1,220 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Indexing a Parquet dataset for use by Streaming.""" + +import os +from re import Pattern +from typing import Any, Callable, Dict, Iterable, Optional, Union + +from pyarrow import parquet as pq +from tqdm import tqdm + +from streaming.format.mds.encodings import get_mds_encoded_size +from streaming.storage.extra import list_dataset_files, smart_download_file +from streaming.util.shorthand import normalize_duration + +__all__ = ['index_parquet'] + + +def _get_mds_column(val: Any) -> str: + """Get the MDS column encoding of one field. + + Args: + val (Any): The field. + + Returns: + str: Its corresponding MDS encoding. + """ + if isinstance(val, int): + return 'int' + elif isinstance(val, str): + return 'str' + else: + raise ValueError('Unsupported column type: {type(val)}.') + + +def _sample_to_columns(sample: Dict[str, Any]) -> Dict[str, Any]: + """Get column names, encodings, and sizes. + + Args: + sample (Dict[str, Any]): A sample to derive column info from. + + Returns: + Dict[str, Any]: MDS column names, encodings, and sizes. + """ + col_names = sorted(sample) + col_encs = [] + for name in col_names: + val = sample[name] + enc = _get_mds_column(val) + col_encs.append(enc) + col_sizes = list(map(get_mds_encoded_size, col_encs)) + return { + 'column_names': col_names, + 'column_encodings': col_encs, + 'column_sizes': col_sizes, + } + + +def _index_file(local: str, + remote: Optional[str], + split: Optional[str], + rel_path: str, + download_timeout: Union[float, str] = '2m', + max_file_bytes: Optional[Union[int, str]] = '200mb', + want_mds_columns: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Get info a Streaming index needs about a Parquet shard. + + Args: + local (str): Local dataset root. + remote (str, optional): Remote dataset root, if remote is provided. + split (str, optional): Split, if used. + rel_path (str): Path to file, relative to serialized dataset root. + download_timeout (Union[float, str]): Maximum download time. Defaults to ``2m``. + max_file_bytes (Union[int, str], optional): Maximum file size. This is to catch people + trying to stream gigantic Parquet shards. Defaults to ``200mb``. + want_mds_columns (Dict[str, Any], optional): If provided, MDS schemna that this Parquet + shard must match upon conversion to MDS. + + Returns: + Dict[str, Any]: Shard info, or None upon failure. + """ + local_path = os.path.join(local, split or '', rel_path) + if not os.path.exists(local): + if not remote: + raise ValueError('Remote was needed, but not provided.') + + remote_path = os.path.join(remote, split or '', rel_path) + smart_download_file(remote=remote_path, + local=local_path, + timeout=download_timeout, + max_size=max_file_bytes) + + num_bytes = os.stat(local).st_size + + table = pq.read_table(local_path) + samples = table.to_pylist() + num_samples = len(samples) + mds_columns = _sample_to_columns(samples[0]) + if want_mds_columns and want_mds_columns != mds_columns: + raise ValueError(f'MDS column mismatch: required {want_mds_columns}, but got ' + + f'{mds_columns}.') + + ret = { + 'version': 2, + 'format': 'parquet', + 'raw_parquet': { + 'basename': rel_path, + 'bytes': num_bytes, + }, + 'raw_data': { + 'basename': rel_path + '.mds', + }, + 'samples': num_samples, + } + ret.update(mds_columns) + return ret + + +def _shard_metadata_to_columns(info: Dict[str, Any]) -> Dict[str, Any]: + """Extract MDS column information from the info for a shard. + + Args: + info (Dict[str, Any]): Shard info. + + Returns: + Dict[str, Any]: MDS columns. + """ + ret = {} + for key in ['column_names', 'column_encodings', 'column_sizes']: + ret[key] = info[key] + return ret + + +Predicate = Union[str, Pattern, Callable[[str], bool]] + + +def index_parquet(*, + local: str, + remote: Optional[str] = None, + split: Optional[str] = None, + files: Optional[Iterable[str]] = None, + keep: Optional[Predicate] = r'.*\.parquet$', + num_procs: Optional[int] = None, + show_progress: bool = True, + columns: Optional[Dict[str, Dict[str, str]]] = None, + match_columns: bool = True, + download_timeout: Union[float, str] = '5m', + max_file_size: Optional[Union[int, str]] = '200mb') -> Dict[str, Any]: + r"""Index a local and/or remote Parquet dataset directory for use by Streaming. + + "Parquet dataset" means the samples live in a collection of naked Parquet files. There is not + any kind of index or manifest we can count on existing, so we will have to create one. + + Assumptions: + * Samples live in a collection of naked Parquet files. + * There is not any kind of index or manifest that we can count on existing. + * Files are all found under a common root directory, which local/remote point to. + * This root directory may contain other files, which we ignore. + * Ideally, but not necessarily, the Parquets all have the same columns. + + Locality: + * If we are given an explicit list of Parquet files, we try local first, then remote. Both + are cross-checked for completeness. + * If we are default listing all files instead, and just have a local, it is assumed to be + complete. + * If we are listing files, and remote is provided, the remote must be authoritative. + + TODO: use num_procs. + TODO: use columns. + + Args: + local (str): Where the dataset is cached on the local filesystem. + remote (str, optional): Where the dataset is downloaded from. Defaults to ``None``. + split (str, optional): Which dataset split to use. Defaults to ``None``. + files (Iterable[str], optional): An Iterable of file paths relative to dataset root. These + paths filtered for the Parquets constituting this dataset by ``keep``. If not set, we + default to a sorted listing of all the files under dataset root. We list the remote if + provided, else we assume local is complete. Defaults to ``None``. + keep (Union[str, Pattern, Callable[[str], bool]], optional): Iterating ``files``, we keep + the ones this regex matches (if str) or predicate accepts (if Callable). Defaults to + ``.*\.parquet$``, i.e. include every file that ends with ".parquet". + num_procs (int, optional): Number of processes for download/processing of potentially many + large Parquet files. ``0`` means single-process; ``None`` means + processes; positive int means that number of processes. Defaults to ``None``. + show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. + columns (Dict[str, str], optional): For field names and types specified here, override the + inferred columns to configure it manually. Defaults to ``None``. + match_columns (bool): Whether to require that all the dataset Parquets have exactly the same + Parquet columns. This is a correctness guard rail, preventing non-dataset Parquet shards + from sneaking into our dataset. Streaming for its part is fine with shards being + "incompatible"; assumes client will handle it. Defaults to ``True``. + download_timeout (Union[float, str]): For each Parquet file. Defaults to ``2m``. + max_file_size (Union[int, str], optional): File size limit, above which we raise an error. + This is a performance guard rail, as choppiness increases linearly with shard size. The + sweet spot is typically around 32mb. Defaults to ``200mb``. + + Returns: + Dict[str, Any]: StreamingDataset index configuration to stream this Parquet dataset. + """ + norm_download_timeout = normalize_duration(download_timeout) + + rel_paths = list_dataset_files(local=local, remote=remote, split=split, paths=files, keep=keep) + if show_progress: + rel_paths = tqdm(rel_paths, leave=False) + + want_mds_columns = None + infos = [] + for rel_path in rel_paths: + info = _index_file(local, remote, split, rel_path, norm_download_timeout, max_file_size, + want_mds_columns) + infos.append(info) + + if match_columns and not want_mds_columns: + want_mds_columns = _shard_metadata_to_columns(info) + + return { + 'version': 2, + 'shards': infos, + } diff --git a/streaming/format/parquet/shard.py b/streaming/format/parquet/shard.py new file mode 100644 index 000000000..bbaaded0d --- /dev/null +++ b/streaming/format/parquet/shard.py @@ -0,0 +1,203 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming Parquet shard reading.""" + +import os +from copy import deepcopy +from tempfile import TemporaryDirectory +from typing import Any, Dict, List, Optional, Set + +from pyarrow import parquet as pq +from typing_extensions import Self + +from streaming.format.mds.shard import MDSShard +from streaming.format.mds.writer import MDSWriter +from streaming.format.shard import FileInfo + + +# TODO: This approach is close, but wrong. +class ParquetShard(MDSShard): + """Provides random access to the samples of a Parquet shard (via MDS internally). + + Args: + dirname (str): Local dataset directory. + split (str, optional): Which dataset split to use, if any. + column_encodings (List[str]): Column encodings. + column_names (List[str]): Column names. + column_sizes (List[Optional[int]]): Column fixed sizes, if any. + raq_parquet (FileInfo): Non-compressed Parquet file info. + raw_data (FileInfo): Uncompressed data file info. + samples (int): Number of samples in this shard. + """ + + def __init__( + self, + dirname: str, + split: Optional[str], + column_encodings: List[str], + column_names: List[str], + column_sizes: List[Optional[int]], + raw_parquet: FileInfo, + raw_data: FileInfo, + samples: int, + ) -> None: + super().__init__(dirname=dirname, + split=split, + column_encodings=column_encodings, + column_names=column_names, + column_sizes=column_sizes, + compression=None, + hashes=[], + raw_data=raw_data, + samples=samples, + size_limit=None, + zip_data=None) + self.raw_parquet = raw_parquet + self.file_pairs.append((raw_parquet, None)) + + @classmethod + def from_json(cls, dirname: str, split: Optional[str], obj: Dict[str, Any]) -> Self: + """Initialize from JSON object. + + Args: + dirname (str): Local directory containing shards. + split (str, optional): Which dataset split to use, if any. + obj (Dict[str, Any]): JSON object to load. + + Returns: + Self: Loaded ParquetShard. + """ + args = deepcopy(obj) + + if args['version'] != 2: + raise ValueError(f'Unsupported streaming data version: {args["version"]}. ' + + f'Expected version 2.') + del args['version'] + + if args['format'] != 'parquet': + raise ValueError(f'Unsupported data format: {args["format"]}. ' + + f'Expected to be `parquet`.') + del args['format'] + + args['dirname'] = dirname + args['split'] = split + for key in ['raw_parquet', 'raw_data', 'zip_data']: + arg = args.get(key) + if arg: + args[key] = FileInfo(**arg) + + return cls(**args) + + def set_up_local(self, listing: Set[str], safe_keep_zip: bool, safe_keep_parquet: bool) -> int: + """Bring what shard files are present to a consistent state, returning whether present. + + Args: + listing (Set[str]): The listing of all files under dirname/[split/]. This is listed + once and then saved because there could potentially be very many shard files. + safe_keep_zip (bool): Whether to keep or drop the zip form after decompression, if + applicable, safely taking into account whether this directory is the official copy. + safe_keep_parquet (bool): Whether to keep or drop the Parquet form after MDS + conversion, if applicable, safely taking into account whether this directory is the + official copy. + + Returns: + int: This shard's current contribution to cache usage after normalization. + """ + parquet_filename = os.path.join(self.dirname, self.split, self.raw_parquet.basename) + mds_filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + if os.path.exists(mds_filename): + if os.path.exists(parquet_filename): + if safe_keep_parquet: + # Present: keep both (because of safe_keep_parquet). + size = os.stat(mds_filename).st_size + os.stat(parquet_filename).st_size + else: + # Present: keep MDS, drop Parquet (because of saftfe_keep_parquet). + os.remove(parquet_filename) + size = os.stat(mds_filename).st_size + else: + if safe_keep_parquet: + # Normalize to missing, because safe_keep_parquet requires that we keep the + # Parquet. + os.remove(mds_filename) + size = 0 + else: + # Present: have MDS, don't have or want Parquet. + size = os.stat(mds_filename).st_size + else: + if os.path.exists(parquet_filename): + # Present: Parquet hasn't been converted to MDS yet and we don't have time to here. + size = os.stat(parquet_filename).st_size + else: + # Missing: both Parquet and MDS are not there. + size = 0 + return size + + def get_column(self, val: Any) -> str: + """Get the MDS column encoding of one field. + + Args: + val (Any): The field. + + Returns: + str: Its corresponding MDS encoding. + """ + if isinstance(val, int): + return 'int' + elif isinstance(val, str): + return 'str' + else: + raise ValueError('Unsupported column type: {type(val)}.') + + def get_columns(self, sample: Dict[str, Any]) -> Dict[str, str]: + """Get the MDS columns given one sample. + + Args: + sample (Dict[str, Any]): Mapping of column name to value. + + Returns: + Dict[str, str]: Mapping of column name to MDS encoding. + """ + col_names = sorted(sample) + col_encs = [] + for name in col_names: + val = sample[name] + enc = self.get_column(val) + col_encs.append(enc) + return dict(zip(col_names, col_encs)) + + def prepare(self, safe_keep_zip: bool, safe_keep_parquet: bool) -> int: + """Prepare this shard for fast random access by converting to MDS. + + Args: + safe_keep_zip (bool): Whether to keep or drop the zip form after decompression, if + applicable, safely taking into account whether this directory is the official copy. + safe_keep_parquet (bool): Whether to keep or drop the Parquet form after MDS + conversion, if applicable, safely taking into account whether this directory is the + official copy. + + Returns: + int: Change in cache usage in bytes resulting from Parquet to MDS conversion. + """ + # Read the samples from Parquet. + parquet_filename = os.path.join(self.dirname, self.split, self.raw_parquet.basename) + table = pq.read_table(parquet_filename) + samples = table.to_pylist() + + # Write the samples to MDS. + columns = dict(zip(self.column_names, self.column_encodings)) + with TemporaryDirectory() as temp_dir: + with MDSWriter(columns=columns, out=temp_dir, size_limit=None) as out: + for sample in samples: + out.write(sample) + temp_mds_filename = os.path.join(temp_dir, 'shard.00000.mds') + mds_filename = os.path.join(self.dirname, self.split, self.raw_data.basename) + os.rename(temp_mds_filename, mds_filename) + delta = os.stat(mds_filename).st_size + + # Maybe drop the original Parquet. + if not safe_keep_parquet: + os.remove(parquet_filename) + delta -= os.stat(parquet_filename).st_size + + return delta diff --git a/streaming/format/shard.py b/streaming/format/shard.py index 818fc036f..3ff0635c5 100644 --- a/streaming/format/shard.py +++ b/streaming/format/shard.py @@ -17,13 +17,13 @@ @dataclass -class FileInfo(object): - """File validation info. +class FileInfo: + """Per-file metadata, by which we know exactly what to expect. Args: - basename (str): File basename. - bytes (int): File size in bytes. - hashes (Dict[str, str]): Mapping of hash algorithm to hash value. + basename (str): File path relative to the root of this dataset or split. + bytes (int): File size in bytes. + hashes (Dict[str, str]): Map of hash algo to hash value. """ basename: str bytes: int @@ -142,17 +142,20 @@ def evict(self) -> int: """ return self._evict_raw() + self._evict_zip() - def set_up_local(self, listing: Set[str], safe_keep_zip: bool) -> int: + def set_up_local(self, listing: Set[str], safe_keep_zip: bool, safe_keep_parquet: bool) -> int: """Bring what shard files are present to a consistent state, returning whether present. Args: listing (Set[str]): The listing of all files under dirname/[split/]. This is listed once and then saved because there could potentially be very many shard files. - safe_keep_zip (bool): Whether to keep zip files when decompressing. Possible when - compression was used. Necessary when local is the remote or there is no remote. + safe_keep_zip (bool): Whether to keep or drop the zip form after decompression, if + applicable, safely taking into account whether this directory is the official copy. + safe_keep_parquet (bool): Whether to keep or drop the Parquet form after MDS + conversion, if applicable, safely taking into account whether this directory is the + official copy. Returns: - bool: Whether the shard is present. + int: This shard's current contribution to cache usage after normalization. """ # For raw/zip to be considered present, each raw/zip file must be present. raw_files_present = 0 @@ -337,6 +340,21 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: for i in range(len(self)): yield self[i] + def prepare(self, safe_keep_zip: bool, safe_keep_parquet: bool) -> int: + """Any work that needs to happen between shard download and shard access. + + Args: + safe_keep_zip (bool): Whether to keep or drop the zip form after decompression, if + applicable, safely taking into account whether this directory is the official copy. + safe_keep_parquet (bool): Whether to keep or drop the Parquet form after MDS + conversion, if applicable, safely taking into account whether this directory is the + official copy. + + Returns: + int: Resulting change in cache usage in bytes. + """ + return 0 + class MonoShard(Shard): """Provides random access to the samples of a mono shard. From 6d4bb55c7e30175354c55947e569898311ec3f61 Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 08:30:16 -0800 Subject: [PATCH 11/12] Fix (docstrings). --- streaming/format/index.py | 4 ++-- streaming/format/parquet/indexing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/streaming/format/index.py b/streaming/format/index.py index 87a6cc8e1..c111d746a 100644 --- a/streaming/format/index.py +++ b/streaming/format/index.py @@ -59,8 +59,8 @@ def materialize_index(*, large Parquet files. ``0`` means single-process; ``None`` means processes; positive int means that number of processes. Defaults to ``None``. show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. - columns (Dict[str, str], optional): For field names and types specified here, override the - inferred columns to configure it manually. Defaults to ``None``. + columns (Dict[str, Dict[str, str]], optional): For field names and types specified here, + override the inferred columns to configure it manually. Defaults to ``None``. match_columns (bool): Whether to require that all the dataset Parquets have exactly the same column configuration. This is a correctness guard rail, preventing non-dataset Parquet shards from sneaking into our dataset. Streaming for its part is fine with shards being diff --git a/streaming/format/parquet/indexing.py b/streaming/format/parquet/indexing.py index 16efa8380..8dfcd21b0 100644 --- a/streaming/format/parquet/indexing.py +++ b/streaming/format/parquet/indexing.py @@ -184,8 +184,8 @@ def index_parquet(*, large Parquet files. ``0`` means single-process; ``None`` means processes; positive int means that number of processes. Defaults to ``None``. show_progress (bool): Show progress bar for download/processing. Defaults to ``True``. - columns (Dict[str, str], optional): For field names and types specified here, override the - inferred columns to configure it manually. Defaults to ``None``. + columns (Dict[str, Dict[str, str]], optional): For field names and types specified here, + override the inferred columns to configure it manually. Defaults to ``None``. match_columns (bool): Whether to require that all the dataset Parquets have exactly the same Parquet columns. This is a correctness guard rail, preventing non-dataset Parquet shards from sneaking into our dataset. Streaming for its part is fine with shards being From 2a383b64e2f23a642c93737d0f470862373f3edb Mon Sep 17 00:00:00 2001 From: James Knighton Date: Fri, 15 Dec 2023 08:31:15 -0800 Subject: [PATCH 12/12] Fix (import). --- streaming/format/parquet/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/format/parquet/__init__.py b/streaming/format/parquet/__init__.py index 170429408..0a813c746 100644 --- a/streaming/format/parquet/__init__.py +++ b/streaming/format/parquet/__init__.py @@ -4,6 +4,6 @@ """Streaming Parquet shards.""" from streaming.format.parquet.indexing import index_parquet -from streaming.format.parquet.reader import ParquetShard +from streaming.format.parquet.shard import ParquetShard __all__ = ['index_parquet', 'ParquetShard']