Skip to content

Commit

Permalink
Fix bug in s3 resume
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Jun 20, 2024
1 parent 17fbd55 commit cf51aaa
Showing 1 changed file with 59 additions and 21 deletions.
80 changes: 59 additions & 21 deletions src/anemoi/registry/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
import os
import threading
from copy import deepcopy

import tqdm

Expand All @@ -43,7 +44,9 @@ def _s3_client():
return thread_local.s3_client


def _upload_file(source, target, overwrite=False, resume=False, verbosity=1):
def _upload_file(source, target, overwrite=False, resume=False, verbosity=1, config=None):
# from boto3.s3.transfer import TransferConfig
# TransferConfig(use_threads=False)
from botocore.exceptions import ClientError

assert target.startswith("s3://")
Expand All @@ -67,21 +70,21 @@ def _upload_file(source, target, overwrite=False, resume=False, verbosity=1):

if remote_size is not None:
if remote_size != size:
LOGGER.warning(f"{target} already exists, but with different size, re-uploading")
overwrite = True

if resume:
LOGGER.info(f"{target} already exists, skipping")
return

if remote_size is not None and not overwrite:
LOGGER.warning(
f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})"
)
elif resume:
# LOGGER.info(f"{target} already exists, skipping")
return size

if remote_size is not None and not overwrite and not resume:
raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip")

if verbosity > 0:
with tqdm.tqdm(total=size, unit="B", unit_scale=True, leave=False) as pbar:
s3_client.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x))
s3_client.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x), Config=config)
else:
s3_client.upload_file(source, bucket, key)
s3_client.upload_file(source, bucket, key, Config=config)

return size

Expand Down Expand Up @@ -167,11 +170,20 @@ def upload(source, target, overwrite=False, resume=False, threads=1, verbosity=T
_upload_file(source, target, overwrite, resume)


def _download_file(source, target, overwrite=False, resume=False, verbosity=0):
def _download_file(source, target, overwrite=False, resume=False, verbosity=0, config=None):
# from boto3.s3.transfer import TransferConfig

s3_client = _s3_client()
_, _, bucket, key = source.split("/", 3)

response = s3_client.head_object(Bucket=bucket, Key=key)
try:
response = s3_client.head_object(Bucket=bucket, Key=key)
except s3_client.exceptions.ClientError as e:
print(e.response["Error"]["Code"], e.response["Error"]["Message"], bucket, key)
if e.response["Error"]["Code"] == "404":
raise ValueError(f"{source} does not exist ({bucket}, {key})")
raise

size = int(response["ContentLength"])

if verbosity > 0:
Expand All @@ -182,21 +194,22 @@ def _download_file(source, target, overwrite=False, resume=False, verbosity=0):

if resume:
if os.path.exists(target):
if os.path.getsize(target) != size:
LOGGER.warning(f"{target} already with different size, re-downloading")
local_size = os.path.getsize(target)
if local_size != size:
LOGGER.warning(f"{target} already with different size, re-downloading (remote={size}, local={size})")
else:
if verbosity > 0:
LOGGER.info(f"{target} already exists, skipping")
return
# if verbosity > 0:
# LOGGER.info(f"{target} already exists, skipping")
return size

if os.path.exists(target) and not overwrite:
raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip")

if verbosity > 0:
with tqdm.tqdm(total=size, unit="B", unit_scale=True, leave=False) as pbar:
s3_client.download_file(bucket, key, target, Callback=lambda x: pbar.update(x))
s3_client.download_file(bucket, key, target, Callback=lambda x: pbar.update(x), Config=config)
else:
s3_client.download_file(bucket, key, target)
s3_client.download_file(bucket, key, target, Config=config)

return size

Expand Down Expand Up @@ -299,7 +312,7 @@ def _list_objects(target, batch=False):

for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if "Contents" in page:
objects = page["Contents"]
objects = deepcopy(page["Contents"])
if batch:
yield objects
else:
Expand Down Expand Up @@ -382,3 +395,28 @@ def list_folders(folder):
for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
if "CommonPrefixes" in page:
yield from [folder + _["Prefix"] for _ in page.get("CommonPrefixes")]


def object_info(target):
"""Get information about an object on S3.
Parameters
----------
target : str
The URL of a file or a folder on S3. The url should start with 's3://'.
Returns
-------
dict
A dictionary with information about the object.
"""

s3_client = _s3_client()
_, _, bucket, key = target.split("/", 3)

try:
return s3_client.head_object(Bucket=bucket, Key=key)
except s3_client.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
raise ValueError(f"{target} does not exist")
raise

0 comments on commit cf51aaa

Please sign in to comment.