From cf51aaab7308320b986d39a378dd9079b7ead849 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 20 Jun 2024 20:48:25 +0100 Subject: [PATCH] Fix bug in s3 resume --- src/anemoi/registry/s3.py | 80 +++++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 21 deletions(-) diff --git a/src/anemoi/registry/s3.py b/src/anemoi/registry/s3.py index 00d8810..36d42e3 100644 --- a/src/anemoi/registry/s3.py +++ b/src/anemoi/registry/s3.py @@ -22,6 +22,7 @@ import logging import os import threading +from copy import deepcopy import tqdm @@ -43,7 +44,9 @@ def _s3_client(): return thread_local.s3_client -def _upload_file(source, target, overwrite=False, resume=False, verbosity=1): +def _upload_file(source, target, overwrite=False, resume=False, verbosity=1, config=None): + # from boto3.s3.transfer import TransferConfig + # TransferConfig(use_threads=False) from botocore.exceptions import ClientError assert target.startswith("s3://") @@ -67,21 +70,21 @@ def _upload_file(source, target, overwrite=False, resume=False, verbosity=1): if remote_size is not None: if remote_size != size: - LOGGER.warning(f"{target} already exists, but with different size, re-uploading") - overwrite = True - - if resume: - LOGGER.info(f"{target} already exists, skipping") - return - - if remote_size is not None and not overwrite: + LOGGER.warning( + f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})" + ) + elif resume: + # LOGGER.info(f"{target} already exists, skipping") + return size + + if remote_size is not None and not overwrite and not resume: raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") if verbosity > 0: with tqdm.tqdm(total=size, unit="B", unit_scale=True, leave=False) as pbar: - s3_client.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x)) + s3_client.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x), Config=config) else: - s3_client.upload_file(source, bucket, key) + s3_client.upload_file(source, bucket, key, Config=config) return size @@ -167,11 +170,20 @@ def upload(source, target, overwrite=False, resume=False, threads=1, verbosity=T _upload_file(source, target, overwrite, resume) -def _download_file(source, target, overwrite=False, resume=False, verbosity=0): +def _download_file(source, target, overwrite=False, resume=False, verbosity=0, config=None): + # from boto3.s3.transfer import TransferConfig + s3_client = _s3_client() _, _, bucket, key = source.split("/", 3) - response = s3_client.head_object(Bucket=bucket, Key=key) + try: + response = s3_client.head_object(Bucket=bucket, Key=key) + except s3_client.exceptions.ClientError as e: + print(e.response["Error"]["Code"], e.response["Error"]["Message"], bucket, key) + if e.response["Error"]["Code"] == "404": + raise ValueError(f"{source} does not exist ({bucket}, {key})") + raise + size = int(response["ContentLength"]) if verbosity > 0: @@ -182,21 +194,22 @@ def _download_file(source, target, overwrite=False, resume=False, verbosity=0): if resume: if os.path.exists(target): - if os.path.getsize(target) != size: - LOGGER.warning(f"{target} already with different size, re-downloading") + local_size = os.path.getsize(target) + if local_size != size: + LOGGER.warning(f"{target} already with different size, re-downloading (remote={size}, local={size})") else: - if verbosity > 0: - LOGGER.info(f"{target} already exists, skipping") - return + # if verbosity > 0: + # LOGGER.info(f"{target} already exists, skipping") + return size if os.path.exists(target) and not overwrite: raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") if verbosity > 0: with tqdm.tqdm(total=size, unit="B", unit_scale=True, leave=False) as pbar: - s3_client.download_file(bucket, key, target, Callback=lambda x: pbar.update(x)) + s3_client.download_file(bucket, key, target, Callback=lambda x: pbar.update(x), Config=config) else: - s3_client.download_file(bucket, key, target) + s3_client.download_file(bucket, key, target, Config=config) return size @@ -299,7 +312,7 @@ def _list_objects(target, batch=False): for page in paginator.paginate(Bucket=bucket, Prefix=prefix): if "Contents" in page: - objects = page["Contents"] + objects = deepcopy(page["Contents"]) if batch: yield objects else: @@ -382,3 +395,28 @@ def list_folders(folder): for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): if "CommonPrefixes" in page: yield from [folder + _["Prefix"] for _ in page.get("CommonPrefixes")] + + +def object_info(target): + """Get information about an object on S3. + + Parameters + ---------- + target : str + The URL of a file or a folder on S3. The url should start with 's3://'. + + Returns + ------- + dict + A dictionary with information about the object. + """ + + s3_client = _s3_client() + _, _, bucket, key = target.split("/", 3) + + try: + return s3_client.head_object(Bucket=bucket, Key=key) + except s3_client.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + raise ValueError(f"{target} does not exist") + raise