Skip to content

Commit

Permalink
support different queue labels
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Oct 25, 2024
1 parent c232181 commit 9b58426
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 3 additions & 2 deletions paraffin/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typer

from paraffin.submit import submit_node_graph
from paraffin.utils import get_stage_graph
from paraffin.utils import get_stage_graph, get_custom_queue

app = typer.Typer()

Expand All @@ -25,7 +25,8 @@ def main(
from paraffin.worker import app as celery_app

subgraph = get_stage_graph(names=names, glob=glob)
submit_node_graph(subgraph, shutdown_after_finished=shutdown_after_finished)
custom_queues = get_custom_queue()
submit_node_graph(subgraph, shutdown_after_finished=shutdown_after_finished, custom_queues=custom_queues)

typer.echo(f"Submitted all (n = {len(subgraph)}) tasks.")
if concurrency > 0:
Expand Down
11 changes: 8 additions & 3 deletions paraffin/submit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import networkx as nx
from celery import chord, group

import typing as t
from paraffin.worker import repro, shutdown_worker
import fnmatch


def submit_node_graph(subgraph: nx.DiGraph, shutdown_after_finished: bool = False):
def submit_node_graph(subgraph: nx.DiGraph, shutdown_after_finished: bool = False, custom_queues: t.Optional[dict] = None):
task_dict = {}
custom_queues = custom_queues or {}
for node in subgraph.nodes:
task_dict[node.name] = repro.s(name=node.name)
if (matched_pattern := next((pattern for pattern in custom_queues if fnmatch.fnmatch(node.name, pattern)), None)):
task_dict[node.name] = repro.s(name=node.name).set(queue=custom_queues[matched_pattern])
else:
task_dict[node.name] = repro.s(name=node.name)

endpoints = []
chords = {}
Expand Down
9 changes: 9 additions & 0 deletions paraffin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import dvc.api
import networkx as nx
import pathlib
import yaml


def get_subgraph_with_predecessors(G, X, reverse=False):
Expand Down Expand Up @@ -38,3 +40,10 @@ def get_stage_graph(names, glob=False):
subgraph = nx.subgraph_view(subgraph, filter_node=lambda x: hasattr(x, "name"))

return subgraph

def get_custom_queue():
with pathlib.Path("paraffin.yaml").open() as f:
config = yaml.safe_load(f)

return config.get("queue", {})
# TODO: what about lists, shutdown, ?

0 comments on commit 9b58426

Please sign in to comment.