From 791ba66ba3ba9944278705e7153dbbca3504ef0b Mon Sep 17 00:00:00 2001 From: Olof Salberger Date: Wed, 29 May 2024 11:12:41 +0200 Subject: [PATCH] Add ssh tunneling support --- configparser.go | 20 ++++++++--------- pgxjob.go | 13 ++++++++++- ssh_tunnel.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 11 deletions(-) create mode 100644 ssh_tunnel.go diff --git a/configparser.go b/configparser.go index 613a2b4..2be7398 100644 --- a/configparser.go +++ b/configparser.go @@ -14,31 +14,31 @@ type DatabaseConfig struct { ConnString string PasswordVar string JustUsePgPass bool + Sshconfig SSHConnConfig } -func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]string, error) { +func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]DatabaseConfig, error) { var configs map[string]DatabaseConfig decoder := toml.NewDecoder(crontab) err := decoder.Decode(&configs) if err != nil { return nil, err } - databases := map[string]string{} + databases := map[string]DatabaseConfig{} for key, config := range configs { if config.ConnString == "" { return nil, fmt.Errorf("Missing connstring in database %s", key) } if usepgpass || config.JustUsePgPass { - databases[key] = strings.Replace(config.ConnString, ":$password", "", 1) - } else if config.PasswordVar == "" { - databases[key] = config.ConnString - } else { + config.ConnString = strings.Replace(config.ConnString, ":$password", "", 1) + } else if config.PasswordVar != "" { password := os.Getenv(config.PasswordVar) if password == "" { return nil, fmt.Errorf("Injected passwordvar %s is empty!", config.PasswordVar) } - databases[key] = strings.Replace(config.ConnString, "$password", password, 1) + config.ConnString = strings.Replace(config.ConnString, "$password", password, 1) } + databases[key] = config } return databases, nil } @@ -65,18 +65,18 @@ func DecodeJobs(crontab io.Reader) (jobconfigs map[string]JobConfig, err error) return jobconfigs, nil } -func CreateJobs(configs map[string]JobConfig, databases map[string]string, monitor Monitor) ([]Job, error) { +func CreateJobs(configs map[string]JobConfig, databases map[string]DatabaseConfig, monitor Monitor) ([]Job, error) { jobs := []Job{} for name, config := range configs { schedule, err := cron.ParseStandard(config.CronSchedule) if err != nil { return nil, fmt.Errorf("Cron schedule error: %w", err) } - connstr, ok := databases[config.Database] + dbconfig, ok := databases[config.Database] if !ok { return nil, fmt.Errorf("Missing Db: The database %s specified by job %s does not seem to exist!", config.Database, name) } - job, err := CreateJob(name, config.Database, schedule, connstr, config.Query, config.JobMiscOptions, monitor) + job, err := CreateJob(name, config.Database, schedule, dbconfig.ConnString, config.Query, dbconfig.Sshconfig, config.JobMiscOptions, monitor) if err != nil { return nil, err } diff --git a/pgxjob.go b/pgxjob.go index 941a6bb..5d33c67 100644 --- a/pgxjob.go +++ b/pgxjob.go @@ -1,8 +1,10 @@ package main import ( + "context" "fmt" "log" + "net" "slices" "strings" "time" @@ -26,7 +28,7 @@ type Job struct { valid bool } -func CreateJob(jobname, dbname string, s Schedule, target, query string, misc JobMiscOptions, monitor Monitor) (j Job, err error) { +func CreateJob(jobname, dbname string, s Schedule, target, query string, ssh SSHConnConfig, misc JobMiscOptions, monitor Monitor) (j Job, err error) { if jobname == "" || dbname == "" || s == nil { return j, fmt.Errorf("Received nil input(s) when creating %s", jobname) } @@ -47,6 +49,15 @@ func CreateJob(jobname, dbname string, s Schedule, target, query string, misc Jo if config.ConnectTimeout == time.Duration(0) { // Default to 50 seconds if no finite timeout is provided config.ConnectTimeout = 50 * time.Second // via the standard pgx & psql PGCONNECT_TIMEOUT env var } + if ssh.Host != "" { + client, err := NewSSHClient(&ssh) + if err != nil { + return j, err + } + config.DialFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) { + return client.Dial(network, addr) + } + } return Job{ JobName: jobname, diff --git a/ssh_tunnel.go b/ssh_tunnel.go new file mode 100644 index 0000000..87708d1 --- /dev/null +++ b/ssh_tunnel.go @@ -0,0 +1,58 @@ +package main + +import ( + "net" + "os" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/knownhosts" +) + +type SSHConnConfig struct { + Host string + Port string + User string + Knownhosts string + Keyfile string +} + +func NewSSHClient(config *SSHConnConfig) (*ssh.Client, error) { + sshConfig := &ssh.ClientConfig{ + User: config.User, + } + + if auth := SSHAgent(); auth != nil { + sshConfig.Auth = append(sshConfig.Auth, auth) + } + + if hostKeyCallback, err := knownhosts.New(config.Knownhosts); err == nil { + sshConfig.HostKeyCallback = hostKeyCallback + } + if config.Keyfile != "" { + if auth := PrivateKey(config.Keyfile); auth != nil { + sshConfig.Auth = append(sshConfig.Auth, auth) + } + } + + return ssh.Dial("tcp", net.JoinHostPort(config.Host, config.Port), sshConfig) +} + +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} + +func PrivateKey(path string) ssh.AuthMethod { + key, err := os.ReadFile(path) + if err != nil { + return nil + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil + } + return ssh.PublicKeys(signer) +}