diff --git a/websocketproxy.go b/websocketproxy.go index 63d39ba..bd50e08 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -16,15 +16,51 @@ import ( var ( // DefaultUpgrader specifies the parameters for upgrading an HTTP // connection to a WebSocket connection. - DefaultUpgrader = &websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, + DefaultUpgrader = &WSProxyUpgrader{ + websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, } // DefaultDialer is a dialer with all fields set to the default zero values. DefaultDialer = websocket.DefaultDialer ) +// IUpgrader interface define the Upgrade method which is different from +// websocket.Upgrader's by returning interface IConn instead of Conn.IUpgrader +// This allows users of the lib to see traffic and close the websocket +type IUpgrader interface { + Upgrade(http.ResponseWriter, *http.Request, http.Header) (IConn, error) +} + +// WSProxyUpgrader is used a default upgrader which wraps the websocket's Upgrader +type WSProxyUpgrader struct { + websocket.Upgrader +} + +// Upgrade is called when the proxy upgrades an http connection to a websocket +func (wsu *WSProxyUpgrader) Upgrade(w http.ResponseWriter, r *http.Request, requestHeaders http.Header) (IConn, error) { + c, e := wsu.Upgrader.Upgrade(w, r, requestHeaders) + return c, e +} + +// IConn has all function in use by the websocket traffic back and forth +// It is returned from the Upgrade method of the IUpgrade interface +// when using the webscoketproxy, one can provide implementation of IUpgrader, +// That returns own implementation of IConn and the observe or modify it. +// and also can get notified when the proxy closes the web socket +type IConn interface { + // ReadMessage is called when reading message sent from the client + ReadMessage() (int, []byte, error) + + // WriteMessage is called when writing message to the client + WriteMessage(int, []byte) error + + // Close is called when calling close of the connection + Close() error +} + // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket // connection and proxies it to another server. type WebsocketProxy struct { @@ -33,6 +69,8 @@ type WebsocketProxy struct { // which will be forwarded to another server. Director func(incoming *http.Request, out http.Header) + Rewriter func(msg []byte) []byte + // Backend returns the backend URL which the proxy uses to reverse proxy // the incoming WebSocket connection. Request is the initial incoming and // unmodified request. @@ -40,7 +78,7 @@ type WebsocketProxy struct { // Upgrader specifies the parameters for upgrading a incoming HTTP // connection to a WebSocket connection. If nil, DefaultUpgrader is used. - Upgrader *websocket.Upgrader + Upgrader IUpgrader // Dialer contains options for connecting to the backend WebSocket server. // If nil, DefaultDialer is used. @@ -177,7 +215,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { + replicateWebsocketConn := func(dst, src IConn, errc chan error) { for { msgType, msg, err := src.ReadMessage() if err != nil { diff --git a/websocketproxy_test.go b/websocketproxy_test.go index b90e02b..8927ddc 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -18,15 +18,16 @@ var ( func TestProxy(t *testing.T) { // websocket proxy supportedSubProtocols := []string{"test-protocol"} - upgrader := &websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - return true + upgrader := &WSProxyUpgrader{ + websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: supportedSubProtocols, }, - Subprotocols: supportedSubProtocols, } - u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader @@ -46,7 +47,7 @@ func TestProxy(t *testing.T) { mux2 := http.NewServeMux() mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // Don't upgrade if original host header isn't preserved - if r.Host != "127.0.0.1:7777" { + if r.Host != "127.0.0.1:7777" { log.Printf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", r.Host) return }