Skip to content

Commit

Permalink
Move from per-task to per-engine heart beat messages
Browse files Browse the repository at this point in the history
The driver was sometimes struggling to keep up with messages for runs with 1000s of concurrent tasks.

We now send one heart beat message per engine per beat. The message has a list of task heart beats in it. NoPulse has been increased to account for an extra time out delay incurred by the engine while gathering the task heart beats.

Polling for task completion is also done at a lower rate.

Also some minor typo fixes and variable renaming.
  • Loading branch information
njcarriero committed Oct 23, 2023
1 parent 57d1130 commit 911beb1
Showing 1 changed file with 51 additions and 34 deletions.
85 changes: 51 additions & 34 deletions disbatchc/disBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

# Heart beat info.
PulseTime = 30
NoPulse = 3*PulseTime + 1 # A task is considered dead if we don't hear from it after 3 heart beat cycles.
NoPulse = 4*PulseTime + 1 # A task is considered dead if we don't hear from it after 4 heart beat cycles. Counting is approximate because of the interaction of various timeouts.

logger = logging.getLogger('DisBatch')
warnings.formatwarning = lambda msg, cat, *args, **kwargs: f'{cat.__name__}: {msg}\n' # a friendlier warning format
Expand Down Expand Up @@ -232,7 +232,7 @@ def retireEnv(self, nodeList, retList):
return env

def retireNodeList(self, nodeList, retList):
'''Called when one or mode nodes has exited. May be overridden to release resources.'''
'''Called when one or more nodes has exited. May be overridden to release resources.'''
for ret in retList:
if ret:
self.error = True
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def stopped(self, status):
def run(self):
self.kvs.put('DisBatch status', '<Starting...>', False)

cRank2taskCount, cylKey2eRank, engineHeartBeat, enginesDone, finishedTasks, hbFails = DD(int), {}, {}, False, {}, set()
cRank2taskCount, cylKey2eRank, enginesDone, finishedTasks, hbFails = DD(int), {}, False, {}, set()
notifiedAllDone, outstanding, pending, retired = False, {}, [], -1
assignedCylinders, freeCylinders = set(), set()
while not (self.tasksDone and enginesDone):
Expand Down Expand Up @@ -1138,14 +1138,11 @@ def run(self):
enginesDone = True
for e in self.engines.values():
if e.status != 'stopped':
if e.rank not in engineHeartBeat:
engineHeartBeat[e.rank] = now
last = engineHeartBeat[e.rank]
if (now - last) > NoPulse:
if (now - e.last) > NoPulse:
logger.info('Heart beat failure for engine %s.', e)
e.status = 'heart beat failure' # This doesn't mean much at the moment.
else:
logger.debug('Engine %d in ICU (%.1f)', e.rank, (now - last))
logger.debug('Engine %d in ICU (%.1f)', e.rank, (now - e.last))
enginesDone = False
else:
for tinfo, ckey, start, ts in outstanding.values():
Expand All @@ -1154,6 +1151,15 @@ def run(self):
if tinfo.taskId not in hbFails: # Guard against a pile up of heart beat msgs.
hbFails.add(tinfo.taskId)
self.kvs.put('.controller', ('task hb fail', (tinfo, ckey, start, ts)))
elif msg == 'engine heart beats':
now = time.time()
engineRank, taskHeartBeats = o
self.engines[engineRank].last = now
for tEngineRank, taskId in taskHeartBeats:
assert engineRank == tEngineRank
if taskId != -1:
if taskId in outstanding:
outstanding[taskId][3] = now
elif msg == 'engine started':
#TODO: reject if no more tasks or in shutdown?
rank, cRank, hn, pid, start = o
Expand All @@ -1173,7 +1179,7 @@ def run(self):
self.noMoreTasks = True
logger.info('No more tasks: %d accepted', o)
self.barriers.append(TaskReport(TaskInfo(o, -1, -1, b'#NO MORE TASKS BARRIER', None, kind='D'), start=time.time()))
# If no tasks where actually processed, we won't
# If no tasks were actually processed, we won't
# notice we are now done until the next heart beat, so
# send one now to speed things along.
self.kvs.put('.controller', ('driver heart beat', None))
Expand Down Expand Up @@ -1288,15 +1294,6 @@ def run(self):
self.kvs.put(self.trackResults + b' done tasks', str(tinfo.taskId), False)
if self.db_info.args.mailTo and self.finished%self.db_info.args.mailFreq == 0:
self.sendNotification()
elif msg == 'task heart beat':
engineRank, taskId = o
now = time.time()
engineHeartBeat[engineRank] = now
if taskId != -1:
if taskId not in outstanding:
logger.info('Unexpected heart beat for task %d.', taskId)
else:
outstanding[taskId][3] = now
else:
raise Exception('Weird message: ' + msg)

Expand Down Expand Up @@ -1451,10 +1448,10 @@ def fetch(self, kvs):
return self.kvsOp(kvs, next(self.keySeq))

