diff --git a/conn_test.go b/conn_test.go index 169f5ee..88a87c5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -44,7 +44,8 @@ func (s *StompSuite) Test_unsuccessful_connect(c *C) { c.Assert(err, IsNil) c.Assert(f1.Command, Equals, "CONNECT") f2 := frame.New("ERROR", "message", "auth-failed") - writer.Write(f2) + err = writer.Write(f2) + c.Assert(err, IsNil) }() conn, err := Connect(fc1) @@ -112,7 +113,8 @@ func (s *StompSuite) Test_successful_connect_and_disconnect(c *C) { if tc.ExpectedServer != "" { connectedFrame.Header.Add("server", tc.ExpectedServer) } - writer.Write(connectedFrame) + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) f2, err := reader.Read() c.Assert(err, IsNil) @@ -120,7 +122,9 @@ func (s *StompSuite) Test_successful_connect_and_disconnect(c *C) { receipt, _ := f2.Header.Contains("receipt") c.Check(receipt, Equals, "1") - writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) + }() client, err := Connect(fc1, tc.Options...) @@ -137,6 +141,67 @@ func (s *StompSuite) Test_successful_connect_and_disconnect(c *C) { } } +func (s *StompSuite) Test_successful_connect_get_headers(c *C) { + var respHeaders *frame.Header + + testcases := []struct { + Options []func(*Conn) error + Headers map[string]string + }{ + { + Options: []func(*Conn) error{ConnOpt.ResponseHeaders(func(f *frame.Header) { respHeaders = f })}, + Headers: map[string]string{"custom-header": "test", "foo": "bar"}, + }, + } + + for _, tc := range testcases { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + stop := make(chan struct{}) + + go func() { + defer func() { + fc2.Close() + close(stop) + }() + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + connectedFrame := frame.New("CONNECTED") + for key, value := range tc.Headers { + connectedFrame.Header.Add(key, value) + } + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + + f2, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f2.Command, Equals, "DISCONNECT") + receipt, _ := f2.Header.Contains("receipt") + c.Check(receipt, Equals, "1") + + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) + + }() + + client, err := Connect(fc1, tc.Options...) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + c.Assert(respHeaders, NotNil) + for key, value := range tc.Headers { + c.Assert(respHeaders.Get(key), Equals, value) + } + err = client.Disconnect() + c.Assert(err, IsNil) + + <-stop + } +} + func (s *StompSuite) Test_successful_connect_with_nonstandard_header(c *C) { resetId() fc1, fc2 := testutil.NewFakeConn(c) @@ -162,7 +227,8 @@ func (s *StompSuite) Test_successful_connect_with_nonstandard_header(c *C) { connectedFrame.Header.Add("heart-beat", "0,0") connectedFrame.Header.Add("server", "RabbitMQ/3.2.1") connectedFrame.Header.Add("version", "1.0") - writer.Write(connectedFrame) + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) f2, err := reader.Read() c.Assert(err, IsNil) @@ -170,7 +236,8 @@ func (s *StompSuite) Test_successful_connect_with_nonstandard_header(c *C) { receipt, _ := f2.Header.Contains("receipt") c.Check(receipt, Equals, "1") - writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + err = writer.Write(frame.New("RECEIPT", frame.ReceiptId, "1")) + c.Assert(err, IsNil) }() client, err := Connect(fc1, @@ -200,8 +267,10 @@ func (s *StompSuite) Test_connect_not_panic_on_empty_response(c *C) { close(stop) }() reader := frame.NewReader(fc2) - reader.Read() - fc2.Write([]byte("\n")) + _, err := reader.Read() + c.Assert(err, IsNil) + _, err = fc2.Write([]byte("\n")) + c.Assert(err, IsNil) }() client, err := Connect(fc1, ConnOpt.Host("the_server")) @@ -225,7 +294,8 @@ func connectHelper(c *C, version Version) (*Conn, *fakeReaderWriter) { c.Assert(err, IsNil) c.Assert(f1.Command, Equals, "CONNECT") f2 := frame.New("CONNECTED", "version", version.String()) - writer.Write(f2) + err = writer.Write(f2) + c.Assert(err, IsNil) close(stop) }() @@ -287,7 +357,8 @@ func subscribeHelper(c *C, ackMode AckMode, version Version, opts ...func(*frame f4.Header.Add(frame.Ack, messageId) } f4.Body = []byte(bodyText) - rw.Write(f4) + err = rw.Write(f4) + c.Assert(err, IsNil) if ackMode.ShouldAck() { f5, _ := rw.Read() @@ -305,13 +376,13 @@ func subscribeHelper(c *C, ackMode AckMode, version Version, opts ...func(*frame c.Assert(f6.Command, Equals, "UNSUBSCRIBE") c.Assert(f6.Header.Get(frame.Receipt), Not(Equals), "") c.Assert(f6.Header.Get(frame.Id), Equals, id) - rw.Write(frame.New(frame.RECEIPT, - frame.ReceiptId, f6.Header.Get(frame.Receipt))) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f6.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) f7, _ := rw.Read() c.Assert(f7.Command, Equals, "DISCONNECT") - rw.Write(frame.New(frame.RECEIPT, - frame.ReceiptId, f7.Header.Get(frame.Receipt))) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f7.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) }() var sub *Subscription @@ -332,14 +403,16 @@ func subscribeHelper(c *C, ackMode AckMode, version Version, opts ...func(*frame c.Assert(msg.ShouldAck(), Equals, ackMode.ShouldAck()) if msg.ShouldAck() { - msg.Conn.Ack(msg) + err = msg.Conn.Ack(msg) + c.Assert(err, IsNil) } } err = sub.Unsubscribe(SubscribeOpt.Header("custom", "true")) c.Assert(err, IsNil) - conn.Disconnect() + err = conn.Disconnect() + c.Assert(err, IsNil) } func (s *StompSuite) TestTransaction(c *C) { @@ -391,7 +464,8 @@ func subscribeTransactionHelper(c *C, ackMode AckMode, version Version, abort bo f4.Header.Add(frame.Ack, messageId) } f4.Body = []byte(bodyText) - rw.Write(f4) + err = rw.Write(f4) + c.Assert(err, IsNil) beginFrame, err := rw.Read() c.Assert(err, IsNil) @@ -436,13 +510,13 @@ func subscribeTransactionHelper(c *C, ackMode AckMode, version Version, abort bo c.Assert(f6.Command, Equals, "UNSUBSCRIBE") c.Assert(f6.Header.Get(frame.Receipt), Not(Equals), "") c.Assert(f6.Header.Get(frame.Id), Equals, id) - rw.Write(frame.New(frame.RECEIPT, - frame.ReceiptId, f6.Header.Get(frame.Receipt))) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f6.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) f7, _ := rw.Read() c.Assert(f7.Command, Equals, "DISCONNECT") - rw.Write(frame.New(frame.RECEIPT, - frame.ReceiptId, f7.Header.Get(frame.Receipt))) + err = rw.Write(frame.New(frame.RECEIPT, frame.ReceiptId, f7.Header.Get(frame.Receipt))) + c.Assert(err, IsNil) }() sub, err := conn.Subscribe("/queue/test-1", ackMode) @@ -463,24 +537,29 @@ func subscribeTransactionHelper(c *C, ackMode AckMode, version Version, abort bo c.Assert(tx.Id(), Not(Equals), "") if msg.ShouldAck() { if nack && version.SupportsNack() { - tx.Nack(msg) + err = tx.Nack(msg) + c.Assert(err, IsNil) } else { - tx.Ack(msg) + err = tx.Ack(msg) + c.Assert(err, IsNil) } } err = tx.Send("/queue/another-queue", "text/plain", []byte(bodyText)) c.Assert(err, IsNil) if abort { - tx.Abort() + err = tx.Abort() + c.Assert(err, IsNil) } else { - tx.Commit() + err = tx.Commit() + c.Assert(err, IsNil) } } err = sub.Unsubscribe() c.Assert(err, IsNil) - conn.Disconnect() + err = conn.Disconnect() + c.Assert(err, IsNil) } func (s *StompSuite) TestHeartBeatReadTimeout(c *C) { @@ -495,7 +574,8 @@ func (s *StompSuite) TestHeartBeatReadTimeout(c *C) { "message-id", "1", "subscription", f1.Header.Get("id")) messageFrame.Body = []byte("Message body") - rw.Write(messageFrame) + err = rw.Write(messageFrame) + c.Assert(err, IsNil) }() sub, err := conn.Subscribe("/queue/test1", AckAuto) @@ -509,6 +589,7 @@ func (s *StompSuite) TestHeartBeatReadTimeout(c *C) { msg, ok = <-sub.C c.Assert(msg, NotNil) + c.Assert(ok, Equals, true) c.Assert(msg.Err, NotNil) c.Assert(msg.Err.Error(), Equals, "read timeout") @@ -529,7 +610,8 @@ func (s *StompSuite) TestHeartBeatWriteTimeout(c *C) { }() time.Sleep(250) - conn.Disconnect() + err := conn.Disconnect() + c.Assert(err, IsNil) } func createHeartBeatConnection( @@ -549,7 +631,8 @@ func createHeartBeatConnection( c.Assert(f1.Header.Get("heart-beat"), Equals, "1,1") f2 := frame.New("CONNECTED", "version", "1.2") f2.Header.Add("heart-beat", fmt.Sprintf("%d,%d", readTimeout, writeTimeout)) - writer.Write(f2) + err = writer.Write(f2) + c.Assert(err, IsNil) close(stop) }()