From 99d6360229938edc99693e58342a20398a6ef4d2 Mon Sep 17 00:00:00 2001 From: Kyle Smith Date: Fri, 13 Dec 2024 12:55:01 -0500 Subject: [PATCH] fix: slacktest GetSeenOutboundMessages race condition This patch addresses issue #1361. The storage backing GetSeenOutboundMessages is updated in a goroutine that may or may not be executed in time for assertions against this method. This patch updates the storage to be updated synchronously in the handler, while maintaining the asynchronous queue behavior for websockets handlers. The test case has been updated to remove the sleep statement that likely worked around this very issue. I opted to change the locking behavior to be more closely related with the messageCollection type. There is a smaller version of this fix that move the lock/update/unlock block into the callsite of each queue update, if that is preferable. --- slacktest/funcs.go | 10 +++----- slacktest/server.go | 53 ++++++++++++++++++++-------------------- slacktest/server_test.go | 24 +++++++++--------- slacktest/types.go | 22 ++++++++++++++--- 4 files changed, 60 insertions(+), 49 deletions(-) diff --git a/slacktest/funcs.go b/slacktest/funcs.go index d083af2b5..7be54fe8d 100644 --- a/slacktest/funcs.go +++ b/slacktest/funcs.go @@ -15,11 +15,9 @@ func (sts *Server) queueForWebsocket(s, hubname string) { channel, err := getHubForServer(hubname) if err != nil { log.Printf("Unable to get server's channels: %s", err.Error()) + } else { + channel.sent <- s } - sts.seenOutboundMessages.Lock() - sts.seenOutboundMessages.messages = append(sts.seenOutboundMessages.messages, s) - sts.seenOutboundMessages.Unlock() - channel.sent <- s } func handlePendingMessages(c *websocket.Conn, hubname string) { @@ -43,9 +41,7 @@ func (sts *Server) postProcessMessage(m, hubname string) { log.Printf("Unable to get server's channels: %s", err.Error()) return } - sts.seenInboundMessages.Lock() - sts.seenInboundMessages.messages = append(sts.seenInboundMessages.messages, m) - sts.seenInboundMessages.Unlock() + sts.seenInboundMessages.observe(m) // send to firehose channel.seen <- m } diff --git a/slacktest/server.go b/slacktest/server.go index 6d9849451..ba01939fa 100644 --- a/slacktest/server.go +++ b/slacktest/server.go @@ -106,28 +106,20 @@ func (sts *Server) GetGroups() []slack.Group { // GetSeenInboundMessages returns all messages seen via websocket excluding pings func (sts *Server) GetSeenInboundMessages() []string { - sts.seenInboundMessages.RLock() - m := sts.seenInboundMessages.messages - sts.seenInboundMessages.RUnlock() - return m + return sts.seenInboundMessages.get() } // GetSeenOutboundMessages returns all messages seen via websocket excluding pings func (sts *Server) GetSeenOutboundMessages() []string { - sts.seenOutboundMessages.RLock() - m := sts.seenOutboundMessages.messages - sts.seenOutboundMessages.RUnlock() - return m + return sts.seenOutboundMessages.get() } // SawOutgoingMessage checks if a message was sent to connected websocket clients func (sts *Server) SawOutgoingMessage(msg string) bool { - sts.seenOutboundMessages.RLock() - defer sts.seenOutboundMessages.RUnlock() - for _, m := range sts.seenOutboundMessages.messages { + for _, m := range sts.seenOutboundMessages.get() { evt := &slack.MessageEvent{} - jErr := json.Unmarshal([]byte(m), evt) - if jErr != nil { + err := json.Unmarshal([]byte(m), evt) + if err != nil { continue } @@ -135,17 +127,16 @@ func (sts *Server) SawOutgoingMessage(msg string) bool { return true } } + return false } // SawMessage checks if an incoming message was seen func (sts *Server) SawMessage(msg string) bool { - sts.seenInboundMessages.RLock() - defer sts.seenInboundMessages.RUnlock() - for _, m := range sts.seenInboundMessages.messages { + for _, m := range sts.seenInboundMessages.get() { evt := &slack.MessageEvent{} - jErr := json.Unmarshal([]byte(m), evt) - if jErr != nil { + err := json.Unmarshal([]byte(m), evt) + if err != nil { // This event isn't a message event so we'll skip it continue } @@ -153,6 +144,7 @@ func (sts *Server) SawMessage(msg string) bool { return true } } + return false } @@ -184,11 +176,14 @@ func (sts *Server) SendMessageToBot(channel, msg string) { m.User = defaultNonBotUserID m.Text = fmt.Sprintf("<@%s> %s", sts.BotID, msg) m.Timestamp = fmt.Sprintf("%d", time.Now().Unix()) - j, jErr := json.Marshal(m) - if jErr != nil { - log.Printf("Unable to marshal message for bot: %s", jErr.Error()) + + j, err := json.Marshal(m) + if err != nil { + log.Printf("Unable to marshal message for bot: %s", err.Error()) return } + + sts.seenOutboundMessages.observe(string(j)) go sts.queueForWebsocket(string(j), sts.ServerAddr) } @@ -200,11 +195,14 @@ func (sts *Server) SendDirectMessageToBot(msg string) { m.User = defaultNonBotUserID m.Text = msg m.Timestamp = fmt.Sprintf("%d", time.Now().Unix()) - j, jErr := json.Marshal(m) - if jErr != nil { - log.Printf("Unable to marshal private message for bot: %s", jErr.Error()) + + j, err := json.Marshal(m) + if err != nil { + log.Printf("Unable to marshal private message for bot: %s", err.Error()) return } + + sts.seenOutboundMessages.observe(string(j)) go sts.queueForWebsocket(string(j), sts.ServerAddr) } @@ -216,18 +214,21 @@ func (sts *Server) SendMessageToChannel(channel, msg string) { m.Text = msg m.User = defaultNonBotUserID m.Timestamp = fmt.Sprintf("%d", time.Now().Unix()) + j, jErr := json.Marshal(m) if jErr != nil { log.Printf("Unable to marshal message for channel: %s", jErr.Error()) return } - stringMsg := string(j) - go sts.queueForWebsocket(stringMsg, sts.ServerAddr) + + sts.seenOutboundMessages.observe(string(j)) + go sts.queueForWebsocket(string(j), sts.ServerAddr) } // SendToWebsocket send `s` as is to connected clients. // This is useful for sending your own custom json to the websocket func (sts *Server) SendToWebsocket(s string) { + sts.seenOutboundMessages.observe(s) go sts.queueForWebsocket(s, sts.ServerAddr) } diff --git a/slacktest/server_test.go b/slacktest/server_test.go index 45f620858..01a603a46 100644 --- a/slacktest/server_test.go +++ b/slacktest/server_test.go @@ -27,7 +27,7 @@ func TestCustomNewServer(t *testing.T) { func TestServerSendMessageToChannel(t *testing.T) { s := NewTestServer() - go s.Start() + s.Start() s.SendMessageToChannel("C123456789", "some text") time.Sleep(2 * time.Second) assert.True(t, s.SawOutgoingMessage("some text")) @@ -36,7 +36,7 @@ func TestServerSendMessageToChannel(t *testing.T) { func TestServerSendMessageToBot(t *testing.T) { s := NewTestServer() - go s.Start() + s.Start() s.SendMessageToBot("C123456789", "some text") expectedMsg := fmt.Sprintf("<@%s> %s", s.BotID, "some text") time.Sleep(2 * time.Second) @@ -46,7 +46,7 @@ func TestServerSendMessageToBot(t *testing.T) { func TestBotDirectMessageBotHandler(t *testing.T) { s := NewTestServer() - go s.Start() + s.Start() s.SendDirectMessageToBot("some text") expectedMsg := "some text" time.Sleep(2 * time.Second) @@ -55,14 +55,14 @@ func TestBotDirectMessageBotHandler(t *testing.T) { } func TestGetSeenOutboundMessages(t *testing.T) { - maxWait := 5 * time.Second s := NewTestServer() - go s.Start() + s.Start() s.SendMessageToChannel("foo", "should see this message") - time.Sleep(maxWait) + seenOutbound := s.GetSeenOutboundMessages() - assert.True(t, len(seenOutbound) > 0) + assert.Len(t, seenOutbound, 1) + hadMessage := false for _, msg := range seenOutbound { var m = slack.Message{} @@ -79,7 +79,7 @@ func TestGetSeenOutboundMessages(t *testing.T) { func TestGetSeenInboundMessages(t *testing.T) { maxWait := 5 * time.Second s := NewTestServer() - go s.Start() + s.Start() api := slack.New("ABCDEFG", slack.OptionAPIURL(s.GetAPIURL())) rtm := api.NewRTM() @@ -108,7 +108,7 @@ func TestGetSeenInboundMessages(t *testing.T) { func TestSendChannelInvite(t *testing.T) { maxWait := 5 * time.Second s := NewTestServer() - go s.Start() + s.Start() rtm := s.GetTestRTMInstance() go rtm.ManageConnection() evChan := make(chan (slack.Channel), 1) @@ -137,7 +137,7 @@ func TestSendChannelInvite(t *testing.T) { func TestSendGroupInvite(t *testing.T) { maxWait := 5 * time.Second s := NewTestServer() - go s.Start() + s.Start() rtm := s.GetTestRTMInstance() go rtm.ManageConnection() evChan := make(chan (slack.Channel), 1) @@ -165,12 +165,12 @@ func TestSendGroupInvite(t *testing.T) { func TestServerSawMessage(t *testing.T) { s := NewTestServer() - go s.Start() + s.Start() assert.False(t, s.SawMessage("foo"), "should not have seen any message") } func TestServerSawOutgoingMessage(t *testing.T) { s := NewTestServer() - go s.Start() + s.Start() assert.False(t, s.SawOutgoingMessage("foo"), "should not have seen any message") } diff --git a/slacktest/types.go b/slacktest/types.go index 4c90ed9ba..c6e776f72 100644 --- a/slacktest/types.go +++ b/slacktest/types.go @@ -40,15 +40,29 @@ type hub struct { } type messageChannels struct { - seen chan (string) - sent chan (string) - posted chan (slack.Message) + seen chan string + sent chan string + posted chan slack.Message } type messageCollection struct { sync.RWMutex messages []string } +func (mc *messageCollection) observe(msg string) { + mc.Lock() + defer mc.Unlock() + mc.messages = append(mc.messages, msg) +} + +func (mc *messageCollection) get() []string { + mc.RLock() + defer mc.RUnlock() + + m := mc.messages + return m +} + type serverChannels struct { sync.RWMutex channels []slack.Channel @@ -68,7 +82,7 @@ type Server struct { BotName string BotID string ServerAddr string - SeenFeed chan (string) + SeenFeed chan string channels *serverChannels groups *serverGroups seenInboundMessages *messageCollection