Skip to content

Commit

Permalink
Addressing feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaevIlya committed Oct 29, 2024
1 parent 0d7a4ba commit 736cde4
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 33 deletions.
39 changes: 30 additions & 9 deletions s3torchconnector/src/s3torchconnector/s3iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,39 @@ def __iter__(self) -> Iterator[Any]:
worker_id = worker_info.id
num_workers = worker_info.num_workers

""""
In a multi-process setting (e.g., distributed training), the dataset needs to be
sharded across multiple processes. The following variables control this sharding:
_rank: The rank (index) of the current process within the world (group of processes).
_world_size: The total number of processes in the world (group).
In addition, within each process, the dataset may be further sharded across multiple
worker threads or processes (e.g., for data loading). The following variables control
this intra-process sharding:
worker_id: The ID of the current worker thread/process within the process.
num_workers: The total number of worker threads/processes within the process.
The _shard_index and _shard_count variables are computed based on the above values,
and they determine which subset of the dataset objects should be processed by the
current worker thread/process in the current process rank.
"""

self._shard_index = num_workers * self._rank + worker_id
self._shard_count = num_workers * self._world_size

if self._shard_index == 0 and self._shard_count == 1:
return map(
self._get_transformed_object,
self._get_dataset_objects(self._get_client()),
if self._shard_count > 1:
# we have more than one shard, so need to distribute dataset between shards
sharded_objects = (
obj
for idx, obj in enumerate(self._get_dataset_objects(self._get_client()))
if idx % self._shard_count == self._shard_index
)
return map(self._get_transformed_object, sharded_objects)

sharded_objects = (
obj
for idx, obj in enumerate(self._get_dataset_objects(self._get_client()))
if idx % self._shard_count == self._shard_index
# only one shard, so return the entire dataset
return map(
self._get_transformed_object,
self._get_dataset_objects(self._get_client()),
)
return map(self._get_transformed_object, sharded_objects)
12 changes: 9 additions & 3 deletions s3torchconnector/tst/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ def add(self, key: str, contents: bytes, **kwargs):
self.s3.put_object(Bucket=self.bucket, Key=full_key, Body=contents, **kwargs)
self.contents[full_key] = contents

def get_context_only(self):
def get_data_snapshot(self):
"""
Returns a read-only copy of the current instance's data.
The returned object cannot modify the actual S3 bucket.
Useful when passing data to another process without serializing s3 client
"""
return BucketPrefixData(
self.region, self.bucket, self.prefix, self.storage_class, self.contents
)
Expand Down Expand Up @@ -102,8 +108,8 @@ def image_directory_for_dp(request) -> BucketPrefixFixture:
# When conducting distributed training tests, be cautious about the number of files (images) in the test dataset.
# If the total number of images cannot be evenly divided by the number of workers,
# the DistributedSampler will duplicate a subset of the images across workers to ensure an equal
# distribution of data among all processes. This duplication of images can potentially invalidate or
# compromise the results of the distributed training test.
# distribution of data among all processes. This duplication of images will cause
# integration distributed training test to fail.
NUM_IMAGES = 36
IMAGE_SIZE = 100
return _create_image_directory_fixture(NUM_IMAGES, IMAGE_SIZE, request.node.name)
Expand Down
9 changes: 7 additions & 2 deletions s3torchconnector/tst/e2e/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from typing import Tuple


def _get_start_methods() -> set:
def _get_fork_methods() -> set:
"""
Get a set of valid start methods for PyTorch's multiprocessing.
On macOS, the 'fork' and 'forkserver' start methods are known to crash,
despite being reported as usable by PyTorch. This function filters out
those methods for macOS systems.
:rtype: object
Returns:
set: A set of valid start methods for the current platform.
"""
methods = set(torch.multiprocessing.get_all_start_methods())

Expand Down
22 changes: 13 additions & 9 deletions s3torchconnector/tst/e2e/test_distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import Counter
from itertools import product
from typing import Callable, TYPE_CHECKING
import hashlib

import pytest
import torch.multiprocessing as mp
Expand All @@ -17,10 +18,10 @@
from .conftest import BucketPrefixFixture, BucketPrefixData


from test_common import _get_start_methods, _read_data, _set_start_method
from test_common import _get_fork_methods, _read_data, _set_start_method


start_methods = _get_start_methods()
start_methods = _get_fork_methods()

import torch.distributed as dist

Expand Down Expand Up @@ -80,8 +81,8 @@ def dataloader_for_iterable(dataset_builder, image_directory, num_workers, batch
# Allow us to construct dataloaders in test with either S3MapDataset or S3IterableDataset
dataloader_builders = (dataloader_for_iterable, dataloader_for_map)

num_workers_to_test = [1,]
num_processes_to_test = [3 ]
num_workers_to_test = [1, 2, 3]
num_processes_to_test = [1, 2, 3]
test_args = list(
product(
sorted(start_methods),
Expand All @@ -108,7 +109,9 @@ def test_distributed_training(
):
# Generate unique port number in range [2000:61000] based on the test name
# to ensure that different test workers would use different ports
unique_port = hash(request.node) % 60000 + 2000
test_name = str(request.node)
test_name_hash = hashlib.sha256(test_name.encode()).hexdigest()
unique_port = int(test_name_hash, 16) % 60000 + 2000

manager = mp.Manager()
result_queue = manager.Queue()
Expand All @@ -122,7 +125,7 @@ def test_distributed_training(
start_method,
dataset_builder,
dataloader_builder,
image_directory_for_dp.get_context_only(),
image_directory_for_dp.get_data_snapshot(),
result_queue,
),
nprocs=num_processes,
Expand All @@ -137,14 +140,15 @@ def test_distributed_training(
for uris_seen in results:
combined_uris_seen.update(uris_seen)

# Check if each item in image_directory was seen exactly once
# Check all items in image_directory were seen
expected_uris = set(image_directory_for_dp.contents.keys())
assert set(combined_uris_seen.keys()) == expected_uris

# When conducting distributed training tests, be cautious about the number of files (images) in the test dataset.
# If the total number of images cannot be evenly divided by the number of workers,
# the DistributedSampler will duplicate a subset of the images across workers to ensure an equal
# distribution of data among all processes. This duplication of images can potentially invalidate or
# compromise the results of the distributed training test.
# distribution of data among all processes. This duplication of images will cause
# integration distributed training test to fail.
assert all(count == 1 for count in combined_uris_seen.values())


Expand Down
4 changes: 2 additions & 2 deletions s3torchconnector/tst/e2e/test_multiprocess_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
if TYPE_CHECKING:
from .conftest import BucketPrefixFixture

from test_common import _get_start_methods, _read_data, _set_start_method
from test_common import _get_fork_methods, _read_data, _set_start_method


start_methods = _get_start_methods()
start_methods = _get_fork_methods()


def from_prefix(cls, image_directory: BucketPrefixFixture, **kwargs):
Expand Down
Loading

0 comments on commit 736cde4

Please sign in to comment.