Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry ssh connection on failure #1060

Merged
merged 2 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions pyinfra/connectors/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from distutils.spawn import find_executable
from getpass import getpass
from os import path
from random import uniform
from socket import error as socket_error, gaierror
from time import sleep
from typing import TYPE_CHECKING, Type, Union

import click
Expand Down Expand Up @@ -71,6 +73,10 @@ class DataKeys:
known_hosts_file = "Custom SSH known hosts file"
strict_host_key_checking = "Override strict host keys check setting"

connect_retries = "Number of tries to connect via ssh"
connect_retry_min_delay = "Lower bound for random delay between retries"
connect_retry_max_delay = "Upper bound for random delay between retries"

paramiko_connect_kwargs = (
"Override keyword arguments passed into paramiko's `SSHClient.connect`"
)
Expand Down Expand Up @@ -226,7 +232,15 @@ def connect(state: "State", host: "Host"):
Connect to a single host. Returns the SSH client if successful. Stateless by
design so can be run in parallel.
"""
retries = host.data.get(DATA_KEYS.connect_retries, 0)

for tries_left in range(retries, -1, -1):
con = _connect(state, host, tries_left)
if con:
return con


def _connect(state: "State", host: "Host", tries_left: int):
kwargs = _make_paramiko_kwargs(state, host)
logger.debug("Connecting to: %s (%r)", host.name, kwargs)

Expand Down Expand Up @@ -273,16 +287,26 @@ def connect(state: "State", host: "Host"):
)

except SSHException as e:
_raise_connect_error(host, "SSH error", e)
if tries_left == 0:
_raise_connect_error(host, "SSH error", e)

except gaierror:
_raise_connect_error(host, "Could not resolve hostname", hostname)
if tries_left == 0:
_raise_connect_error(host, "Could not resolve hostname", hostname)

except socket_error as e:
_raise_connect_error(host, "Could not connect", e)
if tries_left == 0:
_raise_connect_error(host, "Could not connect", e)

except EOFError as e:
_raise_connect_error(host, "EOF error", e)
if tries_left == 0:
_raise_connect_error(host, "EOF error", e)

min_delay = host.data.get(DATA_KEYS.connect_retry_min_delay, 0.1)
max_delay = host.data.get(DATA_KEYS.connect_retry_max_delay, 0.5)
sleep(uniform(min_delay, max_delay))

return None


def run_shell_command(
Expand Down
52 changes: 52 additions & 0 deletions tests/test_connectors/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,3 +1042,55 @@ def test_get_sftp_fail(self, fake_sftp_client, fake_ssh_client):
"not-another-file",
print_output=True,
)

@patch("pyinfra.connectors.ssh.SSHClient")
@patch("time.sleep")
def test_ssh_connect_fail_retry(self, fake_sleep, fake_ssh_client):
for exception_class in (
SSHException,
gaierror,
socket_error,
EOFError,
):
inventory = make_inventory(
hosts=("unresposivehost",), override_data={"ssh_connect_retries": 1}
)
State(inventory, Config())

unresposivehost = inventory.get_host("unresposivehost")
assert unresposivehost.data.ssh_connect_retries == 1

fake_ssh = MagicMock()
fake_ssh.connect.side_effect = exception_class()
fake_ssh_client.return_value = fake_ssh

with self.assertRaises(ConnectError):
unresposivehost.connect(show_errors=False, raise_exceptions=True)
assert fake_sleep.called_once()
assert fake_ssh_client.connect.called_twice()

@patch("pyinfra.connectors.ssh.SSHClient")
@patch("time.sleep")
def test_ssh_connect_fail_success(self, fake_sleep, fake_ssh_client):
for exception_class in (
SSHException,
gaierror,
socket_error,
EOFError,
):
inventory = make_inventory(
hosts=("unresposivehost",), override_data={"ssh_connect_retries": 1}
)
State(inventory, Config())

unresposivehost = inventory.get_host("unresposivehost")
assert unresposivehost.data.ssh_connect_retries == 1

connection = MagicMock()
fake_ssh = MagicMock()
fake_ssh.connect.side_effect = [exception_class(), connection]
fake_ssh_client.return_value = fake_ssh

unresposivehost.connect(show_errors=False, raise_exceptions=True)
assert fake_sleep.called_once()
assert fake_ssh_client.connect.called_twice()
Loading