Skip to content

Commit

Permalink
Release websocket.Client deadlock fix to all users (#1242)
Browse files Browse the repository at this point in the history
* Release websocket.Client deadlock fix to all users

* Remove flaky websocket client test
  • Loading branch information
tomas-stripe authored Sep 17, 2024
1 parent d5651bf commit acb55dc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 175 deletions.
111 changes: 32 additions & 79 deletions pkg/websocket/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,7 @@ func (c *Client) changeConnection(conn *ws.Conn) {
c.stopReadPumpMutex.Lock()
defer c.stopReadPumpMutex.Unlock()
c.conn = conn
if os.Getenv("STRIPE_CLI_CANARY") == "true" {
c.notifyClose = make(chan error, 1)
} else {
c.notifyClose = make(chan error)
}
c.notifyClose = make(chan error, 1)
c.stopReadPump = make(chan struct{})
c.stopWritePump = make(chan struct{})
}
Expand Down Expand Up @@ -386,66 +382,33 @@ func (c *Client) readPump() {

_, data, err := c.conn.ReadMessage()
if err != nil {
if os.Getenv("STRIPE_CLI_CANARY") == "true" {
select {
case <-c.stopReadPump:
select {
case <-c.stopReadPump:
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Debug("stopReadPump")
case c.notifyClose <- err:
switch {
case !ws.IsCloseError(err):
// read errors do not prevent websocket reconnects in the CLI so we should
// only display this on debug-level logging
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Debug("stopReadPump")
case c.notifyClose <- err:
switch {
case !ws.IsCloseError(err):
// read errors do not prevent websocket reconnects in the CLI so we should
// only display this on debug-level logging
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Debug("read error: ", err)
case ws.IsUnexpectedCloseError(err, ws.CloseNormalClosure):
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Error("close error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
default:
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("other error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
}
}
} else {
select {
case <-c.stopReadPump:
}).Debug("read error: ", err)
case ws.IsUnexpectedCloseError(err, ws.CloseNormalClosure):
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Debug("stopReadPump")
}).Error("close error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
default:
switch {
case !ws.IsCloseError(err):
// read errors do not prevent websocket reconnects in the CLI so we should
// only display this on debug-level logging
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Debug("read error: ", err)
case ws.IsUnexpectedCloseError(err, ws.CloseNormalClosure):
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.readPump",
}).Error("close error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
default:
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("other error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
}
c.notifyClose <- err
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("other error: ", err)
c.cfg.Log.WithFields(log.Fields{
"prefix": "stripecli.ADDITIONAL_INFO",
}).Error("If you run into issues, please re-run with `--log-level debug` and share the output with the Stripe team on GitHub.")
}
}

Expand Down Expand Up @@ -529,19 +492,15 @@ func (c *Client) writePump() {
// Requeue the message to be processed when writePump restarts
c.send <- outMsg

if os.Getenv("STRIPE_CLI_CANARY") == "true" {
select {
case <-c.stopWritePump:
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.writePump",
}).Debug("stopWritePump - Failed to WriteJSON; connection is resetting")
case c.notifyClose <- err:
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.writePump",
}).Debug("Failed to WriteJSON; closing connection")
}
} else {
c.notifyClose <- err
select {
case <-c.stopWritePump:
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.writePump",
}).Debug("stopWritePump - Failed to WriteJSON; connection is resetting")
case c.notifyClose <- err:
c.cfg.Log.WithFields(log.Fields{
"prefix": "websocket.Client.writePump",
}).Debug("Failed to WriteJSON; closing connection")
}

return
Expand Down Expand Up @@ -592,12 +551,6 @@ func (c *Client) writePump() {
}
}

func (c *Client) terminateReadPump() {
c.stopReadPumpMutex.Lock()
defer c.stopReadPumpMutex.Unlock()
c.stopReadPump <- struct{}{}
}

//
// Public functions
//
Expand Down
96 changes: 0 additions & 96 deletions pkg/websocket/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package websocket
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -13,7 +12,6 @@ import (

ws "github.com/gorilla/websocket"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -240,97 +238,3 @@ func TestClientExpiredError(t *testing.T) {
require.FailNow(t, "Timed out waiting for response from test server")
}
}

// This test is a regression test for deadlocks that can be encountered
// when the write pump is interrupted by closed connections at inopportune
// times.
//
// The goal is to simulate a scenario where the read pump is shut down but the
// client still has messages to send. The read pump should be shut down because
// in the majority of cases it is how the client ends up stopped. However, there's
// no hard synchronization between the read and write pumps so we have to defend
// against race conditions where the read side is shut down, hence this test.
func TestWritePumpInterruptionRequeued(t *testing.T) {
serverReceivedMessages := make(chan string, 10)
wg := sync.WaitGroup{}

upgrader := ws.Upgrader{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wg.Add(1)

require.NotEmpty(t, r.UserAgent())
require.NotEmpty(t, r.Header.Get("X-Stripe-Client-User-Agent"))
require.Equal(t, "websocket-random-id", r.Header.Get("Websocket-Id"))
c, err := upgrader.Upgrade(w, r, nil)
require.NoError(t, err)

require.Equal(t, "websocket_feature=webhook-payloads", r.URL.RawQuery)

defer c.Close()

msgType, msg, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, msgType, ws.TextMessage)
serverReceivedMessages <- string(msg)

// To simulate a forced reconnection, the server closes the connection
// after receiving any messages
c.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseNormalClosure, ""), time.Now().Add(5*time.Second))
c.Close()
wg.Done()
}))

defer ts.Close()

url := "ws" + strings.TrimPrefix(ts.URL, "http")

client := NewClient(
url,
"websocket-random-id",
"webhook-payloads",
&Config{
EventHandler: EventHandlerFunc(func(msg IncomingMessage) {}),
WriteWait: 10 * time.Second,
PongWait: 60 * time.Second,
PingPeriod: 60 * time.Hour,
},
)

go client.Run(context.Background())

defer client.Stop()

actualMessages := []string{}
connectedChan := client.Connected()
<-connectedChan
go func() { client.terminateReadPump() }()

for i := 0; i < 2; i++ {
client.SendMessage(NewEventAck(fmt.Sprintf("event_%d", i), fmt.Sprintf("event_%d", i), fmt.Sprintf("event_%d", i)))
// Needed to deflake the test from racing against itself
// Something to do with the buffering
time.Sleep(100 * time.Millisecond)

msg := <-serverReceivedMessages
actualMessages = append(actualMessages, msg)
wg.Wait()
}

wg.Wait()

for {
exhausted := false
select {
case msg := <-serverReceivedMessages:
actualMessages = append(actualMessages, msg)
default:
exhausted = true
}

if exhausted {
break
}
}

assert.Len(t, actualMessages, 2)
}

0 comments on commit acb55dc

Please sign in to comment.