Skip to content

Commit

Permalink
Add implementation of readinto to decrease amount of copy operations
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaevIlya committed Apr 4, 2024
1 parent 0b3b63a commit 407ffce
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### New features
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.
* Improve the performance of `s3reader` when utilized with `pytorch.load` by incorporating support for the `readinto` method.


## v1.2.2 (March 22, 2024)
Expand Down
45 changes: 39 additions & 6 deletions s3torchconnector/src/s3torchconnector/s3reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# // SPDX-License-Identifier: BSD

import io
import time
from functools import cached_property
from io import SEEK_CUR, SEEK_END, SEEK_SET
from typing import Callable, Optional, Iterator
Expand All @@ -13,11 +14,11 @@ class S3Reader(io.BufferedIOBase):
"""A read-only, file like representation of a single object stored in S3."""

def __init__(
self,
bucket: str,
key: str,
get_object_info: Callable[[], ObjectInfo],
get_stream: Callable[[], GetObjectStream],
self,
bucket: str,
key: str,
get_object_info: Callable[[], ObjectInfo],
get_stream: Callable[[], GetObjectStream],
):
if not bucket:
raise ValueError("Bucket should be specified")
Expand Down Expand Up @@ -53,6 +54,36 @@ def prefetch(self) -> None:
if self._stream is None:
self._stream = self._get_stream()

def readinto(self, buf: memoryview) -> int | None:
"""Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf.
Return the number of bytes read. If no bytes available, None is returned.
Args:
buf : writable bytes-like object
Returns:
int : numer of bytes read or None, if no bytes available
"""
buf_size = len(buf)
if self._position_at_end() or buf_size == 0:
# If no bytes are available or no place to write data, None should be returned
return None

self.prefetch()
assert self._stream is not None

cur_pos = self._position
# preload enough bytes in buffer
self.seek(buf_size, SEEK_CUR)
# restore position, before starting to write into buf
self._buffer.seek(cur_pos)
size = self._buffer.readinto(buf)
self._position = self._buffer.tell()

if size == 0:
return None
return size

