Skip to content

Commit

Permalink
restart connection for remote host if connection is dropped
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Jul 11, 2023
1 parent a814ca3 commit e3bae86
Showing 1 changed file with 56 additions and 5 deletions.
61 changes: 56 additions & 5 deletions src/jobflow_remote/remote/host/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import io
import logging
import shlex
import traceback
from pathlib import Path

import fabric
from paramiko.ssh_exception import SSHException

from jobflow_remote.remote.host.base import BaseHost

Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(
keepalive=60,
shell_cmd="bash",
login_shell=True,
retry_on_closed_connection=True,
):
self.host = host
self.user = user
Expand All @@ -47,6 +50,10 @@ def __init__(
self.keepalive = keepalive
self.shell_cmd = shell_cmd
self.login_shell = login_shell
self.retry_on_closed_connection = retry_on_closed_connection
self._create_connection()

def _create_connection(self):
self._connection = fabric.Connection(
host=self.host,
user=self.user,
Expand Down Expand Up @@ -114,8 +121,12 @@ def execute(
remote_command = command

with self.connection.cd(workdir):
out = self.connection.run(
remote_command, hide=True, warn=True, timeout=timeout
out = self._execute_remote_func(
self.connection.run,
remote_command,
hide=True,
warn=True,
timeout=timeout,
)

return out.stdout, out.stderr, out.exited
Expand Down Expand Up @@ -145,7 +156,7 @@ def write_text_file(self, filepath: str | Path, content: str):

f = io.StringIO(content)

self.connection.put(f, str(filepath))
self._execute_remote_func(self.connection.put, f, str(filepath))

def connect(self):
self.connection.open()
Expand All @@ -166,13 +177,53 @@ def is_connected(self) -> bool:
def put(self, src, dst):
self._check_connected()

self.connection.put(src, dst)
self._execute_remote_func(self.connection.put, src, dst)

def get(self, src, dst):
self._check_connected()

self.connection.get(src, dst)
self._execute_remote_func(self.connection.get, src, dst)

def copy(self, src, dst):
cmd = ["cp", str(src), str(dst)]
self.execute(cmd)

def _execute_remote_func(self, remote_cmd, *args, **kwargs):
if self.retry_on_closed_connection:
try:
return remote_cmd(*args, **kwargs)
except OSError as e:
msg = getattr(e, "message", str(e))
error = e
if "Socket is closed" not in msg:
raise e
except SSHException as e:
error = e
msg = getattr(e, "message", str(e))
if "Server connection dropped" not in msg:
raise e
except EOFError as e:
error = e
else:
return remote_cmd(*args, **kwargs)

# if the code gets here one of the errors that could be due to drop of the
# connection occurred. Try to close and reopen the connection and retry
# one more time
logger.warning(
f"Error while trying to execute a command on host {self.host}:\n"
f"{''.join(traceback.format_exception(error))}"
"Probably due to the connection dropping. "
"Will reopen the connection and retry."
)
try:
self.connection.close()
except Exception:
logger.warning(
"Error while closing the connection during a retry. "
"Proceeding with the retry.",
exc_info=True,
)
self._create_connection()
self.connect()
return remote_cmd(*args, **kwargs)

0 comments on commit e3bae86

Please sign in to comment.