Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of mila init command #146

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def init():
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

setup_keys_on_login_node()
setup_keys_on_login_node(cluster="mila")
setup_vscode_settings()
print_welcome_message()

Expand Down
152 changes: 73 additions & 79 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import difflib
import functools
import json
import shlex
import shutil
import subprocess
import sys
Expand All @@ -14,10 +15,11 @@

import questionary as qn
from invoke.exceptions import UnexpectedExit
from paramiko.config import SSHConfig as SSHConfigReader

from milatools.utils.remote_v2 import SSH_CONFIG_FILE
from milatools.utils.local_v2 import LocalV2

from ..utils.local_v1 import LocalV1, check_passwordless, display
from ..utils.local_v1 import check_passwordless, display
from ..utils.remote_v1 import RemoteV1
from ..utils.vscode_utils import (
get_expected_vscode_settings_json_path,
Expand Down Expand Up @@ -239,101 +241,90 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
_copy_if_needed(linux_key_file, windows_key_file)


def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
"""Sets up passwordless ssh access to the Mila and optionally also to DRAC.
def get_identityfile_from_ssh_config(
ssh_config: SSHConfig, hostname: str
) -> Path | None:
ssh_config_reader = SSHConfigReader.from_path(ssh_config.path)
private_key_path = ssh_config_reader.lookup(hostname).get("identityfile")
if private_key_path is None:
return None
# Seems to be a list for some reason?
if isinstance(private_key_path, list):
assert private_key_path
private_key_path = private_key_path[0]
return Path(private_key_path)


def setup_passwordless_ssh_access(
ssh_config: SSHConfig,
clusters: list[str] | tuple[str, ...] = ("mila", *DRAC_CLUSTERS),
) -> bool:
"""Sets up passwordless ssh access to the given clusters.

Sets up ssh connection to the DRAC clusters if they are present in the SSH config
file.

Returns whether the operation completed successfully or not.
"""
print("Checking passwordless authentication")
clusters = list(clusters)
if not clusters:
print("No clusters to setup.")
return True

here = LocalV1()
sshdir = Path.home() / ".ssh"

# Check if there is a public key file in ~/.ssh
if not list(sshdir.glob("id*.pub")):
if yn("You have no public keys. Generate one?"):
# Run ssh-keygen with the given location and no passphrase.
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
create_ssh_keypair(ssh_private_key_path, here)
else:
print("No public keys.")
return False

# TODO: This uses the public key set in the SSH config file, which may (or may not)
# be the random id*.pub file that was just checked for above.
success = setup_passwordless_ssh_access_to_cluster("mila")

if not success:
return False
setup_keys_on_login_node("mila")
printed_drac_warning = False

drac_clusters_in_ssh_config: list[str] = []
hosts_in_config = ssh_config.hosts()
for cluster in DRAC_CLUSTERS:
if any(cluster in hostname for hostname in hosts_in_config):
drac_clusters_in_ssh_config.append(cluster)
for cluster in clusters:
private_key_path = get_identityfile_from_ssh_config(ssh_config, cluster)
if private_key_path is None:
# todo: if the cluster doesn't have an `IdentityFile` set in the config,
# should we set a `IdentityFile` based on the cluster name? Or use the
# default key?
# For now, we just create the default ~/.ssh/id_rsa key if needed.
private_key_path = Path.home() / ".ssh" / "id_rsa"

if not drac_clusters_in_ssh_config:
logger.debug(
f"There are no DRAC clusters in the SSH config at {ssh_config.path}."
)
return True
if not private_key_path.exists():
# Run ssh-keygen with the given location and no passphrase.
print(
f"You don't have an SSH key for the {cluster!r} cluster. "
f"Generating one at {private_key_path}."
)
create_ssh_keypair(private_key_path)

if cluster in DRAC_CLUSTERS and not printed_drac_warning:
print(
"Setting up passwordless ssh access to the DRAC clusters with ssh-copy-id.\n"
"\n"
"Please note that you can also setup passwordless SSH access to all the DRAC "
"clusters by visiting https://ccdb.alliancecan.ca/ssh_authorized_keys and "
"copying in the content of your public key in the box.\n"
"See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info."
)
printed_drac_warning = True
success = run_ssh_copy_id(cluster, private_key_path)

print(
"Setting up passwordless ssh access to the DRAC clusters with ssh-copy-id.\n"
"\n"
"Please note that you can also setup passwordless SSH access to all the DRAC "
"clusters by visiting https://ccdb.alliancecan.ca/ssh_authorized_keys and "
"copying in the content of your public key in the box.\n"
"See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info."
)
for drac_cluster in drac_clusters_in_ssh_config:
success = setup_passwordless_ssh_access_to_cluster(drac_cluster)
setup_keys_on_login_node(cluster)
if not success:
return False
setup_keys_on_login_node(drac_cluster)

return True


def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
def run_ssh_copy_id(cluster: str, identity_file: Path) -> bool:
"""Sets up passwordless SSH access to the given hostname.

On Mac/Linux, uses `ssh-copy-id`. Performs the steps of ssh-copy-id manually on
Windows.

Returns whether the operation completed successfully or not.
"""
here = LocalV1()
# Check that it is possible to connect without using a password.
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
# the default.

from paramiko.config import SSHConfig

config = SSHConfig.from_path(str(SSH_CONFIG_FILE))
identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa")
# Seems to be a list for some reason?
if isinstance(identity_file, list):
assert identity_file
identity_file = identity_file[0]
ssh_private_key_path = Path(identity_file).expanduser()
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
assert ssh_public_key_path.exists()

# TODO: This will fail on Windows for clusters with 2FA.
# if check_passwordless(cluster):
# logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
# return True
# if not yn(
# f"Your public key does not appear be registered on the {cluster} cluster. "
# "Register it?"
# ):
# print("No passwordless login.")
# return False
print("Please enter your password if prompted.")
if sys.platform == "win32":
# NOTE: This is to remove extra '^M' characters that would be added at the end
Expand All @@ -356,14 +347,15 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
f.seek(0)
subprocess.run(command, check=True, text=False, stdin=f)
else:
here.run(
"ssh-copy-id",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
check=True,
LocalV2.run(
(
"ssh-copy-id",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
),
)

# double-check that this worked.
Expand All @@ -373,6 +365,10 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
return True


def run_in_bash(cmd: str) -> str:
return shlex.join(["bash", "-c", cmd])


def setup_keys_on_login_node(cluster: str = "mila"):
#####################################
# Step 3: Set up keys on login node #
Expand All @@ -396,8 +392,8 @@ def setup_keys_on_login_node(cluster: str = "mila"):
else:
exit("Cannot proceed because there is no public key")

common = remote.with_bash().get_output(
"comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)"
common = remote.get_output(
run_in_bash("comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)")
)
if common:
print("# OK")
Expand Down Expand Up @@ -465,7 +461,6 @@ def get_windows_home_path_in_wsl() -> Path:

def create_ssh_keypair(
ssh_private_key_path: Path,
local: LocalV1 | None = None,
passphrase: str | None = "",
) -> None:
"""Creates a public/private key pair at the given path using ssh-keygen.
Expand All @@ -474,7 +469,6 @@ def create_ssh_keypair(
Otherwise, if passphrase is an empty string, no passphrase will be used (default).
If a string is passed, it is passed to ssh-keygen and used as the passphrase.
"""
local = local or LocalV1()
command = [
"ssh-keygen",
"-f",
Expand Down
20 changes: 11 additions & 9 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
create_ssh_keypair,
get_windows_home_path_in_wsl,
has_passphrase,
run_ssh_copy_id,
setup_keys_on_login_node,
setup_passwordless_ssh_access,
setup_passwordless_ssh_access_to_cluster,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
Expand All @@ -42,7 +42,7 @@
SSHConfig,
running_inside_WSL,
)
from milatools.utils.local_v1 import LocalV1, check_passwordless
from milatools.utils.local_v1 import check_passwordless
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import (
SSH_CACHE_DIR,
Expand Down Expand Up @@ -1473,7 +1473,7 @@ def _mock_subprocess_run(command: tuple[str], *args, **kwargs):
return subprocess_run(command, *args, **kwargs)

mock_subprocess_run = mocker.patch("subprocess.run", wraps=_mock_subprocess_run)
success = setup_passwordless_ssh_access_to_cluster(cluster)
success = run_ssh_copy_id(cluster)
if passwordless_ssh_was_previously_setup:
# We already had access to the cluster.
assert success is True
Expand Down Expand Up @@ -1501,7 +1501,7 @@ def _mock_subprocess_run(command: tuple[str], *args, **kwargs):
]
regression_text = "\n".join(
[
f"Calling {function_call_string(setup_passwordless_ssh_access_to_cluster, cluster)}",
f"Calling {function_call_string(run_ssh_copy_id, cluster)}",
]
+ [
f"with passwordless SSH access to {cluster} already setup"
Expand Down Expand Up @@ -1582,9 +1582,7 @@ def test_setup_passwordless_ssh_access(
else:
# There should be an ssh key in the .ssh dir.
# Won't ask to generate a key.
create_ssh_keypair(
ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=LocalV1()
)
create_ssh_keypair(ssh_private_key_path=ssh_dir / "id_rsa_milatools")
if drac_clusters_in_ssh_config:
# We should get a prompt asking if we want to register the public key
# on the DRAC clusters or not.
Expand All @@ -1609,14 +1607,14 @@ def test_setup_passwordless_ssh_access(
# It's okay because we have a good test for it above. Therefore we just test how it
# gets called here.
mock_setup_passwordless_ssh_access_to_cluster = Mock(
spec=setup_passwordless_ssh_access_to_cluster,
spec=run_ssh_copy_id,
side_effect=[accept_mila, *(accept_drac for _ in drac_clusters_in_ssh_config)],
)
import milatools.cli.init_command

monkeypatch.setattr(
milatools.cli.init_command,
setup_passwordless_ssh_access_to_cluster.__name__,
run_ssh_copy_id.__name__,
mock_setup_passwordless_ssh_access_to_cluster,
)

Expand Down Expand Up @@ -1666,3 +1664,7 @@ def test_setup_passwordless_ssh_access(
for drac_cluster in drac_clusters_in_ssh_config:
mock_setup_passwordless_ssh_access_to_cluster.assert_any_call(drac_cluster)
assert result is True


def test_inaccessible_cluster_is_skipped_in_mila_init():
...
Loading