Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson committed Dec 20, 2024
1 parent f31f1f3 commit 6fa3022
Showing 1 changed file with 19 additions and 30 deletions.
49 changes: 19 additions & 30 deletions scripts/download_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,30 @@
)


def download_s3(uri: str, outfile_path: Path, tqdm_params: dict | None = None) -> None:
if not tqdm_params:
tqdm_params = {}
_logger.info("Downloading %s from %s...", outfile_path.name, uri)

bucket, key = uri.removeprefix("s3://").split("/", 1)

s3 = boto3.client("s3")
try:
response = s3.head_object(Bucket=bucket, Key=key)
except ClientError as e:
_logger.error("Encountered ClientError downloading %s: %s", uri, e.response)
raise e

file_size = response["ContentLength"]

with tqdm(total=file_size, **tqdm_params) as progress_bar:
s3.download_file(
Bucket=bucket,
Key=key,
Filename=outfile_path,
Callback=lambda bytes_amount: progress_bar.update(bytes_amount),
)


class UnversionedS3Data(UnversionedDataSource):
_datatype = "claims"
_filetype = "tsv" # most of this data is TSV, can manually set otherwise

def _download_data(self, version: str, outfile: Path) -> None:
download_s3(
f"s3://nch-igm-wagner-lab/dgidb/source_data/{self._src_name}/{self._src_name}_{self._datatype}.{self._filetype}",
outfile,
self._tqdm_params,
)
uri = f"s3://nch-igm-wagner-lab/dgidb/source_data/{self._src_name}/{self._src_name}_{self._datatype}.{self._filetype}"
_logger.info("Downloading %s from %s...", outfile.name, uri)
bucket, key = uri.removeprefix("s3://").split("/", 1)
s3 = boto3.client("s3")
try:
response = s3.head_object(Bucket=bucket, Key=key)
except ClientError as e:
_logger.error("Encountered ClientError downloading %s: %s", uri, e.response)
raise e

file_size = response["ContentLength"]

with tqdm(total=file_size, **self._tqdm_params) as progress_bar:
s3.download_file(
Bucket=bucket,
Key=key,
Filename=outfile,
Callback=lambda bytes_amount: progress_bar.update(bytes_amount),
)

def get_latest(
self, from_local: bool = False, force_refresh: bool = False
Expand Down

0 comments on commit 6fa3022

Please sign in to comment.