Skip to content

Commit

Permalink
Add implementation of readinto to decrease amount of copy operations
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaevIlya committed Apr 4, 2024
1 parent 0b3b63a commit af0a03d
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### New features
* Update S3ClientConfig to pass in the configuration for allowing unsigned requests, under boolean flag `unsigned`.
* Improve the performance of `s3reader` when utilized with `pytorch.load` by incorporating support for the `readinto` method.


## v1.2.2 (March 22, 2024)
Expand Down
32 changes: 31 additions & 1 deletion s3torchconnector/src/s3torchconnector/s3reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ def prefetch(self) -> None:
if self._stream is None:
self._stream = self._get_stream()

def readinto(self, buf) -> int:
"""Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf.
Return the number of bytes read. If no bytes available, zero is returned.
Args:
buf : writable bytes-like object
Returns:
int : numer of bytes read or None, if no bytes available
"""
buf_size = len(buf)
if self._position_at_end() or buf_size == 0:
# If no bytes are available or no place to write data, zero should be returned
return 0

self.prefetch()
assert self._stream is not None

cur_pos = self._position
# preload enough bytes in buffer
self.seek(buf_size, SEEK_CUR)
# restore position, before starting to write into buf
self._buffer.seek(cur_pos)
size = self._buffer.readinto(buf)
self._position = self._buffer.tell()

return size

def read(self, size: Optional[int] = None) -> bytes:
"""Read up to size bytes from the object and return them.
Expand Down Expand Up @@ -82,7 +110,9 @@ def read(self, size: Optional[int] = None) -> bytes:
if size is None or size < 0:
# Special case read() all to use O(n) algorithm
self._buffer.seek(0, SEEK_END)
self._buffer.write(b"".join(self._stream))
for batch in self._stream:
self._buffer.write(batch)

# Once we've emptied the buffer, we'll always be at EOF!
self._size = self._buffer.tell()
else:
Expand Down
9 changes: 7 additions & 2 deletions s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
# // SPDX-License-Identifier: BSD

import torch
import pytest

from s3torchconnector import S3Checkpoint
from models.net import Net


def test_general_checkpointing(checkpoint_directory):
tensor = torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
@pytest.mark.parametrize(
"tensor_dimensions",
[[3, 2], [10, 1024, 1024]],
)
def test_general_checkpointing(checkpoint_directory, tensor_dimensions):
tensor = torch.rand(tensor_dimensions)
checkpoint_name = "general_checkpoint.pt"
checkpoint = S3Checkpoint(region=checkpoint_directory.region)
s3_uri = f"{checkpoint_directory.s3_uri}/{checkpoint_name}"
Expand Down
111 changes: 110 additions & 1 deletion s3torchconnector/tst/unit/test_s3reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

log = logging.getLogger(__name__)


TEST_BUCKET = "test-bucket"
TEST_KEY = "test-key"
MOCK_OBJECT_INFO = Mock(ObjectInfo)
Expand Down Expand Up @@ -182,14 +181,17 @@ def test_over_read(stream: List[bytes], overread: int):
def test_seeks_end():
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([]))
s3reader._size = 10
buf = memoryview(bytearray(10))

assert s3reader.seek(0, SEEK_END) == 10
assert s3reader.tell() == 10
assert s3reader.read() == b""
assert s3reader.readinto(buf) == 0

assert s3reader.seek(0, SEEK_CUR) == 10
assert s3reader.tell() == 10
assert s3reader.read() == b""
assert s3reader.readinto(buf) == 0


def test_not_writable():
Expand Down Expand Up @@ -301,3 +303,110 @@ def test_s3reader_writes_size_after_read_all_explicit(stream: List[bytes]):
assert s3reader.read(1) == b""
# Once we've read past the end, we know how big the file is
assert s3reader._size == total_length


@given(
lists(binary(min_size=2, max_size=3), min_size=0, max_size=3),
integers(min_value=0, max_value=1),
)
def test_s3reader_readinto_buffer_smaller_then_chunks(
stream: List[bytes], buf_size: int
):
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
assert s3reader._size is None
total_length = sum(map(len, stream))
buf = memoryview(bytearray(buf_size))
# We're able to read all the available data or the data that can be accommodated in buf
if buf_size > 0 and total_length > 0:
assert s3reader.readinto(buf) == buf_size
assert s3reader.tell() == buf_size
# We haven't reached the end yet
assert s3reader._size is None
# confirm that read data is the same as in source
assert buf[:buf_size] == (b"".join(stream))[:buf_size]
else:
assert s3reader.readinto(buf) == 0
assert s3reader.tell() == 0


@given(
lists(binary(min_size=2, max_size=3), min_size=1, max_size=3),
integers(min_value=3, max_value=10),
)
def test_s3reader_readinto_buffer_bigger_then_chunks(
stream: List[bytes], buf_size: int
):
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
assert s3reader._size is None
total_length = sum(map(len, stream))
buf = memoryview(bytearray(buf_size))
should_read_bytes_in_first_pass = min(buf_size, total_length)
# We're able to read all the available data or the data that can be accommodated in buf
assert s3reader.readinto(buf) == should_read_bytes_in_first_pass
assert s3reader.tell() == should_read_bytes_in_first_pass
all_data = b"".join(stream)
# confirm that read data is the same as in source
assert (
buf[:should_read_bytes_in_first_pass]
== all_data[:should_read_bytes_in_first_pass]
)
if total_length < buf_size:
assert s3reader._size == total_length

should_read_bytes_in_second_pass = max(
min(buf_size, total_length - should_read_bytes_in_first_pass), 0
)
if should_read_bytes_in_second_pass > 0:
# We're able to read all the available data or the data that can be accommodated in buf
assert s3reader.readinto(buf) == should_read_bytes_in_second_pass
total_read = should_read_bytes_in_first_pass + should_read_bytes_in_second_pass
assert s3reader.tell() == total_read
# confirm that read data is the same as in source
assert (
buf[:should_read_bytes_in_second_pass]
== all_data[should_read_bytes_in_first_pass:total_read]
)
if total_length < total_read:
assert s3reader._size == total_read


@given(
lists(binary(min_size=2, max_size=12), min_size=1, max_size=5),
integers(min_value=3, max_value=10),
integers(min_value=0, max_value=1),
)
def test_s3reader_mixing_readinto_and_read(
stream: List[bytes], buf_size: int, flip: int
):
position = 0
loops_count = 20
all_data = b"".join(stream)
total_length = len(all_data)
buf = memoryview(bytearray(buf_size))
s3reader = S3Reader(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter(stream))
for i in range(0, loops_count):
if position >= total_length:
break

if (i + flip) % 2 == 0:
result = s3reader.read(buf_size)
# confirm that read data is the same as in source
if position + buf_size < total_length:
assert result[:buf_size] == all_data[position : position + buf_size]
else:
read_bytes = total_length - position
assert result[:read_bytes] == all_data[position:total_length]
position += buf_size
else:
read_bytes = s3reader.readinto(buf)
# confirm that read data is the same as in source
assert buf[position:read_bytes] == all_data[position:read_bytes]
position += read_bytes

if position > total_length:
# we read all the data, it is time to stop
assert s3reader.tell() == total_length
break
else:
# confirm that position is as expected
assert s3reader.tell() == position

0 comments on commit af0a03d

Please sign in to comment.