Skip to content

Commit

Permalink
Basic support for hashbangs
Browse files Browse the repository at this point in the history
  • Loading branch information
c2h5oh committed May 19, 2016
1 parent bac1263 commit 508b539
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 15 deletions.
70 changes: 63 additions & 7 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"

Expand All @@ -28,6 +29,7 @@ type SSHClient struct {
running bool
env string //export FOO="bar"; export BAR="baz";
color string
shell string
}

type ErrConnect struct {
Expand Down Expand Up @@ -149,8 +151,8 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
return nil
}

// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
// NewSession creates a new session within a SSH connection and connects IN/OUT pipes
func (c *SSHClient) NewSession() error {
if c.running {
return fmt.Errorf("Session already running")
}
Expand Down Expand Up @@ -181,8 +183,18 @@ func (c *SSHClient) Run(task *Task) error {
c.sess = sess
c.sessOpened = true

return nil
}

// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
// Start a new session
if err := c.NewSession(); err != nil {
return err
}

// Start the remote command.
if err := c.sess.Start(c.env + "set -x;" + task.Run); err != nil {
if err := c.sess.Start(c.buildComand(task)); err != nil {
return ErrTask{task, err.Error()}
}

Expand All @@ -206,16 +218,16 @@ func (c *SSHClient) Wait() error {
}

// DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer.
func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := sc.conn.Dial(net, addr)
func (c *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := c.conn.Dial(net, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
cl, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
return ssh.NewClient(cl, chans, reqs), nil

}

Expand Down Expand Up @@ -260,3 +272,47 @@ func (c *SSHClient) Write(p []byte) (n int, err error) {
func (c *SSHClient) WriteClose() error {
return c.remoteStdin.Close()
}

func (c *SSHClient) getShell() error {
// Start a new session
if err := c.NewSession(); err != nil {
return err
}

if err := c.sess.Start(`echo "$SHELL"`); err != nil {
return err
}
c.running = true

resp, err := ioutil.ReadAll(c.Stdout())
if err != nil && err != io.EOF {
return err
}
if err := c.Wait(); err != nil {
return err
}

c.shell = string(resp)

return nil
}

func (c *SSHClient) buildComand(task *Task) string {
inter := strings.TrimSpace(task.Interpreter)
if inter == "" {
inter = strings.TrimSpace(c.shell)
}
_, i := filepath.Split(inter)

switch inter {
case "bash", "sh", "zsh", "ksh", "tcsh":
return fmt.Sprintf("/usr/bin/env %s %s -c \"set -x\n%s\"", c.env, i, strings.Replace(task.Run, `"`, `\"`, -1))

// TODO: add support for python already called via env
case "python", "python3":
return fmt.Sprintf("/usr/bin/env %s %s -c \"%s\"", c.env, i, strings.Replace(task.Run, `"`, `\"`, -1))

default:
return fmt.Sprintf("/usr/bin/env %s bash -c \"set -x\n%s\"", c.env, strings.Replace(task.Run, `"`, `\"`, -1))
}
}
6 changes: 4 additions & 2 deletions sup.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {
// `export FOO="bar"; export BAR="baz";`.
env := ``
for _, v := range append(sup.conf.Env, network.Env...) {
env += v.AsExport() + " "
env += v.String() + " "
}

// Create clients for every host (either SSH or Localhost).
Expand Down Expand Up @@ -67,7 +67,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {

// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host + `";`,
env: env + `SUP_HOST="` + host + `"`,
color: Colors[i%len(Colors)],
}

Expand All @@ -81,6 +81,8 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {
}
}
defer remote.Close()
remote.getShell()

clients = append(clients, remote)
}

Expand Down
34 changes: 28 additions & 6 deletions task.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package sup

import (
"bufio"
"fmt"
"io"
"io/ioutil"
"os"
)

// Task represents a set of commands to be run.
type Task struct {
Run string
Input io.Reader
Clients []Client
Interpreter string
Run string
Input io.Reader
Clients []Client
}

func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) {
Expand Down Expand Up @@ -50,13 +51,34 @@ func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) {
if err != nil {
return nil, err
}
data, err := ioutil.ReadAll(f)
rd := bufio.NewReader(f)
fLine, _, err := rd.ReadLine()
if err != nil {
return nil, err
}
var hashbang, data []byte
if len(fLine) > 2 && fLine[0] == '#' && fLine[1] == '!' {
hashbang = fLine[2:len(fLine)]
} else {
rd.Reset(f)
}
for {
line, err := rd.ReadSlice('\n')
if len(line) > 0 {
data = append(data, line...)
}

if err != nil {
if err != io.EOF {
return nil, err
}
break
}
}

task := Task{
Run: string(data),
Interpreter: string(hashbang),
Run: string(data),
}
if cmd.Stdin {
task.Input = os.Stdin
Expand Down

0 comments on commit 508b539

Please sign in to comment.