diff --git a/src/jobflow_remote/remote/host/remote.py b/src/jobflow_remote/remote/host/remote.py index 4d8f8f02..cb9b1007 100644 --- a/src/jobflow_remote/remote/host/remote.py +++ b/src/jobflow_remote/remote/host/remote.py @@ -3,9 +3,11 @@ import io import logging import shlex +import traceback from pathlib import Path import fabric +from paramiko.ssh_exception import SSHException from jobflow_remote.remote.host.base import BaseHost @@ -33,6 +35,7 @@ def __init__( keepalive=60, shell_cmd="bash", login_shell=True, + retry_on_closed_connection=True, ): self.host = host self.user = user @@ -47,6 +50,10 @@ def __init__( self.keepalive = keepalive self.shell_cmd = shell_cmd self.login_shell = login_shell + self.retry_on_closed_connection = retry_on_closed_connection + self._create_connection() + + def _create_connection(self): self._connection = fabric.Connection( host=self.host, user=self.user, @@ -114,8 +121,12 @@ def execute( remote_command = command with self.connection.cd(workdir): - out = self.connection.run( - remote_command, hide=True, warn=True, timeout=timeout + out = self._execute_remote_func( + self.connection.run, + remote_command, + hide=True, + warn=True, + timeout=timeout, ) return out.stdout, out.stderr, out.exited @@ -145,7 +156,7 @@ def write_text_file(self, filepath: str | Path, content: str): f = io.StringIO(content) - self.connection.put(f, str(filepath)) + self._execute_remote_func(self.connection.put, f, str(filepath)) def connect(self): self.connection.open() @@ -166,13 +177,53 @@ def is_connected(self) -> bool: def put(self, src, dst): self._check_connected() - self.connection.put(src, dst) + self._execute_remote_func(self.connection.put, src, dst) def get(self, src, dst): self._check_connected() - self.connection.get(src, dst) + self._execute_remote_func(self.connection.get, src, dst) def copy(self, src, dst): cmd = ["cp", str(src), str(dst)] self.execute(cmd) + + def _execute_remote_func(self, remote_cmd, *args, **kwargs): + if self.retry_on_closed_connection: + try: + return remote_cmd(*args, **kwargs) + except OSError as e: + msg = getattr(e, "message", str(e)) + error = e + if "Socket is closed" not in msg: + raise e + except SSHException as e: + error = e + msg = getattr(e, "message", str(e)) + if "Server connection dropped" not in msg: + raise e + except EOFError as e: + error = e + else: + return remote_cmd(*args, **kwargs) + + # if the code gets here one of the errors that could be due to drop of the + # connection occurred. Try to close and reopen the connection and retry + # one more time + logger.warning( + f"Error while trying to execute a command on host {self.host}:\n" + f"{''.join(traceback.format_exception(error))}" + "Probably due to the connection dropping. " + "Will reopen the connection and retry." + ) + try: + self.connection.close() + except Exception: + logger.warning( + "Error while closing the connection during a retry. " + "Proceeding with the retry.", + exc_info=True, + ) + self._create_connection() + self.connect() + return remote_cmd(*args, **kwargs)