class Cylinder(Thread):
def __init__(self, context, env, envres, kvs, engineRank, cylinderRank, fetchTask):
def __init__(self, context, env, envres, kvs, hbQueue, engineRank, cylinderRank, fetchTask):
super(EngineBlock.Cylinder, self).__init__()
self.daemon = True
self.context, self.engineRank, self.cylinderRank, self.fetchTask = context, engineRank, cylinderRank, fetchTask
self.context, self.hbQueue, self.engineRank, self.cylinderRank, self.fetchTask = context, hbQueue, engineRank, cylinderRank, fetchTask
self.localEnv = env.copy()
self.localEnv['DISBATCH_CORES_PER_TASK'] = str(self.context.cores_per_cylinder[self.context.nodeId])
logger.info('Cylinder %d initializing, %s cores', self.cylinderRank, self.localEnv['DISBATCH_CORES_PER_TASK'])
Expand Down Expand Up @@ -1511,14 +1508,16 @@ def main(self):
obp = OutputCollector(self.taskProc.stdout, 40, 40)
ebp = OutputCollector(self.taskProc.stderr, 40, 40)
ct = 0.0
pollInterval = 1 # in seconds.
while True:
# Popen.wait with timeout is resource intensive, so let's roll our own.
r = self.context.poll_task(self.taskProc)
if r is not None: break
time.sleep(.1)
ct += .1
time.sleep(pollInterval)
ct += pollInterval
if ct >= PulseTime:
self.kvs.put('.controller', ('task heart beat', (self.engineRank, -1 if self.cylinderRank == -1 else ti.taskId)))
# We won't track hb info for per engine tasks, since they may occur multiple times, so don't need a real taskId for those.
self.hbQueue.put((self.engineRank, -1 if self.cylinderRank == -1 else ti.taskId))
ct = 0.0
pid = self.taskProc.pid
self.taskProc = None
Expand All @@ -1541,6 +1540,7 @@ def __init__(self, kvs, context, rank):
super(EngineBlock, self).__init__(name='EngineBlock')
self.daemon = True
self.context = context
self.hbQueue = Queue()
self.rank = rank
cylinders = context.cylinders[context.nodeId]

Expand All @@ -1564,35 +1564,49 @@ def __init__(self, kvs, context, rank):
self.kvs.put('.controller', ('engine started', (self.rank, context.rank, myHostname, myPid, time.time())))
env.update(self.kvs.view('.common env'))

def indexKeyGen(temp):
def indexKeyGen(template):
c = 0
while 1:
yield temp%c
yield template%c
c += 1

def constantKeyGen(temp):
while 1: yield temp
def constantKeyGen(template):
while 1: yield template

logger.info('Engine %d running start tasks', self.rank)
peStart = self.Cylinder(context, env, envres, kvs, self.rank, -1, self.FetchTask('.per engine START %d', indexKeyGen, kvsstcp.KVSClient.view))
peStart.join()
peStart = self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, -1, self.FetchTask('.per engine START %d', indexKeyGen, kvsstcp.KVSClient.view))
self.joinWithHB(peStart)
logger.info('Engine %d completed start tasks', self.rank)

logger.info('Engine %d running normal tasks, %d-way concurrency', self.rank, cylinders)
self.cylinders = [self.Cylinder(context, env, envres, kvs, self.rank, x, self.FetchTask('.cylinder %d %d'%(self.rank, x), constantKeyGen, kvsstcp.KVSClient.get)) for x in range(cylinders)]
self.cylinders = [self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, x, self.FetchTask('.cylinder %d %d'%(self.rank, x), constantKeyGen, kvsstcp.KVSClient.get)) for x in range(cylinders)]
self.finished, self.inFlight, self.liveCylinders = 0, 0, len(self.cylinders)
self.start()
self.join()
self.joinWithHB(self)
logger.info('Engine %d completed normal tasks', self.rank)

logger.info('Engine %d running stop tasks', self.rank)
peStop = self.Cylinder(context, env, envres, kvs, self.rank, -1, self.FetchTask('.per engine STOP %d', indexKeyGen, kvsstcp.KVSClient.view))
peStop.join()
peStop = self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, -1, self.FetchTask('.per engine STOP %d', indexKeyGen, kvsstcp.KVSClient.view))
self.joinWithHB(peStop)
logger.info('Engine %d completed stop tasks', self.rank)

self.kvs.put('.controller', ('engine stopped', (self.finalStatus, self.rank)))
self.kvs.close()

def joinWithHB(self, thr):
while True:
thr.join(timeout=PulseTime)
if not thr.isAlive(): break
# We are still running, collect and transmit heart beat data.
hbs = []
while True:
try:
o = self.hbQueue.get(block=False)
hbs.append(o)
except Empty:
break
self.kvs.put('.controller', ('engine heart beats', (self.rank, hbs)))

def run(self):
#TODO: not currently checking for a per engine clean up
# task. Probably need to explicitly join pec, which means
Expand Down Expand Up @@ -1657,6 +1671,8 @@ def main(kvsq=None):
os.chdir(dbInfo.wd)
except Exception as e:
print('Failed to change working directory to "%s".'%dbInfo.wd, file=sys.stderr)
#TODO: Fail here?

context.setNode(args.node)
logger = logging.getLogger('DisBatch Engine')
lconf = {'format': '%(asctime)s %(levelname)-8s %(name)-15s: %(message)s', 'level': dbInfo.args.loglevel}
Expand Down Expand Up @@ -1728,7 +1744,8 @@ def shutdown(s=None, f=None):
os.chdir(dbInfo.wd)
except Exception as e:
print('Failed to change working directory to "%s".'%dbInfo.wd, file=sys.stderr)

#TODO: Fail here?

# Try to find a batch context.
if args.ssh_node:
context = SSHContext(dbInfo, rank, args)
Expand Down

0 comments on commit 911beb1

Please sign in to comment.