Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Basic support for hashbangs #77

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"os"
"os/user"
"path/filepath"
"strings"
"sync"

Expand All @@ -28,6 +29,7 @@ type SSHClient struct {
running bool
env string //export FOO="bar"; export BAR="baz";
color string
shell string
}

type ErrConnect struct {
Expand Down Expand Up @@ -149,8 +151,8 @@ func (c *SSHClient) ConnectWith(host string, dialer SSHDialFunc) error {
return nil
}

// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
// NewSession creates a new session within a SSH connection and connects IN/OUT pipes
func (c *SSHClient) NewSession() error {
if c.running {
return fmt.Errorf("Session already running")
}
Expand Down Expand Up @@ -181,8 +183,18 @@ func (c *SSHClient) Run(task *Task) error {
c.sess = sess
c.sessOpened = true

return nil
}

// Run runs the task.Run command remotely on c.host.
func (c *SSHClient) Run(task *Task) error {
// Start a new session
if err := c.NewSession(); err != nil {
return err
}

// Start the remote command.
if err := c.sess.Start(c.env + "set -x;" + task.Run); err != nil {
if err := c.sess.Start(c.buildComand(task)); err != nil {
return ErrTask{task, err.Error()}
}

Expand All @@ -206,16 +218,16 @@ func (c *SSHClient) Wait() error {
}

// DialThrough will create a new connection from the ssh server sc is connected to. DialThrough is an SSHDialer.
func (sc *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := sc.conn.Dial(net, addr)
func (c *SSHClient) DialThrough(net, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
conn, err := c.conn.Dial(net, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
cl, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
return ssh.NewClient(cl, chans, reqs), nil

}

Expand Down Expand Up @@ -260,3 +272,47 @@ func (c *SSHClient) Write(p []byte) (n int, err error) {
func (c *SSHClient) WriteClose() error {
return c.remoteStdin.Close()
}

func (c *SSHClient) getShell() error {
// Start a new session
if err := c.NewSession(); err != nil {
return err
}

if err := c.sess.Start(`echo "$SHELL"`); err != nil {
return err
}
c.running = true

resp, err := ioutil.ReadAll(c.Stdout())
if err != nil && err != io.EOF {
return err
}
if err := c.Wait(); err != nil {
return err
}

c.shell = string(resp)

return nil
}

func (c *SSHClient) buildComand(task *Task) string {
inter := strings.TrimSpace(task.Interpreter)
if inter == "" {
inter = strings.TrimSpace(c.shell)
}
_, i := filepath.Split(inter)

switch i {
case "bash", "sh", "zsh", "ksh", "tcsh":
return fmt.Sprintf("/usr/bin/env %s %s -c 'set -x\n%s'", c.env, i, strings.Replace(task.Run, `'`, `'"'"'`, -1))

// TODO: add support for python already called via env
case "python", "python3":
return fmt.Sprintf("/usr/bin/env %s %s -c '%s'", c.env, i, strings.Replace(task.Run, `'`, `'"'"'`, -1))

default:
return fmt.Sprintf("/usr/bin/env %s bash -c 'set -x\n%s'", c.env, strings.Replace(task.Run, `'`, `'"'"'`, -1))
}
}
6 changes: 4 additions & 2 deletions sup.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {
// `export FOO="bar"; export BAR="baz";`.
env := ``
for _, v := range append(sup.conf.Env, network.Env...) {
env += v.AsExport() + " "
env += v.QuotedString() + " "
}

// Create clients for every host (either SSH or Localhost).
Expand Down Expand Up @@ -67,7 +67,7 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {

// SSH client.
remote := &SSHClient{
env: env + `export SUP_HOST="` + host + `";`,
env: env + `SUP_HOST="` + host + `" PATH=$PATH`,
color: Colors[i%len(Colors)],
}

Expand All @@ -81,6 +81,8 @@ func (sup *Stackup) Run(network *Network, commands ...*Command) error {
}
}
defer remote.Close()
remote.getShell()

clients = append(clients, remote)
}

Expand Down
4 changes: 4 additions & 0 deletions supfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (e EnvVar) String() string {
return e.Key + `=` + e.Value
}

func (e EnvVar) QuotedString() string {
return e.Key + `="` + e.Value + `"`
}

// AsExport returns the environment variable as a bash export statement
func (e EnvVar) AsExport() string {
return `export ` + e.Key + `="` + e.Value + `";`
Expand Down
35 changes: 29 additions & 6 deletions task.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package sup

import (
"bufio"
"fmt"
"io"
"io/ioutil"
"os"
)

// Task represents a set of commands to be run.
type Task struct {
Run string
Input io.Reader
Clients []Client
Interpreter string
Run string
Input io.Reader
Clients []Client
}

func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) {
Expand Down Expand Up @@ -50,13 +51,35 @@ func CreateTasks(cmd *Command, clients []Client, env string) ([]*Task, error) {
if err != nil {
return nil, err
}
data, err := ioutil.ReadAll(f)
rd := bufio.NewReader(f)
fLine, _, err := rd.ReadLine()
if err != nil {
return nil, err
}
var hashbang, data []byte
if len(fLine) > 2 && fLine[0] == '#' && fLine[1] == '!' {
hashbang = fLine[2:len(fLine)]
} else {
rd.Reset(f)
}
for {
line, _, err := rd.ReadLine()
if len(line) > 0 {
data = append(data, line...)
data = append(data, '\n')
}

if err != nil {
if err != io.EOF {
return nil, err
}
break
}
}

task := Task{
Run: string(data),
Interpreter: string(hashbang),
Run: string(data),
}
if cmd.Stdin {
task.Input = os.Stdin
Expand Down