Skip to content

Commit

Permalink
feat: add method to obtain schemas from workflow repo (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
johanneskoester authored Apr 24, 2024
1 parent bfce991 commit 2834e39
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 29 deletions.
77 changes: 57 additions & 20 deletions snakedeploy/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,38 @@
from pathlib import Path
import os
import shutil
from typing import Optional
from typing import Dict, Optional

from jinja2 import Environment, PackageLoader
import yaml

from snakedeploy.providers import get_provider
from snakedeploy.logger import logger
from snakedeploy.exceptions import UserError


class WorkflowDeployer:
def __init__(self, source: str, dest: Path, force=False):
def __init__(
self,
source: str,
dest: Path,
tag: Optional[str] = None,
branch: Optional[str] = None,
force=False,
):
self.provider = get_provider(source)
self.env = Environment(loader=PackageLoader("snakedeploy"))
self.dest_path = dest
self.force = force
self._cloned = None
self.tag = tag
self.branch = branch

def __enter__(self):
return self

def __exit__(self, exc, value, tb):
self._cloned.cleanup()

@property
def snakefile(self):
Expand All @@ -26,14 +43,14 @@ def snakefile(self):
def config(self):
return self.dest_path / "config"

def deploy_config(self, tmpdir: str):
def deploy_config(self):
"""
Deploy the config directory, either using an existing or creating a dummy.
returns a boolean "no_config" to indicate if there is not a config (True)
"""
# Handle the config/
config_dir = Path(tmpdir) / "config"
config_dir = Path(self.repo_clone) / "config"
no_config = not config_dir.exists()
if no_config:
logger.warning(
Expand All @@ -53,24 +70,30 @@ def deploy_config(self, tmpdir: str):
shutil.copytree(config_dir, self.config, dirs_exist_ok=self.force)
return no_config

def deploy(self, name: str, tag: str, branch: str):
@property
def repo_clone(self):
if self._cloned is None:
logger.info("Obtaining source repository...")
self._cloned = tempfile.TemporaryDirectory()
self.provider.clone(self._cloned.name)
if self.tag is not None:
self.provider.checkout(self._cloned.name, self.tag)
elif self.branch is not None:
self.provider.checkout(self._cloned.name, self.branch)

return self._cloned.name

def deploy(self, name: str):
"""
Deploy a source to a destination.
"""
self.check()

# Create a temporary directory to grab config directory and snakefile
with tempfile.TemporaryDirectory() as tmpdir:
logger.info("Obtaining source repository...")

# Clone temporary directory to find assets
self.provider.clone(tmpdir)
# Either copy existing config or create a dummy config
no_config = self.deploy_config()

# Either copy existing config or create a dummy config
no_config = self.deploy_config(tmpdir)

# Inspect repository to find existing snakefile
self.deploy_snakefile(tmpdir, name, tag, branch)
# Inspect repository to find existing snakefile
self.deploy_snakefile(self.repo_clone, name)

logger.info(
self.env.get_template("post-instructions.txt.jinja").render(
Expand All @@ -92,7 +115,7 @@ def check(self):
f"{self.config} already exists, aborting (use --force to overwrite)"
)

def deploy_snakefile(self, tmpdir: str, name: str, tag: str, branch: str):
def deploy_snakefile(self, tmpdir: str, name: str):
"""
Deploy the Snakefile to workflow/Snakefile
"""
Expand Down Expand Up @@ -120,12 +143,24 @@ def deploy_snakefile(self, tmpdir: str, name: str, tag: str, branch: str):
os.makedirs(self.dest_path / "workflow", exist_ok=True)
module_deployment = template.render(
name=name,
snakefile=self.provider.get_source_file_declaration(snakefile, tag, branch),
snakefile=self.provider.get_source_file_declaration(
snakefile, self.tag, self.branch
),
repo=self.provider.source_url,
)
with open(self.snakefile, "w") as f:
print(module_deployment, file=f)

def get_json_schema(self, item: str) -> Dict:
"""Get schema under workflow/schemas/{item}.schema.{yaml|yml|json} as
python dict."""
clone = Path(self.repo_clone)
for ext in ["yaml", "yml", "json"]:
path = clone / "workflow" / "schemas" / f"{item}.schema.{ext}"
if path.exists():
return yaml.safe_load(path.read_text())
raise UserError(f"Schema {item} not found in repository.")


def deploy(
source_url: str,
Expand All @@ -138,5 +173,7 @@ def deploy(
"""
Deploy a given workflow to the local machine, using the Snakemake module system.
"""
sd = WorkflowDeployer(source=source_url, dest=dest_path, force=force)
sd.deploy(name=name, tag=tag, branch=branch)
with WorkflowDeployer(
source=source_url, dest=dest_path, tag=tag, branch=branch, force=force
) as sd:
sd.deploy(name=name)
27 changes: 18 additions & 9 deletions snakedeploy/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ def __init__(self, source_url):

@classmethod
@abstractmethod
def matches(cls, source_url: str):
pass
def matches(cls, source_url: str): ...

@abstractmethod
def clone(self, path: str):
pass
def clone(self, path: str): ...

@abstractmethod
def get_raw_file(self, path: str, tag: str):
pass
def checkout(self, path: str, ref: str): ...

@abstractmethod
def get_raw_file(self, path: str, tag: str): ...

def get_repo_name(self):
return self.source_url.split("/")[-1]
Expand All @@ -56,6 +56,10 @@ def clone(self, tmpdir: str):
"""
copy_tree(self.source_url, tmpdir)

def checkout(self, path: str, ref: str):
# Local repositories don't need to check out anything
pass

def get_raw_file(self, path: str, tag: str):
if tag:
print(
Expand All @@ -77,15 +81,21 @@ def matches(cls, source_url: str):
def name(self):
return self.__class__.__name__.lower()

def clone(self, tmpdir: str):
def clone(self, path: str):
"""
Clone the known source URL to a temporary directory
"""
try:
sp.run(["git", "clone", self.source_url, "."], cwd=tmpdir, check=True)
sp.run(["git", "clone", self.source_url, "."], cwd=path, check=True)
except sp.CalledProcessError as e:
raise UserError("Failed to clone repository {}:\n{}", self.source_url, e)

def checkout(self, path: str, ref: str):
try:
sp.run(["git", "checkout", ref], cwd=path, check=True)
except sp.CalledProcessError as e:
raise UserError("Failed to checkout ref {}:\n{}", ref, e)

def get_raw_file(self, path: str, tag: str):
return f"{self.source_url}/raw/{tag}/{path}"

Expand All @@ -98,7 +108,6 @@ def get_source_file_declaration(self, path: str, tag: str, branch: str):


class Gitlab(Github):
@classmethod
def get_raw_file(self, path: str, tag: str):
return f"{self.source_url}/-/raw/{tag}/{path}"

Expand Down

0 comments on commit 2834e39

Please sign in to comment.