Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 25, 2024
1 parent 9b58426 commit f00845f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
8 changes: 6 additions & 2 deletions paraffin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typer

from paraffin.submit import submit_node_graph
from paraffin.utils import get_stage_graph, get_custom_queue
from paraffin.utils import get_custom_queue, get_stage_graph

app = typer.Typer()

Expand All @@ -26,7 +26,11 @@ def main(

subgraph = get_stage_graph(names=names, glob=glob)
custom_queues = get_custom_queue()
submit_node_graph(subgraph, shutdown_after_finished=shutdown_after_finished, custom_queues=custom_queues)
submit_node_graph(
subgraph,
shutdown_after_finished=shutdown_after_finished,
custom_queues=custom_queues,
)

typer.echo(f"Submitted all (n = {len(subgraph)}) tasks.")
if concurrency > 0:
Expand Down
25 changes: 20 additions & 5 deletions paraffin/submit.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import fnmatch
import typing as t

import networkx as nx
from celery import chord, group
import typing as t

from paraffin.worker import repro, shutdown_worker
import fnmatch


def submit_node_graph(subgraph: nx.DiGraph, shutdown_after_finished: bool = False, custom_queues: t.Optional[dict] = None):
def submit_node_graph(
subgraph: nx.DiGraph,
shutdown_after_finished: bool = False,
custom_queues: t.Optional[dict] = None,
):
task_dict = {}
custom_queues = custom_queues or {}
for node in subgraph.nodes:
if (matched_pattern := next((pattern for pattern in custom_queues if fnmatch.fnmatch(node.name, pattern)), None)):
task_dict[node.name] = repro.s(name=node.name).set(queue=custom_queues[matched_pattern])
if matched_pattern := next(
(
pattern
for pattern in custom_queues
if fnmatch.fnmatch(node.name, pattern)
),
None,
):
task_dict[node.name] = repro.s(name=node.name).set(
queue=custom_queues[matched_pattern]
)
else:
task_dict[node.name] = repro.s(name=node.name)

Expand Down
9 changes: 6 additions & 3 deletions paraffin/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import fnmatch
import pathlib

import dvc.api
import networkx as nx
import pathlib
import yaml


Expand Down Expand Up @@ -41,9 +41,12 @@ def get_stage_graph(names, glob=False):

return subgraph


def get_custom_queue():
with pathlib.Path("paraffin.yaml").open() as f:
config = yaml.safe_load(f)

return config.get("queue", {})
# TODO: what about lists, shutdown, ?


# TODO: what about lists, shutdown, ?

0 comments on commit f00845f

Please sign in to comment.