diff --git a/connection_impl.go b/connection_impl.go index 396a4c20..1fa1a8e4 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -404,10 +404,17 @@ func (c *connection) waitRead(n int) (err error) { } // wait full n for c.inputBuffer.Len() < n { - if !c.IsActive() { + switch c.status(closing) { + case poller: + return Exception(ErrEOF, "wait read") + case user: return Exception(ErrConnClosed, "wait read") + default: + err = <-c.readTrigger + if err != nil { + return err + } } - <-c.readTrigger } return nil } @@ -422,23 +429,32 @@ func (c *connection) waitReadWithTimeout(n int) (err error) { } for c.inputBuffer.Len() < n { - if !c.IsActive() { - // cannot return directly, stop timer before ! + switch c.status(closing) { + case poller: + // cannot return directly, stop timer first! + err = Exception(ErrEOF, "wait read") + goto RET + case user: + // cannot return directly, stop timer first! err = Exception(ErrConnClosed, "wait read") - break - } - select { - case <-c.readTimer.C: - // double check if there is enough data to be read - if c.inputBuffer.Len() >= n { - return nil + goto RET + default: + select { + case <-c.readTimer.C: + // double check if there is enough data to be read + if c.inputBuffer.Len() >= n { + return nil + } + return Exception(ErrReadTimeout, c.remoteAddr.String()) + case err = <-c.readTrigger: + if err != nil { + return err + } + continue } - return Exception(ErrReadTimeout, c.remoteAddr.String()) - case <-c.readTrigger: - continue } } - +RET: // clean timer.C if !c.readTimer.Stop() { <-c.readTimer.C