From 911beb1bdcd4c8b806728290e297419520a8d805 Mon Sep 17 00:00:00 2001 From: Nicholas Carriero Date: Mon, 23 Oct 2023 16:20:00 -0400 Subject: [PATCH] Move from per-task to per-engine heart beat messages The driver was sometimes struggling to keep up with messages for runs with 1000s of concurrent tasks. We now send one heart beat message per engine per beat. The message has a list of task heart beats in it. NoPulse has been increased to account for an extra time out delay incurred by the engine while gathering the task heart beats. Polling for task completion is also done at a lower rate. Also some minor typo fixes and variable renaming. --- disbatchc/disBatch.py | 85 ++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/disbatchc/disBatch.py b/disbatchc/disBatch.py index 144d03b..e059623 100644 --- a/disbatchc/disBatch.py +++ b/disbatchc/disBatch.py @@ -52,7 +52,7 @@ # Heart beat info. PulseTime = 30 -NoPulse = 3*PulseTime + 1 # A task is considered dead if we don't hear from it after 3 heart beat cycles. +NoPulse = 4*PulseTime + 1 # A task is considered dead if we don't hear from it after 4 heart beat cycles. Counting is approximate because of the interaction of various timeouts. logger = logging.getLogger('DisBatch') warnings.formatwarning = lambda msg, cat, *args, **kwargs: f'{cat.__name__}: {msg}\n' # a friendlier warning format @@ -232,7 +232,7 @@ def retireEnv(self, nodeList, retList): return env def retireNodeList(self, nodeList, retList): - '''Called when one or mode nodes has exited. May be overridden to release resources.''' + '''Called when one or more nodes has exited. May be overridden to release resources.''' for ret in retList: if ret: self.error = True @@ -1092,7 +1092,7 @@ def stopped(self, status): def run(self): self.kvs.put('DisBatch status', '', False) - cRank2taskCount, cylKey2eRank, engineHeartBeat, enginesDone, finishedTasks, hbFails = DD(int), {}, {}, False, {}, set() + cRank2taskCount, cylKey2eRank, enginesDone, finishedTasks, hbFails = DD(int), {}, False, {}, set() notifiedAllDone, outstanding, pending, retired = False, {}, [], -1 assignedCylinders, freeCylinders = set(), set() while not (self.tasksDone and enginesDone): @@ -1138,14 +1138,11 @@ def run(self): enginesDone = True for e in self.engines.values(): if e.status != 'stopped': - if e.rank not in engineHeartBeat: - engineHeartBeat[e.rank] = now - last = engineHeartBeat[e.rank] - if (now - last) > NoPulse: + if (now - e.last) > NoPulse: logger.info('Heart beat failure for engine %s.', e) e.status = 'heart beat failure' # This doesn't mean much at the moment. else: - logger.debug('Engine %d in ICU (%.1f)', e.rank, (now - last)) + logger.debug('Engine %d in ICU (%.1f)', e.rank, (now - e.last)) enginesDone = False else: for tinfo, ckey, start, ts in outstanding.values(): @@ -1154,6 +1151,15 @@ def run(self): if tinfo.taskId not in hbFails: # Guard against a pile up of heart beat msgs. hbFails.add(tinfo.taskId) self.kvs.put('.controller', ('task hb fail', (tinfo, ckey, start, ts))) + elif msg == 'engine heart beats': + now = time.time() + engineRank, taskHeartBeats = o + self.engines[engineRank].last = now + for tEngineRank, taskId in taskHeartBeats: + assert engineRank == tEngineRank + if taskId != -1: + if taskId in outstanding: + outstanding[taskId][3] = now elif msg == 'engine started': #TODO: reject if no more tasks or in shutdown? rank, cRank, hn, pid, start = o @@ -1173,7 +1179,7 @@ def run(self): self.noMoreTasks = True logger.info('No more tasks: %d accepted', o) self.barriers.append(TaskReport(TaskInfo(o, -1, -1, b'#NO MORE TASKS BARRIER', None, kind='D'), start=time.time())) - # If no tasks where actually processed, we won't + # If no tasks were actually processed, we won't # notice we are now done until the next heart beat, so # send one now to speed things along. self.kvs.put('.controller', ('driver heart beat', None)) @@ -1288,15 +1294,6 @@ def run(self): self.kvs.put(self.trackResults + b' done tasks', str(tinfo.taskId), False) if self.db_info.args.mailTo and self.finished%self.db_info.args.mailFreq == 0: self.sendNotification() - elif msg == 'task heart beat': - engineRank, taskId = o - now = time.time() - engineHeartBeat[engineRank] = now - if taskId != -1: - if taskId not in outstanding: - logger.info('Unexpected heart beat for task %d.', taskId) - else: - outstanding[taskId][3] = now else: raise Exception('Weird message: ' + msg) @@ -1451,10 +1448,10 @@ def fetch(self, kvs): return self.kvsOp(kvs, next(self.keySeq)) class Cylinder(Thread): - def __init__(self, context, env, envres, kvs, engineRank, cylinderRank, fetchTask): + def __init__(self, context, env, envres, kvs, hbQueue, engineRank, cylinderRank, fetchTask): super(EngineBlock.Cylinder, self).__init__() self.daemon = True - self.context, self.engineRank, self.cylinderRank, self.fetchTask = context, engineRank, cylinderRank, fetchTask + self.context, self.hbQueue, self.engineRank, self.cylinderRank, self.fetchTask = context, hbQueue, engineRank, cylinderRank, fetchTask self.localEnv = env.copy() self.localEnv['DISBATCH_CORES_PER_TASK'] = str(self.context.cores_per_cylinder[self.context.nodeId]) logger.info('Cylinder %d initializing, %s cores', self.cylinderRank, self.localEnv['DISBATCH_CORES_PER_TASK']) @@ -1511,14 +1508,16 @@ def main(self): obp = OutputCollector(self.taskProc.stdout, 40, 40) ebp = OutputCollector(self.taskProc.stderr, 40, 40) ct = 0.0 + pollInterval = 1 # in seconds. while True: # Popen.wait with timeout is resource intensive, so let's roll our own. r = self.context.poll_task(self.taskProc) if r is not None: break - time.sleep(.1) - ct += .1 + time.sleep(pollInterval) + ct += pollInterval if ct >= PulseTime: - self.kvs.put('.controller', ('task heart beat', (self.engineRank, -1 if self.cylinderRank == -1 else ti.taskId))) + # We won't track hb info for per engine tasks, since they may occur multiple times, so don't need a real taskId for those. + self.hbQueue.put((self.engineRank, -1 if self.cylinderRank == -1 else ti.taskId)) ct = 0.0 pid = self.taskProc.pid self.taskProc = None @@ -1541,6 +1540,7 @@ def __init__(self, kvs, context, rank): super(EngineBlock, self).__init__(name='EngineBlock') self.daemon = True self.context = context + self.hbQueue = Queue() self.rank = rank cylinders = context.cylinders[context.nodeId] @@ -1564,35 +1564,49 @@ def __init__(self, kvs, context, rank): self.kvs.put('.controller', ('engine started', (self.rank, context.rank, myHostname, myPid, time.time()))) env.update(self.kvs.view('.common env')) - def indexKeyGen(temp): + def indexKeyGen(template): c = 0 while 1: - yield temp%c + yield template%c c += 1 - def constantKeyGen(temp): - while 1: yield temp + def constantKeyGen(template): + while 1: yield template logger.info('Engine %d running start tasks', self.rank) - peStart = self.Cylinder(context, env, envres, kvs, self.rank, -1, self.FetchTask('.per engine START %d', indexKeyGen, kvsstcp.KVSClient.view)) - peStart.join() + peStart = self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, -1, self.FetchTask('.per engine START %d', indexKeyGen, kvsstcp.KVSClient.view)) + self.joinWithHB(peStart) logger.info('Engine %d completed start tasks', self.rank) logger.info('Engine %d running normal tasks, %d-way concurrency', self.rank, cylinders) - self.cylinders = [self.Cylinder(context, env, envres, kvs, self.rank, x, self.FetchTask('.cylinder %d %d'%(self.rank, x), constantKeyGen, kvsstcp.KVSClient.get)) for x in range(cylinders)] + self.cylinders = [self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, x, self.FetchTask('.cylinder %d %d'%(self.rank, x), constantKeyGen, kvsstcp.KVSClient.get)) for x in range(cylinders)] self.finished, self.inFlight, self.liveCylinders = 0, 0, len(self.cylinders) self.start() - self.join() + self.joinWithHB(self) logger.info('Engine %d completed normal tasks', self.rank) logger.info('Engine %d running stop tasks', self.rank) - peStop = self.Cylinder(context, env, envres, kvs, self.rank, -1, self.FetchTask('.per engine STOP %d', indexKeyGen, kvsstcp.KVSClient.view)) - peStop.join() + peStop = self.Cylinder(context, env, envres, kvs, self.hbQueue, self.rank, -1, self.FetchTask('.per engine STOP %d', indexKeyGen, kvsstcp.KVSClient.view)) + self.joinWithHB(peStop) logger.info('Engine %d completed stop tasks', self.rank) self.kvs.put('.controller', ('engine stopped', (self.finalStatus, self.rank))) self.kvs.close() + def joinWithHB(self, thr): + while True: + thr.join(timeout=PulseTime) + if not thr.isAlive(): break + # We are still running, collect and transmit heart beat data. + hbs = [] + while True: + try: + o = self.hbQueue.get(block=False) + hbs.append(o) + except Empty: + break + self.kvs.put('.controller', ('engine heart beats', (self.rank, hbs))) + def run(self): #TODO: not currently checking for a per engine clean up # task. Probably need to explicitly join pec, which means @@ -1657,6 +1671,8 @@ def main(kvsq=None): os.chdir(dbInfo.wd) except Exception as e: print('Failed to change working directory to "%s".'%dbInfo.wd, file=sys.stderr) + #TODO: Fail here? + context.setNode(args.node) logger = logging.getLogger('DisBatch Engine') lconf = {'format': '%(asctime)s %(levelname)-8s %(name)-15s: %(message)s', 'level': dbInfo.args.loglevel} @@ -1728,7 +1744,8 @@ def shutdown(s=None, f=None): os.chdir(dbInfo.wd) except Exception as e: print('Failed to change working directory to "%s".'%dbInfo.wd, file=sys.stderr) - + #TODO: Fail here? + # Try to find a batch context. if args.ssh_node: context = SSHContext(dbInfo, rank, args)