diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index a3f7938..956a2ee 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -5,9 +5,10 @@ on: workflow_call: jobs: - integration-test: - name: Run Integration Tests - runs-on: [ self-hosted, linux, x64, large ] + integration-test-legacy: + name: Run Legacy Integration Tests + runs-on: [ self-hosted, linux, x64, edge ] + timeout-minutes: 30 steps: - uses: actions/checkout@v2 @@ -24,11 +25,14 @@ jobs: - name: Install Aproxy Snap run: | + sudo snap remove aproxy || : sudo snap install --dangerous aproxy_*_amd64.snap - name: Configure Aproxy run: | sudo snap set aproxy proxy=squid.internal:3128 listen=:23403 + + sudo nft flush ruleset sudo nft -f - << EOF define default-ip = $(ip route get $(ip route show 0.0.0.0/0 | grep -oP 'via \K\S+') | grep -oP 'src \K\S+') define private-ips = { 10.0.0.0/8, 127.0.0.1/8, 172.16.0.0/12, 192.168.0.0/16 } @@ -49,14 +53,108 @@ jobs: - name: Test HTTP run: | - curl --noproxy "*" http://example.com -svS -o /dev/null + curl --noproxy "*" --max-time 30 http://canonical.com -svS -o /dev/null - name: Test HTTPS run: | - curl --noproxy "*" https://example.com -svS -o /dev/null + curl --noproxy "*" --max-time 30 https://canonical.com -svS -o /dev/null - name: Test Access Logs run: | sudo snap logs aproxy.aproxy - sudo snap logs aproxy.aproxy | grep -Fq "example.com:80" - sudo snap logs aproxy.aproxy | grep -Fq "example.com:443" + sudo snap logs aproxy.aproxy | grep -Fq "canonical.com:80" + sudo snap logs aproxy.aproxy | grep -Fq "canonical.com:443" + + integration-test: + name: Run Integration Tests + runs-on: [ self-hosted, linux, x64, edge ] + timeout-minutes: 30 + + steps: + - uses: actions/checkout@v2 + + - name: Install Tinyproxy + run: | + sudo apt update + sudo apt install tinyproxy -y + + - name: Build Aproxy Snap + id: snapcraft-build + uses: snapcore/action-build@v1 + + - name: Upload Aproxy Snap + uses: actions/upload-artifact@v3 + with: + name: snap + path: aproxy*.snap + + - name: Install Aproxy Snap + run: | + sudo snap remove aproxy || : + sudo snap install --dangerous aproxy_*_amd64.snap + + - name: Configure Aproxy + run: | + sudo snap connect aproxy:network-control + sudo snap set aproxy fwmark=7316 listen=:23403 + + sudo nft flush ruleset + sudo nft -f - << EOF + define default-ip = $(ip route get $(ip route show 0.0.0.0/0 | grep -oP 'via \K\S+') | grep -oP 'src \K\S+') + define private-ips = { 10.0.0.0/8, 127.0.0.1/8, 172.16.0.0/12, 192.168.0.0/16 } + table ip aproxy + flush table ip aproxy + table ip aproxy { + chain prerouting { + type nat hook prerouting priority dstnat; policy accept; + meta skuid != tinyproxy mark != 7316 ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:23403 + } + + chain output { + type nat hook output priority -100; policy accept; + meta skuid != tinyproxy mark != 7316 ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:23403 + } + } + EOF + + - name: Test Passthrough HTTP + run: | + curl --noproxy "*" --max-time 30 http://www.canonical.com -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "passthrough.*host=www.canonical.com" + + - name: Test Passthrough HTTPS + run: | + curl --noproxy "*" --max-time 30 https://canonical.com -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "passthrough.*host=canonical.com" + + - name: Set HTTP Proxy + run: | + sudo snap set aproxy http.proxy=http://localhost:8888 + + - name: Test Proxy HTTP + run: | + curl --noproxy "*" --max-time 30 http://www.ubuntu.com -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "http.*proxy.*host=www.ubuntu.com" + + - name: Test Passthrough HTTPS + run: | + curl --noproxy "*" --max-time 30 https://ubuntu.com -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "passthrough.*host=ubuntu.com" + + - name: Set HTTPS Proxy + run: | + sudo snap set aproxy https.proxy=http://localhost:8888 + + - name: Test Proxy HTTP + run: | + curl --noproxy "*" --max-time 30 http://www.ubuntu.net -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "http.*proxy.*host=www.ubuntu.net" + + - name: Test Proxy HTTPS + run: | + curl --noproxy "*" --max-time 30 https://ubuntu.net -svS -o /dev/null + sudo snap logs aproxy.aproxy -n 1 | grep -qi "tls.*proxy.*host=ubuntu.net" + + - name: Print Aproxy Logs + if: always() + run: sudo snap logs aproxy -n all diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index c59b4d7..f6ddfd1 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -7,7 +7,8 @@ on: jobs: test: name: Run Tests - runs-on: ubuntu-latest + runs-on: [ self-hosted, linux, x64, large ] + timeout-minutes: 30 steps: - uses: actions/checkout@v2 diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..844a0e5 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @canonical/is-charms diff --git a/README.md b/README.md index 8495a12..d40b8bd 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,13 @@ requiring destination hostname for auditing or access control. ## Usage -Install aproxy using snap, and configure the upstream http proxy. +Install aproxy using snap, and configure the upstream http proxy and the forward +traffic firewall mark. ```bash sudo snap install aproxy --edge -sudo snap set aproxy proxy=squid.internal:3128 +sudo snap connect aproxy:network-control +sudo snap set aproxy http.proxy=http://squid.internal:3128 https.proxy=http://squid.internal:3128 fwmark=7316 ``` Create the following nftables rules to redirect outbound traffic to aproxy on @@ -27,12 +29,12 @@ flush table ip aproxy table ip aproxy { chain prerouting { type nat hook prerouting priority dstnat; policy accept; - ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:8443 + mark != 7316 ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:8443 } chain output { type nat hook output priority -100; policy accept; - ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:8443 + mark != 7316 ip daddr != \$private-ips tcp dport { 80, 443 } counter dnat to \$default-ip:8443 } } EOF @@ -55,5 +57,5 @@ Follow these steps to get started: git clone https://github.com/canonical/aproxy.git cd aproxy go mod download -go run . --proxy=squid.internal:3128 +go run . --http-proxy=http://squid.internal:3128 --https-proxy=http://squid.internal:3128 ``` diff --git a/aproxy.go b/aproxy.go index 8885dbf..9c2e794 100644 --- a/aproxy.go +++ b/aproxy.go @@ -3,195 +3,21 @@ package main import ( "bufio" "context" - "encoding/binary" "errors" "flag" "fmt" - "io" "log" - "log/slog" "net" "net/http" - "net/url" "os" "os/signal" "strings" - "sync" - "sync/atomic" - "syscall" - - "golang.org/x/crypto/cryptobyte" ) -var version = "0.2.2" - -// PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection. -// Any Read before the EndPreread can be undone and read again by calling the EndPreread function. -type PrereadConn struct { - ended bool - buf []byte - mu sync.Mutex - conn net.Conn -} - -// EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again. -// EndPreread can be only called once. -func (c *PrereadConn) EndPreread() { - c.mu.Lock() - defer c.mu.Unlock() - if c.ended { - panic("call EndPreread after preread has ended or hasn't started") - } - c.ended = true -} - -// Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread. -func (c *PrereadConn) Read(p []byte) (n int, err error) { - c.mu.Lock() - defer c.mu.Unlock() - if c.ended { - n = copy(p, c.buf) - bufLen := len(c.buf) - c.buf = c.buf[n:] - if n == len(p) || (bufLen > 0 && bufLen == n) { - return n, nil - } - rn, err := c.conn.Read(p[n:]) - return rn + n, err - } else { - n, err = c.conn.Read(p) - c.buf = append(c.buf, p[:n]...) - return n, err - } -} - -// Write writes data to the underlying connection. -func (c *PrereadConn) Write(p []byte) (n int, err error) { - return c.conn.Write(p) -} - -// NewPrereadConn wraps the network connection and return a *PrereadConn. -// It's recommended to not touch the original connection after wrapped. -func NewPrereadConn(conn net.Conn) *PrereadConn { - return &PrereadConn{conn: conn} -} - -// PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection. -func PrereadSNI(conn *PrereadConn) (_ string, err error) { - defer conn.EndPreread() - defer func() { - if err != nil { - err = fmt.Errorf("failed to preread TLS client hello: %w", err) - } - }() - typeVersionLen := make([]byte, 5) - n, err := conn.Read(typeVersionLen) - if n != 5 { - return "", errors.New("too short") - } - if err != nil { - return "", err - } - if typeVersionLen[0] != 22 { - return "", errors.New("not a TCP handshake") - } - msgLen := binary.BigEndian.Uint16(typeVersionLen[3:]) - buf := make([]byte, msgLen+5) - n, err = conn.Read(buf[5:]) - if n != int(msgLen) { - return "", errors.New("too short") - } - if err != nil { - return "", err - } - copy(buf[:5], typeVersionLen) - return extractSNI(buf) -} - -func extractSNI(data []byte) (string, error) { - s := cryptobyte.String(data) - var ( - version uint16 - random []byte - sessionId []byte - ) - - if !s.Skip(9) || - !s.ReadUint16(&version) || !s.ReadBytes(&random, 32) || - !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&sessionId)) { - return "", fmt.Errorf("failed to parse TLS client hello version, random or session id") - } - - var cipherSuitesData cryptobyte.String - if !s.ReadUint16LengthPrefixed(&cipherSuitesData) { - return "", fmt.Errorf("failed to parse TLS client hello cipher suites") - } - - var cipherSuites []uint16 - for !cipherSuitesData.Empty() { - var suite uint16 - if !cipherSuitesData.ReadUint16(&suite) { - return "", fmt.Errorf("failed to parse TLS client hello cipher suites") - } - cipherSuites = append(cipherSuites, suite) - } - - var compressionMethods []byte - if !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&compressionMethods)) { - return "", fmt.Errorf("failed to parse TLS client hello compression methods") - } - - if s.Empty() { - // ClientHello is optionally followed by extension data - return "", fmt.Errorf("no extension data in TLS client hello") - } - - var extensions cryptobyte.String - if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { - return "", fmt.Errorf("failed to parse TLS client hello extensions") - } - - finalServerName := "" - for !extensions.Empty() { - var extension uint16 - var extData cryptobyte.String - if !extensions.ReadUint16(&extension) || - !extensions.ReadUint16LengthPrefixed(&extData) { - return "", fmt.Errorf("failed to parse TLS client hello extension") - } - if extension != 0 { - continue - } - var nameList cryptobyte.String - if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { - return "", fmt.Errorf("failed to parse server name extension") - } - - for !nameList.Empty() { - var nameType uint8 - var serverName cryptobyte.String - if !nameList.ReadUint8(&nameType) || - !nameList.ReadUint16LengthPrefixed(&serverName) || - serverName.Empty() { - return "", fmt.Errorf("failed to parse server name indication extension") - } - if nameType != 0 { - continue - } - if len(finalServerName) != 0 { - return "", fmt.Errorf("multiple names of the same name_type are prohibited in server name extension") - } - finalServerName = string(serverName) - if strings.HasSuffix(finalServerName, ".") { - return "", fmt.Errorf("SNI name ends with a trailing dot") - } - } - } - return finalServerName, nil -} +var version = "1.0.0" -// PrereadHttpHost pre-reads the HTTP Host header from an HTTP connection. -func PrereadHttpHost(conn *PrereadConn) (_ string, err error) { +// PrereadHTTPHost pre-reads the HTTP Host header from an HTTP connection. +func PrereadHTTPHost(conn *PrereadConn) (_ string, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to preread HTTP request: %w", err) @@ -210,172 +36,66 @@ func PrereadHttpHost(conn *PrereadConn) (_ string, err error) { return host, nil } -// DialProxy dials the TCP connection to the proxy. -func DialProxy(proxy string) (net.Conn, error) { - proxyAddr, err := net.ResolveTCPAddr("tcp", proxy) - if err != nil { - return nil, fmt.Errorf("failed to resolve proxy address: %w", err) - } - conn, err := net.DialTCP("tcp", nil, proxyAddr) - if err != nil { - return nil, fmt.Errorf("failed to connect to proxy: %w", err) - } - return conn, nil -} - -// DialProxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy. -func DialProxyConnect(proxy string, dst string) (net.Conn, error) { - conn, err := DialProxy(proxy) - if err != nil { - return nil, err - } - request := http.Request{ - Method: "CONNECT", - URL: &url.URL{ - Host: dst, - }, - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: map[string][]string{ - "User-Agent": {fmt.Sprintf("aproxy/%s", version)}, - }, - Host: dst, - } - err = request.Write(conn) - if err != nil { - return nil, fmt.Errorf("failed to send connect request to http proxy: %w", err) - } - response, err := http.ReadResponse(bufio.NewReaderSize(conn, 0), &request) - if response.StatusCode != 200 { - return nil, fmt.Errorf("proxy return %d response for connect request", response.StatusCode) - } - if err != nil { - return nil, fmt.Errorf("failed to receive http connect response from proxy: %w", err) - } - return conn, nil -} - -// GetOriginalDst get the original destination address of a TCP connection before dstnat. -func GetOriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) { - file, err := conn.File() - defer func(file *os.File) { - err := file.Close() - if err != nil { - slog.Error("failed to close the duplicated TCP socket file descriptor") - } - }(file) +// HandleTLSConn handles one incoming TCP connection +func HandleTLSConn(ctx context.Context, consigned *ConsignedConn, forwarder *Forwarder) { + sni, err := PrereadSNI(consigned.PrereadConn) if err != nil { - return nil, fmt.Errorf("failed to convert connection to file: %w", err) - } - return GetsockoptIPv4OriginalDst( - int(file.Fd()), - syscall.SOL_IP, - 80, // SO_ORIGINAL_DST - ) -} - -// RelayTCP relays data between the incoming TCP connection and the proxy connection. -func RelayTCP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { - var closed atomic.Bool - go func() { - _, err := io.Copy(proxyConn, conn) - if err != nil && !closed.Load() { - logger.Error("failed to relay network traffic to proxy", "error", err) - } - closed.Store(true) - _ = proxyConn.Close() - }() - _, err := io.Copy(conn, proxyConn) - if err != nil && !closed.Load() { - logger.Error("failed to relay network traffic from proxy", "error", err) + logger.ErrorContext(ctx, "failed to preread SNI from connection", "error", err) + return } - closed.Store(true) + host := fmt.Sprintf("%s:%d", sni, consigned.OriginalDst.Port) + consigned.Host = host + forwarder.ForwardHTTPS(ctx, consigned) } -// RelayHTTP relays a single HTTP request and response between a local connection and a proxy. -// It modifies the Connection header to "close" in both the request and response. -func RelayHTTP(conn io.ReadWriter, proxyConn io.ReadWriteCloser, logger *slog.Logger) { - defer proxyConn.Close() - req, err := http.ReadRequest(bufio.NewReader(conn)) +// HandleHTTPConn handles one incoming HTTP connection +func HandleHTTPConn(ctx context.Context, consigned *ConsignedConn, forwarder *Forwarder) { + host, err := PrereadHTTPHost(consigned.PrereadConn) if err != nil { - logger.Error("failed to read HTTP request from connection", "error", err) + logger.ErrorContext(ctx, "failed to preread HTTP host from connection", "error", err) return } - req.URL.Host = req.Host - req.URL.Scheme = "http" - req.Header.Set("Connection", "close") - if err := req.WriteProxy(proxyConn); err != nil { - logger.Error("failed to send HTTP request to proxy", "error", err) - return - } - resp, err := http.ReadResponse(bufio.NewReader(proxyConn), req) - if err != nil { - logger.Error("failed to read HTTP response from proxy", "error", err) - return - } - resp.Header.Set("Connection", "close") - if err := resp.Write(conn); err != nil { - logger.Error("failed to send HTTP response to connection", "error", err) - return + if !strings.Contains(host, ":") { + host = fmt.Sprintf("%s:%d", host, consigned.OriginalDst.Port) } + consigned.Host = host + forwarder.ForwardHTTP(ctx, consigned) } // HandleConn manages the incoming connections. -func HandleConn(conn net.Conn, proxy string) { +func HandleConn(ctx context.Context, conn *net.TCPConn, forwarder *Forwarder) { defer conn.Close() - logger := slog.With("src", conn.RemoteAddr()) - dst, err := GetOriginalDst(conn.(*net.TCPConn)) + dst, err := GetSocketIPv4OriginalDst(conn) if err != nil { - slog.Error("failed to get connection original destination", "error", err) + logger.ErrorContext(ctx, "failed to get connection original destination", "error", err) return } - logger = logger.With("original_dst", dst) - consigned := NewPrereadConn(conn) + consigned := NewConsignedConn(conn) + consigned.OriginalDst = dst + ctx = ContextWithConsignedConn(ctx, consigned) switch dst.Port { case 443: - sni, err := PrereadSNI(consigned) - if err != nil { - logger.Error("failed to preread SNI from connection", "error", err) - return - } else { - host := fmt.Sprintf("%s:%d", sni, dst.Port) - logger = logger.With("host", host) - proxyConn, err := DialProxyConnect(proxy, host) - if err != nil { - logger.Error("failed to connect to http proxy", "error", err) - return - } - logger.Info("relay TLS connection to proxy") - RelayTCP(consigned, proxyConn, logger) - } + HandleTLSConn(ctx, consigned, forwarder) case 80: - host, err := PrereadHttpHost(consigned) - if err != nil { - logger.Error("failed to preread HTTP host from connection", "error", err) - return - } - if !strings.Contains(host, ":") { - host = fmt.Sprintf("%s:%d", host, dst.Port) - } - logger = logger.With("host", host) - proxyConn, err := DialProxy(proxy) - if err != nil { - logger.Error("failed to connect to http proxy", "error", err) - return - } - logger.Info("relay HTTP connection to proxy") - RelayHTTP(consigned, proxyConn, logger) + HandleHTTPConn(ctx, consigned, forwarder) default: - logger.Error(fmt.Sprintf("unknown destination port: %d", dst.Port)) + logger.ErrorContext(ctx, fmt.Sprintf("unknown destination port: %d", dst.Port)) return } } func main() { - proxyFlag := flag.String("proxy", "", "upstream HTTP proxy address in the 'host:port' format") + httpProxyFlag := flag.String("http-proxy", "", "upstream HTTP proxy URL") + httpsProxyFlag := flag.String("https-proxy", "", "upstream HTTPS proxy URL") listenFlag := flag.String("listen", ":8443", "the address and port on which the server will listen") + fwmarkFlag := flag.Uint("fwmark", 0, "set firewall mark for outgoing traffic") flag.Parse() + httpProxy := *httpProxyFlag + httpsProxy := *httpsProxyFlag + forwarder, err := NewForwarder(*httpProxyFlag, *httpsProxyFlag, *fwmarkFlag) + if err != nil { + log.Fatal(err) + } listenAddr := *listenFlag ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) defer stop() @@ -384,17 +104,22 @@ func main() { if err != nil { log.Fatalf("failed to listen on %#v", listenAddr) } - slog.Info(fmt.Sprintf("start listening on %s", listenAddr)) - proxy := *proxyFlag - if proxy == "" { - log.Fatalf("no upstearm proxy specified") + logger.InfoContext(ctx, fmt.Sprintf("start listening on %s", listenAddr)) + if httpProxy != "" { + logger.InfoContext(ctx, fmt.Sprintf("start forwarding HTTP connection to proxy %s", httpProxy)) + } else { + logger.InfoContext(ctx, "start passthrough HTTP connection") + } + if httpsProxy != "" { + logger.InfoContext(ctx, fmt.Sprintf("start forwarding HTTPS connection to proxy %s", httpsProxy)) + } else { + logger.InfoContext(ctx, "start passthrough HTTPS connection") } - slog.Info(fmt.Sprintf("start forwarding to proxy %s", proxy)) go func() { for { conn, err := listener.Accept() if err != nil { - slog.Error("failed to accept connection", "error", err) + logger.ErrorContext(ctx, "failed to accept connection", "error", err) continue } go HandleConn(conn, proxy) diff --git a/aproxy_test.go b/aproxy_test.go index b598b5e..8ed6e91 100644 --- a/aproxy_test.go +++ b/aproxy_test.go @@ -2,36 +2,10 @@ package main import ( "encoding/hex" - "io" "net" "testing" ) -func TestPrereadConn(t *testing.T) { - remote, local := net.Pipe() - go remote.Write([]byte("hello, world")) - preread := &PrereadConn{conn: local} - buf := make([]byte, 5) - _, err := preread.Read(buf) - if err != nil { - t.Fatalf("Read failed during preread: %s", err) - } - buf = make([]byte, 3) - _, err = preread.Read(buf) - if err != nil { - t.Fatalf("Read failed during preread: %s", err) - } - preread.EndPreread() - buf2 := make([]byte, 12) - _, err = io.ReadFull(preread, buf2) - if err != nil { - t.Fatalf("Read failed after preread: %s", err) - } - if string(buf2) != "hello, world" { - t.Fatalf("preread altered the read state: got %s", string(buf2)) - } -} - func TestPrereadSNI(t *testing.T) { remote, local := net.Pipe() // data obtained from https://gitlab.com/wireshark/wireshark/-/blob/master/test/captures/tls12-aes256gcm.pcap @@ -49,11 +23,11 @@ func TestPrereadSNI(t *testing.T) { func TestPrereadHttpHost(t *testing.T) { remote, local := net.Pipe() go remote.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nAccept: */*\r\n\r\n")) - host, err := PrereadHttpHost(NewPrereadConn(local)) + host, err := PrereadHTTPHost(NewPrereadConn(local)) if err != nil { - t.Fatalf("PrereadHttpHost failed: %s", err) + t.Fatalf("PrereadHTTPHost failed: %s", err) } if host != "example.com" { - t.Fatalf("PrereadHttpHost returns incorrect host: expected: example.com, got %s", host) + t.Fatalf("PrereadHTTPHost returns incorrect host: expected: example.com, got %s", host) } } diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..77ad062 --- /dev/null +++ b/conn.go @@ -0,0 +1,21 @@ +package main + +import ( + "net" +) + +// ConsignedConn wraps the PrereadConn and provides some slots to attach information related to the connection. +type ConsignedConn struct { + *PrereadConn + OriginalDst *net.TCPAddr + Host string +} + +// NewConsignedConn creates a new *ConsignedConn from the connection. +func NewConsignedConn(conn net.Conn) *ConsignedConn { + return &ConsignedConn{ + PrereadConn: NewPrereadConn(conn), + OriginalDst: nil, + Host: "", + } +} diff --git a/forwarder.go b/forwarder.go new file mode 100644 index 0000000..6826873 --- /dev/null +++ b/forwarder.go @@ -0,0 +1,236 @@ +package main + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "net/http" + "net/url" + "sync/atomic" + "syscall" +) + +type Forwarder struct { + fwmark uint32 + httpProxy string + httpsProxy string + dialFunc func(f *Forwarder, addr string) (net.Conn, error) // use dialFunc instead of dialTCP if not nil +} + +// parseProxyUrl parses a proxy URL to a TCP address in the format of 'host:port'. +func verifyProxyUrl(proxyUrl string) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("failed to parse proxy URL '%v': %w", proxyUrl, err) + } + }() + u, err := url.Parse(proxyUrl) + if err != nil { + return err + } + if u.Scheme != "http" { + return fmt.Errorf("proxy protocol %s not supported", u.Scheme) + } + if u.User != nil { + return fmt.Errorf("proxy authencation not supported") + } + if u.Port() == "" { + return fmt.Errorf("proxy URL doesn't contain a port") + } + return nil +} + +func NewForwarder(httpProxy, httpsProxy string, fwmark uint) (*Forwarder, error) { + if err := verifyProxyUrl(httpProxy); err != nil && httpProxy != "" { + return nil, err + } + if err := verifyProxyUrl(httpsProxy); err != nil && httpsProxy != "" { + return nil, err + } + if fwmark > 4294967295 { + return nil, fmt.Errorf("invalid fwmark %d", fwmark) + } + return &Forwarder{ + fwmark: uint32(fwmark), + httpProxy: httpProxy, + httpsProxy: httpsProxy, + }, nil +} + +func (f *Forwarder) proxyAddr(proxyUrl string) string { + u, err := url.Parse(proxyUrl) + if err != nil { + panic(err) + } + return u.Host +} + +// dialTCP dials the TCP connection to the remote address. +// dialTCP sets the fwmark of the underlying socket if the fwmark argument is not 0. +func (f *Forwarder) dialTCP(addr string) (net.Conn, error) { + var fwmarkErr error + dialer := &net.Dialer{ + Control: func(_, _ string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if f.fwmark > 0 { + err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, int(f.fwmark)) + if err != nil { + fwmarkErr = fmt.Errorf("failed to set mark on socket: %w", err) + } + } + }) + }, + } + conn, err := dialer.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("failed to connect to '%v': %w", addr, err) + } + if fwmarkErr != nil { + return nil, fmt.Errorf("failed to set mark on socket: %w", fwmarkErr) + } + return conn, nil +} + +// dial dials the connection to the remote address. +// if dialFunc is not nil, it will be used, or else dialTCP will be used. +func (f *Forwarder) dial(addr string) (net.Conn, error) { + if f.dialFunc != nil { + return f.dialFunc(f, addr) + } + return f.dialTCP(addr) +} + +// proxyConnect dials the TCP connection and finishes the HTTP CONNECT handshake with the proxy. +// The dst argument is used during the handshake as the destination. +func (f *Forwarder) proxyConnect(dst string) (net.Conn, error) { + conn, err := f.dial(f.proxyAddr(f.httpsProxy)) + if err != nil { + return nil, err + } + request := http.Request{ + Method: "CONNECT", + URL: &url.URL{ + Host: dst, + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: map[string][]string{ + "User-Agent": {fmt.Sprintf("aproxy/%s", version)}, + }, + Host: dst, + } + err = request.Write(conn) + if err != nil { + return nil, fmt.Errorf("failed to send connect request to http proxy: %w", err) + } + response, err := http.ReadResponse(bufio.NewReaderSize(conn, 0), &request) + if err != nil { + return nil, fmt.Errorf("failed to receive http connect response from proxy: %w", err) + } + if response.StatusCode != 200 { + return nil, fmt.Errorf("proxy return %d response for connect request", response.StatusCode) + } + return conn, nil +} + +// relayTCP relays data between the incoming TCP connection and the outgoing connection. +func (f *Forwarder) relayTCP(ctx context.Context, in io.ReadWriter, out io.ReadWriteCloser) { + var closed atomic.Bool + go func() { + _, err := io.Copy(out, in) + if err != nil && !closed.Load() { + logger.ErrorContext(ctx, "failed to relay network traffic to outgoing connection", "error", err) + } + closed.Store(true) + _ = out.Close() + }() + _, err := io.Copy(in, out) + if err != nil && !closed.Load() { + logger.ErrorContext(ctx, "failed to relay network traffic to incoming connection", "error", err) + } + closed.Store(true) +} + +// relayHTTP relays a single HTTP request and response between a local connection and a proxy. +// It modifies the Connection header to "close" in both the request and response. +func (f *Forwarder) relayHTTP(ctx context.Context, conn io.ReadWriter, proxyConn io.ReadWriteCloser) { + defer proxyConn.Close() + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + logger.ErrorContext(ctx, "failed to read HTTP request from connection", "error", err) + return + } + req.URL.Host = req.Host + req.URL.Scheme = "http" + req.Header.Set("Connection", "close") + if err := req.WriteProxy(proxyConn); err != nil { + logger.ErrorContext(ctx, "failed to send HTTP request to proxy", "error", err) + return + } + resp, err := http.ReadResponse(bufio.NewReader(proxyConn), req) + if err != nil { + logger.ErrorContext(ctx, "failed to read HTTP response from proxy", "error", err) + return + } + resp.Header.Set("Connection", "close") + if err := resp.Write(conn); err != nil { + logger.ErrorContext(ctx, "failed to send HTTP response to connection", "error", err) + return + } +} + +// passthrough forwards the connection to the original destination. +func (f *Forwarder) passthrough(ctx context.Context, conn *ConsignedConn) { + out, err := f.dial(conn.OriginalDst.String()) + if err != nil { + logger.ErrorContext(ctx, "failed to dial original src address for passthrough connection", "error", err) + return + } + logger.InfoContext(ctx, "passthrough connection") + f.relayTCP(ctx, conn, out) +} + +// proxyHTTP forwards the connection to an upstream HTTP proxy. +func (f *Forwarder) proxyHTTP(ctx context.Context, conn *ConsignedConn) { + out, err := f.dial(f.proxyAddr(f.httpProxy)) + if err != nil { + logger.ErrorContext(ctx, "failed to dial http proxy", "error", err) + return + } + logger.InfoContext(ctx, "relay HTTP connection to proxy", "http_proxy", f.httpProxy) + f.relayHTTP(ctx, conn, out) +} + +// proxyHTTPS forwards the connection to an upstream HTTPS proxy. +func (f *Forwarder) proxyHTTPS(ctx context.Context, conn *ConsignedConn) { + out, err := f.proxyConnect(conn.Host) + if err != nil { + logger.ErrorContext(ctx, "failed to connect to https proxy", "error", err) + return + } + logger.InfoContext(ctx, "relay TLS connection to proxy", "https_proxy", f.httpsProxy) + f.relayTCP(ctx, conn, out) +} + +// ForwardHTTP forwards the given HTTP connection to upstream http proxy or passthrough to original destination +// base on the configuration. It's the duty of the caller to close the input connection. +func (f *Forwarder) ForwardHTTP(ctx context.Context, conn *ConsignedConn) { + if f.httpProxy == "" { + f.passthrough(ctx, conn) + } else { + f.proxyHTTP(ctx, conn) + } +} + +// ForwardHTTPS forwards the given HTTPS/TLS connection to upstream https proxy or passthrough to original destination +// base on the configuration. It's the duty of the caller to close the input connection. +func (f *Forwarder) ForwardHTTPS(ctx context.Context, conn *ConsignedConn) { + if f.httpsProxy == "" { + f.passthrough(ctx, conn) + } else { + f.proxyHTTPS(ctx, conn) + } +} diff --git a/forwarder_test.go b/forwarder_test.go new file mode 100644 index 0000000..f133451 --- /dev/null +++ b/forwarder_test.go @@ -0,0 +1,148 @@ +package main + +import ( + "context" + "io" + "net" + "sync" + "testing" +) + +func TestVerifyProxyUrl(t *testing.T) { + tests := []struct { + name string + proxyUrl string + wantErr bool + }{ + {"host and port", "http://example.com:123", false}, + {"ip and port", "http://10.30.74.14:8888", false}, + // surprisingly this is okay, at least for curl + {"with path", "http://example.com:1234/test", false}, + {"no port", "http://example.com", true}, + {"no protocol", "example.com:1234", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := verifyProxyUrl(tt.proxyUrl) + if (err != nil) != tt.wantErr { + t.Errorf("parseProxyUrl() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func TestForwarderForwardHTTP(t *testing.T) { + egressIn, egressOut := net.Pipe() + ingressIn, ingressOut := net.Pipe() + f := Forwarder{ + httpProxy: "http://http-proxy:1234", + dialFunc: func(f *Forwarder, addr string) (net.Conn, error) { + if addr != "http-proxy:1234" { + panic(addr) + } + return egressIn, nil + }, + } + wg := sync.WaitGroup{} + defer func() { + _ = egressIn.Close() + _ = egressOut.Close() + _ = ingressIn.Close() + _ = ingressOut.Close() + wg.Wait() + }() + wg.Add(1) + go func() { + f.ForwardHTTP(context.Background(), &ConsignedConn{ + PrereadConn: NewPrereadConn(ingressOut), + OriginalDst: &net.TCPAddr{}, + Host: "example.com", + }) + wg.Done() + }() + wg.Add(1) + go func() { + _, _ = ingressIn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: curl/8.4.0\r\n\r\n")) + wg.Done() + }() + buf := make([]byte, 1000) + n, _ := egressOut.Read(buf) + expected := "GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\nUser-Agent: curl/8.4.0\r\nConnection: close\r\n\r\n" + got := string(buf[:n]) + if expected != got { + t.Fatalf("expected HTTP request sent by aproxy %#v, got %#v", expected, got) + } + wg.Add(1) + go func() { + _, _ = egressOut.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")) + wg.Done() + }() + expected = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Length: 0\r\n\r\n" + n, _ = io.ReadAtLeast(ingressIn, buf, len(expected)) + got = string(buf[:n]) + if expected != got { + t.Fatalf("expected HTTP response sent by aproxy %#v, got %#v", expected, got) + } +} + +func TestForwarderForwardHTTPS(t *testing.T) { + egressIn, egressOut := net.Pipe() + ingressIn, ingressOut := net.Pipe() + f := Forwarder{ + httpsProxy: "http://http-proxy:1234", + dialFunc: func(f *Forwarder, addr string) (net.Conn, error) { + if addr != "http-proxy:1234" { + panic(addr) + } + return egressIn, nil + }, + } + wg := sync.WaitGroup{} + defer func() { + _ = egressIn.Close() + _ = egressOut.Close() + _ = ingressIn.Close() + _ = ingressOut.Close() + wg.Wait() + }() + wg.Add(1) + go func() { + f.ForwardHTTPS(context.Background(), &ConsignedConn{ + PrereadConn: NewPrereadConn(ingressOut), + OriginalDst: &net.TCPAddr{}, + Host: "example.com", + }) + wg.Done() + }() + expected := "CONNECT example.com HTTP/1.1\r\nHost: example.com\r\nUser-Agent: aproxy/1.0.0\r\n\r\n" + buf := make([]byte, 1000) + n, _ := egressOut.Read(buf) + got := string(buf[:n]) + if expected != got { + t.Fatalf("expected HTTP CONNECT request sent by aproxy %#v, got %#v", expected, got) + } + + wg.Add(1) + go func() { + _, _ = ingressIn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\nUser-Agent: curl/8.4.0\r\n\r\n")) + wg.Done() + }() + n, _ = egressOut.Read(buf) + expected = "GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\nUser-Agent: curl/8.4.0\r\nConnection: close\r\n\r\n" + got = string(buf[:n]) + if expected != got { + t.Fatalf("expected HTTP request sent by aproxy %#v, got %#v", expected, got) + } + wg.Add(1) + go func() { + _, _ = egressOut.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")) + wg.Done() + }() + expected = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Length: 0\r\n\r\n" + n, _ = io.ReadAtLeast(ingressIn, buf, len(expected)) + got = string(buf[:n]) + if expected != got { + t.Fatalf("expected HTTP response sent by aproxy %#v, got %#v", expected, got) + } +} diff --git a/go.mod b/go.mod index 42bd3bd..4c01d82 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module aproxy go 1.21 -require golang.org/x/crypto v0.17.0 +require golang.org/x/crypto v0.19.0 diff --git a/go.sum b/go.sum index c09097a..b3dc059 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= diff --git a/logging.go b/logging.go new file mode 100644 index 0000000..d6a9955 --- /dev/null +++ b/logging.go @@ -0,0 +1,41 @@ +package main + +import ( + "context" + "log/slog" +) + +type connContextKey string + +var ( + connContextConsignedConn connContextKey = "consigned_conn" +) + +type aproxyHandler struct { + slog.Handler +} + +func ConsignedConnFromContext(ctx context.Context) (*ConsignedConn, bool) { + conn, ok := ctx.Value(connContextConsignedConn).(*ConsignedConn) + return conn, ok +} + +func ContextWithConsignedConn(ctx context.Context, conn *ConsignedConn) context.Context { + return context.WithValue(ctx, connContextConsignedConn, conn) +} + +func (h *aproxyHandler) Handle(ctx context.Context, r slog.Record) error { + conn, ok := ConsignedConnFromContext(ctx) + if !ok { + return h.Handler.Handle(ctx, r) + } + if conn.OriginalDst != nil { + r.Add("original_dst", conn.OriginalDst) + } + if conn.Host != "" { + r.Add("host", conn.Host) + } + return h.Handler.Handle(ctx, r) +} + +var logger = slog.New(&aproxyHandler{Handler: slog.Default().Handler()}) diff --git a/relay.go b/relay.go new file mode 100644 index 0000000..e6dd718 --- /dev/null +++ b/relay.go @@ -0,0 +1,54 @@ +package main + +import ( + "errors" + "io" + "net" + "time" +) + +type tcpForwarder struct { + Fwmark uint32 + ReadTimeout time.Duration + WriteTimeout time.Duration +} + +func (f *tcpForwarder) copyBuffer(dst net.Conn, src net.Conn) (written int64, err error) { + buf := make([]byte, 32*1024) + for { + err = src.SetReadDeadline(time.Now().Add(f.ReadTimeout)) + if err != nil { + break + } + nr, er := src.Read(buf) + if nr > 0 { + err = src.SetWriteDeadline(time.Now().Add(f.ReadTimeout)) + if err != nil { + break + } + nw, ew := dst.Write(buf[0:nr]) + if nw < 0 || nr < nw { + nw = 0 + if ew == nil { + ew = errors.New("invalid write result") + } + } + written += int64(nw) + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} diff --git a/snap/hooks/configure b/snap/hooks/configure index 9ff0bc8..3e5e807 100755 --- a/snap/hooks/configure +++ b/snap/hooks/configure @@ -1,40 +1,22 @@ #!/bin/bash set -e +[ -z "$(snapctl get fwmark)" ] && snapctl set fwmark="0" [ -z "$(snapctl get listen)" ] && snapctl set listen=":8443" -validate_proxy() { - local hostport="$1" - local host - local port - - host="${hostport%:*}" - port="${hostport#*:}" - - if [[ ! "$host" =~ ^[a-zA-Z0-9.-]+$ ]]; then - echo "invalid proxy: '$hostport'" - return 1 - fi - - if ! [[ "$port" =~ ^[0-9]+$ ]] || (( port <= 0 || port > 65535 )); then - echo "invalid proxy: '$hostport'" - return 1 - fi - - return 0 -} - -proxy="$(snapctl get proxy)" +fwmark="$(snapctl get fwmark)" +http_proxy="$(snapctl get http.proxy)" +https_proxy="$(snapctl get https.proxy)" listen="$(snapctl get listen)" +proxy="$(snapctl get proxy)" -if [ -z "${proxy}" ]; then - echo "set upstream proxy using \`snap set aproxy proxy=example:1234\`" - exit 0 -fi +[ -n "${proxy}" ] && echo "proxy configuration is deprecated, use http.proxy and https.proxy instead" 1>&2 -validate_proxy "$proxy" +# for backward compatability +[ -z "${http_proxy}" ] && [ -n "${proxy}" ] && http_proxy="http://${proxy}" +[ -z "${https_proxy}" ] && [ -n "${proxy}" ] && https_proxy="http://${proxy}" -echo "--proxy $proxy --listen $listen" > $SNAP_DATA/args +echo "--http-proxy='${http_proxy}' --https-proxy='${https_proxy}' --listen='${listen}' --fwmark='${fwmark}'" > $SNAP_DATA/args snapctl stop ${SNAP_NAME}.aproxy snapctl start ${SNAP_NAME}.aproxy --enable diff --git a/snap/snapcraft.yaml b/snap/snapcraft.yaml index 6a5e9cb..1761253 100644 --- a/snap/snapcraft.yaml +++ b/snap/snapcraft.yaml @@ -1,5 +1,5 @@ name: aproxy -version: 0.2.2 +version: 0.3.0 summary: Transparent proxy for HTTP and HTTPS/TLS connections. description: | Aproxy is a transparent proxy for HTTP and HTTPS/TLS connections. By @@ -26,6 +26,7 @@ apps: plugs: - network - network-bind + - network-control parts: aproxy: diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..1508b9e --- /dev/null +++ b/stream.go @@ -0,0 +1,306 @@ +package main + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + "syscall" + + "golang.org/x/crypto/cryptobyte" +) + +// Stream represents the incoming connections to aproxy. +type Stream interface { + Host() string + Src() *net.TCPAddr + Dst() *net.TCPAddr + OriginalDst() *net.TCPAddr + io.ReadWriteCloser +} + +type ConnInfo struct { + src *net.TCPAddr + dst *net.TCPAddr + originalDst *net.TCPAddr +} + +func (i *ConnInfo) Src() *net.TCPAddr { + return i.src +} + +func (i *ConnInfo) Dst() *net.TCPAddr { + return i.dst +} + +func (i *ConnInfo) OriginalDst() *net.TCPAddr { + return i.originalDst +} + +// GetConnInfo retrieve information from the TCP connection. +func GetConnInfo(conn *net.TCPConn) (info *ConnInfo, err error) { + originalDst, err := GetSocketIPv4OriginalDst(conn) + var errno syscall.Errno + // errno 92: connection didn't go through NAT on this machine + if err != nil && !errors.As(err, &errno) && errno != 92 { + return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: %s", err) + } + return &ConnInfo{ + src: conn.RemoteAddr().(*net.TCPAddr), + dst: conn.LocalAddr().(*net.TCPAddr), + originalDst: originalDst, + }, nil +} + +// PrereadConn is a wrapper around net.Conn that supports pre-reading from the underlying connection. +// Any Read before the EndPreread can be undone and read again by calling the EndPreread function. +type PrereadConn struct { + ended bool + buf []byte + mu sync.Mutex + conn net.Conn +} + +// EndPreread ends the pre-reading phase. Any Read before will be undone and data in the stream can be read again. +// EndPreread can be only called once. +func (c *PrereadConn) EndPreread() { + c.mu.Lock() + defer c.mu.Unlock() + if c.ended { + panic("call EndPreread after preread has ended or hasn't started") + } + c.ended = true +} + +// Read reads from the underlying connection. Read during the pre-reading phase can be undone by EndPreread. +func (c *PrereadConn) Read(p []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.ended { + n = copy(p, c.buf) + bufLen := len(c.buf) + c.buf = c.buf[n:] + if n == len(p) || (bufLen > 0 && bufLen == n) { + return n, nil + } + rn, err := c.conn.Read(p[n:]) + return rn + n, err + } else { + n, err = c.conn.Read(p) + c.buf = append(c.buf, p[:n]...) + return n, err + } +} + +// Write writes data to the underlying connection. +func (c *PrereadConn) Write(p []byte) (n int, err error) { + return c.conn.Write(p) +} + +// Close closes the underlying connection. +func (c *PrereadConn) Close() error { + return c.conn.Close() +} + +// NewPrereadConn wraps the network connection and return a *PrereadConn. +// It's recommended to not touch the original connection after wrapped. +func NewPrereadConn(conn net.Conn) *PrereadConn { + return &PrereadConn{conn: conn} +} + +// addPort adds the port from connection info to host if host doesn't have one +func addPort(host string, info *ConnInfo) (string, error) { + _, _, err := net.SplitHostPort(host) + if err != nil { + if strings.Contains(err.Error(), "missing port in address") { + if info.OriginalDst() != nil { + return net.JoinHostPort(host, strconv.Itoa(info.OriginalDst().Port)), nil + } + + return net.JoinHostPort(host, strconv.Itoa(info.Dst().Port)), nil + } + return "", err + } + return host, nil +} + +type HttpStream struct { + *PrereadConn + host string + *ConnInfo +} + +func (s *HttpStream) Host() string { + return s.host +} + +func NewHttpStream(conn net.Conn, info *ConnInfo) (s *HttpStream, err error) { + preread := NewPrereadConn(conn) + defer func() { + if err != nil { + err = fmt.Errorf("failed to preread HTTP request: %w", err) + } + }() + defer preread.EndPreread() + req, err := http.ReadRequest(bufio.NewReader(preread)) + if err != nil { + return nil, err + } + host := req.Host + if host != "" { + host, err = addPort(host, info) + if err != nil { + return nil, fmt.Errorf("failed to parse HTTP Host %#v: %w", host, err) + } + } + return &HttpStream{PrereadConn: preread, host: host, ConnInfo: info}, nil +} + +// PrereadSNI pre-reads the Server Name Indication (SNI) from a TLS connection. +func PrereadSNI(conn *PrereadConn) (_ string, err error) { + defer conn.EndPreread() + defer func() { + if err != nil { + err = fmt.Errorf("failed to preread TLS client hello: %w", err) + } + }() + typeVersionLen := make([]byte, 5) + n, err := conn.Read(typeVersionLen) + if n != 5 { + return "", errors.New("too short") + } + if err != nil { + return "", err + } + if typeVersionLen[0] != 22 { + return "", errors.New("not a TCP handshake") + } + msgLen := binary.BigEndian.Uint16(typeVersionLen[3:]) + buf := make([]byte, msgLen+5) + n, err = conn.Read(buf[5:]) + if n != int(msgLen) { + return "", errors.New("too short") + } + if err != nil { + return "", err + } + copy(buf[:5], typeVersionLen) + return extractSNI(buf) +} + +func extractSNI(data []byte) (string, error) { + s := cryptobyte.String(data) + var ( + version uint16 + random []byte + sessionId []byte + ) + + if !s.Skip(9) || + !s.ReadUint16(&version) || !s.ReadBytes(&random, 32) || + !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&sessionId)) { + return "", fmt.Errorf("failed to parse TLS client hello version, random or session id") + } + + var cipherSuitesData cryptobyte.String + if !s.ReadUint16LengthPrefixed(&cipherSuitesData) { + return "", fmt.Errorf("failed to parse TLS client hello cipher suites") + } + + var cipherSuites []uint16 + for !cipherSuitesData.Empty() { + var suite uint16 + if !cipherSuitesData.ReadUint16(&suite) { + return "", fmt.Errorf("failed to parse TLS client hello cipher suites") + } + cipherSuites = append(cipherSuites, suite) + } + + var compressionMethods []byte + if !s.ReadUint8LengthPrefixed((*cryptobyte.String)(&compressionMethods)) { + return "", fmt.Errorf("failed to parse TLS client hello compression methods") + } + + if s.Empty() { + // ClientHello is optionally followed by extension data + return "", fmt.Errorf("no extension data in TLS client hello") + } + + var extensions cryptobyte.String + if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return "", fmt.Errorf("failed to parse TLS client hello extensions") + } + + finalServerName := "" + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return "", fmt.Errorf("failed to parse TLS client hello extension") + } + if extension != 0 { + continue + } + var nameList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() { + return "", fmt.Errorf("failed to parse server name extension") + } + + for !nameList.Empty() { + var nameType uint8 + var serverName cryptobyte.String + if !nameList.ReadUint8(&nameType) || + !nameList.ReadUint16LengthPrefixed(&serverName) || + serverName.Empty() { + return "", fmt.Errorf("failed to parse server name indication extension") + } + if nameType != 0 { + continue + } + if len(finalServerName) != 0 { + return "", fmt.Errorf("multiple names of the same name_type are prohibited in server name extension") + } + finalServerName = string(serverName) + if strings.HasSuffix(finalServerName, ".") { + return "", fmt.Errorf("SNI name ends with a trailing dot") + } + } + } + return finalServerName, nil +} + +type TlsStream struct { + *PrereadConn + host string + *ConnInfo +} + +func (s *TlsStream) Host() string { + return s.host +} + +func NewTlsStream(conn net.Conn, info *ConnInfo) (*TlsStream, error) { + preread := NewPrereadConn(conn) + sni, err := PrereadSNI(preread) + if err != nil { + return nil, err + } + if sni != "" { + sni, err = addPort(sni, info) + if err != nil { + return nil, fmt.Errorf("failed to parse SNI %#v as host: %w", sni, err) + } + } + return &TlsStream{ + PrereadConn: preread, + host: sni, + ConnInfo: info, + }, nil +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..005bdc0 --- /dev/null +++ b/stream_test.go @@ -0,0 +1,146 @@ +package main + +import ( + "bytes" + "encoding/hex" + "io" + "net" + "testing" +) + +func TestPrereadConn(t *testing.T) { + remote, local := net.Pipe() + go remote.Write([]byte("hello, world")) + preread := &PrereadConn{conn: local} + buf := make([]byte, 5) + _, err := preread.Read(buf) + if err != nil { + t.Fatalf("Read failed during preread: %s", err) + } + buf = make([]byte, 3) + _, err = preread.Read(buf) + if err != nil { + t.Fatalf("Read failed during preread: %s", err) + } + preread.EndPreread() + buf2 := make([]byte, 12) + _, err = io.ReadFull(preread, buf2) + if err != nil { + t.Fatalf("Read failed after preread: %s", err) + } + if string(buf2) != "hello, world" { + t.Fatalf("preread altered the read state: got %s", string(buf2)) + } +} + +func TestNewHttpStream(t *testing.T) { + remote, local := net.Pipe() + payload := []byte("GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 0\r\n\r\n") + go remote.Write(payload) + s, err := NewHttpStream(local, &ConnInfo{ + src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8443}, + dst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}, + originalDst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 80}, + }) + if err != nil { + t.Fatalf("NewHttpStream failed: %s", err) + } + if s.Host() != "example.com:80" { + t.Fatalf("incorrect host in HttpStream, expect: \"example.com:80\", got: %#v", s.Host()) + } + buf := make([]byte, len(payload)) + _, err = io.ReadFull(s, buf) + if err != nil { + t.Fatalf("HttpStream.Read failed: %s", err) + } + if !bytes.Equal(payload, buf) { + t.Fatalf("HttpStream.Read failed, expect: %#v, got: %#v", string(payload), string(buf)) + } +} + +func TestNewHttpStreamNonDefaultPort(t *testing.T) { + remote, local := net.Pipe() + payload := []byte("GET / HTTP/1.1\r\nHost: example.com:8080\r\nContent-Length: 0\r\n\r\n") + go remote.Write(payload) + s, err := NewHttpStream(local, &ConnInfo{ + src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8443}, + dst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}, + originalDst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + }) + if err != nil { + t.Fatalf("NewHttpStream failed: %s", err) + } + if s.Host() != "example.com:8080" { + t.Fatalf("incorrect host in HttpStream, expect: \"example.com:8080\", got: %#v", s.Host()) + } +} + +func TestNewTlsStream(t *testing.T) { + remote, local := net.Pipe() + // data obtained from https://gitlab.com/wireshark/wireshark/-/blob/master/test/captures/tls12-aes256gcm.pcap + clientHello, _ := hex.DecodeString( + "160301004f0100004b0303588e60d1d96bad5f1fcf0b8818466257d73385bdaaed0ac4bfd7228a6da059ad00000200a901000020" + + "0005000501000000000000000e000c0000096c6f63616c686f7374ff01000100") + go remote.Write(clientHello) + s, err := NewTlsStream(local, &ConnInfo{ + src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8443}, + dst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}, + originalDst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443}, + }) + if err != nil { + t.Fatalf("NewTlsStream failed: %s", err) + } + if s.Host() != "localhost:443" { + t.Fatalf("incorrect host in TlsStream, expect: \"localhost:443\", got: %#v", s.Host()) + } + buf := make([]byte, len(clientHello)) + _, err = io.ReadFull(s, buf) + if err != nil { + t.Fatalf("TlsStream.Read failed: %s", err) + } + if !bytes.Equal(clientHello, buf) { + t.Fatalf("TlsStream.Read failed, expect: %#v, got: %#v", string(clientHello), string(buf)) + } +} + +func TestNewTlsStreamWithoutSNI(t *testing.T) { + remote, local := net.Pipe() + clientHello, _ := hex.DecodeString("160301012801000124030315a03a6cbea1ff32d0fb9af5d6d94988e212b6bcf15a3e672ed" + + "7d31f6d946edd20f8879d969a75d1da26560c92a942f13458a0cd2a96e690c0fa628ff6357119de0062130313021301cca9cca8ccaa" + + "c030c02cc028c024c014c00a009f006b0039ff8500c400880081009d003d003500c00084c02fc02bc027c023c013c009009e0067003" + + "300be0045009c003c002f00ba0041c011c00700050004c012c0080016000a00ff01000079002b000908030403030302030100330026" + + "0024001d00203754ae4e94f3a5fb69709af119b982db1322c5da9299f7ce0da661a05f06ce35000b00020100000a000a0008001d001" + + "700180019000d00180016080606010603080505010503080404010403020102030010000e000c02683208687474702f312e31") + go remote.Write(clientHello) + s, err := NewTlsStream(local, &ConnInfo{ + src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8443}, + dst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}, + originalDst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 443}, + }) + if err != nil { + t.Fatalf("NewTlsStream failed: %s", err) + } + if s.Host() != "" { + t.Fatalf("incorrect host in TlsStream, expect: \"\", got: %#v", s.Host()) + } +} + +func TestNewTlsStreamNonDefaultPort(t *testing.T) { + remote, local := net.Pipe() + // data obtained from https://gitlab.com/wireshark/wireshark/-/blob/master/test/captures/tls12-aes256gcm.pcap + clientHello, _ := hex.DecodeString( + "160301004f0100004b0303588e60d1d96bad5f1fcf0b8818466257d73385bdaaed0ac4bfd7228a6da059ad00000200a901000020" + + "0005000501000000000000000e000c0000096c6f63616c686f7374ff01000100") + go remote.Write(clientHello) + s, err := NewTlsStream(local, &ConnInfo{ + src: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8443}, + dst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}, + originalDst: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1443}, + }) + if err != nil { + t.Fatalf("NewTlsStream failed: %s", err) + } + if s.Host() != "localhost:1443" { + t.Fatalf("incorrect host in TlsStream, expect: \"localhost:1443\", got: %#v", s.Host()) + } +} diff --git a/syscall_linux.go b/syscall_linux.go index 1b15e4a..6f6649f 100644 --- a/syscall_linux.go +++ b/syscall_linux.go @@ -10,20 +10,26 @@ import ( "unsafe" ) -func GetsockoptIPv4OriginalDst(fd, level, opt int) (*net.TCPAddr, error) { +// GetSocketIPv4OriginalDst get the original destination address of a TCP connection before dstnat. +func GetSocketIPv4OriginalDst(conn *net.TCPConn) (*net.TCPAddr, error) { + file, err := conn.File() + defer file.Close() + if err != nil { + return nil, fmt.Errorf("failed to get file decriptor of given TCP connection: %w", err) + } var sockaddr [16]byte size := 16 _, _, e := syscall.Syscall6( syscall.SYS_GETSOCKOPT, - uintptr(fd), - uintptr(level), - uintptr(opt), + file.Fd(), + syscall.SOL_IP, + 80, // SO_ORIGINAL_DST uintptr(unsafe.Pointer(&sockaddr)), uintptr(unsafe.Pointer(&size)), 0, ) if e != 0 { - return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST failed: errno %d", e) + return nil, e } return &net.TCPAddr{ IP: sockaddr[4:8],