def read(self, size: Optional[int] = None) -> bytes:
"""Read up to size bytes from the object and return them.
Expand Down Expand Up @@ -82,7 +113,9 @@ def read(self, size: Optional[int] = None) -> bytes:
if size is None or size < 0:
# Special case read() all to use O(n) algorithm
self._buffer.seek(0, SEEK_END)
self._buffer.write(b"".join(self._stream))
for batch in self._stream:
self._buffer.write(batch)

# Once we've emptied the buffer, we'll always be at EOF!
self._size = self._buffer.tell()
else:
Expand Down
9 changes: 7 additions & 2 deletions s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
# // SPDX-License-Identifier: BSD

import torch
import pytest

from s3torchconnector import S3Checkpoint
from models.net import Net


def test_general_checkpointing(checkpoint_directory):
tensor = torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
@pytest.mark.parametrize(
"tensor_dimensions",
[[3, 2], [10, 1024, 1024]],
)
def test_general_checkpointing(checkpoint_directory, tensor_dimensions):
tensor = torch.rand(tensor_dimensions)
checkpoint_name = "general_checkpoint.pt"
checkpoint = S3Checkpoint(region=checkpoint_directory.region)
s3_uri = f"{checkpoint_directory.s3_uri}/{checkpoint_name}"
Expand Down
99 changes: 96 additions & 3 deletions s3torchconnector/tst/unit/test_s3reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

log = logging.getLogger(__name__)


TEST_BUCKET = "test-bucket"
TEST_KEY = "test-key"
MOCK_OBJECT_INFO = Mock(ObjectInfo)
Expand Down Expand Up @@ -137,7 +136,7 @@ def test_s3reader_read(stream_and_positions: Tuple[List[bytes], List[int]]):

assert s3reader.read(size) == bytesio.read(size)
assert (
s3reader.tell() == s3reader._buffer.tell() == bytesio.tell() == new_position
s3reader.tell() == s3reader._buffer.tell() == bytesio.tell() == new_position
)


Expand Down Expand Up @@ -182,14 +181,17 @@ def test_over_read(stream: List[bytes], overread: int):
def test_seeks_end():
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([]))
s3reader._size = 10
buf = memoryview(bytearray(10))

assert s3reader.seek(0, SEEK_END) == 10
assert s3reader.tell() == 10
assert s3reader.read() == b""
assert s3reader.readinto(buf) is None

assert s3reader.seek(0, SEEK_CUR) == 10
assert s3reader.tell() == 10
assert s3reader.read() == b""
assert s3reader.readinto(buf) is None


def test_not_writable():
Expand Down Expand Up @@ -268,7 +270,7 @@ def test_s3reader_relative_seek(stream_and_positions: Tuple[List[bytes], List[in
bytesio.seek(new_position)

assert (
s3reader.tell() == s3reader._buffer.tell() == bytesio.tell() == new_position
s3reader.tell() == s3reader._buffer.tell() == bytesio.tell() == new_position
)
assert s3reader.read() == bytesio.read()

Expand Down Expand Up @@ -301,3 +303,94 @@ def test_s3reader_writes_size_after_read_all_explicit(stream: List[bytes]):
assert s3reader.read(1) == b""
# Once we've read past the end, we know how big the file is
assert s3reader._size == total_length


@given(
lists(binary(min_size=2, max_size=3), min_size=0, max_size=3),
integers(min_value=0, max_value=1)
)
def test_s3reader_readinto_buffer_smaller_then_chunks(stream: List[bytes], buf_size: int):
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
assert s3reader._size is None
total_length = sum(map(len, stream))
buf = memoryview(bytearray(buf_size))
# We're able to read all the available data or the data that can be accommodated in buf
if buf_size > 0 and total_length > 0:
assert s3reader.readinto(buf) == buf_size
assert s3reader.tell() == buf_size
# We haven't reached the end yet
assert s3reader._size is None
# confirm that read data is the same as in source
assert buf[:buf_size] == (b"".join(stream))[:buf_size]
else:
assert s3reader.readinto(buf) is None
assert s3reader.tell() == 0

@given(
lists(binary(min_size=2, max_size=3), min_size=1, max_size=3),
integers(min_value=3, max_value=10)
)
def test_s3reader_readinto_buffer_bigger_then_chunks(stream: List[bytes], buf_size: int):
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
assert s3reader._size is None
total_length = sum(map(len, stream))
buf = memoryview(bytearray(buf_size))
should_read_bytes_in_first_pass = min(buf_size, total_length)
# We're able to read all the available data or the data that can be accommodated in buf
assert s3reader.readinto(buf) == should_read_bytes_in_first_pass
assert s3reader.tell() == should_read_bytes_in_first_pass
all_data = b"".join(stream)
# confirm that read data is the same as in source
assert buf[:should_read_bytes_in_first_pass] == all_data[:should_read_bytes_in_first_pass]
if total_length < buf_size:
assert s3reader._size == total_length

should_read_bytes_in_second_pass = max(min(buf_size, total_length - should_read_bytes_in_first_pass), 0)
if should_read_bytes_in_second_pass > 0:
# We're able to read all the available data or the data that can be accommodated in buf
assert s3reader.readinto(buf) == should_read_bytes_in_second_pass
total_read = should_read_bytes_in_first_pass + should_read_bytes_in_second_pass
assert s3reader.tell() == total_read
# confirm that read data is the same as in source
assert buf[:should_read_bytes_in_second_pass] == all_data[should_read_bytes_in_first_pass:total_read]
if total_length < total_read:
assert s3reader._size == total_read

@given(
lists(binary(min_size=2, max_size=12), min_size=1, max_size=5),
integers(min_value=3, max_value=10),
integers(min_value=0, max_value=1)
)
def test_s3reader_mixing_readinto_and_read(stream: List[bytes], buf_size: int, flip: int):
position = 0
loops_count = 20
all_data = b"".join(stream)
total_length = len(all_data)
buf = memoryview(bytearray(buf_size))
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
for i in range(0, loops_count):
if position >= total_length:
break

if (i + flip) % 2 == 0:
result = s3reader.read(buf_size)
# confirm that read data is the same as in source
if position + buf_size < total_length:
assert result[:buf_size] == all_data[position:position + buf_size]
else:
read_bytes = total_length - position
assert result[:read_bytes] == all_data[position:total_length]
position += buf_size
else:
read_bytes = s3reader.readinto(buf)
# confirm that read data is the same as in source
assert buf[position:read_bytes] == all_data[position:read_bytes]
position += read_bytes

if position > total_length:
# we read all the data, it is time to stop
assert s3reader.tell() == total_length
break
else:
# confirm that position is as expected
assert s3reader.tell() == position

0 comments on commit 407ffce

Please sign in to comment.