diff --git a/src/dcqc/file.py b/src/dcqc/file.py index f949764..dcd29c5 100644 --- a/src/dcqc/file.py +++ b/src/dcqc/file.py @@ -11,12 +11,13 @@ from __future__ import annotations +import glob import os from collections.abc import Collection, Mapping from copy import deepcopy from dataclasses import dataclass from pathlib import Path -from tempfile import mkdtemp +from tempfile import gettempdir, mkdtemp from typing import Any, ClassVar, Optional from warnings import warn @@ -308,6 +309,27 @@ def is_file_local(self) -> bool: """ return self._local_path is not None + def already_staged(self) -> list[Path]: + """Check if the target file has already been staged to the remote directory. + + Returns: + staged_file_paths (list): List of already staged file paths. + Empty list if file has not been staged. + + Raises: + FileExistsError: If the file has already been staged more than once. + This would cause a name collision in Nextflow. + """ + path_str = os.path.join(gettempdir(), self.tmp_dir + "*", self.name) + staged_file_strs = glob.glob(path_str) + staged_file_paths = [Path(path) for path in staged_file_strs] + if len(staged_file_paths) > 1: + message = ( + f"File has already been staged multiple times: {staged_file_paths}" + ) + raise FileExistsError(message) + return staged_file_paths + def stage( self, destination: Optional[Path] = None, @@ -338,8 +360,15 @@ def stage( if self._local_path is not None: return self._local_path else: - destination_str = mkdtemp(prefix=self.tmp_dir) - destination = Path(destination_str) + # check if file has already been staged + staged_files = self.already_staged() + if not staged_files: + destination_str = mkdtemp(prefix=self.tmp_dir) + destination = Path(destination_str) + else: + destination = staged_files[0] + self._local_path = destination + return destination # By this point, destination is defined (not None) if destination.is_dir(): diff --git a/tests/test_file.py b/tests/test_file.py index 6e96adc..68b5eec 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,11 +1,55 @@ +import glob +import os +import shutil from pathlib import Path -from tempfile import TemporaryDirectory +from tempfile import TemporaryDirectory, gettempdir +from typing import List import pytest from dcqc.file import File, FileType +def create_duplicate_files(file_num) -> List[str]: + """Create duplicate files (empty txt) for testing. + + Args: + file_num (int): number of files to create + + Returns: + file_path_list (List[str]): list of file paths + """ + file_path_list = [ + os.path.join(gettempdir(), f"dcqc-staged-test{i}/test.txt") + for i in range(file_num) + ] + + for file_path in file_path_list: + parent_dir = os.path.dirname(file_path) + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) + if not os.path.exists(file_path): + with open(file_path, "w"): + pass + + return file_path_list + + +def remove_staged_files(): + """Removes all staged files and their parent directories + which follow the 'dcqc-staged-*' pattern. + + To be used at the end of all tests which result in such + files being created. + """ + path_str = os.path.join(gettempdir(), "dcqc-staged-" + "*", "test.txt") + staged_file_strs = glob.glob(path_str) + for staged_file_str in staged_file_strs: + directory_path = os.path.dirname(staged_file_str) + if os.path.exists(directory_path): + shutil.rmtree(directory_path) + + def test_for_an_error_if_registering_a_duplicate_file_type(): with pytest.raises(ValueError): FileType("txt", (".foo",)) @@ -60,6 +104,23 @@ def test_that_a_local_temporary_path_is_created_when_staging_a_remote_file(test_ staged_path = remote_file.stage() assert staged_path.exists() assert remote_file.local_path == staged_path + remove_staged_files() + + +def test_that_error_is_raised_when_a_file_has_been_staged_multiple_times(test_files): + create_duplicate_files(2) + remote_file = test_files["remote"] + with pytest.raises(FileExistsError): + remote_file.stage() + remove_staged_files() + + +def test_that_file_is_not_staged_when_it_already_has_been_staged(test_files): + duplicate_file = create_duplicate_files(1)[0] + remote_file = test_files["remote"] + destination = remote_file.stage() + assert destination == Path(duplicate_file) + remove_staged_files() def test_that_a_remote_file_is_created_when_staged_with_a_destination(test_files):