Skip to content

Commit

Permalink
Initial attempt to implement a fast-scheduling option using a topolog…
Browse files Browse the repository at this point in the history
…ical sort, so next tasks can be obtained from the scheduler in O(1) instead of O(n). Deals with issue spotify#1750 for LanguageMachines/LuigiNLP#4. Still contains lots of debug statements and breaks certain stuff.
  • Loading branch information
proycon committed Jan 31, 2017
1 parent 2585496 commit d52359e
Showing 1 changed file with 175 additions and 25 deletions.
200 changes: 175 additions & 25 deletions luigi/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import collections
import inspect
import json
import sys

from luigi.batch_notifier import BatchNotifier

Expand Down Expand Up @@ -147,6 +148,10 @@ class scheduler(Config):

prune_on_get_work = parameter.BoolParameter(default=False)

fast_scheduling = parameter.BoolParameter(default=False, description="Do fast scheduling by enabling topological sort, may be incompatible with priorities, resources and schedules tasks more arbitrarily, i.e. not necessarily in the order they were added. This speeds up scheduling and allows for scaling to a higher number of tasks")

task_limit = parameter.IntParameter(default=100, description="Maximum number of running and pending tasks to consider at once in fast scheduling (the lower the faster, but the more arbitrary the order may be, ignoring ranking)")

def _get_retry_policy(self):
return RetryPolicy(self.retry_count, self.disable_hard_timeout, self.disable_window)

Expand Down Expand Up @@ -348,7 +353,7 @@ def __init__(self, worker_id, last_active=None):
self.last_active = last_active or time.time() # seconds since epoch
self.last_get_work = None
self.started = time.time() # seconds since epoch
self.tasks = set() # task objects
self.tasks = OrderedSet() # task objects
self.info = {}
self.disabled = False

Expand Down Expand Up @@ -428,6 +433,9 @@ def set_state(self, state):
self._tasks, self._active_workers = state[:2]
if len(state) >= 3:
self._task_batchers = state[2]
self._status_tasks = collections.defaultdict(dict)
for task in six.itervalues(self._tasks):
self._status_tasks[task.status][task.id] = task

def dump(self):
try:
Expand All @@ -450,9 +458,6 @@ def load(self):
return

self.set_state(state)
self._status_tasks = collections.defaultdict(dict)
for task in six.itervalues(self._tasks):
self._status_tasks[task.status][task.id] = task
else:
logger.info("No prior state file exists at %s. Starting with empty state", self._state_path)

Expand Down Expand Up @@ -548,8 +553,7 @@ def set_status(self, task, new_status, config=None):
task.scheduler_disable_time = None

if new_status != task.status:
self._status_tasks[task.status].pop(task.id)
self._status_tasks[new_status][task.id] = task
self._update_status_tasks(task, new_status)
task.status = new_status
task.updated = time.time()

Expand All @@ -558,6 +562,12 @@ def set_status(self, task, new_status, config=None):
if remove_on_failure:
task.remove = time.time()

def _update_status_tasks(self, task, new_status):
print("DEBUG _update_status_tasks for ", task.id, " status ", task.status,"->",new_status, file=sys.stderr)
self._status_tasks[task.status].pop(task.id)
self._status_tasks[new_status][task.id] = task


def fail_dead_worker_task(self, task, config, assistants):
# If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic
if task.status in (BATCH_RUNNING, RUNNING) and task.worker_running and task.worker_running not in task.stakeholders | assistants:
Expand Down Expand Up @@ -634,6 +644,122 @@ def disable_workers(self, worker_ids):
self.get_worker(worker_id).disabled = True


class TopoSortedTaskState(SimpleTaskState):
"""
Keep track of the current state and handle persistence.
"""

def __init__(self, state_path):
print("DEBUG init TopoSortedTaskState",file=sys.stderr)
self._state_path = state_path
self._active_workers = {} # map from id to a Worker object
self._task_batchers = {}

self._tasks = {} # map from id to a Task object
self._status_tasks = collections.defaultdict(collections.deque) #sorted tasks in topological order, per status
self._unsorted_tasks = set()

self.num_pending = 0
self.num_unique_pending = 0
self.num_pending_last_scheduled = 0

def _sort_needed(self):
return bool(self._unsorted_tasks)

def get_state(self):
if self._sorted_needed():
self._sort_tasks()
return self._tasks, self._active_workers, self._status_tasks, self._unsorted_tasks, self.num_pending, self.num_unique_pending, self.num_pending_last_scheduled

def set_state(self, state):
self._tasks, self._active_workers, self._status_tasks, self._unsorted_tasks, self.num_pending, self.num_unique_pending, self.num_pending_last_scheduled = state[:7]

