diff --git a/handlers.go b/handlers.go index 6d0a319..8d0c1cb 100644 --- a/handlers.go +++ b/handlers.go @@ -343,9 +343,6 @@ func (s *Server) handleMessage(ctx context.Context, ws *WebSocket, message []byt } func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - store := s.relay.Storage(ctx) - conn, err := upgrader.Upgrade(w, r, nil) if err != nil { s.Log.Errorf("failed to upgrade websocket: %v", err) @@ -355,7 +352,6 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { defer s.clientsMu.Unlock() s.clients[conn] = struct{}{} ticker := time.NewTicker(pingPeriod) - stop := make(chan struct{}) ip := conn.RemoteAddr().String() if realIP := r.Header.Get("X-Forwarded-For"); realIP != "" { @@ -374,12 +370,15 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { ) } + ctx, cancel := context.WithCancel(context.Background()) + + store := s.relay.Storage(ctx) + // reader go func() { defer func() { + cancel() ticker.Stop() - stop <- struct{}{} - close(stop) s.clientsMu.Lock() if _, ok := s.clients[conn]; ok { conn.Close() @@ -388,8 +387,6 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } s.clientsMu.Unlock() s.Log.Infof("disconnected from %s", ip) - - ctx.Done() }() conn.SetReadLimit(maxMessageSize) @@ -432,17 +429,16 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { continue } - go s.handleMessage(context.TODO(), ws, message, store) + go s.handleMessage(ctx, ws, message, store) } }() // writer go func() { defer func() { + cancel() ticker.Stop() conn.Close() - for range stop { - } }() for { @@ -454,7 +450,7 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { return } s.Log.Infof("pinging for %s", ip) - case <-stop: + case <-ctx.Done(): return } } diff --git a/listener.go b/listener.go index 5ef362a..c6ef9f9 100644 --- a/listener.go +++ b/listener.go @@ -76,6 +76,7 @@ func removeListenerId(ws *WebSocket, id string) { func removeListener(ws *WebSocket) { listenersMutex.Lock() defer listenersMutex.Unlock() + clear(listeners[ws]) delete(listeners, ws) }