Skip to content

Commit

Permalink
Implement multipart file downloads. Closes #131
Browse files Browse the repository at this point in the history
  • Loading branch information
kyboi committed Apr 29, 2024
1 parent 857b1d2 commit 3fca7d8
Show file tree
Hide file tree
Showing 3 changed files with 401 additions and 361 deletions.
98 changes: 75 additions & 23 deletions aioboto3/s3/inject.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import aiofiles
import inspect
import logging
from typing import Optional, Callable, BinaryIO, Dict, Any
Expand Down Expand Up @@ -52,18 +53,18 @@ async def object_summary_load(self, *args, **kwargs):
async def download_file(
self, Bucket, Key, Filename, ExtraArgs=None, Callback=None, Config=None
):
"""Download an S3 object to a file.
"""Download an S3 object to a file asynchronously.
Usage::
import boto3
s3 = boto3.resource('s3')
s3.meta.client.download_file('mybucket', 'hello.txt', '/tmp/hello.txt')
import aioboto3
s3 = aioboto3.resource('s3')
await s3.meta.client.download_file('mybucket', 'hello.txt', '/tmp/hello.txt')
Similar behavior as S3Transfer's download_file() method,
except that parameters are capitalized.
Similar behaviour as S3Transfer's download_file() method,
except that parameters are capitalised.
"""
with open(Filename, 'wb') as open_file:
async with aiofiles.open(Filename, 'wb') as open_file:
await download_fileobj(
self,
Bucket,
Expand All @@ -75,23 +76,47 @@ async def download_file(
)


async def _download_part(self, bucket, key, headers, start, file, semaphore, callback=None):
async with semaphore: # limit number of concurrent downloads
response = await self.get_object(
Bucket=bucket, Key=key, Range=headers['Range']
)
content = await response['Body'].read()

# Check if it's aiofiles file
if inspect.iscoroutinefunction(file.seek) and inspect.iscoroutinefunction(file.write):
await file.seek(start)
await file.write(content)
else:
# Fallback to synchronous operations for file objects that are not async
file.seek(start)
file.write(content)

# Call the wrapper callback with the number of bytes written, if provided
if callback:
try:
callback(len(content))
except: # noqa: E722
pass


async def download_fileobj(
self, Bucket, Key, Fileobj, ExtraArgs=None, Callback=None, Config=None
):
"""Download an object from S3 to a file-like object.
The file-like object must be in binary mode.
This is a managed transfer which will perform a multipart download in
multiple threads if necessary.
This is a managed transfer which will perform a multipart download
with asyncio if necessary.
Usage::
import boto3
s3 = boto3.client('s3')
with open('filename', 'wb') as data:
s3.download_fileobj('mybucket', 'mykey', data)
async with aiofiles.open('filename', 'wb') as data:
await s3.download_fileobj('mybucket', 'mykey', data)
:type Fileobj: a file-like object
:param Fileobj: A file-like object to download into. At a minimum, it must
Expand Down Expand Up @@ -126,24 +151,51 @@ async def download_fileobj(
raise ClientError({'Error': {'Code': '404', 'Message': 'Not Found'}}, 'HeadObject')
raise

body = resp['Body']

while True:
data = await body.read(4096)

if data == b'':
break
# Keep track of total downloaded bytes
total_downloaded = 0

def wrapper_callback(bytes_transferred):
nonlocal total_downloaded
total_downloaded += bytes_transferred
if Callback:
try:
Callback(len(data))
Callback(total_downloaded)
except: # noqa: E722
pass

o = Fileobj.write(data)
if inspect.isawaitable(o):
await o
await asyncio.sleep(0.0)
# Size of each part (8MB)
part_size = 8 * 1024 * 1024

try:
# Get object metadata to determine the total size
response = await self.head_object(Bucket=Bucket, Key=Key, **ExtraArgs)
total_size = response['ContentLength']
total_parts = (total_size + part_size - 1) // part_size

# Semaphore to limit the number of concurrent downloads
semaphore = asyncio.Semaphore(10)

tasks = []
for i in range(total_parts):
start = i * part_size
end = min(
start + part_size, total_size
) # Ensure we don't go beyond the total size
headers = {'Range': f'bytes={start}-{end - 1}'}
# Create a task for each part download
tasks.append(
_download_part(self, Bucket, Key, headers, start, Fileobj, semaphore, wrapper_callback)
)

# Run all the download tasks concurrently
await asyncio.gather(*tasks)

logger.info(f'Downloaded file from {Bucket}/{Key}')

except ClientError as e:
raise Exception(
f"Couldn't download file from {Bucket}/{Key}"
) from e


async def upload_fileobj(
Expand Down
Loading

0 comments on commit 3fca7d8

Please sign in to comment.