Skip to content

Commit

Permalink
Merge pull request #40 from antoniomika/tcp_aliases
Browse files Browse the repository at this point in the history
Added tcp alias
  • Loading branch information
antoniomika authored Nov 10, 2019
2 parents ab9a2fe + bb7b59b commit 5ba74a7
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 28 deletions.
4 changes: 3 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"-sish.addr=localhost:2222",
"-sish.domain=testing.ssi.sh",
"-sish.forcerandomsubdomain=false",
"-sish.bindrandom=false"
"-sish.bindrandom=false",
"-sish.tcpalias=true",
"-sish.proxyprotoenabled=false"
]
}
]
Expand Down
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,23 @@ Usage of ./sish:
-sish.proxyprotoenabled
Whether or not to enable the use of the proxy protocol
-sish.proxyprotoversion string
What version of the proxy protocol to use. Can either be 1, 2, or userdefined. If userdefined, the user needs to add a command to SSH called proxy:version (ie proxy:1) (default "1")
What version of the proxy protocol to use. Can either be 1, 2, or userdefined. If userdefined, the user needs to add a command to SSH called proxyproto:version (ie proxyproto:1) (default "1")
-sish.redirectroot
Whether or not to redirect the root domain (default true)
-sish.redirectrootlocation string
Where to redirect the root domain to (default "https://github.com/antoniomika/sish")
-sish.subdomainlen int
The length of the random subdomain to generate (default 3)
-sish.tcpalias
Whether or not to allow the use of TCP aliasing
-sish.usegeodb
Whether or not to use the maxmind geodb
-sish.verifyorigin
Whether or not to verify origin on websocket connection (default true)
-sish.verifyssl
Whether or not to verify SSL on proxy connection (default true)
-sish.version
Print version and exit
-sish.whitelistedcountries string
A comma separated list of whitelisted countries
-sish.whitelistedips string
Expand Down
44 changes: 44 additions & 0 deletions channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io"
"log"
"net"
"strings"

"golang.org/x/crypto/ssh"
Expand All @@ -14,6 +15,7 @@ var proxyProtoPrefix = "proxyproto:"
func handleSession(newChannel ssh.NewChannel, sshConn *SSHConnection, state *State) {
connection, requests, err := newChannel.Accept()
if err != nil {
sshConn.CleanUp(state)
return
}

Expand Down Expand Up @@ -89,6 +91,48 @@ func handleSession(newChannel ssh.NewChannel, sshConn *SSHConnection, state *Sta
}()
}

func handleAlias(newChannel ssh.NewChannel, sshConn *SSHConnection, state *State) {
connection, requests, err := newChannel.Accept()
if err != nil {
sshConn.CleanUp(state)
return
}

go ssh.DiscardRequests(requests)

if *debug {
log.Println("Handling alias connection for:", connection)
}

check := &forwardedTCPPayload{}
err = ssh.Unmarshal(newChannel.ExtraData(), check)
if err != nil {
log.Println("Error unmarshaling information:", err)
sshConn.CleanUp(state)
return
}

tcpAliasToConnect := fmt.Sprintf("%s:%d", check.Addr, check.Port)
loc, ok := state.TCPListeners.Load(tcpAliasToConnect)
if !ok {
log.Println("Unable to load tcp alias:", tcpAliasToConnect)
sshConn.CleanUp(state)
return
}

conn, err := net.Dial("unix", loc.(string))
if err != nil {
log.Println("Error connecting to alias:", err)
sshConn.CleanUp(state)
return
}

sshConn.Listeners.Store(conn.RemoteAddr(), nil)

copyBoth(conn, connection, false)
sshConn.CleanUp(state)
}

