Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add retries to SCPMover.copy() #153

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 73 additions & 41 deletions trollmoves/movers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def move(self):
def get_connection(self, hostname, port, username=None):
"""Get the connection."""
with self.active_connection_lock:
LOGGER.debug("Destination username and passwd: %s %s",
self._dest_username, self._dest_password)
LOGGER.debug("Destination username: %s", self._dest_username)
LOGGER.debug('Getting connection to %s@%s:%s',
username, hostname, port)
try:
Expand Down Expand Up @@ -297,38 +296,57 @@ class ScpMover(Mover):

def open_connection(self):
"""Open a connection."""
from paramiko import SSHClient, SSHException
ssh_connection = self._run_with_retries(self._open_connection, "ssh connect")
if ssh_connection is None:
raise IOError("Failed to ssh connect after 3 attempts")
return ssh_connection

def _open_connection(self):
from paramiko import SSHException

try:
ssh_connection = self._create_ssh_connection()
except SSHException as sshe:
LOGGER.exception("Failed to init SSHClient: %s", str(sshe))
except socket.timeout as sto:
LOGGER.exception("SSH connection timed out: %s", str(sto))
except Exception as err:
LOGGER.exception("Unknown exception at init SSHClient: %s", str(err))
else:
return ssh_connection

return None

def _create_ssh_connection(self):
from paramiko import SSHClient

retries = 3
ssh_key_filename = self.attrs.get("ssh_key_filename", None)
timeout = self.attrs.get("ssh_connection_timeout", None)
while retries > 0:
retries -= 1
try:
ssh_connection = SSHClient()
ssh_connection.load_system_host_keys()
ssh_connection.connect(self.destination.hostname,
username=self._dest_username,
port=self.destination.port or 22,
key_filename=ssh_key_filename,
timeout=timeout)
LOGGER.debug("Successfully connected to %s:%s as %s",
self.destination.hostname,
self.destination.port or 22,
self._dest_username)
except SSHException as sshe:
LOGGER.exception("Failed to init SSHClient: %s", str(sshe))
except socket.timeout as sto:
LOGGER.exception("SSH connection timed out: %s", str(sto))
except Exception as err:
LOGGER.exception("Unknown exception at init SSHClient: %s", str(err))
else:
return ssh_connection

ssh_connection.close()
time.sleep(2)
LOGGER.debug("Retrying ssh connect ...")
raise IOError("Failed to ssh connect after 3 attempts")
ssh_connection = SSHClient()
ssh_connection.load_system_host_keys()
ssh_connection.connect(self.destination.hostname,
username=self._dest_username,
port=self.destination.port or 22,
key_filename=ssh_key_filename,
timeout=timeout)
LOGGER.debug("Successfully connected to %s:%s as %s",
self.destination.hostname,
self.destination.port or 22,
self._dest_username)
return ssh_connection

def _run_with_retries(self, func, name):
num_retries = self.attrs.get("num_ssh_retries", 3)
res = None
for _ in range(num_retries):
res = func()
if res:
break
time.sleep(2)
LOGGER.debug(f"Retrying {name} ...")

return res

@staticmethod
def is_connected(connection):
Expand Down Expand Up @@ -357,21 +375,16 @@ def move(self):

def copy(self):
"""Upload the file."""
from scp import SCPClient
_ = self._run_with_retries(self._copy, "SCP copy")

ssh_connection = self.get_connection(self.destination.hostname,
self.destination.port or 22,
self._dest_username)

try:
scp = SCPClient(ssh_connection.get_transport())
except Exception as err:
LOGGER.error("Failed to initiate SCPClient: %s", str(err))
ssh_connection.close()
raise
def _copy(self):
from scp import SCPException

success = False
try:
scp = self._get_scp_client()
scp.put(self.origin, self.destination.path)
success = True
except OSError as osex:
if osex.errno == 2:
LOGGER.error("No such file or directory. File not transfered: "
Expand All @@ -380,6 +393,8 @@ def copy(self):
else:
LOGGER.error("OSError in scp.put: %s", str(osex))
raise
except SCPException as err:
LOGGER.error("SCP failed: %s", str(err))
except Exception as err:
LOGGER.error("Something went wrong with scp: %s", str(err))
LOGGER.error("Exception name %s", type(err).__name__)
Expand All @@ -388,6 +403,23 @@ def copy(self):
finally:
scp.close()

return success

def _get_scp_client(self):
from scp import SCPClient

ssh_connection = self.get_connection(self.destination.hostname,
self.destination.port or 22,
self._dest_username)

try:
scp = SCPClient(ssh_connection.get_transport())
except Exception as err:
LOGGER.error("Failed to initiate SCPClient: %s", str(err))
ssh_connection.close()
raise
return scp


class SftpMover(Mover):
"""Move files over sftp."""
Expand Down
8 changes: 8 additions & 0 deletions trollmoves/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def _read_ini_config(filename):
_parse_nameserver(res[section], cp_[section])
_parse_addresses(res[section])
_parse_delete(res[section], cp_[section])
_parse_ssh_retries(res[section], cp_[section])
if not _check_origin_and_listen(res, section):
continue
if not _check_topic(res, section):
Expand All @@ -594,6 +595,7 @@ def _set_config_defaults(conf):
conf.setdefault("transfer_req_timeout", 10 * DEFAULT_REQ_TIMEOUT)
conf.setdefault("ssh_key_filename", None)
conf.setdefault("delete", False)
conf.setdefault("num_ssh_retries", 3)


def _parse_nameserver(conf, raw_conf):
Expand All @@ -617,6 +619,12 @@ def _parse_delete(conf, raw_conf):
conf["delete"] = val


def _parse_ssh_retries(conf, raw_conf):
val = raw_conf.getint("num_ssh_retries")
if val is not None:
conf["num_ssh_retries"] = val


def _check_origin_and_listen(res, section):
if ("origin" not in res[section]) and ('listen' not in res[section]):
LOGGER.warning("Incomplete section %s: add an 'origin' or 'listen' item.", section)
Expand Down
48 changes: 48 additions & 0 deletions trollmoves/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,51 @@ def test_requestmanager_is_delete_set_True(patch_validate_file_pattern):
port = 9876
req_man = RequestManager(port, attrs={'delete': True})
assert req_man._is_delete_set() is True


CONFIG_MINIMAL = """
[test]
origin = foo
listen = bar
"""
CONFIG_NUM_SSH_RETRIES = CONFIG_MINIMAL + """
num_ssh_retries = 5
"""


def test_config_defaults():
"""Test that config defaults are set."""
from trollmoves.server import read_config

with NamedTemporaryFile(mode='w') as tmp_file:
tmp_file.write(CONFIG_MINIMAL)
tmp_file.file.flush()

config = read_config(tmp_file.name)

test_section = config["test"]
assert "origin" in test_section
assert "listen" in test_section
assert test_section["working_directory"] is None
assert test_section["compression"] is False
assert test_section["req_timeout"] == 1
assert test_section["transfer_req_timeout"] == 10
assert test_section["ssh_key_filename"] is None
assert test_section["delete"] is False
assert test_section["num_ssh_retries"] == 3
assert test_section["nameserver"] is None
assert test_section["addresses"] is None


def test_config_num_ssh_retries():
"""Test that config defaults are set."""
from trollmoves.server import read_config

with NamedTemporaryFile(mode='w') as tmp_file:
tmp_file.write(CONFIG_NUM_SSH_RETRIES)
tmp_file.file.flush()

config = read_config(tmp_file.name)

test_section = config["test"]
assert test_section["num_ssh_retries"] == 5