Skip to content

Commit

Permalink
Merge pull request #126 from extbe/feature/disconnect-timeout
Browse files Browse the repository at this point in the history
Add DisconnectReceiptTimeout conn option
  • Loading branch information
worg authored Apr 11, 2023
2 parents e752104 + 0165750 commit 8ba92d1
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 42 deletions.
58 changes: 34 additions & 24 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,31 @@ const DefaultMsgSendTimeout = 10 * time.Second
// Default receipt timeout in Conn.Send function
const DefaultRcvReceiptTimeout = 30 * time.Second

// Default receipt timeout in Conn.Disconnect function
const DefaultDisconnectReceiptTimeout = 30 * time.Second

// Reply-To header used for temporary queues/RPC with rabbit.
const ReplyToHeader = "reply-to"

// A Conn is a connection to a STOMP server. Create a Conn using either
// the Dial or Connect function.
type Conn struct {
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
disconnectReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
}

type writeRequest struct {
Expand Down Expand Up @@ -195,6 +199,7 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error)

c.msgSendTimeout = options.MsgSendTimeout
c.rcvReceiptTimeout = options.RcvReceiptTimeout
c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout

if options.ResponseHeadersCallback != nil {
options.ResponseHeadersCallback(response.Header)
Expand Down Expand Up @@ -421,13 +426,18 @@ func (c *Conn) Disconnect() error {
C: ch,
}

response := <-ch
if response.Command != frame.RECEIPT {
return newError(response)
err := readReceiptWithTimeout(ch, c.disconnectReceiptTimeout, ErrDisconnectReceiptTimeout)
if err == nil {
c.closed = true
return c.conn.Close()
}

c.closed = true
return c.conn.Close()
if err == ErrDisconnectReceiptTimeout {
c.closed = true
_ = c.conn.Close()
}

return err
}

// MustDisconnect will disconnect 'ungracefully' from the STOMP server.
Expand Down Expand Up @@ -480,7 +490,7 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*
return err
}

err = readReceiptWithTimeout(request, c.rcvReceiptTimeout)
err = readReceiptWithTimeout(request.C, c.rcvReceiptTimeout, ErrMsgReceiptTimeout)
if err != nil {
return err
}
Expand All @@ -497,16 +507,16 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*
return nil
}

func readReceiptWithTimeout(request writeRequest, timeout time.Duration) error {
func readReceiptWithTimeout(responseChan chan *frame.Frame, timeout time.Duration, timeoutErr error) error {
var timeoutChan <-chan time.Time
if timeout > 0 {
timeoutChan = time.After(timeout)
}

select {
case <-timeoutChan:
return ErrMsgReceiptTimeout
case response := <-request.C:
return timeoutErr
case response := <-responseChan:
if response.Command != frame.RECEIPT {
return newError(response)
}
Expand Down
16 changes: 15 additions & 1 deletion conn_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type connOptions struct {
HeartBeatError time.Duration
MsgSendTimeout time.Duration
RcvReceiptTimeout time.Duration
DisconnectReceiptTimeout time.Duration
HeartBeatGracePeriodMultiplier float64
Login, Passcode string
AcceptVersions []string
Expand All @@ -38,6 +39,7 @@ func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error)
HeartBeatError: DefaultHeartBeatError,
MsgSendTimeout: DefaultMsgSendTimeout,
RcvReceiptTimeout: DefaultRcvReceiptTimeout,
DisconnectReceiptTimeout: DefaultDisconnectReceiptTimeout,
Logger: log.StdLogger{},
}

Expand Down Expand Up @@ -146,9 +148,14 @@ var ConnOpt struct {

// RcvReceiptTimeout is a connect option that allows the client to specify
// how long to wait for a receipt in the Conn.Send function. This helps
// avoid deadlocks. If this is not specified, the default is 10 seconds.
// avoid deadlocks. If this is not specified, the default is 30 seconds.
RcvReceiptTimeout func(rcvReceiptTimeout time.Duration) func(*Conn) error

// DisconnectReceiptTimeout is a connect option that allows the client to specify
// how long to wait for a receipt in the Conn.Disconnect function. This helps
// avoid deadlocks. If this is not specified, the default is 30 seconds.
DisconnectReceiptTimeout func(disconnectReceiptTimeout time.Duration) func(*Conn) error

// HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout
// the broker will enforce for each client’s connection. The multiplier is applied to
// the read-timeout interval the client specifies in its CONNECT frame
Expand Down Expand Up @@ -248,6 +255,13 @@ func init() {
}
}

ConnOpt.DisconnectReceiptTimeout = func(disconnectReceiptTimeout time.Duration) func(*Conn) error {
return func(c *Conn) error {
c.options.DisconnectReceiptTimeout = disconnectReceiptTimeout
return nil
}
}

ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error {
return func(c *Conn) error {
c.options.HeartBeatGracePeriodMultiplier = multiplier
Expand Down
37 changes: 33 additions & 4 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@ func (s *StompSuite) Test_connect_not_panic_on_empty_response(c *C) {
<-stop
}

func (s *StompSuite) Test_successful_disconnect_with_receipt_timeout(c *C) {
resetId()
fc1, fc2 := testutil.NewFakeConn(c)

defer func() {
fc2.Close()
}()

go func() {
reader := frame.NewReader(fc2)
writer := frame.NewWriter(fc2)

f1, err := reader.Read()
c.Assert(err, IsNil)
c.Assert(f1.Command, Equals, "CONNECT")
connectedFrame := frame.New("CONNECTED")
err = writer.Write(connectedFrame)
c.Assert(err, IsNil)
}()

client, err := Connect(fc1, ConnOpt.DisconnectReceiptTimeout(1 * time.Nanosecond))
c.Assert(err, IsNil)
c.Assert(client, NotNil)

err = client.Disconnect()
c.Assert(err, Equals, ErrDisconnectReceiptTimeout)
c.Assert(client.closed, Equals, true)
}

// Sets up a connection for testing
func connectHelper(c *C, version Version) (*Conn, *fakeReaderWriter) {
fc1, fc2 := testutil.NewFakeConn(c)
Expand Down Expand Up @@ -697,7 +726,7 @@ func (s *StompSuite) Test_TimeoutTriggers(c *C) {
C: make(chan *frame.Frame),
}

err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, NotNil)
}
Expand All @@ -715,7 +744,7 @@ func (s *StompSuite) Test_ChannelReceviesReceipt(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, IsNil)
}
Expand All @@ -733,7 +762,7 @@ func (s *StompSuite) Test_ChannelReceviesNonReceipt(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, NotNil)
}
Expand All @@ -751,7 +780,7 @@ func (s *StompSuite) Test_ZeroTimeout(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, IsNil)
}
27 changes: 14 additions & 13 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@ import (

// Error values
var (
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrNilOption = newErrorMessage("nil option")
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout")
ErrNilOption = newErrorMessage("nil option")
)

// StompError implements the Error interface, and provides
Expand Down

0 comments on commit 8ba92d1

Please sign in to comment.