Skip to content

Commit

Permalink
Add functionality for setting custom names
Browse files Browse the repository at this point in the history
of downloaded files
  • Loading branch information
marcellevstek committed Sep 30, 2024
1 parent 74ba4ec commit c6a9d0c
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 14 deletions.
9 changes: 9 additions & 0 deletions docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ Change Log

All notable changes to this project are documented in this file.

===================
21.3.0 - 2024-09-30
===================

Added
-----
- Add an argument for custom file naming of downloaded files
and propagate this change in ``Data`` resource


===================
21.2.0 - 2024-07-10
Expand Down
36 changes: 29 additions & 7 deletions src/resdk/resolwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,11 @@ def get_or_run(self, slug=None, input={}):
return Data(resolwe=self, **model_data)

def _download_files(
self, files: List[Union[str, Path]], download_dir=None, show_progress=True
self,
files: List[Union[str, Path]],
download_dir=None,
show_progress=True,
custom_file_names: Union[List[str], None] = None,
):
"""Download files.
Expand All @@ -471,6 +475,8 @@ def _download_files(
:param files: files to download
:type files: list of file URI
:param custom_file_names: list of file names to save the downloaded files as
:type custom_file_names: list of strings or None
:param download_dir: download directory
:type download_dir: string
:rtype: None
Expand All @@ -484,17 +490,24 @@ def _download_files(
"Download directory does not exist: {}".format(download_dir)
)

if not custom_file_names:
custom_file_names = len(files) * [None]
else:
if not len(files) == len(custom_file_names):
raise ValueError(
"Number of files and their corresponding custom names must be equal."
)

if not files:
self.logger.info("No files to download.")

else:
self.logger.info("Downloading files to %s:", download_dir)
# Store the sizes of files in the given directory.
# Use the dictionary to cache the responses.
sizes: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
checksums: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))

for file_uri in files:
for file_uri, custom_file_name in zip(files, custom_file_names):
file_name = os.path.basename(file_uri)
file_path = os.path.dirname(file_uri)
file_url = urljoin(self.url, "data/{}".format(file_uri))
Expand All @@ -518,12 +531,19 @@ def _download_files(

file_size = sizes[file_directory][file_name]

if custom_file_name:
desc = f"Downloading file {file_name} as {custom_file_name}"
actual_file_name = custom_file_name
else:
desc = f"Downloading file {file_name}"
actual_file_name = file_name

with tqdm.tqdm(
total=file_size,
disable=not show_progress,
desc=f"Downloading file {file_name}",
desc=desc,
) as progress_bar, open(
os.path.join(download_dir, file_path, file_name), "wb"
os.path.join(download_dir, file_path, actual_file_name), "wb"
) as file_handle:
response = self.session.get(file_url, stream=True, auth=self.auth)

Expand All @@ -540,10 +560,12 @@ def _download_files(
# checksums that are difficult to reproduce here.
return
expected_md5 = checksums[file_directory][file_name]
computed_md5 = md5(os.path.join(download_dir, file_path, file_name))
computed_md5 = md5(
os.path.join(download_dir, file_path, actual_file_name)
)
if expected_md5 != computed_md5:
raise ValueError(
f"Checksum of downloaded file {file_name} does not match the expected value."
f"Checksum of downloaded file {actual_file_name} does not match the expected value."
)

def data_usage(self, **query_params):
Expand Down
2 changes: 1 addition & 1 deletion src/resdk/resources/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def download(self, file_name=None, field_name=None, download_dir=None):
data_files = data.files(file_name, field_name)
files.extend("{}/{}".format(data.id, file_name) for file_name in data_files)

self.resolwe._download_files(files, download_dir)
self.resolwe._download_files(files=files, download_dir=download_dir)


class Collection(CollectionRelationsMixin, BaseCollection):
Expand Down
22 changes: 20 additions & 2 deletions src/resdk/resources/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def files(self, file_name=None, field_name=None):

return file_list

def download(self, file_name=None, field_name=None, download_dir=None):
def download(
self, file_name=None, field_name=None, download_dir=None, custom_file_name=None
):
"""Download Data object's files and directories.
Download files and directories from the Resolwe server to the
Expand All @@ -336,6 +338,8 @@ def download(self, file_name=None, field_name=None, download_dir=None):
:type field_name: string
:param download_dir: download path
:type download_dir: string
:param custom_file_name: custom file name
:type custom_file_name: string
:rtype: None
Data objects can contain multiple files and directories. All are
Expand All @@ -353,7 +357,21 @@ def download(self, file_name=None, field_name=None, download_dir=None):
"{}/{}".format(self.id, fname)
for fname in self.files(file_name, field_name)
]
self.resolwe._download_files(files, download_dir)

# Only applies if downloading a single file
custom_file_names = None
if custom_file_name:
if file_name or field_name:
custom_file_names = [custom_file_name] * len(files)
else:
logging.warning(
"Setting a custom file name is not supported "
"without specifying file name or field name."
)

self.resolwe._download_files(
files=files, download_dir=download_dir, custom_file_names=custom_file_names
)

def stdout(self):
"""Return process standard output (stdout.txt file content).
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/test_resolwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,15 @@ def test_fail_if_bad_dir(self, resolwe_mock):

message = "Download directory does not exist: .*"
with self.assertRaisesRegex(ValueError, message):
Resolwe._download_files(resolwe_mock, self.file_list, "/does/not/exist/")
Resolwe._download_files(
resolwe_mock, files=self.file_list, download_dir="/does/not/exist/"
)

@patch("resdk.resolwe.Resolwe", spec=True)
def test_empty_file_list(self, resolwe_mock):
resolwe_mock.configure_mock(**self.config)

Resolwe._download_files(resolwe_mock, [], download_dir=self.tmp_dir)
Resolwe._download_files(resolwe_mock, files=[], download_dir=self.tmp_dir)

resolwe_mock.logger.info.assert_called_once_with("No files to download.")

Expand All @@ -474,7 +476,7 @@ def test_bad_response(self, resolwe_mock):
with self.assertRaisesRegex(Exception, "abc"):
Resolwe._download_files(
resolwe_mock,
self.file_list[:1],
files=self.file_list[:1],
download_dir=self.tmp_dir,
)
self.assertEqual(resolwe_mock.logger.info.call_count, 2)
Expand Down Expand Up @@ -508,7 +510,7 @@ def test_good_response(self, resolwe_mock):

Resolwe._download_files(
resolwe_mock,
self.file_list,
files=self.file_list,
download_dir=self.tmp_dir,
show_progress=False,
)
Expand Down

0 comments on commit c6a9d0c

Please sign in to comment.