def get_task(self, task_id, default=None, setdefault=None):
if setdefault:
l = len(self._tasks)
task = self._tasks.setdefault(task_id, setdefault)
if l < len(self._tasks):
#this is a new task, add it to _unsorted_tasks, a resort will be trigger next time tasks are queried
self._unsorted_tasks.add(task_id)
return task
else:
return self._tasks.get(task_id, default)

def get_active_tasks(self, limit=None):
return self.get_active_tasks_by_status( (DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, BATCH_RUNNING), limit )

def get_active_tasks_by_status(self, *statuses, limit=None):
if self._sort_needed():
self._sort_tasks()
yielded = 0
for status in statuses:
for task_id in self._status_tasks[status]:
yield self._tasks[task_id]
yielded +=1
if limit is not None and yielded == limit:
break

def _sort_tasks(self):
"""Sort pending tasks topologically"""
print("***********************\nStarting topological sort",file=sys.stderr)
self._print()
tempmarks = set()
while self._unsorted_tasks:
task_id = self._unsorted_tasks.pop()
self._sort_tasks_visit(task_id, self._unsorted_tasks, tempmarks, popped=True)
print("Done with topological sort",file=sys.stderr)
self._print()

def _sort_tasks_visit(self, task_id, unvisited, tempmarks, popped=False):
"""auxiliary function in topological sort"""
if task_id in tempmarks:
raise Exception("Dependency graph is not acyclic!!")
if popped or task_id in unvisited:
tempmarks.add(task_id)
task = self.get_task(task_id, default=None)
for dep in task.deps:
dep_task = self.get_task(dep, default=None)
if dep_task is not None:
self._sort_tasks_visit(dep, unvisited, tempmarks, False)
if not popped:
unvisited.remove(task_id)
tempmarks.remove(task_id)
print("Sorted task " + task.id + " with status " + task.status,file=sys.stderr)
self._status_tasks[task.status].appendleft(task_id)
if task.status == PENDING:
self.num_pending += 1
self.num_unique_pending += int(len(task.workers) == 1)
#self.num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id)

def inactivate_tasks(self, delete_tasks):
# The terminology is a bit confusing: we used to "delete" tasks when they became inactive,
# but with a pluggable state storage, you might very well want to keep some history of
# older tasks as well. That's why we call it "inactivate" (as in the verb)
print("Inactivating tasks",file=sys.stderr)
for task in delete_tasks:
task_obj = self._tasks.pop(task)


def _update_status_tasks(self, task, new_status):
print("DEBUG _update_status_tasks for ", task.id, " status ", task.status,"->",new_status, file=sys.stderr)
if self._sort_needed():
self._sort_tasks()
if task.status == PENDING:
self.num_pending -= 1
self.num_unique_pending -= int(len(task.workers) == 1)
#self.num_pending_last_scheduled -= int(task.workers.peek(last=True) == worker_id)
elif new_status == PENDING:
self.num_pending += 1
self.num_unique_pending += int(len(task.workers) == 1)

self._status_tasks[task.status].remove(task.id)
self._status_tasks[new_status].append(task.id)

def _print(self):
print("DEBUG state tasks=",",".join(self._tasks.keys()),file=sys.stderr)
for status in self._status_tasks:
print("DEBUG state status=",status, " tasks=", ",".join(self._status_tasks[status]),file=sys.stderr)

class Scheduler(object):
"""
Async scheduler that can handle multiple workers, etc.
Expand All @@ -648,8 +774,14 @@ def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs
:param resources: a dict of str->int constraints
:param task_history_impl: ignore config and use this object as the task history
"""
print("DEBUG instantiating scheduler",file=sys.stderr)
self._config = config or scheduler(**kwargs)
self._state = SimpleTaskState(self._config.state_path)
if self._config.fast_scheduling:
print("DEBUG Selecting fast scheduling",file=sys.stderr)
self._state = TopoSortedTaskState(self._config.state_path)
else:
print("DEBUG Selecting normal scheduling",file=sys.stderr)
self._state = SimpleTaskState(self._config.state_path)

