diff --git a/libvirt/uri/ssh.go b/libvirt/uri/ssh.go index 6c617ceff..54135589a 100644 --- a/libvirt/uri/ssh.go +++ b/libvirt/uri/ssh.go @@ -5,7 +5,6 @@ import ( "log" "net" "os" - "os/user" "path/filepath" "strings" @@ -39,19 +38,22 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi // 2. load override as specified in ssh config // 3. load default ssh keyfile path sshKeyPaths := []string{} + sshKeyPath := q.Get("keyfile") if sshKeyPath != "" { sshKeyPaths = append(sshKeyPaths, sshKeyPath) } - keyPaths, err := sshcfg.GetAll(target, "IdentityFile") - if err != nil { - log.Printf("[WARN] unable to get IdentityFile values - ignoring") - } else { - sshKeyPaths = append(sshKeyPaths, keyPaths...) + if sshcfg != nil { + keyPaths, err := sshcfg.GetAll(target, "IdentityFile") + if err != nil { + log.Printf("[WARN] unable to get IdentityFile values - ignoring") + } else { + sshKeyPaths = append(sshKeyPaths, keyPaths...) + } } - if len(keyPaths) == 0 { + if len(sshKeyPaths) == 0 { log.Printf("[DEBUG] found no ssh keys, using default keypath") sshKeyPaths = []string{defaultSSHKeyPath} } @@ -116,14 +118,17 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi // construct the whole ssh connection, which can consist of multiple hops if using proxy jumps, // the ssh configuration file is loaded once and passed along to each host connection. func (u *ConnectionURI) dialSSH() (net.Conn, error) { + var sshcfg* ssh_config.Config = nil + sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile)) if err != nil { log.Printf("[WARN] Failed to open ssh config file: %v", err) - } + } else { + sshcfg, err = ssh_config.Decode(sshConfigFile) + if err != nil { + log.Printf("[WARN] Failed to parse ssh config file: '%v' - sshconfig will be ignored.", err) + } - sshcfg, err := ssh_config.Decode(sshConfigFile) - if err != nil { - log.Printf("[WARN] Failed to parse ssh config file: %v", err) } // configuration loaded, build tunnel @@ -164,11 +169,11 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port) } - hostName, err := sshcfg.Get(target, "HostName") - if err == nil { - if hostName == "" { - hostName = target - } else { + hostName := target + if sshcfg != nil { + host, err := sshcfg.Get(target, "HostName") + if err == nil && host != "" { + hostName = host log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName) } } @@ -182,18 +187,22 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth if knownHostsVerify == "ignore" { skipVerify = true } else { - strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking") - if err != nil && strictCheck == "yes" { - skipVerify = false + if sshcfg != nil { + strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking") + if err != nil && strictCheck == "yes" { + skipVerify = false + } } } if knownHostsPath == "" { - knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile") - if err == nil && knownHosts != "" { - knownHostsPath = knownHosts - } else { - knownHostsPath = defaultSSHKnownHostsPath + knownHostsPath = defaultSSHKnownHostsPath + + if sshcfg != nil { + knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile") + if err == nil && knownHosts != "" { + knownHostsPath = knownHosts + } } } @@ -226,10 +235,12 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth return err } - keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms") - if err == nil && keyAlgs != "" { - log.Printf("Got host key algorithms '%s'", keyAlgs) - hostKeyAlgorithms = strings.Split(keyAlgs, ",") + if sshcfg != nil { + keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms") + if err == nil && keyAlgs != "" { + log.Printf("[DEBUG] HostKeyAlgorithms is overridden to '%s'", keyAlgs) + hostKeyAlgorithms = strings.Split(keyAlgs, ",") + } } } @@ -240,38 +251,39 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth HostKeyAlgorithms: hostKeyAlgorithms, Timeout: dialTimeout, } + var bastion *ssh.Client = nil + var bastion_proxy string = "" - proxy, err := sshcfg.Get(target, "ProxyCommand") - if err == nil && proxy != "" { - log.Printf("[WARNING] unsupported ssh ProxyCommand '%v'", proxy) + if sshcfg != nil { + command, err := sshcfg.Get(target, "ProxyCommand") + if err == nil && command != "" { + log.Printf("[WARNING] unsupported ssh ProxyCommand '%v' - ignoring", command) + } } - proxy, err = sshcfg.Get(target, "ProxyJump") - var bastion *ssh.Client - if err == nil && proxy != "" { - log.Printf("[DEBUG] found ProxyJump '%v'", proxy) - - // this is a proxy jump: we recurse into that proxy - bastion, err = u.dialHost(proxy, sshcfg, depth+1) - if err != nil { - return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err) + if sshcfg != nil { + proxy, err := sshcfg.Get(target, "ProxyJump") + if err == nil && proxy != "" { + log.Printf("[DEBUG] found ProxyJump '%v'", proxy) + // this is a proxy jump: we recurse into that proxy + bastion, err = u.dialHost(proxy, sshcfg, depth+1) + bastion_proxy = proxy + if err != nil { + return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err) + } } } - if cfg.User == "" { + // cfg.User value defaults to u.User.Username() + if sshcfg != nil { sshu, err := sshcfg.Get(target, "User") - log.Printf("[DEBUG] SSH User for target '%v' is '%v'", target, sshu) if err != nil { - log.Printf("[DEBUG] ssh user: using current login") - u, err := user.Current() - if err != nil { - return nil, fmt.Errorf("unable to get username: %w", err) - } - sshu = u.Username + log.Printf("[DEBUG] ssh user for target '%v' is overridden to '%v'", target, sshu) + cfg.User = sshu } - cfg.User = sshu } + cfg.Auth = u.parseAuthMethods(target, sshcfg) if len(cfg.Auth) < 1 { return nil, fmt.Errorf("could not configure SSH authentication methods") @@ -279,7 +291,7 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth if bastion != nil { // if this is a proxied connection, we want to dial through the bastion host - log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, proxy) + log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, bastion_proxy) // Dial a connection to the service host, from the bastion conn, err := bastion.Dial("tcp", net.JoinHostPort(hostName, port)) if err != nil {