From 508b5390e3d9e2d68887c1d8360dfd07fcacee94 Mon Sep 17 00:00:00 2001 From: Maciej Lisiewski Date: Thu, 19 May 2016 10:27:18 -0400 Subject: [PATCH] Basic support for hashbangs --- ssh.go | 70 +++++++++++++++++++++++++++++++++++++++++++++++++++------ sup.go | 6 +++-- task.go | 34 +++++++++++++++++++++++----- 3 files changed, 95 insertions(+), 15 deletions(-) diff --git a/ssh.go b/ssh.go index 106c415..d2cad3f 100644 --- a/ssh.go +++ b/ssh.go @@ -7,6 +7,7 @@ import ( "net" "os" "os/user" + "path/filepath" "strings" "sync" @@ -28,6 +29,7 @@ type SSHClient struct { running bool env string //export FOO="bar"; export BAR="baz"; color string + shell string } type ErrConnect struct { @@ -149,8 +151,8 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error { return nil } -// Run runs the task.Run command remotely on c.host. -func (c *SSHClient) Run(task *Task) error { +// NewSession creates a new session within a SSH connection and connects IN/OUT pipes +func (c *SSHClient) NewSession() error { if c.running { return fmt.Errorf("Session already running") } @@ -181,8 +183,18 @@ func (c *SSHClient) Run(task *Task) error { c.sess = sess c.sessOpened = true + return nil +} + +// Run runs the task.Run command remotely on c.host. +func (c *SSHClient) Run(task *Task) error { + // Start a new session + if err := c.NewSession(); err != nil { + return err + } + // Start the remote command. - if err := c.sess.Start(c.env + "set -x;" + task.Run); err != nil { + if err := c.sess.Start(c.buildComand(task)); err != nil { return ErrTask{task, err.Error()} } @@ -206,16 +218,16 @@ func (c *SSHClient) Wait() error { } // DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer. -func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - conn, err := sc.conn.Dial(net, addr) +func (c *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + conn, err := c.conn.Dial(net, addr) if err != nil { return nil, err } - c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + cl, chans, reqs, err := ssh.NewClientConn(conn, addr, config) if err != nil { return nil, err } - return ssh.NewClient(c, chans, reqs), nil + return ssh.NewClient(cl, chans, reqs), nil } @@ -260,3 +272,47 @@ func (c *SSHClient) Write(p []byte) (n int, err error) { func (c *SSHClient) WriteClose() error { return c.remoteStdin.Close() } + +func (c *SSHClient) getShell() error { + // Start a new session + if err := c.NewSession(); err != nil { + return err + } + + if err := c.sess.Start(`echo "$SHELL"`); err != nil { + return err + } + c.running = true + + resp, err := ioutil.ReadAll(c.Stdout()) + if err != nil && err != io.EOF { + return err + } + if err := c.Wait(); err != nil { + return err + } + + c.shell = string(resp) + + return nil +} + +func (c *SSHClient) buildComand(task *Task) string { + inter := strings.TrimSpace(task.Interpreter) + if inter == "" { + inter = strings.TrimSpace(c.shell) + } + _, i := filepath.Split(inter) + + switch inter { + case "bash", "sh", "zsh", "ksh", "tcsh": + return fmt.Sprintf("/usr/bin/env %s %s -c \"set -x\n%s\"", c.env, i, strings.Replace(task.Run, `"`, `\"`, -1)) + + // TODO: add support for python already called via env + case "python", "python3": + return fmt.Sprintf("/usr/bin/env %s %s -c \"%s\"", c.env, i, strings.Replace(task.Run, `"`, `\"`, -1)) + + default: + return fmt.Sprintf("/usr/bin/env %s bash -c \"set -x\n%s\"", c.env, strings.Replace(task.Run, `"`, `\"`, -1)) + } +} diff --git a/sup.go b/sup.go index 678a278..6ee7ed0 100644 --- a/sup.go +++ b/sup.go @@ -36,7 +36,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error { // `export FOO="bar"; export BAR="baz";`. env := `` for _, v := range append(sup.conf.Env, network.Env...) { - env += v.AsExport() + " " + env += v.String() + " " } // Create clients for every host (either SSH or Localhost). @@ -67,7 +67,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error { // SSH client. remote := &SSHClient{ - env: env + `export SUP_HOST="` + host + `";`, + env: env + `SUP_HOST="` + host + `"`, color: Colors[i%len(Colors)], } @@ -81,6 +81,8 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error { } } defer remote.Close() + remote.getShell() + clients = append(clients, remote) } diff --git a/task.go b/task.go index ec055c0..8b434ff 100644 --- a/task.go +++ b/task.go @@ -1,17 +1,18 @@ package sup import ( + "bufio" "fmt" "io" - "io/ioutil" "os" ) // Task represents a set of commands to be run. type Task struct { - Run string - Input io.Reader - Clients []Client + Interpreter string + Run string + Input io.Reader + Clients []Client } func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) { @@ -50,13 +51,34 @@ func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) { if err != nil { return nil, err } - data, err := ioutil.ReadAll(f) + rd := bufio.NewReader(f) + fLine, _, err := rd.ReadLine() if err != nil { return nil, err } + var hashbang, data []byte + if len(fLine) > 2 && fLine[0] == '#' && fLine[1] == '!' { + hashbang = fLine[2:len(fLine)] + } else { + rd.Reset(f) + } + for { + line, err := rd.ReadSlice('\n') + if len(line) > 0 { + data = append(data, line...) + } + + if err != nil { + if err != io.EOF { + return nil, err + } + break + } + } task := Task{ - Run: string(data), + Interpreter: string(hashbang), + Run: string(data), } if cmd.Stdin { task.Input = os.Stdin