From b366bbb466eab381b436feeb20d176102ffb7ece Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Wed, 8 Nov 2023 12:30:57 -1000 Subject: [PATCH 1/4] save state Signed-off-by: Ayush Kamat --- latch_cli/snakemake/serialize.py | 18 ++---------------- latch_cli/snakemake/serialize_utils.py | 12 ++++++++++++ latch_cli/snakemake/workflow.py | 2 +- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 63354480..b44c62e8 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -363,18 +363,11 @@ def generate_snakemake_entrypoint( from latch.types.file import LatchFile from latch_cli.utils import get_parameter_json_value, urljoins, check_exists_and_rename + from latch_cli.snakemake.serialize_utils import update_mapping sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) - def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): - if local.is_file(): - mapping[str(local)] = remote - return - - for p in local.iterdir(): - update_mapping(p, urljoins(remote, p.name), mapping) - def si_unit(num, base: float = 1000.0): for unit in (" ", "k", "M", "G", "T", "P", "E", "Z"): @@ -458,6 +451,7 @@ def generate_jit_register_code( ) from latch_cli.utils import get_parameter_json_value, check_exists_and_rename import latch_cli.snakemake + from latch_cli.snakemake.serialize_utils import update_mapping from latch_cli.utils import urljoins from latch import small_task @@ -473,14 +467,6 @@ def generate_jit_register_code( sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) - def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): - if local.is_file(): - mapping[str(local)] = remote - return - - for p in local.iterdir(): - update_mapping(p, urljoins(remote, p.name), mapping) - def si_unit(num, base: float = 1000.0): for unit in (" ", "k", "M", "G", "T", "P", "E", "Z"): diff --git a/latch_cli/snakemake/serialize_utils.py b/latch_cli/snakemake/serialize_utils.py index 4064c1a1..136a95de 100644 --- a/latch_cli/snakemake/serialize_utils.py +++ b/latch_cli/snakemake/serialize_utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Dict, Union from flytekit import LaunchPlan @@ -17,6 +18,8 @@ from flytekit.models.core.workflow import TaskNodeOverrides from typing_extensions import TypeAlias +from latch_cli.utils import urljoins + FlyteLocalEntity: TypeAlias = Union[ PythonTask, Node, @@ -210,3 +213,12 @@ def get_serializable_workflow( admin_wf = admin_workflow_models.WorkflowSpec(template=wf_t, sub_workflows=[]) cache[entity] = admin_wf return admin_wf + + +def update_mapping(cur: Path, stem: Path, remote: str, mapping: Dict[str, str]): + if cur.is_file(): + mapping[str(stem / cur.name)] = remote + return + + for p in cur.iterdir(): + update_mapping(p, stem / cur.name, urljoins(remote, p.name), mapping) diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 197abb42..720ef81c 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -427,7 +427,7 @@ def get_fn_code( {param}_dst_p ) - update_mapping({param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) + update_mapping({param}_p, {param}_dst_p.parent, {param}.remote_path, local_to_remote_path_mapping) """, 1, From b0cb35cf90849fe7996d051a3008d2c2b0157bbd Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Wed, 8 Nov 2023 14:21:53 -1000 Subject: [PATCH 2/4] upload intermediate task outputs Signed-off-by: Ayush Kamat --- latch_cli/snakemake/serialize_utils.py | 4 ++-- latch_cli/snakemake/workflow.py | 24 ++++++++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/latch_cli/snakemake/serialize_utils.py b/latch_cli/snakemake/serialize_utils.py index 136a95de..83baf528 100644 --- a/latch_cli/snakemake/serialize_utils.py +++ b/latch_cli/snakemake/serialize_utils.py @@ -217,8 +217,8 @@ def get_serializable_workflow( def update_mapping(cur: Path, stem: Path, remote: str, mapping: Dict[str, str]): if cur.is_file(): - mapping[str(stem / cur.name)] = remote + mapping[str(stem)] = remote return for p in cur.iterdir(): - update_mapping(p, stem / cur.name, urljoins(remote, p.name), mapping) + update_mapping(p, stem / p.name, urljoins(remote, p.name), mapping) diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 720ef81c..79c79e5e 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -422,13 +422,13 @@ def get_fn_code( code_block += reindent( rf""" print(f"Moving {param} to {{{param}_dst_p}}") + + update_mapping({param}_p, {param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) check_exists_and_rename( {param}_p, {param}_dst_p ) - update_mapping({param}_p, {param}_dst_p.parent, {param}.remote_path, local_to_remote_path_mapping) - """, 1, ) @@ -492,6 +492,9 @@ def get_fn_code( print(f"JIT Workflow Version: {{jit_wf_version}}") print(f"JIT Execution Display Name: {{jit_exec_display_name}}") + import json + print(json.dumps(local_to_remote_path_mapping, indent=2)) + wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters) wf_name = wf.name generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters) @@ -581,7 +584,7 @@ class _WorkflowInfoNode(TypedDict): _interface_request = { "workflow_id": wf_id, "params": params, - "snakemake_jit": True, + # "snakemake_jit": True, } response = requests.post(urljoin(config.nucleus_url, "/api/create-execution"), headers=headers, json=_interface_request) @@ -1261,11 +1264,12 @@ def get_fn_code( } if remote_output_url is None: - remote_path = Path("/Snakemake Outputs") / self.wf.name + remote_path = Path("/Snakemake Outputs") / self.wf.name / self.job.name else: remote_path = Path(urlparse(remote_output_url).path) log_files = self.job.log if self.job.log is not None else [] + output_files = self.job.output if self.job.output is not None else [] code_block += reindent( rf""" @@ -1344,6 +1348,18 @@ def get_fn_code( lp.upload(local, remote) print(" Done") + print("Uploading outputs:") + for x in {repr(output_files)}: + local = Path(x) + remote = f"latch://{remote_path}/{{str(local).removeprefix('/')}}" + print(f" {{file_name_and_size(local)}} -> {{remote}}") + if not local.exists(): + print(" Does not exist") + continue + + lp.upload(local, remote) + print(" Done") + benchmark_file = {repr(self.job.benchmark)} if benchmark_file is not None: print("\nUploading benchmark:") From 36c99f88675cf65f7614f752da0b2e1c1e32ef50 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 13 Nov 2023 10:32:16 -0800 Subject: [PATCH 3/4] comments Signed-off-by: Ayush Kamat --- latch_cli/snakemake/workflow.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 79c79e5e..90363cd8 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -492,9 +492,6 @@ def get_fn_code( print(f"JIT Workflow Version: {{jit_wf_version}}") print(f"JIT Execution Display Name: {{jit_exec_display_name}}") - import json - print(json.dumps(local_to_remote_path_mapping, indent=2)) - wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters) wf_name = wf.name generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters) @@ -584,7 +581,7 @@ class _WorkflowInfoNode(TypedDict): _interface_request = { "workflow_id": wf_id, "params": params, - # "snakemake_jit": True, + "snakemake_jit": True, } response = requests.post(urljoin(config.nucleus_url, "/api/create-execution"), headers=headers, json=_interface_request) From 17d0cb01b0177db077df3b1eafb267b430f680d4 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Mon, 13 Nov 2023 10:36:36 -0800 Subject: [PATCH 4/4] type hint Signed-off-by: Ayush Kamat --- latch_cli/snakemake/serialize_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/latch_cli/snakemake/serialize_utils.py b/latch_cli/snakemake/serialize_utils.py index c76ec5fc..637522e6 100644 --- a/latch_cli/snakemake/serialize_utils.py +++ b/latch_cli/snakemake/serialize_utils.py @@ -228,5 +228,5 @@ def update_mapping(cur: Path, stem: Path, remote: str, mapping: Dict[str, str]): underscores = re.compile(r"_+") -def best_effort_display_name(x: str): +def best_effort_display_name(x: str) -> str: return underscores.sub(" ", x).title().strip()