diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 473c9c30..c341a420 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -363,18 +363,11 @@ def generate_snakemake_entrypoint( from latch.types.file import LatchFile from latch_cli.utils import get_parameter_json_value, urljoins, check_exists_and_rename + from latch_cli.snakemake.serialize_utils import update_mapping sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) - def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): - if local.is_file(): - mapping[str(local)] = remote - return - - for p in local.iterdir(): - update_mapping(p, urljoins(remote, p.name), mapping) - def si_unit(num, base: float = 1000.0): for unit in (" ", "k", "M", "G", "T", "P", "E", "Z"): @@ -459,6 +452,7 @@ def generate_jit_register_code( ) from latch_cli.utils import get_parameter_json_value, check_exists_and_rename import latch_cli.snakemake + from latch_cli.snakemake.serialize_utils import update_mapping from latch_cli.utils import urljoins from latch import small_task @@ -474,14 +468,6 @@ def generate_jit_register_code( sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) - def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): - if local.is_file(): - mapping[str(local)] = remote - return - - for p in local.iterdir(): - update_mapping(p, urljoins(remote, p.name), mapping) - def si_unit(num, base: float = 1000.0): for unit in (" ", "k", "M", "G", "T", "P", "E", "Z"): diff --git a/latch_cli/snakemake/serialize_utils.py b/latch_cli/snakemake/serialize_utils.py index 9d793cfa..637522e6 100644 --- a/latch_cli/snakemake/serialize_utils.py +++ b/latch_cli/snakemake/serialize_utils.py @@ -1,4 +1,5 @@ import re +from pathlib import Path from typing import Dict, Union from flytekit import LaunchPlan @@ -18,6 +19,8 @@ from flytekit.models.core.workflow import TaskNodeOverrides from typing_extensions import TypeAlias +from latch_cli.utils import urljoins + FlyteLocalEntity: TypeAlias = Union[ PythonTask, Node, @@ -213,6 +216,17 @@ def get_serializable_workflow( return admin_wf -def best_effort_display_name(x: str): - expr = re.compile(r"_+") - return expr.sub(" ", x).title().strip() +def update_mapping(cur: Path, stem: Path, remote: str, mapping: Dict[str, str]): + if cur.is_file(): + mapping[str(stem)] = remote + return + + for p in cur.iterdir(): + update_mapping(p, stem / p.name, urljoins(remote, p.name), mapping) + + +underscores = re.compile(r"_+") + + +def best_effort_display_name(x: str) -> str: + return underscores.sub(" ", x).title().strip() diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 197abb42..90363cd8 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -422,13 +422,13 @@ def get_fn_code( code_block += reindent( rf""" print(f"Moving {param} to {{{param}_dst_p}}") + + update_mapping({param}_p, {param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) check_exists_and_rename( {param}_p, {param}_dst_p ) - update_mapping({param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) - """, 1, ) @@ -1261,11 +1261,12 @@ def get_fn_code( } if remote_output_url is None: - remote_path = Path("/Snakemake Outputs") / self.wf.name + remote_path = Path("/Snakemake Outputs") / self.wf.name / self.job.name else: remote_path = Path(urlparse(remote_output_url).path) log_files = self.job.log if self.job.log is not None else [] + output_files = self.job.output if self.job.output is not None else [] code_block += reindent( rf""" @@ -1344,6 +1345,18 @@ def get_fn_code( lp.upload(local, remote) print(" Done") + print("Uploading outputs:") + for x in {repr(output_files)}: + local = Path(x) + remote = f"latch://{remote_path}/{{str(local).removeprefix('/')}}" + print(f" {{file_name_and_size(local)}} -> {{remote}}") + if not local.exists(): + print(" Does not exist") + continue + + lp.upload(local, remote) + print(" Done") + benchmark_file = {repr(self.job.benchmark)} if benchmark_file is not None: print("\nUploading benchmark:")