From 8b0854ad27efad28732d1b8acdd4962f44d23a7f Mon Sep 17 00:00:00 2001 From: Lorain Date: Fri, 19 Jul 2024 18:00:16 +0800 Subject: [PATCH 1/4] feat(ws_reverse_proxy): support dynamic route --- ws_reverse_proxy.go | 9 ++- ws_reverse_proxy_option.go | 16 +++- ws_reverse_proxy_option_test.go | 3 + ws_reverse_proxy_test.go | 133 ++++++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 6 deletions(-) diff --git a/ws_reverse_proxy.go b/ws_reverse_proxy.go index 22faffe..31aa6ee 100644 --- a/ws_reverse_proxy.go +++ b/ws_reverse_proxy.go @@ -47,7 +47,6 @@ import ( "net/http" "github.com/bytedance/gopkg/util/gopool" - "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol" @@ -81,9 +80,13 @@ func (w *WSReverseProxy) ServeHTTP(ctx context.Context, c *app.RequestContext) { if w.options.Director != nil { w.options.Director(ctx, c, forwardHeader) } - connBackend, respBackend, err := w.options.Dialer.Dial(w.target, forwardHeader) + target := w.target + if w.options.DynamicRoute { + target = w.target + b2s(c.Path()) + } + connBackend, respBackend, err := w.options.Dialer.Dial(target, forwardHeader) if err != nil { - hlog.CtxErrorf(ctx, "can not dial to remote backend(%v): %v", w.target, err) + hlog.CtxErrorf(ctx, "can not dial to remote backend(%v): %v", target, err) if respBackend != nil { if err = wsCopyResponse(&c.Response, respBackend); err != nil { hlog.CtxErrorf(ctx, "can not copy response: %v", err) diff --git a/ws_reverse_proxy_option.go b/ws_reverse_proxy_option.go index a24c6da..7b68c02 100644 --- a/ws_reverse_proxy_option.go +++ b/ws_reverse_proxy_option.go @@ -28,9 +28,10 @@ type Director func(ctx context.Context, c *app.RequestContext, forwardHeader htt type Option func(o *Options) type Options struct { - Director Director - Dialer *websocket.Dialer - Upgrader *hzws.HertzUpgrader + Director Director + Dialer *websocket.Dialer + Upgrader *hzws.HertzUpgrader + DynamicRoute bool } var DefaultOptions = &Options{ @@ -40,6 +41,7 @@ var DefaultOptions = &Options{ ReadBufferSize: 1024, WriteBufferSize: 1024, }, + DynamicRoute: false, } func newOptions(opts ...Option) *Options { @@ -79,3 +81,11 @@ func WithUpgrader(upgrader *hzws.HertzUpgrader) Option { o.Upgrader = upgrader } } + +// WithDynamicRoute enable dynamic route +// backend url = handler url + proxy url +func WithDynamicRoute() Option { + return func(o *Options) { + o.DynamicRoute = true + } +} diff --git a/ws_reverse_proxy_option_test.go b/ws_reverse_proxy_option_test.go index bdf65a6..3b06d92 100644 --- a/ws_reverse_proxy_option_test.go +++ b/ws_reverse_proxy_option_test.go @@ -39,10 +39,12 @@ func TestOptions(t *testing.T) { WithDirector(director), WithDialer(dialer), WithUpgrader(upgrader), + WithDynamicRoute(), ) assert.DeepEqual(t, fmt.Sprintf("%p", director), fmt.Sprintf("%p", options.Director)) assert.DeepEqual(t, fmt.Sprintf("%p", dialer), fmt.Sprintf("%p", options.Dialer)) assert.DeepEqual(t, fmt.Sprintf("%p", upgrader), fmt.Sprintf("%p", options.Upgrader)) + assert.DeepEqual(t, true, options.DynamicRoute) } func TestDefaultOptions(t *testing.T) { @@ -50,4 +52,5 @@ func TestDefaultOptions(t *testing.T) { assert.Nil(t, options.Director) assert.DeepEqual(t, DefaultOptions.Dialer, options.Dialer) assert.DeepEqual(t, DefaultOptions.Upgrader, options.Upgrader) + assert.DeepEqual(t, DefaultOptions.DynamicRoute, options.DynamicRoute) } diff --git a/ws_reverse_proxy_test.go b/ws_reverse_proxy_test.go index 83fd506..62dd5e0 100644 --- a/ws_reverse_proxy_test.go +++ b/ws_reverse_proxy_test.go @@ -124,3 +124,136 @@ func TestProxy(t *testing.T) { assert.DeepEqual(t, websocket.TextMessage, msgType) assert.DeepEqual(t, msg, string(data)) } + +var ( + dynamicBackendURL = "ws://127.0.0.1:8888/api" +) + +func TestProxyWithDynamicRoute(t *testing.T) { + // websocket proxy + supportedSubProtocols := []string{"test-protocol"} + upgrader := &hzws.HertzUpgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(c *app.RequestContext) bool { + return true + }, + Subprotocols: supportedSubProtocols, + } + + // enable dynamic route option + proxy := NewWSReverseProxy(dynamicBackendURL, WithUpgrader(upgrader), WithDynamicRoute()) + + // proxy server + ps := server.Default(server.WithHostPorts(":7777")) + ps.NoHijackConnPool = true + ps.GET("/test", proxy.ServeHTTP) + ps.GET("/test2", proxy.ServeHTTP) + go ps.Spin() + + time.Sleep(time.Millisecond * 100) + + go func() { + // backend server + bs := server.Default() + bs.NoHijackConnPool = true + bs.GET("/api/test", func(ctx context.Context, c *app.RequestContext) { + // Don't upgrade if original host header isn't preserved + host := string(c.Host()) + if host != "127.0.0.1:7777" { + hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", host) + return + } + + if err := upgrader.Upgrade(c, func(conn *hzws.Conn) { + msgType, msg, err := conn.ReadMessage() + assert.Nil(t, err) + + if err = conn.WriteMessage(msgType, msg); err != nil { + return + } + }); err != nil { + hlog.Errorf("upgrade error: %v", err) + return + } + }) + bs.GET("/api/test2", func(ctx context.Context, c *app.RequestContext) { + // Don't upgrade if original host header isn't preserved + host := string(c.Host()) + if host != "127.0.0.1:7777" { + hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", host) + return + } + + if err := upgrader.Upgrade(c, func(conn *hzws.Conn) { + msgType, msg, err := conn.ReadMessage() + assert.Nil(t, err) + + if err = conn.WriteMessage(msgType, msg); err != nil { + return + } + }); err != nil { + hlog.Errorf("upgrade error: %v", err) + return + } + }) + bs.Spin() + }() + + time.Sleep(time.Millisecond * 100) + + // only one is supported by the server + clientSubProtocols := []string{"test-protocol", "test-notsupported"} + h := http.Header{} + for _, subproto := range clientSubProtocols { + h.Add("Sec-WebSocket-Protocol", subproto) + } + + // client + conn, resp, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:7777/test", h) + assert.Nil(t, err) + conn2, resp2, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:7777/test2", h) + assert.Nil(t, err) + + // check if the server really accepted the correct protocol + in := func(desired string) bool { + for _, proto := range resp.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + if desired == proto { + return true + } + } + return false + } + in2 := func(desired string) bool { + for _, proto := range resp2.Header[http.CanonicalHeaderKey("Sec-WebSocket-Protocol")] { + if desired == proto { + return true + } + } + return false + } + + assert.True(t, in("test-protocol")) + assert.True(t, in2("test-protocol")) + assert.False(t, in("test-notsupported")) + assert.False(t, in2("test-notsupported")) + + // now write a message and send it to the proxy + msg := "hello world" + err = conn.WriteMessage(websocket.TextMessage, []byte(msg)) + assert.Nil(t, err) + + msg2 := "hello world2" + err = conn2.WriteMessage(websocket.TextMessage, []byte(msg2)) + assert.Nil(t, err) + + msgType, data, err := conn.ReadMessage() + assert.Nil(t, err) + assert.DeepEqual(t, websocket.TextMessage, msgType) + assert.DeepEqual(t, msg, string(data)) + + msgType2, data2, err := conn2.ReadMessage() + assert.Nil(t, err) + assert.DeepEqual(t, websocket.TextMessage, msgType2) + assert.DeepEqual(t, msg2, string(data2)) +} From c92b929007d11997d70f4c0c7c0584755caf6d14 Mon Sep 17 00:00:00 2001 From: Lorain Date: Fri, 19 Jul 2024 18:07:26 +0800 Subject: [PATCH 2/4] update doc --- README.md | 11 ++++++----- ws_reverse_proxy_option.go | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 93a7e5d..9922d21 100644 --- a/README.md +++ b/README.md @@ -111,11 +111,12 @@ func main() { } ``` -| Configuration | Default | Description | -|----------------|---------------------------|------------------------------| -| `WithDirector` | `nil` | customize the forward header | -| `WithDialer` | `gorillaws.DefaultDialer` | for dialer customization | -| `WithUpgrader` | `hzws.HertzUpgrader` | for upgrader customization | +| Configuration | Default | Description | +|--------------------|---------------------------|-------------------------------------------------------------| +| `WithDirector` | `nil` | customize the forward header | +| `WithDialer` | `gorillaws.DefaultDialer` | for dialer customization | +| `WithUpgrader` | `hzws.HertzUpgrader` | for upgrader customization | +| `WithDynamicRoute` | `false` | enable dynamic route (proxy url = handler url + target url) | ### More info See [example](https://github.com/cloudwego/hertz-examples) diff --git a/ws_reverse_proxy_option.go b/ws_reverse_proxy_option.go index 7b68c02..f8aa7da 100644 --- a/ws_reverse_proxy_option.go +++ b/ws_reverse_proxy_option.go @@ -83,7 +83,7 @@ func WithUpgrader(upgrader *hzws.HertzUpgrader) Option { } // WithDynamicRoute enable dynamic route -// backend url = handler url + proxy url +// proxy url = handler url + target url func WithDynamicRoute() Option { return func(o *Options) { o.DynamicRoute = true From d1e198c2ba29139229c1e84e7c53b21b18583166 Mon Sep 17 00:00:00 2001 From: Lorain Date: Fri, 19 Jul 2024 18:16:19 +0800 Subject: [PATCH 3/4] fmt --- ws_reverse_proxy_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ws_reverse_proxy_test.go b/ws_reverse_proxy_test.go index 62dd5e0..4408170 100644 --- a/ws_reverse_proxy_test.go +++ b/ws_reverse_proxy_test.go @@ -125,9 +125,7 @@ func TestProxy(t *testing.T) { assert.DeepEqual(t, msg, string(data)) } -var ( - dynamicBackendURL = "ws://127.0.0.1:8888/api" -) +var dynamicBackendURL = "ws://127.0.0.1:8888/api" func TestProxyWithDynamicRoute(t *testing.T) { // websocket proxy From 5ee680920baf4a3b5d9d66ce90cb837c5195754d Mon Sep 17 00:00:00 2001 From: Lorain Date: Mon, 22 Jul 2024 22:06:20 +0800 Subject: [PATCH 4/4] change port --- ws_reverse_proxy_test.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ws_reverse_proxy_test.go b/ws_reverse_proxy_test.go index 4408170..262f08e 100644 --- a/ws_reverse_proxy_test.go +++ b/ws_reverse_proxy_test.go @@ -125,7 +125,7 @@ func TestProxy(t *testing.T) { assert.DeepEqual(t, msg, string(data)) } -var dynamicBackendURL = "ws://127.0.0.1:8888/api" +var dynamicBackendURL = "ws://127.0.0.1:9001/api" func TestProxyWithDynamicRoute(t *testing.T) { // websocket proxy @@ -143,23 +143,23 @@ func TestProxyWithDynamicRoute(t *testing.T) { proxy := NewWSReverseProxy(dynamicBackendURL, WithUpgrader(upgrader), WithDynamicRoute()) // proxy server - ps := server.Default(server.WithHostPorts(":7777")) + ps := server.Default(server.WithHostPorts(":9000")) ps.NoHijackConnPool = true ps.GET("/test", proxy.ServeHTTP) ps.GET("/test2", proxy.ServeHTTP) go ps.Spin() - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Second * 1) go func() { // backend server - bs := server.Default() + bs := server.Default(server.WithHostPorts(":9001")) bs.NoHijackConnPool = true bs.GET("/api/test", func(ctx context.Context, c *app.RequestContext) { // Don't upgrade if original host header isn't preserved host := string(c.Host()) - if host != "127.0.0.1:7777" { - hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", host) + if host != "127.0.0.1:9000" { + hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:9000 got %s", host) return } @@ -178,8 +178,8 @@ func TestProxyWithDynamicRoute(t *testing.T) { bs.GET("/api/test2", func(ctx context.Context, c *app.RequestContext) { // Don't upgrade if original host header isn't preserved host := string(c.Host()) - if host != "127.0.0.1:7777" { - hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:7777 got %s", host) + if host != "127.0.0.1:9000" { + hlog.Errorf("Host header set incorrectly. Expecting 127.0.0.1:9000 got %s", host) return } @@ -198,7 +198,7 @@ func TestProxyWithDynamicRoute(t *testing.T) { bs.Spin() }() - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Second * 1) // only one is supported by the server clientSubProtocols := []string{"test-protocol", "test-notsupported"} @@ -208,9 +208,9 @@ func TestProxyWithDynamicRoute(t *testing.T) { } // client - conn, resp, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:7777/test", h) + conn, resp, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:9000/test", h) assert.Nil(t, err) - conn2, resp2, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:7777/test2", h) + conn2, resp2, err := websocket.DefaultDialer.Dial("ws://127.0.0.1:9000/test2", h) assert.Nil(t, err) // check if the server really accepted the correct protocol