Skip to content

Commit

Permalink
Sshconfig missing bugfix - addresses issue #1105 (#1109)
Browse files Browse the repository at this point in the history
* update log output to indicate intentions moving forward

* if we have failed to read the ssh_config file, ignore it entirely

* add defensive code path that dies if for some reason even this fails

* added this to prevent double warnings in case of missing file

* allow for sshcfg to be nil, which implies no ssh_config found

* update keypath configuration to allow for nil sshcfg

* update HostName configuration to allow for nil sshcfg

* update StrictHostKeyChecking to allow for nil sshcfg

* update UserKnownHostsFile to allow for nil sshcfg

* update HostKeyAlgorithms to allow for nil sshcfg

* remove ProxyCommand from main codepath - while still warning of use

* update ProxyJump config to allow for nil sshcfg

* update user config to allow for nil sshcfg while also simplifying

the code would previously get the current user (current execution's context)
whereas the cfg.User value already takes the ConnectionURI's set username which
should be sufficient as a default value check.

* fix spelling mistake

* why commit once when twice would suffice
  • Loading branch information
memetb authored Oct 19, 2024
1 parent b56a61c commit 0263f35
Showing 1 changed file with 62 additions and 50 deletions.
112 changes: 62 additions & 50 deletions libvirt/uri/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"log"
"net"
"os"
"os/user"
"path/filepath"
"strings"

Expand Down Expand Up @@ -39,19 +38,22 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi
// 2. load override as specified in ssh config
// 3. load default ssh keyfile path
sshKeyPaths := []string{}

sshKeyPath := q.Get("keyfile")
if sshKeyPath != "" {
sshKeyPaths = append(sshKeyPaths, sshKeyPath)
}

keyPaths, err := sshcfg.GetAll(target, "IdentityFile")
if err != nil {
log.Printf("[WARN] unable to get IdentityFile values - ignoring")
} else {
sshKeyPaths = append(sshKeyPaths, keyPaths...)
if sshcfg != nil {
keyPaths, err := sshcfg.GetAll(target, "IdentityFile")
if err != nil {
log.Printf("[WARN] unable to get IdentityFile values - ignoring")
} else {
sshKeyPaths = append(sshKeyPaths, keyPaths...)
}
}

if len(keyPaths) == 0 {
if len(sshKeyPaths) == 0 {
log.Printf("[DEBUG] found no ssh keys, using default keypath")
sshKeyPaths = []string{defaultSSHKeyPath}
}
Expand Down Expand Up @@ -116,14 +118,17 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi
// construct the whole ssh connection, which can consist of multiple hops if using proxy jumps,
// the ssh configuration file is loaded once and passed along to each host connection.
func (u *ConnectionURI) dialSSH() (net.Conn, error) {
var sshcfg* ssh_config.Config = nil

sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile))
if err != nil {
log.Printf("[WARN] Failed to open ssh config file: %v", err)
}
} else {
sshcfg, err = ssh_config.Decode(sshConfigFile)
if err != nil {
log.Printf("[WARN] Failed to parse ssh config file: '%v' - sshconfig will be ignored.", err)
}

sshcfg, err := ssh_config.Decode(sshConfigFile)
if err != nil {
log.Printf("[WARN] Failed to parse ssh config file: %v", err)
}

// configuration loaded, build tunnel
Expand Down Expand Up @@ -164,11 +169,11 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port)
}

hostName, err := sshcfg.Get(target, "HostName")
if err == nil {
if hostName == "" {
hostName = target
} else {
hostName := target
if sshcfg != nil {
host, err := sshcfg.Get(target, "HostName")
if err == nil && host != "" {
hostName = host
log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName)
}
}
Expand All @@ -182,18 +187,22 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
if knownHostsVerify == "ignore" {
skipVerify = true
} else {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
if sshcfg != nil {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
}
}
}

if knownHostsPath == "" {
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
} else {
knownHostsPath = defaultSSHKnownHostsPath
knownHostsPath = defaultSSHKnownHostsPath

if sshcfg != nil {
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
}
}
}

Expand Down Expand Up @@ -226,10 +235,12 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
return err
}

keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("Got host key algorithms '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
if sshcfg != nil {
keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("[DEBUG] HostKeyAlgorithms is overridden to '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
}
}

}
Expand All @@ -240,46 +251,47 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
HostKeyAlgorithms: hostKeyAlgorithms,
Timeout: dialTimeout,
}
var bastion *ssh.Client = nil
var bastion_proxy string = ""

proxy, err := sshcfg.Get(target, "ProxyCommand")
if err == nil && proxy != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v'", proxy)
if sshcfg != nil {
command, err := sshcfg.Get(target, "ProxyCommand")
if err == nil && command != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v' - ignoring", command)
}
}

proxy, err = sshcfg.Get(target, "ProxyJump")
var bastion *ssh.Client
if err == nil && proxy != "" {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)

// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth+1)
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
if sshcfg != nil {
proxy, err := sshcfg.Get(target, "ProxyJump")
if err == nil && proxy != "" {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)
// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth+1)
bastion_proxy = proxy
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
}
}
}

if cfg.User == "" {
// cfg.User value defaults to u.User.Username()
if sshcfg != nil {
sshu, err := sshcfg.Get(target, "User")
log.Printf("[DEBUG] SSH User for target '%v' is '%v'", target, sshu)
if err != nil {
log.Printf("[DEBUG] ssh user: using current login")
u, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to get username: %w", err)
}
sshu = u.Username
log.Printf("[DEBUG] ssh user for target '%v' is overridden to '%v'", target, sshu)
cfg.User = sshu
}
cfg.User = sshu
}


cfg.Auth = u.parseAuthMethods(target, sshcfg)
if len(cfg.Auth) < 1 {
return nil, fmt.Errorf("could not configure SSH authentication methods")
}

if bastion != nil {
// if this is a proxied connection, we want to dial through the bastion host
log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, proxy)
log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, bastion_proxy)
// Dial a connection to the service host, from the bastion
conn, err := bastion.Dial("tcp", net.JoinHostPort(hostName, port))
if err != nil {
Expand Down

0 comments on commit 0263f35

Please sign in to comment.