Skip to content

Commit

Permalink
Merge pull request #339 from latchbio/ayush/config-fixies
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushkamat authored Nov 6, 2023
2 parents 0502ceb + 9223fcc commit 102fa02
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
28 changes: 22 additions & 6 deletions latch_cli/snakemake/serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
import os
import sys
import textwrap
import traceback
from pathlib import Path
from textwrap import dedent
from typing import Dict, List, Optional, Set, Union, get_args
from typing import Any, Dict, List, Optional, Set, Union, get_args

import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan
Expand Down Expand Up @@ -91,12 +92,20 @@ def ensure_snakemake_metadata_exists():

# todo(maximsmol): this needs to run in a subprocess because it pollutes globals
class SnakemakeWorkflowExtractor(Workflow):
def __init__(self, pkg_root: Path, snakefile: Path):
super().__init__(snakefile=snakefile)
def __init__(
self,
pkg_root: Path,
snakefile: Path,
non_blob_parameters: Optional[Dict[str, Any]] = None,
):
super().__init__(snakefile=snakefile, overwrite_config=non_blob_parameters)

self.pkg_root = pkg_root
self._old_cwd = ""

if non_blob_parameters is not None:
print(f"Config: {json.dumps(non_blob_parameters, indent=2)}")

def extract_dag(self):
targets: List[str] = (
[self.default_target] if self.default_target is not None else []
Expand Down Expand Up @@ -169,7 +178,9 @@ def __exit__(self, typ, value, tb):


def snakemake_workflow_extractor(
pkg_root: Path, snakefile: Path
pkg_root: Path,
snakefile: Path,
non_blob_parameters: Optional[Dict[str, Any]] = None,
) -> SnakemakeWorkflowExtractor:
snakefile = snakefile.resolve()

Expand All @@ -184,6 +195,7 @@ def snakemake_workflow_extractor(
extractor = SnakemakeWorkflowExtractor(
pkg_root=pkg_root,
snakefile=snakefile,
non_blob_parameters=non_blob_parameters,
)
with extractor:
extractor.include(
Expand All @@ -201,12 +213,16 @@ def extract_snakemake_workflow(
jit_wf_version: str,
jit_exec_display_name: str,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
non_blob_parameters: Optional[Dict[str, Any]] = None,
) -> SnakemakeWorkflow:
extractor = snakemake_workflow_extractor(pkg_root, snakefile)
extractor = snakemake_workflow_extractor(pkg_root, snakefile, non_blob_parameters)
with extractor:
dag = extractor.extract_dag()
wf = SnakemakeWorkflow(
dag, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping
dag,
jit_wf_version,
jit_exec_display_name,
local_to_remote_path_mapping,
)
wf.compile()

Expand Down
13 changes: 10 additions & 3 deletions latch_cli/snakemake/single_task_snakemake.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Shell,
)
from snakemake.rules import Rule as RRule
from snakemake.workflow import Workflow as WWorkflow

sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
Expand All @@ -40,9 +41,15 @@ def eprint(x: str) -> None:

non_blob_parameters = data.get("non_blob_parameters", {})

# todo(ayush): do this without overwriting globals
sw = sys.modules["snakemake.workflow"]
setattr(sw, "config", non_blob_parameters)
old_workflow_init = WWorkflow.__init__


def new_init(self: WWorkflow, *args, **kwargs):
kwargs["overwrite_config"] = non_blob_parameters
old_workflow_init(self, *args, **kwargs)


WWorkflow.__init__ = new_init


def eprint_named_list(xs):
Expand Down
2 changes: 1 addition & 1 deletion latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def get_fn_code(
print(f"JIT Workflow Version: {{jit_wf_version}}")
print(f"JIT Execution Display Name: {{jit_exec_display_name}}")
wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping)
wf = extract_snakemake_workflow(pkg_root, snakefile, jit_wf_version, jit_exec_display_name, local_to_remote_path_mapping, non_blob_parameters)
wf_name = wf.name
generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}, non_blob_parameters)
Expand Down

0 comments on commit 102fa02

Please sign in to comment.