Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: slacktest GetSeenOutboundMessages race condition #1362

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions slacktest/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ func (sts *Server) queueForWebsocket(s, hubname string) {
channel, err := getHubForServer(hubname)
if err != nil {
log.Printf("Unable to get server's channels: %s", err.Error())
} else {
channel.sent <- s
}
sts.seenOutboundMessages.Lock()
sts.seenOutboundMessages.messages = append(sts.seenOutboundMessages.messages, s)
sts.seenOutboundMessages.Unlock()
channel.sent <- s
}

func handlePendingMessages(c *websocket.Conn, hubname string) {
Expand All @@ -43,9 +41,7 @@ func (sts *Server) postProcessMessage(m, hubname string) {
log.Printf("Unable to get server's channels: %s", err.Error())
return
}
sts.seenInboundMessages.Lock()
sts.seenInboundMessages.messages = append(sts.seenInboundMessages.messages, m)
sts.seenInboundMessages.Unlock()
sts.seenInboundMessages.observe(m)
// send to firehose
channel.seen <- m
}
Expand Down
53 changes: 27 additions & 26 deletions slacktest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,53 +106,45 @@ func (sts *Server) GetGroups() []slack.Group {

// GetSeenInboundMessages returns all messages seen via websocket excluding pings
func (sts *Server) GetSeenInboundMessages() []string {
sts.seenInboundMessages.RLock()
m := sts.seenInboundMessages.messages
sts.seenInboundMessages.RUnlock()
return m
return sts.seenInboundMessages.get()
}

// GetSeenOutboundMessages returns all messages seen via websocket excluding pings
func (sts *Server) GetSeenOutboundMessages() []string {
sts.seenOutboundMessages.RLock()
m := sts.seenOutboundMessages.messages
sts.seenOutboundMessages.RUnlock()
return m
return sts.seenOutboundMessages.get()
}

// SawOutgoingMessage checks if a message was sent to connected websocket clients
func (sts *Server) SawOutgoingMessage(msg string) bool {
sts.seenOutboundMessages.RLock()
defer sts.seenOutboundMessages.RUnlock()
for _, m := range sts.seenOutboundMessages.messages {
for _, m := range sts.seenOutboundMessages.get() {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
err := json.Unmarshal([]byte(m), evt)
if err != nil {
continue
}

if evt.Text == msg {
return true
}
}

return false
}

// SawMessage checks if an incoming message was seen
func (sts *Server) SawMessage(msg string) bool {
sts.seenInboundMessages.RLock()
defer sts.seenInboundMessages.RUnlock()
for _, m := range sts.seenInboundMessages.messages {
for _, m := range sts.seenInboundMessages.get() {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
err := json.Unmarshal([]byte(m), evt)
if err != nil {
// This event isn't a message event so we'll skip it
continue
}
if evt.Text == msg {
return true
}
}

return false
}

Expand Down Expand Up @@ -184,11 +176,14 @@ func (sts *Server) SendMessageToBot(channel, msg string) {
m.User = defaultNonBotUserID
m.Text = fmt.Sprintf("<@%s> %s", sts.BotID, msg)
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())
j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal message for bot: %s", jErr.Error())

j, err := json.Marshal(m)
if err != nil {
log.Printf("Unable to marshal message for bot: %s", err.Error())
return
}

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

Expand All @@ -200,11 +195,14 @@ func (sts *Server) SendDirectMessageToBot(msg string) {
m.User = defaultNonBotUserID
m.Text = msg
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())
j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal private message for bot: %s", jErr.Error())

j, err := json.Marshal(m)
if err != nil {
log.Printf("Unable to marshal private message for bot: %s", err.Error())
return
}

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

Expand All @@ -216,18 +214,21 @@ func (sts *Server) SendMessageToChannel(channel, msg string) {
m.Text = msg
m.User = defaultNonBotUserID
m.Timestamp = fmt.Sprintf("%d", time.Now().Unix())

j, jErr := json.Marshal(m)
if jErr != nil {
log.Printf("Unable to marshal message for channel: %s", jErr.Error())
return
}
stringMsg := string(j)
go sts.queueForWebsocket(stringMsg, sts.ServerAddr)

sts.seenOutboundMessages.observe(string(j))
go sts.queueForWebsocket(string(j), sts.ServerAddr)
}

// SendToWebsocket send `s` as is to connected clients.
// This is useful for sending your own custom json to the websocket
func (sts *Server) SendToWebsocket(s string) {
sts.seenOutboundMessages.observe(s)
go sts.queueForWebsocket(s, sts.ServerAddr)
}

Expand Down
24 changes: 12 additions & 12 deletions slacktest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestCustomNewServer(t *testing.T) {

func TestServerSendMessageToChannel(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendMessageToChannel("C123456789", "some text")
time.Sleep(2 * time.Second)
assert.True(t, s.SawOutgoingMessage("some text"))
Expand All @@ -36,7 +36,7 @@ func TestServerSendMessageToChannel(t *testing.T) {

func TestServerSendMessageToBot(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendMessageToBot("C123456789", "some text")
expectedMsg := fmt.Sprintf("<@%s> %s", s.BotID, "some text")
time.Sleep(2 * time.Second)
Expand All @@ -46,7 +46,7 @@ func TestServerSendMessageToBot(t *testing.T) {

func TestBotDirectMessageBotHandler(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
s.SendDirectMessageToBot("some text")
expectedMsg := "some text"
time.Sleep(2 * time.Second)
Expand All @@ -55,14 +55,14 @@ func TestBotDirectMessageBotHandler(t *testing.T) {
}

func TestGetSeenOutboundMessages(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()

s.SendMessageToChannel("foo", "should see this message")
time.Sleep(maxWait)

seenOutbound := s.GetSeenOutboundMessages()
assert.True(t, len(seenOutbound) > 0)
assert.Len(t, seenOutbound, 1)

hadMessage := false
for _, msg := range seenOutbound {
var m = slack.Message{}
Expand All @@ -79,7 +79,7 @@ func TestGetSeenOutboundMessages(t *testing.T) {
func TestGetSeenInboundMessages(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()

api := slack.New("ABCDEFG", slack.OptionAPIURL(s.GetAPIURL()))
rtm := api.NewRTM()
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestGetSeenInboundMessages(t *testing.T) {
func TestSendChannelInvite(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()
rtm := s.GetTestRTMInstance()
go rtm.ManageConnection()
evChan := make(chan (slack.Channel), 1)
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestSendChannelInvite(t *testing.T) {
func TestSendGroupInvite(t *testing.T) {
maxWait := 5 * time.Second
s := NewTestServer()
go s.Start()
s.Start()
rtm := s.GetTestRTMInstance()
go rtm.ManageConnection()
evChan := make(chan (slack.Channel), 1)
Expand Down Expand Up @@ -165,12 +165,12 @@ func TestSendGroupInvite(t *testing.T) {

func TestServerSawMessage(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
assert.False(t, s.SawMessage("foo"), "should not have seen any message")
}

func TestServerSawOutgoingMessage(t *testing.T) {
s := NewTestServer()
go s.Start()
s.Start()
assert.False(t, s.SawOutgoingMessage("foo"), "should not have seen any message")
}
22 changes: 18 additions & 4 deletions slacktest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,29 @@ type hub struct {
}

type messageChannels struct {
seen chan (string)
sent chan (string)
posted chan (slack.Message)
seen chan string
sent chan string
posted chan slack.Message
}
type messageCollection struct {
sync.RWMutex
messages []string
}

func (mc *messageCollection) observe(msg string) {
mc.Lock()
defer mc.Unlock()
mc.messages = append(mc.messages, msg)
}

func (mc *messageCollection) get() []string {
mc.RLock()
defer mc.RUnlock()

m := mc.messages
return m
}

type serverChannels struct {
sync.RWMutex
channels []slack.Channel
Expand All @@ -68,7 +82,7 @@ type Server struct {
BotName string
BotID string
ServerAddr string
SeenFeed chan (string)
SeenFeed chan string
channels *serverChannels
groups *serverGroups
seenInboundMessages *messageCollection
Expand Down
Loading