diff --git a/ssh_para/ssh_para.py b/ssh_para/ssh_para.py index 33ce604..c91b861 100644 --- a/ssh_para/ssh_para.py +++ b/ssh_para/ssh_para.py @@ -12,6 +12,7 @@ import threading import queue import curses +from typing import Optional from glob import glob from re import sub, escape from socket import gethostbyname_ex, gethostbyaddr, inet_aton @@ -19,7 +20,8 @@ from time import time, strftime, sleep from datetime import timedelta, datetime from subprocess import Popen, DEVNULL -from argparse import ArgumentParser, RawTextHelpFormatter +from io import BufferedReader, TextIOWrapper +from argparse import ArgumentParser, Namespace, RawTextHelpFormatter from dataclasses import dataclass from copy import deepcopy import argcomplete @@ -36,22 +38,22 @@ SSH_OPTS = os.environ.get("SSHP_OPTS") or "" MAX_DOTS = int(os.environ.get("SSHP_MAX_DOTS") or 1) INTERRUPT = False - +EXIT_CODE = 0 jobq = queue.Queue() printq = queue.Queue() pauseq = queue.Queue() -def shell_argcomplete(shell="bash"): +def shell_argcomplete(shell: str = "bash") -> None: """produce code to source in shell . <(ssh-para -C bash) ssh-para -C powershell | Out-String | Invoke-Expression """ - print(argcomplete.shellcode(["ssh-para"], shell=shell)) + print(argcomplete.shell_integration.shellcode(["ssh-para"], shell=shell)) sys.exit(0) -def log_choices(**kwargs): +def log_choices(**kwargs) -> tuple: """argcomplete -L choices""" return ( "*.status", @@ -71,7 +73,7 @@ def log_choices(**kwargs): ) -def parse_args(): +def parse_args() -> Namespace: """argument parse""" if len(sys.argv) == 1: sys.argv.append("-h") @@ -143,7 +145,7 @@ def parse_args(): default is latest ssh-para run (use -j -d to access logs if used for run) : [success,failed,timeout,killed,aborted] """, - ).completer = log_choices + ).completer = log_choices # type: ignore parser.add_argument("-s", "--script", help="script to execute") parser.add_argument("-a", "--args", nargs="+", help="script arguments") @@ -153,19 +155,19 @@ def parse_args(): return parser.parse_args() -def sigint_handler(*args): +def sigint_handler(*args) -> None: """exit all threads if signal""" global INTERRUPT INTERRUPT = True -def hometilde(directory): +def hometilde(directory: str) -> str: """substitute home to tilde in dir""" home = os.path.expanduser("~/") return sub(rf"^{escape(home)}", "~/", directory) -def resolve_hostname(host): +def resolve_hostname(host: str) -> Optional[str]: """try get fqdn from DNS""" try: res = gethostbyname_ex(host) @@ -174,7 +176,7 @@ def resolve_hostname(host): return res[0] -def resolve_in_domains(host, domains): +def resolve_in_domains(host: str, domains: list) -> str: """try get fqdn from short hostname in domains""" fqdn = resolve_hostname(host) if fqdn: @@ -187,7 +189,7 @@ def resolve_in_domains(host, domains): return host -def resolve_ip(ip): +def resolve_ip(ip: str) -> str: """try resolve hostname by reverse dns query on ip addr""" try: host = gethostbyaddr(ip) @@ -197,7 +199,7 @@ def resolve_ip(ip): return host[0] -def is_ip(host): +def is_ip(host: str) -> bool: """determine if host is valid ip""" try: inet_aton(host) @@ -206,46 +208,49 @@ def is_ip(host): return False -def resolve(host, domains): +def resolve(host: str, domains: list) -> str: """resolve hostname from ip / hostname""" if is_ip(host): return resolve_ip(host) return resolve_in_domains(host, domains) -def resolve_hosts(hosts, domains): +def resolve_hosts(hosts: list, domains: list) -> list: """try resolve hosts to get fqdn""" return [resolve(host, domains) for host in hosts] -def addstr(stdscr, *args, **kwargs): +def addstr(stdscr: Optional["curses._CursesWindow"], *args, **kwargs) -> None: """curses addstr w/o exception""" - try: - stdscr.addstr(*args, **kwargs) - except (curses.error, ValueError): - pass + if stdscr: + try: + stdscr.addstr(*args, **kwargs) + except (curses.error, ValueError): + pass -def addstrc(stdscr, *args, **kwargs): +def addstrc(stdscr: Optional["curses._CursesWindow"], *args, **kwargs) -> None: """curses addstr and clear eol""" if stdscr: addstr(stdscr, *args, **kwargs) stdscr.clrtoeol() -def tdelta(*args, **kwargs): +def tdelta(*args, **kwargs) -> str: """timedelta without microseconds""" return str(timedelta(*args, **kwargs)).split(".", maxsplit=1)[0] -def print_tee(*args, file=None, color="", **kwargs): +def print_tee( + *args, file: Optional[TextIOWrapper] = None, color: str = "", **kwargs +) -> None: """print stdout + file""" print(" ".join([color] + list(args)), file=sys.stderr, **kwargs) if file: print(*args, file=file, **kwargs) -def decode_line(line): +def decode_line(line: bytes) -> str: """try decode line exception on binary""" try: return line.decode() @@ -253,7 +258,7 @@ def decode_line(line): return "" -def last_line(fd, maxline=1000): +def last_line(fd: BufferedReader, maxline: int = 1000) -> str: """last non empty line of file""" line = "\n" fd.seek(0, os.SEEK_END) @@ -272,11 +277,11 @@ def last_line(fd, maxline=1000): return line.strip() -def short_host(host): +def short_host(host: str) -> str: """remove dns domain from fqdn""" if is_ip(host): return host - return ".".join(host.split(".")[: MAX_DOTS]) + return ".".join(host.split(".")[:MAX_DOTS]) class Segment: @@ -284,12 +289,12 @@ class Segment: def __init__( self, - stdscr, - nbsegments, - bg=None, - fg=None, - style=None, - seg1=True, + stdscr: "curses._CursesWindow", + nbsegments: int, + bg: Optional[list] = None, + fg: Optional[list] = None, + style: Optional[list] = None, + seg1: bool = True, ): """curses inits""" self.stdscr = stdscr @@ -313,7 +318,7 @@ def __init__( curses.init_pair(i * 2 + 2, fg[i], bg[i]) curses.init_pair(i * 2 + 3, bg[i], bg[i + 1]) - def set_segments(self, x, y, segments): + def set_segments(self, x: int, y: int, segments: list) -> None: """display powerline""" addstr(self.stdscr, y, x, SYMBOL_BEGIN, curses.color_pair(1)) for i, segment in enumerate(segments): @@ -327,16 +332,16 @@ class JobStatus: """handle job statuses""" status: str = "IDLE" - start: str = "" + start: float = 0 host: str = "" shorthost: str = "" - duration: int = 0 + duration: float = 0 pid: int = -1 - exit: int = None + exit: Optional[int] = None logfile: str = "" log: str = "" thread_id: int = -1 - fdlog: int = 0 + fdlog: Optional[BufferedReader] = None class JobStatusLog: @@ -346,27 +351,27 @@ class JobStatusLog: class LogStatus: """fd log/count status""" - fd: int = 0 + fd: Optional[TextIOWrapper] = None nb: int = 0 - def __init__(self, dirlog): + def __init__(self, dirlog: str): """open log files for each status""" statuses = ["SUCCESS", "FAILED", "TIMEOUT", "KILLED", "ABORTED"] self.lstatus = {} for status in statuses: self.lstatus[status] = self.LogStatus(fd=self.open(dirlog, status)) - def open(self, dirlog, status): + def open(self, dirlog: str, status: str) -> TextIOWrapper: """open log file for status""" return open(f"{dirlog}/{status.lower()}.status", "w", encoding="UTF-8") - def addhost(self, host, status): + def addhost(self, host: str, status: str) -> None: """add host in status log""" if status in self.lstatus: self.lstatus[status].nb += 1 print(host, file=self.lstatus[status].fd) - def result(self): + def result(self) -> str: """print counts of statuses""" return " ".join([f"{s.lower()[:4]}: {v.nb}" for s, v in self.lstatus.items()]) @@ -395,13 +400,13 @@ class JobPrint(threading.Thread): def __init__( self, - command, - nbthreads, - nbjobs, - dirlog, - timeout=0, - verbose=False, - maxhostlen=15, + command: list, + nbthreads: int, + nbjobs: int, + dirlog: str, + timeout: float = 0, + verbose: bool = False, + maxhostlen: int = 15, ): """init properties / thread""" super().__init__() @@ -414,7 +419,7 @@ def __init__( self.dirlog = dirlog self.aborted = [] self.startsec = time() - self.stdscr = None + self.stdscr: Optional[curses._CursesWindow] = None self.paused = False self.timeout = timeout self.verbose = verbose @@ -426,10 +431,10 @@ def __init__( self.init_curses() super().__init__() - def __del__(self): + def __del__(self) -> None: self.print_summary() - def init_curses(self): + def init_curses(self) -> None: """curses window init""" self.stdscr = curses.initscr() curses.raw() @@ -461,20 +466,13 @@ def init_curses(self): curses.init_pair(self.COLOR_GAUGE, 8, curses.COLOR_BLUE) curses.init_pair(self.COLOR_HOST, curses.COLOR_YELLOW, curses.COLOR_BLACK) - def join(self, *args): - """returns nb failed""" - super().join(*args) - if INTERRUPT: - return 130 - return self.nbfailed > 0 - - def killall(self): + def killall(self) -> None: """kill all running threads pid""" for status in self.th_status: if status.status == "RUNNING": self.kill(status.thread_id) - def run(self): + def run(self) -> None: """get threads status change""" jobsdur = 0 nbsshjobs = 0 @@ -482,7 +480,7 @@ def run(self): if INTERRUPT: self.abort_jobs() try: - jstatus: JobStatus = printq.get(timeout=0.1) + jstatus: Optional[JobStatus] = printq.get(timeout=0.1) except queue.Empty: jstatus = None th_id = None @@ -492,7 +490,7 @@ def run(self): jstatus.log = last_line(jstatus.fdlog) if jstatus.exit is not None: # FINISHED jstatus.fdlog.close() - jstatus.fdlog = 0 + jstatus.fdlog = None nbsshjobs += 1 jobsdur += jstatus.duration if jstatus.status == "FAILED": @@ -523,31 +521,31 @@ def run(self): if len(self.job_status) == self.nbjobs: break self.resume() + global EXIT_CODE + EXIT_CODE = 130 if INTERRUPT else (self.nbfailed > 0) if self.stdscr: addstrc(self.stdscr, curses.LINES - 1, 0, "All jobs finished") self.stdscr.refresh() self.stdscr.getch() curses.endwin() - # self.print_summary() - # if INTERRUPT: - # os._exit(1) - - def check_timeout(self, th_id, duration): + def check_timeout(self, th_id: int, duration: float) -> None: """kill ssh if duration exceeds timeout""" if not self.timeout: return if duration > self.timeout: self.kill(th_id, "TIMEOUT") - def check_timeouts(self): + def check_timeouts(self) -> None: """check threads timemout""" for i, jstatus in enumerate(self.th_status): if jstatus.status == "RUNNING": duration = time() - jstatus.start self.check_timeout(i, duration) - def print_status(self, status, duration=0, avgjobdur=0): + def print_status( + self, status: str, duration: float = 0, avgjobdur: float = 0 + ) -> None: """print thread status""" color = self.status_color[status] addstr(self.stdscr, SYMBOL_BEGIN, curses.color_pair(color + 1)) @@ -563,7 +561,7 @@ def print_status(self, status, duration=0, avgjobdur=0): addstr(self.stdscr, SYMBOL_END, curses.color_pair(color + 1)) addstr(self.stdscr, f" {tdelta(seconds=round(duration))}") - def print_job(self, line_num, jstatus, duration, avgjobdur): + def print_job(self, line_num: int, jstatus, duration: float, avgjobdur: float): """print host runnin on thread and last out line""" th_id = str(jstatus.thread_id).zfill(2) addstr(self.stdscr, line_num, 0, f" {th_id} ") @@ -580,8 +578,11 @@ def print_job(self, line_num, jstatus, duration, avgjobdur): ) addstrc(self.stdscr, jstatus.log) - def display_curses(self, status_id, total_dur, jobsdur, nbsshjobs): + def display_curses( + self, status_id: Optional[int], total_dur: str, jobsdur, nbsshjobs + ) -> None: """display threads statuses""" + assert self.stdscr is not None nbend = len(self.job_status) last_start = 0 avgjobdur = 0 @@ -637,9 +638,10 @@ def display_curses(self, status_id, total_dur, jobsdur, nbsshjobs): addstrc(self.stdscr, curses.LINES - 1, 0, "[a]bort [k]ill [p]ause") self.stdscr.refresh() - def get_key(self): + def get_key(self) -> None: """manage interactive actions""" global INTERRUPT + assert self.stdscr is not None self.stdscr.nodelay(True) ch = self.stdscr.getch() self.stdscr.nodelay(False) @@ -657,9 +659,10 @@ def get_key(self): self.abort_jobs() self.killall() - def curses_kill(self): + def curses_kill(self) -> None: """interactive kill pid of ssh thread""" curses.echo() + assert self.stdscr is not None addstrc(self.stdscr, curses.LINES - 1, 0, "kill job in thread: ") try: th_id = int(self.stdscr.getstr()) @@ -669,7 +672,7 @@ def curses_kill(self): curses.noecho() self.kill(th_id) - def kill(self, th_id, status="KILLED"): + def kill(self, th_id, status="KILLED") -> None: """kill pid of thread id""" th_status = self.th_status[th_id] if th_status.pid > 0: @@ -679,21 +682,22 @@ def kill(self, th_id, status="KILLED"): except ProcessLookupError: pass - def pause(self): + def pause(self) -> None: """pause JobRun threads""" if not self.paused: self.paused = True pauseq.put(True) - def resume(self): + def resume(self) -> None: """resume JobRun threads""" if self.paused: self.paused = False pauseq.get() pauseq.task_done() - def print_finished(self, line_num): + def print_finished(self, line_num: int) -> None: """display finished jobs""" + assert self.stdscr is not None addstr(self.stdscr, curses.LINES - 1, 0, "") inter = self.verbose + 1 for jstatus in self.job_status[::-1]: @@ -715,7 +719,7 @@ def print_finished(self, line_num): line_num += inter self.stdscr.clrtobot() - def abort_jobs(self): + def abort_jobs(self) -> None: """aborts remaining jobs""" if not jobq.qsize(): return @@ -732,7 +736,7 @@ def abort_jobs(self): self.aborted.append(job.host) self.resume() - def print_summary(self): + def print_summary(self) -> None: """print/log summary of jobs""" end = strftime("%X") total_dur = tdelta(seconds=round(time() - self.startsec)) @@ -782,13 +786,13 @@ def print_summary(self): class Job: """manage job execution""" - def __init__(self, host, command): + def __init__(self, host: str, command: list): """job to run on host init""" self.host = host self.command = command self.status = JobStatus(host=host, shorthost=short_host(host)) - def exec(self, th_id, dirlog): + def exec(self, th_id: int, dirlog: str) -> None: """run command on host using ssh""" self.status.start = time() self.status.thread_id = th_id @@ -837,13 +841,13 @@ class JobRun(threading.Thread): Threads launching jobs from rung in parallel """ - def __init__(self, thread_id, dirlog=""): + def __init__(self, thread_id: int, dirlog: str = ""): """constructor""" self.thread_id = thread_id self.dirlog = dirlog super().__init__() - def run(self): + def run(self) -> None: """schedule Jobs / pause / resume""" while True: pauseq.join() @@ -857,7 +861,7 @@ def run(self): jobq.task_done() -def script_command(script, args): +def script_command(script: str, args: list) -> str: """build ssh command to transfer and execute script with args""" try: with open(script, "r", encoding="UTF-8") as fd: @@ -887,7 +891,7 @@ def script_command(script, args): return command -def get_hosts(hostsfile, hosts): +def get_hosts(hostsfile: str, hosts: list) -> list: """returns hosts list from args host or reading hostsfile""" if hosts: return hosts @@ -896,14 +900,14 @@ def get_hosts(hostsfile, hosts): sys.exit(1) try: with open(hostsfile, "r", encoding="UTF-8") as fhosts: - hosts = fhosts.read().splitlines() + hosts = list(filter(len, fhosts.read().splitlines())) except OSError: print(f"ERROR: ssh-para: Cannot open {hostsfile}", file=sys.stderr) sys.exit(1) return hosts -def tstodatetime(ts): +def tstodatetime(ts) -> Optional[str]: """timestamp to datetime""" try: tsi = int(ts) @@ -912,7 +916,7 @@ def tstodatetime(ts): return datetime.fromtimestamp(tsi).strftime("%Y-%m-%d %H:%M:%S") -def printfile(file, text): +def printfile(file: str, text: str) -> bool: """try print text to file""" try: with open(file, "w", encoding="UTF-8") as fd: @@ -922,7 +926,7 @@ def printfile(file, text): return True -def readfile(file): +def readfile(file: str) -> Optional[str]: """try read from file""" try: with open(file, "r", encoding="UTF-8") as fd: @@ -932,7 +936,7 @@ def readfile(file): return text.strip() -def log_results(dirlog, job): +def log_results(dirlog: str, job: str) -> None: """print log results in dirlog/job""" if job: dirlog = f"{dirlog}/{job}" @@ -952,7 +956,7 @@ def log_results(dirlog, job): sys.exit(0) -def log_content(dirlog, wildcard): +def log_content(dirlog: str, wildcard: str) -> None: """print log file content in dirlog matching wildcard""" dirpattern = f"{dirlog}/{wildcard}" files = glob(dirpattern) @@ -974,7 +978,7 @@ def log_content(dirlog, wildcard): print() -def isdir(directory): +def isdir(directory: str) -> bool: """test dir exits""" try: if os.path.isdir(directory): @@ -984,7 +988,7 @@ def isdir(directory): return False -def get_latest_dir(dirlog): +def get_latest_dir(dirlog: str) -> str: """retrieve last log dir""" try: dirs = glob(f"{dirlog}/[0-9]*") @@ -999,7 +1003,7 @@ def get_latest_dir(dirlog): sys.exit(1) -def log_contents(wildcards, dirlog, job): +def log_contents(wildcards: list, dirlog: str, job: str): """print logs content according to wildcards *.out *.success...""" if job: dirlog += f"/{job}" @@ -1016,7 +1020,7 @@ def log_contents(wildcards, dirlog, job): sys.exit(0) -def make_latest(dirlog, dirlogtime): +def make_latest(dirlog: str, dirlogtime: str) -> None: """make symlink to last log directory""" latest = f"{dirlog}/latest" try: @@ -1027,7 +1031,7 @@ def make_latest(dirlog, dirlogtime): pass -def make_logdir(dirlog, job): +def make_logdir(dirlog: str, job: str) -> str: """create log directory""" jobdirlog = dirlog if job: @@ -1045,7 +1049,7 @@ def make_logdir(dirlog, job): return dirlogtime -def main(): +def main() -> None: """argument read / read hosts file / prepare commands / launch jobs""" global MAX_DOTS init(autoreset=True) @@ -1071,7 +1075,7 @@ def main(): else: command = args.ssh_args if not args.ssh_args: - print("ERROR: ssh-para: No ssh command supplied", file=sys.stderr) + print("Error: ssh-para: No ssh command supplied", file=sys.stderr) sys.exit(1) if args.hostsfile: hostsfile = os.path.basename(args.hostsfile) @@ -1114,8 +1118,8 @@ def main(): sleep(args.delay) jobq.join() - exit_code = p.join() - sys.exit(exit_code) + p.join() + sys.exit(EXIT_CODE) if __name__ == "__main__":