diff --git a/conn.go b/conn.go index 4f0516a..a1b4bca 100644 --- a/conn.go +++ b/conn.go @@ -21,27 +21,31 @@ const DefaultMsgSendTimeout = 10 * time.Second // Default receipt timeout in Conn.Send function const DefaultRcvReceiptTimeout = 30 * time.Second +// Default receipt timeout in Conn.Disconnect function +const DefaultDisconnectReceiptTimeout = 30 * time.Second + // Reply-To header used for temporary queues/RPC with rabbit. const ReplyToHeader = "reply-to" // A Conn is a connection to a STOMP server. Create a Conn using either // the Dial or Connect function. type Conn struct { - conn io.ReadWriteCloser - readCh chan *frame.Frame - writeCh chan writeRequest - version Version - session string - server string - readTimeout time.Duration - writeTimeout time.Duration - msgSendTimeout time.Duration - rcvReceiptTimeout time.Duration - hbGracePeriodMultiplier float64 - closed bool - closeMutex *sync.Mutex - options *connOptions - log Logger + conn io.ReadWriteCloser + readCh chan *frame.Frame + writeCh chan writeRequest + version Version + session string + server string + readTimeout time.Duration + writeTimeout time.Duration + msgSendTimeout time.Duration + rcvReceiptTimeout time.Duration + disconnectReceiptTimeout time.Duration + hbGracePeriodMultiplier float64 + closed bool + closeMutex *sync.Mutex + options *connOptions + log Logger } type writeRequest struct { @@ -195,6 +199,7 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error) c.msgSendTimeout = options.MsgSendTimeout c.rcvReceiptTimeout = options.RcvReceiptTimeout + c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout if options.ResponseHeadersCallback != nil { options.ResponseHeadersCallback(response.Header) @@ -421,13 +426,18 @@ func (c *Conn) Disconnect() error { C: ch, } - response := <-ch - if response.Command != frame.RECEIPT { - return newError(response) + err := readReceiptWithTimeout(ch, c.disconnectReceiptTimeout, ErrDisconnectReceiptTimeout) + if err == nil { + c.closed = true + return c.conn.Close() } - c.closed = true - return c.conn.Close() + if err == ErrDisconnectReceiptTimeout { + c.closed = true + _ = c.conn.Close() + } + + return err } // MustDisconnect will disconnect 'ungracefully' from the STOMP server. @@ -480,7 +490,7 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(* return err } - err = readReceiptWithTimeout(request, c.rcvReceiptTimeout) + err = readReceiptWithTimeout(request.C, c.rcvReceiptTimeout, ErrMsgReceiptTimeout) if err != nil { return err } @@ -497,7 +507,7 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(* return nil } -func readReceiptWithTimeout(request writeRequest, timeout time.Duration) error { +func readReceiptWithTimeout(responseChan chan *frame.Frame, timeout time.Duration, timeoutErr error) error { var timeoutChan <-chan time.Time if timeout > 0 { timeoutChan = time.After(timeout) @@ -505,8 +515,8 @@ func readReceiptWithTimeout(request writeRequest, timeout time.Duration) error { select { case <-timeoutChan: - return ErrMsgReceiptTimeout - case response := <-request.C: + return timeoutErr + case response := <-responseChan: if response.Command != frame.RECEIPT { return newError(response) } diff --git a/conn_options.go b/conn_options.go index f0bf4d4..9daddf7 100644 --- a/conn_options.go +++ b/conn_options.go @@ -19,6 +19,7 @@ type connOptions struct { HeartBeatError time.Duration MsgSendTimeout time.Duration RcvReceiptTimeout time.Duration + DisconnectReceiptTimeout time.Duration HeartBeatGracePeriodMultiplier float64 Login, Passcode string AcceptVersions []string @@ -38,6 +39,7 @@ func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error) HeartBeatError: DefaultHeartBeatError, MsgSendTimeout: DefaultMsgSendTimeout, RcvReceiptTimeout: DefaultRcvReceiptTimeout, + DisconnectReceiptTimeout: DefaultDisconnectReceiptTimeout, Logger: log.StdLogger{}, } @@ -146,9 +148,14 @@ var ConnOpt struct { // RcvReceiptTimeout is a connect option that allows the client to specify // how long to wait for a receipt in the Conn.Send function. This helps - // avoid deadlocks. If this is not specified, the default is 10 seconds. + // avoid deadlocks. If this is not specified, the default is 30 seconds. RcvReceiptTimeout func(rcvReceiptTimeout time.Duration) func(*Conn) error + // DisconnectReceiptTimeout is a connect option that allows the client to specify + // how long to wait for a receipt in the Conn.Disconnect function. This helps + // avoid deadlocks. If this is not specified, the default is 30 seconds. + DisconnectReceiptTimeout func(disconnectReceiptTimeout time.Duration) func(*Conn) error + // HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout // the broker will enforce for each client’s connection. The multiplier is applied to // the read-timeout interval the client specifies in its CONNECT frame @@ -248,6 +255,13 @@ func init() { } } + ConnOpt.DisconnectReceiptTimeout = func(disconnectReceiptTimeout time.Duration) func(*Conn) error { + return func(c *Conn) error { + c.options.DisconnectReceiptTimeout = disconnectReceiptTimeout + return nil + } + } + ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error { return func(c *Conn) error { c.options.HeartBeatGracePeriodMultiplier = multiplier diff --git a/conn_test.go b/conn_test.go index e61a7fb..cbb5a9d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -312,6 +312,35 @@ func (s *StompSuite) Test_connect_not_panic_on_empty_response(c *C) { <-stop } +func (s *StompSuite) Test_successful_disconnect_with_receipt_timeout(c *C) { + resetId() + fc1, fc2 := testutil.NewFakeConn(c) + + defer func() { + fc2.Close() + }() + + go func() { + reader := frame.NewReader(fc2) + writer := frame.NewWriter(fc2) + + f1, err := reader.Read() + c.Assert(err, IsNil) + c.Assert(f1.Command, Equals, "CONNECT") + connectedFrame := frame.New("CONNECTED") + err = writer.Write(connectedFrame) + c.Assert(err, IsNil) + }() + + client, err := Connect(fc1, ConnOpt.DisconnectReceiptTimeout(1 * time.Nanosecond)) + c.Assert(err, IsNil) + c.Assert(client, NotNil) + + err = client.Disconnect() + c.Assert(err, Equals, ErrDisconnectReceiptTimeout) + c.Assert(client.closed, Equals, true) +} + // Sets up a connection for testing func connectHelper(c *C, version Version) (*Conn, *fakeReaderWriter) { fc1, fc2 := testutil.NewFakeConn(c) @@ -697,7 +726,7 @@ func (s *StompSuite) Test_TimeoutTriggers(c *C) { C: make(chan *frame.Frame), } - err := readReceiptWithTimeout(request, timeout) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) c.Assert(err, NotNil) } @@ -715,7 +744,7 @@ func (s *StompSuite) Test_ChannelReceviesReceipt(c *C) { } go sendFrameHelper(&receipt, request.C) - err := readReceiptWithTimeout(request, timeout) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) c.Assert(err, IsNil) } @@ -733,7 +762,7 @@ func (s *StompSuite) Test_ChannelReceviesNonReceipt(c *C) { } go sendFrameHelper(&receipt, request.C) - err := readReceiptWithTimeout(request, timeout) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) c.Assert(err, NotNil) } @@ -751,7 +780,7 @@ func (s *StompSuite) Test_ZeroTimeout(c *C) { } go sendFrameHelper(&receipt, request.C) - err := readReceiptWithTimeout(request, timeout) + err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout) c.Assert(err, IsNil) } diff --git a/errors.go b/errors.go index e078ceb..36bee38 100644 --- a/errors.go +++ b/errors.go @@ -6,19 +6,20 @@ import ( // Error values var ( - ErrInvalidCommand = newErrorMessage("invalid command") - ErrInvalidFrameFormat = newErrorMessage("invalid frame format") - ErrUnsupportedVersion = newErrorMessage("unsupported version") - ErrCompletedTransaction = newErrorMessage("transaction is completed") - ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0") - ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server") - ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto") - ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed") - ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly") - ErrAlreadyClosed = newErrorMessage("connection already closed") - ErrMsgSendTimeout = newErrorMessage("msg send timeout") - ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout") - ErrNilOption = newErrorMessage("nil option") + ErrInvalidCommand = newErrorMessage("invalid command") + ErrInvalidFrameFormat = newErrorMessage("invalid frame format") + ErrUnsupportedVersion = newErrorMessage("unsupported version") + ErrCompletedTransaction = newErrorMessage("transaction is completed") + ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0") + ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server") + ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto") + ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed") + ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly") + ErrAlreadyClosed = newErrorMessage("connection already closed") + ErrMsgSendTimeout = newErrorMessage("msg send timeout") + ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout") + ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout") + ErrNilOption = newErrorMessage("nil option") ) // StompError implements the Error interface, and provides