if task_history_impl:
self._task_history = task_history_impl
Expand All @@ -661,8 +793,6 @@ def __init__(self, config=None, resources=None, task_history_impl=None, **kwargs
self._resources = resources or configuration.get_config().getintdict('resources') # TODO: Can we make this a Parameter?
self._make_task = functools.partial(Task, retry_policy=self._config._get_retry_policy())
self._worker_requests = {}
if self._config.fast_scheduling and not self._config.prune_interval:
self._config.prune_interval = 60

if self._config.batch_emails:
self._email_batcher = BatchNotifier()
Expand Down Expand Up @@ -941,6 +1071,7 @@ def _reset_orphaned_batch_running_tasks(self, worker_id):

@rpc_method()
def count_pending(self, worker):
begintime = time.time()
worker_id, worker = worker, self._state.get_worker(worker)

num_pending, num_unique_pending, num_pending_last_scheduled = 0, 0, 0
Expand All @@ -958,20 +1089,28 @@ def count_pending(self, worker):
more_info.update(other_worker.info)
running_tasks.append(more_info)

for task in worker.get_tasks(self._state, PENDING, FAILED):
if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED:
continue
num_pending += 1
num_unique_pending += int(len(task.workers) == 1)
num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id)
if self._config.fast_scheduling:
num_pending = self._state.num_pending
num_unique_pending = self._state.num_unique_pending
num_pending_last_scheduled = self._state.num_pending_last_scheduled
else:
for task in worker.get_tasks(self._state, PENDING, FAILED):
if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED:
continue
num_pending += 1
num_unique_pending += int(len(task.workers) == 1)
num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id)

return {
reply = {
'n_pending_tasks': num_pending,
'n_unique_pending': num_unique_pending,
'n_pending_last_scheduled': num_pending_last_scheduled,
'worker_state': worker.state,
'running_tasks': running_tasks,
}
print("Count pending: " + str(time.time() - begintime),file=sys.stderr)
print(reply,file=sys.stderr)
return reply

@rpc_method(allow_null=False)
def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs):
Expand All @@ -980,6 +1119,9 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
# Algo: iterate over all nodes, find the highest priority node no dependencies and available
# resources.

# TODO #1750 @proycon: Current algorithm is too inefficient, nodes should
# be presorted topologically so next task can be obtained in O(1) time

# Resource checking looks both at currently available resources and at which resources would
# be available if all running tasks died and we rescheduled all workers greedily. We do both
# checks in order to prevent a worker with many low-priority tasks from starving other
Expand All @@ -988,9 +1130,10 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
# TODO: remove tasks that can't be done, figure out if the worker has absolutely
# nothing it can wait for

if self._config.prune_on_get_work and (self._config.prune_interval == 0 or time.time() - self._state._last_prune >= self._config.prune_interval):
logger.debug("Calling get_work()")

if self._config.prune_on_get_work and not self._config.fast_scheduling: #(prune_on_get_work is incompatible with fast_scheduling)
self.prune()
self._state._last_prune = time.time()

assert worker is not None
worker_id = worker
Expand Down Expand Up @@ -1025,7 +1168,12 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
greedy_resources = collections.defaultdict(int)

worker = self._state.get_worker(worker_id)
if worker.is_trivial_worker(self._state):
if self._config.fast_scheduling:
#relevant_tasks = self._state.get_active_tasks.get_tasks(self._state, PENDING, RUNNING, limit=self._config.task_limit)
relevant_tasks = self._state.get_active_tasks_by_status(PENDING, RUNNING, limit=self._config.task_limit)
used_resources = collections.defaultdict(int)
greedy_workers = dict() # If there's no resources, then they can grab any task
elif worker.is_trivial_worker(self._state):
relevant_tasks = worker.get_tasks(self._state, PENDING, RUNNING)
used_resources = collections.defaultdict(int)
greedy_workers = dict() # If there's no resources, then they can grab any task
Expand All @@ -1037,13 +1185,10 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
greedy_workers = dict((worker.id, worker.info.get('workers', 1))
for worker in active_workers)

tasks = list(relevant_tasks)
tasks.sort(key=self._rank, reverse=True)

if self._config.fast_scheduling:
tasks = relevant_tasks
else:
tasks = list(relevant_tasks)
tasks.sort(key=self._rank, reverse=True)

n_unique_pending = 0
for task in tasks:
if (best_task and batched_params and task.family == best_task.family and
len(batched_tasks) < max_batch_size and task.is_batchable() and all(
Expand All @@ -1052,6 +1197,8 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None,
for name, params in batched_params.items():
params.append(task.params.get(name))
batched_tasks.append(task)

n_unique_pending += 1
if best_task:
if self._config.fast_scheduling and n_unique_pending >= 1:
break
Expand Down Expand Up @@ -1139,7 +1286,9 @@ def _upstream_status(self, task_id, upstream_status_table):
elif self._state.has_task(task_id):
task_stack = [task_id]

c = 0
while task_stack:
c += 1
dep_id = task_stack.pop()
dep = self._state.get_task(dep_id)
if dep:
Expand All @@ -1159,6 +1308,7 @@ def _upstream_status(self, task_id, upstream_status_table):
for a_task_id in dep.deps),
key=UPSTREAM_SEVERITY_KEY)
upstream_status_table[dep_id] = status
print("upstream_status (" + dep_id + ")=",c,file=sys.stderr)
return upstream_status_table[dep_id]

def _serialize_task(self, task_id, include_deps=True, deps=None):
Expand Down

0 comments on commit d52359e

Please sign in to comment.