diff --git a/scripts/s3_downloader.py b/scripts/s3_downloader.py index da186f6c76..5ec41825bf 100644 --- a/scripts/s3_downloader.py +++ b/scripts/s3_downloader.py @@ -52,19 +52,58 @@ class S3Assets: def __init__(self, bucket: str, prefix: str, - destination: str,): + destination: str, ): self._client = boto3.client( 's3', endpoint_url=AWS_S3_ENDPOINT, region_name=AWS_DEFAULT_REGION, aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, - verify = verify_ssl_certs + verify=verify_ssl_certs ) self.bucket = bucket self.prefix = prefix self.destination = destination + @staticmethod + def get_local_values(destination, basename): + lmeval_file = os.path.join(destination, basename) + lmeval_folder = os.path.dirname(lmeval_file) + return lmeval_file, lmeval_folder + + @staticmethod + def _get_basename(key, prefix) -> str: + if key.startswith(prefix): + return key[len(prefix):] + else: + return key + + def process_object(self, _object): + key = _object['Key'] + + prefix = self.prefix if self.prefix.endswith('/') else self.prefix + '/' + + basename_path = self._get_basename(key, prefix) + + if key.endswith('/'): + return + + lmeval_file, lmeval_folder = self.get_local_values(self.destination, basename_path) + + if not os.path.exists(lmeval_folder): + os.makedirs(lmeval_folder, exist_ok=True) + + logging.info(f"Downloading s3://{self.bucket}/{key} -> {lmeval_file}") + self._client.download_file(self.bucket, key, lmeval_file) + + def process_page(self, page): + if 'Contents' not in page: + logging.error(f"No objects found with prefix '{self.prefix}' in bucket '{self.bucket}'.") + return + + for _object in page['Contents']: + self.process_object(_object) + def download(self) -> None: """ Download the contents of a bucket, with a specific prefix, locally. @@ -72,32 +111,11 @@ def download(self) -> None: paginator = self._client.get_paginator('list_objects_v2') try: for page in paginator.paginate(Bucket=self.bucket, Prefix=self.prefix): - if 'Contents' not in page: - logging.error(f"No objects found with prefix '{self.prefix}' in bucket '{self.bucket}'.") - return - - for obj in page['Contents']: - key = obj['Key'] - - prefix = self.prefix if self.prefix.endswith('/') else self.prefix + '/' - if key.startswith(prefix): - rel_path = key[len(prefix):] - else: - rel_path = key - - if key.endswith('/'): - continue - - local_file = os.path.join(self.destination, rel_path) - local_folder = os.path.dirname(local_file) - if not os.path.exists(local_folder): - os.makedirs(local_folder, exist_ok=True) - - logging.info(f"Downloading s3://{self.bucket}/{key} -> {local_file}") - self._client.download_file(self.bucket, key, local_file) + self.process_page(page) except ClientError as e: logging.error(f"Error: {e}") + if __name__ == "__main__": s3Assets = S3Assets(bucket=AWS_S3_BUCKET, prefix=AWS_PATH, destination=DESTINATION) s3Assets.download()