Skip to content

Commit

Permalink
Fix WS reading X-Forwarded-For & Add tests (#3546)
Browse files Browse the repository at this point in the history
Fixes #3545

---------

Co-authored-by: mmmray <[email protected]>
  • Loading branch information
Fangliding and mmmray authored Jul 17, 2024
1 parent 9e6d7a3 commit a7e198e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
4 changes: 2 additions & 2 deletions transport/internet/httpupgrade/httpupgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return
}

_, err = c.Write([]byte("Response"))
_, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err)
}(conn)
})
Expand All @@ -169,7 +169,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte
n, err := conn.Read(b[:])
common.Must(err)
if string(b[:n]) != "Response" {
if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n]))
}

Expand Down
4 changes: 2 additions & 2 deletions transport/internet/splithttp/splithttp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return
}

_, err = c.Write([]byte("Response"))
_, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err)
}(conn)
})
Expand All @@ -113,7 +113,7 @@ func TestDialWithRemoteAddr(t *testing.T) {

var b [1024]byte
n, _ := conn.Read(b[:])
if string(b[:n]) != "Response" {
if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n]))
}

Expand Down
14 changes: 9 additions & 5 deletions transport/internet/websocket/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ import (
var _ buf.Writer = (*connection)(nil)

// connection is a wrapper for net.Conn over WebSocket connection.
// remoteAddr is used to pass "virtual" remote IP addresses in X-Forwarded-For.
// so we shouldn't directly read it form conn.
type connection struct {
conn *websocket.Conn
reader io.Reader
conn *websocket.Conn
reader io.Reader
remoteAddr net.Addr
}

func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{
conn: conn,
reader: extraReader,
conn: conn,
remoteAddr: remoteAddr,
reader: extraReader,
}
}

Expand Down Expand Up @@ -90,7 +94,7 @@ func (c *connection) LocalAddr() net.Addr {
}

func (c *connection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
return c.remoteAddr
}

func (c *connection) SetDeadline(t time.Time) error {
Expand Down
4 changes: 2 additions & 2 deletions transport/internet/websocket/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
return
}

_, err = c.Write([]byte("Response"))
_, err = c.Write([]byte(c.RemoteAddr().String()))
common.Must(err)
}(conn)
})
Expand All @@ -109,7 +109,7 @@ func TestDialWithRemoteAddr(t *testing.T) {
var b [1024]byte
n, err := conn.Read(b[:])
common.Must(err)
if string(b[:n]) != "Response" {
if string(b[:n]) != "1.1.1.1:0" {
t.Error("response: ", string(b[:n]))
}

Expand Down

0 comments on commit a7e198e

Please sign in to comment.