diff --git a/pkg/server/server.go b/pkg/server/server.go index e5960a9..1b67581 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -125,7 +125,7 @@ func (s *Server) HandleSubscribe(c echo.Context) error { switch msgType { case websocket.PingMessage: - if err := ws.WriteMessage(websocket.PongMessage, nil); err != nil { + if err := sub.WriteMessage(websocket.PongMessage, nil); err != nil { log.Error("failed to write pong to websocket", "error", err) cancel() return @@ -241,7 +241,7 @@ func (s *Server) HandleSubscribe(c echo.Context) error { // When compression is enabled, the msg is a zstd compressed message if compress { - if err := ws.WriteMessage(websocket.BinaryMessage, *msg); err != nil { + if err := sub.WriteMessage(websocket.BinaryMessage, *msg); err != nil { log.Error("failed to write message to websocket", "error", err) return nil } @@ -249,7 +249,7 @@ func (s *Server) HandleSubscribe(c echo.Context) error { } // Otherwise, the msg is serialized JSON - if err := ws.WriteMessage(websocket.TextMessage, *msg); err != nil { + if err := sub.WriteMessage(websocket.TextMessage, *msg); err != nil { log.Error("failed to write message to websocket", "error", err) return nil } diff --git a/pkg/server/subscriber.go b/pkg/server/subscriber.go index 9fa8620..a7617ba 100644 --- a/pkg/server/subscriber.go +++ b/pkg/server/subscriber.go @@ -21,6 +21,7 @@ type WantedCollections struct { type Subscriber struct { ws *websocket.Conn + conLk sync.Mutex realIP string lk sync.Mutex seq int64 @@ -266,7 +267,14 @@ func (s *Subscriber) UpdateOptions(opts *SubscriberOptions) { // Terminate sends a close message to the subscriber func (s *Subscriber) Terminate(reason string) error { - return s.ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason)) + return s.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, reason)) +} + +func (s *Subscriber) WriteMessage(msgType int, data []byte) error { + s.conLk.Lock() + defer s.conLk.Unlock() + + return s.ws.WriteMessage(msgType, data) } func (s *Subscriber) SetCursor(cursor *int64) {