diff --git a/snakedeploy/deploy.py b/snakedeploy/deploy.py index eb8166a..c3668d6 100644 --- a/snakedeploy/deploy.py +++ b/snakedeploy/deploy.py @@ -2,9 +2,10 @@ from pathlib import Path import os import shutil -from typing import Optional +from typing import Dict, Optional from jinja2 import Environment, PackageLoader +import yaml from snakedeploy.providers import get_provider from snakedeploy.logger import logger @@ -12,11 +13,27 @@ class WorkflowDeployer: - def __init__(self, source: str, dest: Path, force=False): + def __init__( + self, + source: str, + dest: Path, + tag: Optional[str] = None, + branch: Optional[str] = None, + force=False, + ): self.provider = get_provider(source) self.env = Environment(loader=PackageLoader("snakedeploy")) self.dest_path = dest self.force = force + self._cloned = None + self.tag = tag + self.branch = branch + + def __enter__(self): + return self + + def __exit__(self, exc, value, tb): + self._cloned.cleanup() @property def snakefile(self): @@ -26,14 +43,14 @@ def snakefile(self): def config(self): return self.dest_path / "config" - def deploy_config(self, tmpdir: str): + def deploy_config(self): """ Deploy the config directory, either using an existing or creating a dummy. returns a boolean "no_config" to indicate if there is not a config (True) """ # Handle the config/ - config_dir = Path(tmpdir) / "config" + config_dir = Path(self.repo_clone) / "config" no_config = not config_dir.exists() if no_config: logger.warning( @@ -53,24 +70,30 @@ def deploy_config(self, tmpdir: str): shutil.copytree(config_dir, self.config, dirs_exist_ok=self.force) return no_config - def deploy(self, name: str, tag: str, branch: str): + @property + def repo_clone(self): + if self._cloned is None: + logger.info("Obtaining source repository...") + self._cloned = tempfile.TemporaryDirectory() + self.provider.clone(self._cloned.name) + if self.tag is not None: + self.provider.checkout(self._cloned.name, self.tag) + elif self.branch is not None: + self.provider.checkout(self._cloned.name, self.branch) + + return self._cloned.name + + def deploy(self, name: str): """ Deploy a source to a destination. """ self.check() - # Create a temporary directory to grab config directory and snakefile - with tempfile.TemporaryDirectory() as tmpdir: - logger.info("Obtaining source repository...") - - # Clone temporary directory to find assets - self.provider.clone(tmpdir) + # Either copy existing config or create a dummy config + no_config = self.deploy_config() - # Either copy existing config or create a dummy config - no_config = self.deploy_config(tmpdir) - - # Inspect repository to find existing snakefile - self.deploy_snakefile(tmpdir, name, tag, branch) + # Inspect repository to find existing snakefile + self.deploy_snakefile(self.repo_clone, name) logger.info( self.env.get_template("post-instructions.txt.jinja").render( @@ -92,7 +115,7 @@ def check(self): f"{self.config} already exists, aborting (use --force to overwrite)" ) - def deploy_snakefile(self, tmpdir: str, name: str, tag: str, branch: str): + def deploy_snakefile(self, tmpdir: str, name: str): """ Deploy the Snakefile to workflow/Snakefile """ @@ -120,12 +143,24 @@ def deploy_snakefile(self, tmpdir: str, name: str, tag: str, branch: str): os.makedirs(self.dest_path / "workflow", exist_ok=True) module_deployment = template.render( name=name, - snakefile=self.provider.get_source_file_declaration(snakefile, tag, branch), + snakefile=self.provider.get_source_file_declaration( + snakefile, self.tag, self.branch + ), repo=self.provider.source_url, ) with open(self.snakefile, "w") as f: print(module_deployment, file=f) + def get_json_schema(self, item: str) -> Dict: + """Get schema under workflow/schemas/{item}.schema.{yaml|yml|json} as + python dict.""" + clone = Path(self.repo_clone) + for ext in ["yaml", "yml", "json"]: + path = clone / "workflow" / "schemas" / f"{item}.schema.{ext}" + if path.exists(): + return yaml.safe_load(path.read_text()) + raise UserError(f"Schema {item} not found in repository.") + def deploy( source_url: str, @@ -138,5 +173,7 @@ def deploy( """ Deploy a given workflow to the local machine, using the Snakemake module system. """ - sd = WorkflowDeployer(source=source_url, dest=dest_path, force=force) - sd.deploy(name=name, tag=tag, branch=branch) + with WorkflowDeployer( + source=source_url, dest=dest_path, tag=tag, branch=branch, force=force + ) as sd: + sd.deploy(name=name) diff --git a/snakedeploy/providers.py b/snakedeploy/providers.py index 618900c..728d467 100644 --- a/snakedeploy/providers.py +++ b/snakedeploy/providers.py @@ -30,16 +30,16 @@ def __init__(self, source_url): @classmethod @abstractmethod - def matches(cls, source_url: str): - pass + def matches(cls, source_url: str): ... @abstractmethod - def clone(self, path: str): - pass + def clone(self, path: str): ... @abstractmethod - def get_raw_file(self, path: str, tag: str): - pass + def checkout(self, path: str, ref: str): ... + + @abstractmethod + def get_raw_file(self, path: str, tag: str): ... def get_repo_name(self): return self.source_url.split("/")[-1] @@ -56,6 +56,10 @@ def clone(self, tmpdir: str): """ copy_tree(self.source_url, tmpdir) + def checkout(self, path: str, ref: str): + # Local repositories don't need to check out anything + pass + def get_raw_file(self, path: str, tag: str): if tag: print( @@ -77,15 +81,21 @@ def matches(cls, source_url: str): def name(self): return self.__class__.__name__.lower() - def clone(self, tmpdir: str): + def clone(self, path: str): """ Clone the known source URL to a temporary directory """ try: - sp.run(["git", "clone", self.source_url, "."], cwd=tmpdir, check=True) + sp.run(["git", "clone", self.source_url, "."], cwd=path, check=True) except sp.CalledProcessError as e: raise UserError("Failed to clone repository {}:\n{}", self.source_url, e) + def checkout(self, path: str, ref: str): + try: + sp.run(["git", "checkout", ref], cwd=path, check=True) + except sp.CalledProcessError as e: + raise UserError("Failed to checkout ref {}:\n{}", ref, e) + def get_raw_file(self, path: str, tag: str): return f"{self.source_url}/raw/{tag}/{path}" @@ -98,7 +108,6 @@ def get_source_file_declaration(self, path: str, tag: str, branch: str): class Gitlab(Github): - @classmethod def get_raw_file(self, path: str, tag: str): return f"{self.source_url}/-/raw/{tag}/{path}"