Skip to content

Commit

Permalink
Revert S3IterableDataset to share dataset across workers by default
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaevIlya committed Oct 29, 2024
1 parent 736cde4 commit d3baaed
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 15 deletions.
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
* Add support of distributed training to S3IterableDataset

### Breaking changes
* When using the S3IterableDataset with multiple workers, each worker will have access to a distinct subset
of the data by default. If you require all workers within the same process to have access to the entire dataset,
you should set the `share_dataset_within_process` parameter to True during the creation of the S3IterableDataset.
* No breaking changes.

## v1.2.7 (October 29, 2024)

Expand Down
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,35 @@ For example, assuming the following directory bucket name `my-test-bucket--usw2-
usw2-az1, then the URI used will look like: `s3://my-test-bucket--usw2-az1--x-s3/<PREFIX>` (**please note that the
prefix for Amazon S3 Express One Zone should end with '/'**), paired with region us-west-2.

## Parallel/Distributed Training

Amazon S3 Connector for PyTorch provides support for parallel and distributed training with PyTorch,
allowing you to leverage multiple processes and nodes for efficient data loading and training.
Both S3IterableDataset and S3MapDataset can be used for this purpose.

### S3IterableDataset

The S3IterableDataset can be directly passed to PyTorch's DataLoader for parallel and distributed training.
By default, all worker processes will share the same list of training objects. However,
if you need each worker to have access to a unique portion of the dataset for better parallelization,
you can enable dataset sharding using the `enable_sharding` parameter.
```
dataset = S3IterableDataset.from_prefix(DATASET_URI, region=REGION, enable_sharding=True)
dataloader = DataLoader(dataset, num_workers=4)
```
When `enable_sharding` is set to True, the dataset will be automatically sharded across available number of workers.
This sharding mechanism supports both parallel training on a single host and distributed training across multiple hosts.
Each worker, regardless of its host, will load and process a distinct subset of the dataset.
### S3MapDataset

For the S3MapDataset, you need to pass it to DataLoader along with a DistributedSampler wrapped around it.
The DistributedSampler ensures that each worker or node receives a unique subset of the dataset,
enabling efficient parallel and distributed training.
```
dataset = S3MapDataset.from_prefix(DATASET_URI, region=REGION)
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, num_workers=4)
```
## Lightning Integration

Amazon S3 Connector for PyTorch includes an integration for PyTorch Lightning, featuring S3LightningCheckpoint, an
Expand Down
18 changes: 9 additions & 9 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def __init__(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
share_dataset_within_process: bool = False,
enable_sharding: bool = False,
):
self._get_dataset_objects = get_dataset_objects
self._transform = transform
self._region = region
self._endpoint = endpoint
self._s3client_config = s3client_config
self._client = None
self._share_dataset_within_process = share_dataset_within_process
self._enable_sharding = enable_sharding

self._rank = 0
self._world_size = 1
Expand All @@ -66,7 +66,7 @@ def from_objects(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
share_dataset_within_process: bool = False,
enable_sharding: bool = False,
):
"""Returns an instance of S3IterableDataset using the S3 URI(s) provided.
Expand All @@ -76,7 +76,7 @@ def from_objects(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
share_dataset_within_process: share the dataset across workers within the same process, but use different datasets for different processes. Turned off by default.
enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -91,7 +91,7 @@ def from_objects(
endpoint,
transform=transform,
s3client_config=s3client_config,
share_dataset_within_process=share_dataset_within_process,
enable_sharding=enable_sharding,
)

@classmethod
Expand All @@ -103,7 +103,7 @@ def from_prefix(
endpoint: Optional[str] = None,
transform: Callable[[S3Reader], Any] = identity,
s3client_config: Optional[S3ClientConfig] = None,
share_dataset_within_process: bool = False,
enable_sharding: bool = False,
):
"""Returns an instance of S3IterableDataset using the S3 URI provided.
Expand All @@ -113,7 +113,7 @@ def from_prefix(
endpoint(str): AWS endpoint of the S3 bucket where the objects are stored.
transform: Optional callable which is used to transform an S3Reader into the desired type.
s3client_config: Optional S3ClientConfig with parameters for S3 client.
share_dataset_within_process: share the dataset across workers within the same process, but use different datasets for different processes. Turned off by default.
enable_sharding: If True, shard the dataset across multiple workers for parallel data loading. If False (default), each worker loads the entire dataset independently.
Returns:
S3IterableDataset: An IterableStyle dataset created from S3 objects.
Expand All @@ -128,7 +128,7 @@ def from_prefix(
endpoint,
transform=transform,
s3client_config=s3client_config,
share_dataset_within_process=share_dataset_within_process,
enable_sharding=enable_sharding,
)

def _get_client(self):
Expand All @@ -150,7 +150,7 @@ def _get_transformed_object(self, bucket_key: S3BucketKeyData) -> Any:
def __iter__(self) -> Iterator[Any]:
worker_id = 0
num_workers = 1
if not self._share_dataset_within_process:
if self._enable_sharding:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_info.id
Expand Down
1 change: 1 addition & 0 deletions s3torchconnector/tst/e2e/test_distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def dataloader_for_iterable(dataset_builder, image_directory, num_workers, batch
dataset = dataset_builder(
cls=S3IterableDataset,
image_directory=image_directory,
enable_sharding=True,
)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
return dataloader
Expand Down
3 changes: 1 addition & 2 deletions s3torchconnector/tst/e2e/test_multiprocess_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_s3iterable_dataset_multiprocess_torchdata(
):
_set_start_method(start_method)
dataset = dataset_builder(
S3IterableDataset, image_directory, share_dataset_within_process=True
S3IterableDataset, image_directory
)

dataset = IterableWrapper(dataset, deepcopy=False).sharding_filter().map(_read_data)
Expand Down Expand Up @@ -90,7 +90,6 @@ def test_s3iterable_dataset_multiprocess(
S3IterableDataset,
image_directory,
transform=_extract_object_data,
share_dataset_within_process=True,
)

num_workers = 3
Expand Down
3 changes: 2 additions & 1 deletion s3torchconnector/tst/unit/test_s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_dataset_creation_from_objects_against_multiple_workers(
get_world_size_mock.return_value = world_size

object_uris = [f"{S3_PREFIX}/{key}" for key in all_keys]
dataset = S3IterableDataset.from_objects(object_uris, region=TEST_REGION)
dataset = S3IterableDataset.from_objects(object_uris, region=TEST_REGION, enable_sharding=True)

# use mock client for unit testing
client = _create_mock_client_with_dummy_objects(TEST_BUCKET, all_keys)
Expand Down Expand Up @@ -579,6 +579,7 @@ def test_dataset_creation_from_prefix_against_multiple_workers(
dataset = S3IterableDataset.from_prefix(
s3_uri=prefix,
region=TEST_REGION,
enable_sharding=True,
)

# use mock client for unit testing
Expand Down

0 comments on commit d3baaed

Please sign in to comment.