func getProxyProtoVersion(proxyProtoUserVersion string) byte {
if *proxyProtoVersion != "userdefined" {
proxyProtoUserVersion = *proxyProtoVersion
Expand Down
2 changes: 2 additions & 0 deletions handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func handleChannel(newChannel ssh.NewChannel, sshConn *SSHConnection, state *Sta
switch channel := newChannel.ChannelType(); channel {
case "session":
handleSession(newChannel, sshConn, state)
case "direct-tcpip":
handleAlias(newChannel, sshConn, state)
default:
err := newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", channel))
if err != nil {
Expand Down
5 changes: 4 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type State struct {
SSHConnections *sync.Map
Listeners *sync.Map
HTTPListeners *sync.Map
TCPListeners *sync.Map
IPFilter *ipfilter.IPFilter
}

Expand Down Expand Up @@ -66,9 +67,10 @@ var (
cleanupUnbound = flag.Bool("sish.cleanupunbound", true, "Whether or not to cleanup unbound (forwarded) SSH connections")
bindRandom = flag.Bool("sish.bindrandom", true, "Bind ports randomly (OS chooses)")
proxyProtoEnabled = flag.Bool("sish.proxyprotoenabled", false, "Whether or not to enable the use of the proxy protocol")
proxyProtoVersion = flag.String("sish.proxyprotoversion", "1", "What version of the proxy protocol to use. Can either be 1, 2, or userdefined. If userdefined, the user needs to add a command to SSH called proxy:version (ie proxy:1)")
proxyProtoVersion = flag.String("sish.proxyprotoversion", "1", "What version of the proxy protocol to use. Can either be 1, 2, or userdefined. If userdefined, the user needs to add a command to SSH called proxyproto:version (ie proxyproto:1)")
debug = flag.Bool("sish.debug", false, "Whether or not to print debug information")
versionCheck = flag.Bool("sish.version", false, "Print version and exit")
tcpAlias = flag.Bool("sish.tcpalias", false, "Whether or not to allow the use of TCP aliasing")
bannedSubdomainList = []string{""}
filter *ipfilter.IPFilter
)
Expand Down Expand Up @@ -122,6 +124,7 @@ func main() {
SSHConnections: &sync.Map{},
Listeners: &sync.Map{},
HTTPListeners: &sync.Map{},
TCPListeners: &sync.Map{},
IPFilter: filter,
}

Expand Down
84 changes: 62 additions & 22 deletions requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net"
"os"
"strconv"
"sync"
"time"

"github.com/pires/go-proxyproto"
Expand Down Expand Up @@ -36,22 +37,27 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state

bindPort := check.Rport

handleTCPAliasing := false
if bindPort != uint32(80) && bindPort != uint32(443) {
checkedPort, err := checkPort(check.Rport, *bindRange)
if err != nil && !*bindRandom {
err = newRequest.Reply(false, nil)
if err != nil {
log.Println("Error replying to socket request:", err)
if *tcpAlias && check.Addr != "localhost" {
handleTCPAliasing = true
} else {
checkedPort, err := checkPort(check.Rport, *bindRange)
if err != nil && !*bindRandom {
err = newRequest.Reply(false, nil)
if err != nil {
log.Println("Error replying to socket request:", err)
}
return
}
return
}

bindPort = checkedPort
if *bindRandom {
bindPort = 0
bindPort = checkedPort
if *bindRandom {
bindPort = 0

if *bindRange != "" {
bindPort = getRandomPortInRange(*bindRange)
if *bindRange != "" {
bindPort = getRandomPortInRange(*bindRange)
}
}
}
}
Expand All @@ -70,7 +76,7 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state
}
os.Remove(tmpfile.Name())

if stringPort == "80" || stringPort == "443" {
if stringPort == "80" || stringPort == "443" || handleTCPAliasing {
listenType = "unix"
listenAddr = tmpfile.Name()
}
Expand Down Expand Up @@ -126,7 +132,16 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state
requestMessages += fmt.Sprintf("HTTPS: https://%s:%d", host, *httpsPort)
}
} else {
requestMessages += fmt.Sprintf("TCP: %s:%d", *rootDomain, chanListener.Addr().(*net.TCPAddr).Port)
if handleTCPAliasing {
validAlias := getOpenAlias(check.Addr, stringPort, state, sshConn)

state.TCPListeners.Store(validAlias, chanListener.Addr().String())
defer state.TCPListeners.Delete(validAlias)

requestMessages += fmt.Sprintf("TCP Alias: %s", validAlias)
} else {
requestMessages += fmt.Sprintf("TCP: %s:%d", *rootDomain, chanListener.Addr().(*net.TCPAddr).Port)
}
}

sshConn.Messages <- requestMessages
Expand Down Expand Up @@ -160,9 +175,16 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state

defer newChan.Close()

if sshConn.ProxyProto != 0 && listenType != "unix" {
sourceInfo := cl.RemoteAddr().(*net.TCPAddr)
destInfo := cl.LocalAddr().(*net.TCPAddr)
if sshConn.ProxyProto != 0 && (listenType != "unix" || handleTCPAliasing) {
var sourceInfo *net.TCPAddr
var destInfo *net.TCPAddr
if _, ok := cl.RemoteAddr().(*net.TCPAddr); !ok {
sourceInfo = sshConn.SSHConn.RemoteAddr().(*net.TCPAddr)
destInfo = sshConn.SSHConn.LocalAddr().(*net.TCPAddr)
} else {
sourceInfo = cl.RemoteAddr().(*net.TCPAddr)
destInfo = cl.LocalAddr().(*net.TCPAddr)
}

proxyProtoHeader := proxyproto.Header{
Version: sshConn.ProxyProto,
Expand All @@ -180,30 +202,48 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state
}
}

go copyBoth(cl, newChan)
go copyBoth(cl, newChan, false)
go ssh.DiscardRequests(newReqs)
}
}

func copyBoth(writer net.Conn, reader ssh.Channel) {
func copyBoth(writer net.Conn, reader ssh.Channel, wait bool) {
closeBoth := func() {
time.Sleep(1 * time.Millisecond)
time.Sleep(100 * time.Millisecond)
writer.Close()
reader.Close()
}

defer closeBoth()
var wg sync.WaitGroup

go func() {
defer closeBoth()
if wait {
wg.Add(1)
defer wg.Done()
} else {
defer closeBoth()
}

_, err := io.Copy(writer, reader)
if err != nil {
log.Println("Error writing to reader:", err)
}
}()

if wait {
wg.Add(1)
} else {
defer closeBoth()
}

_, err := io.Copy(reader, writer)
if err != nil {
log.Println("Error writing to writer:", err)
}
if wait {
wg.Done()
}

wg.Wait()
closeBoth()
}
35 changes: 35 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,41 @@ func getOpenHost(addr string, state *State, sshConn *SSHConnection) string {
return getUnusedHost()
}

func getOpenAlias(addr string, port string, state *State, sshConn *SSHConnection) string {
getUnusedAlias := func() string {
first := true
alias := fmt.Sprintf("%s:%s", strings.ToLower(addr), port)
getRandomAlias := func() string {
return fmt.Sprintf("%s:%s", strings.ToLower(RandStringBytesMaskImprSrc(*domainLen)), port)
}
reportUnavailable := func(unavailable bool) {
if first && unavailable {
sshConn.Messages <- "This alias is unavaible. Assigning a random alias."
}
}

checkAlias := func(checkAlias string) bool {
if *forceRandomSubdomain || !first || inBannedList(alias, bannedSubdomainList) {
reportUnavailable(true)
alias = getRandomAlias()
}

_, ok := state.TCPListeners.Load(alias)
reportUnavailable(ok)

first = false
return ok
}

for checkAlias(alias) {
}

return alias
}

return getUnusedAlias()
}

// RandStringBytesMaskImprSrc creates a random string of length n
// https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-golang
func RandStringBytesMaskImprSrc(n int) string {
Expand Down

0 comments on commit 5ba74a7

Please sign in to comment.