Skip to content

Commit

Permalink
Add ssh tunneling support
Browse files Browse the repository at this point in the history
  • Loading branch information
saolof committed May 29, 2024
1 parent 6e5e6ab commit 791ba66
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 11 deletions.
20 changes: 10 additions & 10 deletions configparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@ type DatabaseConfig struct {
ConnString string
PasswordVar string
JustUsePgPass bool
Sshconfig SSHConnConfig
}

func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]string, error) {
func DecodeDatabases(crontab io.Reader, usepgpass bool) (map[string]DatabaseConfig, error) {
var configs map[string]DatabaseConfig
decoder := toml.NewDecoder(crontab)
err := decoder.Decode(&configs)
if err != nil {
return nil, err
}
databases := map[string]string{}
databases := map[string]DatabaseConfig{}
for key, config := range configs {
if config.ConnString == "" {
return nil, fmt.Errorf("Missing connstring in database %s", key)
}
if usepgpass || config.JustUsePgPass {
databases[key] = strings.Replace(config.ConnString, ":$password", "", 1)
} else if config.PasswordVar == "" {
databases[key] = config.ConnString
} else {
config.ConnString = strings.Replace(config.ConnString, ":$password", "", 1)
} else if config.PasswordVar != "" {
password := os.Getenv(config.PasswordVar)
if password == "" {
return nil, fmt.Errorf("Injected passwordvar %s is empty!", config.PasswordVar)
}
databases[key] = strings.Replace(config.ConnString, "$password", password, 1)
config.ConnString = strings.Replace(config.ConnString, "$password", password, 1)
}
databases[key] = config
}
return databases, nil
}
Expand All @@ -65,18 +65,18 @@ func DecodeJobs(crontab io.Reader) (jobconfigs map[string]JobConfig, err error)
return jobconfigs, nil
}

func CreateJobs(configs map[string]JobConfig, databases map[string]string, monitor Monitor) ([]Job, error) {
func CreateJobs(configs map[string]JobConfig, databases map[string]DatabaseConfig, monitor Monitor) ([]Job, error) {
jobs := []Job{}
for name, config := range configs {
schedule, err := cron.ParseStandard(config.CronSchedule)
if err != nil {
return nil, fmt.Errorf("Cron schedule error: %w", err)
}
connstr, ok := databases[config.Database]
dbconfig, ok := databases[config.Database]
if !ok {
return nil, fmt.Errorf("Missing Db: The database %s specified by job %s does not seem to exist!", config.Database, name)
}
job, err := CreateJob(name, config.Database, schedule, connstr, config.Query, config.JobMiscOptions, monitor)
job, err := CreateJob(name, config.Database, schedule, dbconfig.ConnString, config.Query, dbconfig.Sshconfig, config.JobMiscOptions, monitor)
if err != nil {
return nil, err
}
Expand Down
13 changes: 12 additions & 1 deletion pgxjob.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package main

import (
"context"
"fmt"
"log"
"net"
"slices"
"strings"
"time"
Expand All @@ -26,7 +28,7 @@ type Job struct {
valid bool
}

func CreateJob(jobname, dbname string, s Schedule, target, query string, misc JobMiscOptions, monitor Monitor) (j Job, err error) {
func CreateJob(jobname, dbname string, s Schedule, target, query string, ssh SSHConnConfig, misc JobMiscOptions, monitor Monitor) (j Job, err error) {
if jobname == "" || dbname == "" || s == nil {
return j, fmt.Errorf("Received nil input(s) when creating %s", jobname)
}
Expand All @@ -47,6 +49,15 @@ func CreateJob(jobname, dbname string, s Schedule, target, query string, misc Jo
if config.ConnectTimeout == time.Duration(0) { // Default to 50 seconds if no finite timeout is provided
config.ConnectTimeout = 50 * time.Second // via the standard pgx & psql PGCONNECT_TIMEOUT env var
}
if ssh.Host != "" {
client, err := NewSSHClient(&ssh)
if err != nil {
return j, err
}
config.DialFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) {
return client.Dial(network, addr)
}
}

return Job{
JobName: jobname,
Expand Down
58 changes: 58 additions & 0 deletions ssh_tunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
"net"
"os"

"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)

type SSHConnConfig struct {
Host string
Port string
User string
Knownhosts string
Keyfile string
}

func NewSSHClient(config *SSHConnConfig) (*ssh.Client, error) {
sshConfig := &ssh.ClientConfig{
User: config.User,
}

if auth := SSHAgent(); auth != nil {
sshConfig.Auth = append(sshConfig.Auth, auth)
}

if hostKeyCallback, err := knownhosts.New(config.Knownhosts); err == nil {
sshConfig.HostKeyCallback = hostKeyCallback
}
if config.Keyfile != "" {
if auth := PrivateKey(config.Keyfile); auth != nil {
sshConfig.Auth = append(sshConfig.Auth, auth)
}
}

return ssh.Dial("tcp", net.JoinHostPort(config.Host, config.Port), sshConfig)
}

func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}

func PrivateKey(path string) ssh.AuthMethod {
key, err := os.ReadFile(path)
if err != nil {
return nil
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil
}
return ssh.PublicKeys(signer)
}

0 comments on commit 791ba66

Please sign in to comment.