Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom message replicator #30

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions websocketproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ var (

// DefaultDialer is a dialer with all fields set to the default zero values.
DefaultDialer = websocket.DefaultDialer

// DefaultReplicator is a simple message passthrough
DefaultReplicator = passthroughReplicator
)

type MessageReplicatorFunc func(dst, src *websocket.Conn, errc chan error)

// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
type WebsocketProxy struct {
Expand All @@ -33,6 +38,14 @@ type WebsocketProxy struct {
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// IncomeReplicator is a function that forward messages incoming from origin
// into the backend. If nil, passthroughsReplicator is used.
IncomeReplicator MessageReplicatorFunc

// BackendReplicator is a function that forwards messages from backend into
// origin. If nil, passthroughsReplicator is used.
BackendReplicator MessageReplicatorFunc

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Expand Down Expand Up @@ -177,30 +190,18 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}

incomingReplicator := w.IncomeReplicator
if w.IncomeReplicator == nil {
incomingReplicator = DefaultReplicator
}
go incomingReplicator(connPub, connBackend, errClient)

go replicateWebsocketConn(connPub, connBackend, errClient)
go replicateWebsocketConn(connBackend, connPub, errBackend)
backendReplicator := w.BackendReplicator
if w.BackendReplicator == nil {
backendReplicator = DefaultReplicator
}
go backendReplicator(connBackend, connPub, errBackend)

var message string
select {
Expand Down Expand Up @@ -231,3 +232,25 @@ func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
_, err := io.Copy(rw, resp.Body)
return err
}

func passthroughReplicator(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}