diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index add69d1e..bff66b1b 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -39,23 +39,23 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.16 + go-version: "1.20" - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v1 + uses: github/codeql-action/init@v2 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v1 + uses: github/codeql-action/autobuild@v2 # ℹī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v1 + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index abd6cf13..d61d604b 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -1,6 +1,6 @@ name: Push and Pull Request Check -on: [ push, pull_request ] +on: [ push ] jobs: compatibility-test: @@ -15,12 +15,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- - name: Unit Test run: go test -v -race -covermode=atomic -coverprofile=coverage.out ./... - name: Benchmark @@ -33,12 +33,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.20" - - uses: actions/cache@v2 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- +# - uses: actions/cache@v2 +# with: +# path: ~/go/pkg/mod +# key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} +# restore-keys: | +# ${{ runner.os }}-go- - name: Build Test run: go vet -v ./... style-test: diff --git a/connection_impl.go b/connection_impl.go index 19cb2847..1fa1a8e4 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -36,7 +36,7 @@ type connection struct { operator *FDOperator readTimeout time.Duration readTimer *time.Timer - readTrigger chan struct{} + readTrigger chan error waitReadSize int64 writeTimeout time.Duration writeTimer *time.Timer @@ -319,7 +319,7 @@ var barrierPool = sync.Pool{ // init initialize the connection with options func (c *connection) init(conn Conn, opts *options) (err error) { // init buffer, barrier, finalizer - c.readTrigger = make(chan struct{}, 1) + c.readTrigger = make(chan error, 1) c.writeTrigger = make(chan error, 1) c.bookSize, c.maxSize = pagesize, pagesize c.inputBuffer, c.outputBuffer = NewLinkBuffer(pagesize), NewLinkBuffer() @@ -357,19 +357,12 @@ func (c *connection) initNetFD(conn Conn) { } func (c *connection) initFDOperator() { - var op *FDOperator - if c.pd != nil && c.pd.operator != nil { - // reuse operator created at connect step - op = c.pd.operator - } else { - poll := pollmanager.Pick() - op = poll.Alloc() - } + poll := pollmanager.Pick() + op := poll.Alloc() op.FD = c.fd op.OnRead, op.OnWrite, op.OnHup = nil, nil, c.onHup op.Inputs, op.InputAck = c.inputs, c.inputAck op.Outputs, op.OutputAck = c.outputs, c.outputAck - c.operator = op } @@ -385,9 +378,9 @@ func (c *connection) initFinalizer() { }) } -func (c *connection) triggerRead() { +func (c *connection) triggerRead(err error) { select { - case c.readTrigger <- struct{}{}: + case c.readTrigger <- err: default: } } @@ -411,10 +404,17 @@ func (c *connection) waitRead(n int) (err error) { } // wait full n for c.inputBuffer.Len() < n { - if !c.IsActive() { + switch c.status(closing) { + case poller: + return Exception(ErrEOF, "wait read") + case user: return Exception(ErrConnClosed, "wait read") + default: + err = <-c.readTrigger + if err != nil { + return err + } } - <-c.readTrigger } return nil } @@ -429,24 +429,32 @@ func (c *connection) waitReadWithTimeout(n int) (err error) { } for c.inputBuffer.Len() < n { - if !c.IsActive() { - // cannot return directly, stop timer before ! + switch c.status(closing) { + case poller: + // cannot return directly, stop timer first! + err = Exception(ErrEOF, "wait read") + goto RET + case user: + // cannot return directly, stop timer first! err = Exception(ErrConnClosed, "wait read") - break - } - - select { - case <-c.readTimer.C: - // double check if there is enough data to be read - if c.inputBuffer.Len() >= n { - return nil + goto RET + default: + select { + case <-c.readTimer.C: + // double check if there is enough data to be read + if c.inputBuffer.Len() >= n { + return nil + } + return Exception(ErrReadTimeout, c.remoteAddr.String()) + case err = <-c.readTrigger: + if err != nil { + return err + } + continue } - return Exception(ErrReadTimeout, c.remoteAddr.String()) - case <-c.readTrigger: - continue } } - +RET: // clean timer.C if !c.readTimer.Stop() { <-c.readTimer.C diff --git a/connection_lock.go b/connection_lock.go index 2dce6622..4b0f7360 100644 --- a/connection_lock.go +++ b/connection_lock.go @@ -19,7 +19,7 @@ import ( "sync/atomic" ) -type who int32 +type who = int32 const ( none who = iota @@ -65,6 +65,14 @@ func (l *locker) isCloseBy(w who) (yes bool) { return atomic.LoadInt32(&l.keychain[closing]) == int32(w) } +func (l *locker) status(k key) int32 { + return atomic.LoadInt32(&l.keychain[k]) +} + +func (l *locker) force(k key, v int32) { + atomic.StoreInt32(&l.keychain[k], v) +} + func (l *locker) lock(k key) (success bool) { return atomic.CompareAndSwapInt32(&l.keychain[k], 0, 1) } diff --git a/connection_onevent.go b/connection_onevent.go index 8f134306..28045645 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -195,7 +195,7 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f if isProcessable(c) { process(c) } - for c.IsActive() && isProcessable(c) { + for !c.isCloseBy(user) && isProcessable(c) { process(c) } // Handling callback if connection has been closed. @@ -225,12 +225,6 @@ func (c *connection) closeCallback(needLock bool) (err error) { if needLock && !c.lock(processing) { return nil } - // If Close is called during OnPrepare, poll is not registered. - if c.isCloseBy(user) && c.operator.poll != nil { - if err = c.operator.Control(PollDetach); err != nil { - logger.Printf("NETPOLL: closeCallback detach operator failed: %v", err) - } - } var latest = c.closeCallbacks.Load() if latest == nil { return nil @@ -243,14 +237,7 @@ func (c *connection) closeCallback(needLock bool) (err error) { // register only use for connection register into poll. func (c *connection) register() (err error) { - if c.operator.isUnused() { - // operator is not registered - err = c.operator.Control(PollReadable) - } else { - // operator is already registered - // change event to wait read new data - err = c.operator.Control(PollModReadable) - } + err = c.operator.Control(PollReadable) if err != nil { logger.Printf("NETPOLL: connection register failed: %v", err) c.Close() diff --git a/connection_reactor.go b/connection_reactor.go index 65621b93..6fca76e3 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -25,17 +25,19 @@ import ( // onHup means close by poller. func (c *connection) onHup(p Poll) error { - if c.closeBy(poller) { - c.triggerRead() - c.triggerWrite(ErrConnClosed) - // It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively. - // It can be confirmed that the OnRequest goroutine has been exited before closecallback executing, - // and it is safe to close the buffer at this time. - var onConnect, _ = c.onConnectCallback.Load().(OnConnect) - var onRequest, _ = c.onRequestCallback.Load().(OnRequest) - if onConnect != nil || onRequest != nil { - c.closeCallback(true) - } + if !c.closeBy(poller) { + return nil + } + // already PollDetach when call OnHup + c.triggerRead(Exception(ErrEOF, "peer close")) + c.triggerWrite(Exception(ErrConnClosed, "peer close")) + // It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively. + // It can be confirmed that the OnRequest goroutine has been exited before closecallback executing, + // and it is safe to close the buffer at this time. + var onConnect, _ = c.onConnectCallback.Load().(OnConnect) + var onRequest, _ = c.onRequestCallback.Load().(OnRequest) + if onConnect != nil || onRequest != nil { + c.closeCallback(true) } return nil } @@ -43,14 +45,24 @@ func (c *connection) onHup(p Poll) error { // onClose means close by user. func (c *connection) onClose() error { if c.closeBy(user) { - c.triggerRead() - c.triggerWrite(ErrConnClosed) + // If Close is called during OnPrepare, poll is not registered. + if c.operator.poll != nil { + if err := c.operator.Control(PollDetach); err != nil { + logger.Printf("NETPOLL: onClose detach operator failed: %v", err) + } + } + c.triggerRead(Exception(ErrConnClosed, "self close")) + c.triggerWrite(Exception(ErrConnClosed, "self close")) c.closeCallback(true) return nil } - if c.isCloseBy(poller) { - // Connection with OnRequest of nil - // relies on the user to actively close the connection to recycle resources. + + closedByPoller := c.isCloseBy(poller) + // force change closed by user + c.force(closing, user) + + // If OnRequest is nil, relies on the user to actively close the connection to recycle resources. + if closedByPoller { c.closeCallback(true) } return nil @@ -103,7 +115,7 @@ func (c *connection) inputAck(n int) (err error) { needTrigger = c.onRequest() } if needTrigger && length >= int(atomic.LoadInt64(&c.waitReadSize)) { - c.triggerRead() + c.triggerRead(nil) } return nil } diff --git a/connection_test.go b/connection_test.go index 3d8fe160..ec72a0c9 100644 --- a/connection_test.go +++ b/connection_test.go @@ -211,7 +211,7 @@ func writeAll(fd int, buf []byte) error { // Large packet write test. The socket buffer is 2MB by default, here to verify // whether Connection.Close can be executed normally after socket output buffer is full. func TestLargeBufferWrite(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":12345") MustNil(t, err) trigger := make(chan int) @@ -230,40 +230,43 @@ func TestLargeBufferWrite(t *testing.T) { } }() - conn, err := DialConnection("tcp", ":1234", time.Second) + conn, err := DialConnection("tcp", ":12345", time.Second) MustNil(t, err) rfd := <-trigger var wg sync.WaitGroup wg.Add(1) - bufferSize := 2 * 1024 * 1024 + bufferSize := 2 * 1024 * 1024 // 2MB + round := 128 //start large buffer writing go func() { defer wg.Done() - for i := 0; i < 129; i++ { + for i := 1; i <= round+1; i++ { _, err := conn.Writer().Malloc(bufferSize) MustNil(t, err) err = conn.Writer().Flush() - if i < 128 { + if i <= round { MustNil(t, err) } } }() - time.Sleep(time.Millisecond * 50) + // wait socket buffer full + time.Sleep(time.Millisecond * 100) buf := make([]byte, 1024) - for i := 0; i < 128*bufferSize/1024; i++ { - _, err := syscall.Read(rfd, buf) - MustNil(t, err) + for received := 0; received < round*bufferSize; { + n, _ := syscall.Read(rfd, buf) + received += n } // close success err = conn.Close() MustNil(t, err) wg.Wait() + trigger <- 1 } func TestWriteTimeout(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":1234") MustNil(t, err) interval := time.Millisecond * 100 @@ -397,7 +400,7 @@ func TestConnectionUntil(t *testing.T) { buf, err := rconn.Reader().Until('\n') Equal(t, len(buf), 100) - Assert(t, errors.Is(err, ErrConnClosed), err) + Assert(t, errors.Is(err, ErrEOF), err) } func TestBookSizeLargerThanMaxSize(t *testing.T) { @@ -436,7 +439,7 @@ func TestBookSizeLargerThanMaxSize(t *testing.T) { } func TestConnDetach(t *testing.T) { - ln, err := CreateListener("tcp", ":1234") + ln, err := createTestListener("tcp", ":1234") MustNil(t, err) go func() { @@ -491,3 +494,49 @@ func TestConnDetach(t *testing.T) { err = ln.Close() MustNil(t, err) } + +func TestParallelShortConnection(t *testing.T) { + ln, err := createTestListener("tcp", ":12345") + MustNil(t, err) + defer ln.Close() + + var received int64 + el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { + data, err := connection.Reader().Next(connection.Reader().Len()) + if err != nil { + return err + } + atomic.AddInt64(&received, int64(len(data))) + //t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) + return nil + }) + go func() { + el.Serve(ln) + }() + + conns := 100 + sizePerConn := 1024 * 100 + totalSize := conns * sizePerConn + var wg sync.WaitGroup + for i := 0; i < conns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := DialConnection("tcp", ":12345", time.Second) + MustNil(t, err) + n, err := conn.Writer().WriteBinary(make([]byte, sizePerConn)) + MustNil(t, err) + MustTrue(t, n == sizePerConn) + err = conn.Writer().Flush() + MustNil(t, err) + err = conn.Close() + MustNil(t, err) + }() + } + wg.Wait() + + for atomic.LoadInt64(&received) < int64(totalSize) { + t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize) + time.Sleep(time.Millisecond * 100) + } +} diff --git a/net_dialer_test.go b/net_dialer_test.go index 3d08ed89..7383fd0d 100644 --- a/net_dialer_test.go +++ b/net_dialer_test.go @@ -167,7 +167,7 @@ func TestFDClose(t *testing.T) { // fd data package race test, use two servers and two dialers. func TestDialerThenClose(t *testing.T) { // server 1 - ln1, _ := CreateListener("tcp", ":1231") + ln1, _ := createTestListener("tcp", ":1231") el1 := mockDialerEventLoop(1) go func() { el1.Serve(ln1) @@ -177,7 +177,7 @@ func TestDialerThenClose(t *testing.T) { defer el1.Shutdown(ctx1) // server 2 - ln2, _ := CreateListener("tcp", ":1232") + ln2, _ := createTestListener("tcp", ":1232") el2 := mockDialerEventLoop(2) go func() { el2.Serve(ln2) diff --git a/net_polldesc.go b/net_polldesc.go index 0b78c653..dfd95de1 100644 --- a/net_polldesc.go +++ b/net_polldesc.go @@ -21,16 +21,15 @@ import ( "context" ) -// TODO: recycle *pollDesc func newPollDesc(fd int) *pollDesc { pd := &pollDesc{} poll := pollmanager.Pick() - op := poll.Alloc() - op.FD = fd - op.OnWrite = pd.onwrite - op.OnHup = pd.onhup - - pd.operator = op + pd.operator = &FDOperator{ + poll: poll, + FD: fd, + OnWrite: pd.onwrite, + OnHup: pd.onhup, + } pd.writeTrigger = make(chan struct{}) pd.closeTrigger = make(chan struct{}) return pd @@ -45,13 +44,6 @@ type pollDesc struct { // WaitWrite . func (pd *pollDesc) WaitWrite(ctx context.Context) (err error) { - defer func() { - // if return err != nil, upper caller function will close the connection - if err != nil { - pd.operator.Free() - } - }() - if pd.operator.isUnused() { // add ET|Write|Hup if err = pd.operator.Control(PollWritable); err != nil { @@ -84,6 +76,7 @@ func (pd *pollDesc) onwrite(p Poll) error { select { case <-pd.writeTrigger: default: + pd.detach() close(pd.writeTrigger) } return nil diff --git a/netpoll_test.go b/netpoll_test.go index b97bf176..0467e879 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -251,6 +251,41 @@ func TestCloseCallbackWhenOnConnect(t *testing.T) { MustNil(t, err) } +func TestCloseConnWhenOnConnect(t *testing.T) { + var network, address = "tcp", ":8888" + conns := 10 + var wg sync.WaitGroup + wg.Add(conns) + var loop = newTestEventLoop(network, address, + nil, + WithOnConnect(func(ctx context.Context, connection Connection) context.Context { + defer wg.Done() + err := connection.Close() + MustNil(t, err) + return ctx + }), + ) + + for i := 0; i < conns; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var conn, err = DialConnection(network, address, time.Second) + if err != nil { + return + } + _, err = conn.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + err = conn.Close() + MustNil(t, err) + }() + } + + wg.Wait() + err := loop.Shutdown(context.Background()) + MustNil(t, err) +} + func TestServerReadAndClose(t *testing.T) { var network, address = "tcp", ":18888" var sendMsg = []byte("hello") @@ -362,8 +397,18 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func createTestListener(network, address string) (Listener, error) { + for { + ln, err := CreateListener(network, address) + if err == nil { + return ln, nil + } + time.Sleep(time.Millisecond * 100) + } +} + func newTestEventLoop(network, address string, onRequest OnRequest, opts ...Option) EventLoop { - ln, err := CreateListener(network, address) + ln, err := createTestListener(network, address) if err != nil { panic(err) } diff --git a/poll.go b/poll.go index 1d5c42fb..c494ffd6 100644 --- a/poll.go +++ b/poll.go @@ -57,10 +57,6 @@ const ( // PollDetach is used to remove the FDOperator from poll. PollDetach PollEvent = 0x3 - // PollModReadable is used to re-register the readable monitor for the FDOperator created by the dialer. - // It is only used when calling the dialer's conn init. - PollModReadable PollEvent = 0x4 - // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. PollR2RW PollEvent = 0x5 diff --git a/poll_default.go b/poll_default.go index e926311b..b35ff5a6 100644 --- a/poll_default.go +++ b/poll_default.go @@ -55,21 +55,22 @@ func (p *defaultPoll) onhups() { } // readall read all left data before close connection -func readall(op *FDOperator, br barrier) (err error) { +func readall(op *FDOperator, br barrier) (total int, err error) { var bs = br.bs var ivs = br.ivs var n int for { bs = op.Inputs(br.bs) if len(bs) == 0 { - return nil + return total, nil } TryRead: n, err = ioread(op.FD, bs, ivs) op.InputAck(n) + total += n if err != nil { - return err + return total, err } if n == 0 { goto TryRead diff --git a/poll_default_bsd.go b/poll_default_bsd.go index a69d23f6..9c8aa8c9 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -90,6 +90,7 @@ func (p *defaultPoll) Wait() error { continue } + var totalRead int evt := events[i] triggerRead = evt.Filter == syscall.EVFILT_READ && evt.Flags&syscall.EV_ENABLE != 0 triggerWrite = evt.Filter == syscall.EVFILT_WRITE && evt.Flags&syscall.EV_ENABLE != 0 @@ -105,6 +106,7 @@ func (p *defaultPoll) Wait() error { if len(bs) > 0 { var n, err = ioread(operator.FD, bs, barriers[i].ivs) operator.InputAck(n) + totalRead += n if err != nil { p.appendHup(operator) continue @@ -112,14 +114,20 @@ func (p *defaultPoll) Wait() error { } } } - if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close - if err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { - logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) - } - } if triggerHup { - p.appendHup(operator) - continue + if triggerRead && operator.Inputs != nil { + var leftRead int + // read all left data if peer send and close + if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { + logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error()) + } + totalRead += leftRead + } + // only close connection if no further read bytes + if totalRead == 0 { + p.appendHup(operator) + continue + } } if triggerWrite { if operator.OnWrite != nil { @@ -172,19 +180,23 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { evs[0].Ident = uint64(operator.FD) p.setOperator(unsafe.Pointer(&evs[0].Udata), operator) switch event { - case PollReadable, PollModReadable: + case PollReadable: operator.inuse() evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE|syscall.EV_ONESHOT + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: + if operator.OnWrite != nil { // means WaitWrite finished + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + } else { + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + } p.delOperator(operator) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE|syscall.EV_ONESHOT case PollR2RW: evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE|syscall.EV_ONESHOT + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index 72ddf664..a0087ee0 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -117,12 +117,14 @@ func (p *defaultPoll) Wait() (err error) { func (p *defaultPoll) handler(events []epollevent) (closed bool) { var triggerRead, triggerWrite, triggerHup, triggerError bool + var err error for i := range events { operator := p.getOperator(0, unsafe.Pointer(&events[i].data)) if operator == nil || !operator.do() { continue } + var totalRead int evt := events[i].events triggerRead = evt&syscall.EPOLLIN != 0 triggerWrite = evt&syscall.EPOLLOUT != 0 @@ -155,6 +157,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { if len(bs) > 0 { var n, err = ioread(operator.FD, bs, p.barriers[i].ivs) operator.InputAck(n) + totalRead += n if err != nil { p.appendHup(operator) continue @@ -164,14 +167,21 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { logger.Printf("NETPOLL: operator has critical problem! event=%d operator=%v", evt, operator) } } - if triggerHup && triggerRead && operator.Inputs != nil { // read all left data if peer send and close - if err := readall(operator, p.barriers[i]); err != nil && !errors.Is(err, ErrEOF) { - logger.Printf("NETPOLL: readall(fd=%d) before close: %s", operator.FD, err.Error()) - } - } if triggerHup { - p.appendHup(operator) - continue + if triggerRead && operator.Inputs != nil { + // read all left data if peer send and close + var leftRead int + // read all left data if peer send and close + if leftRead, err = readall(operator, p.barriers[i]); err != nil && !errors.Is(err, ErrEOF) { + logger.Printf("NETPOLL: readall(fd=%d)=%d before close: %s", operator.FD, total, err.Error()) + } + totalRead += leftRead + } + // only close connection if no further read bytes + if totalRead == 0 { + p.appendHup(operator) + continue + } } if triggerError { // Under block-zerocopy, the kernel may give an error callback, which is not a real error, just an EAGAIN. @@ -238,8 +248,6 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { case PollWritable: // client create a new connection and wait connect finished operator.inuse() op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollModReadable: // client wait read/write - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR diff --git a/poll_default_linux_test.go b/poll_default_linux_test.go index acd0afc9..072963d7 100644 --- a/poll_default_linux_test.go +++ b/poll_default_linux_test.go @@ -62,7 +62,7 @@ func TestEpollEvent(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - n, err := EpollWait(epollfd, events, -1) + n, err := epollWaitUntil(epollfd, events, -1) MustNil(t, err) Equal(t, n, 1) Equal(t, events[0].data, eventdata2) @@ -80,7 +80,7 @@ func TestEpollEvent(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - n, err = EpollWait(epollfd, events, -1) + n, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Equal(t, events[0].data, eventdata3) _, err = syscall.Read(rfd, recv) @@ -112,7 +112,7 @@ func TestEpollWait(t *testing.T) { } err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -120,7 +120,7 @@ func TestEpollWait(t *testing.T) { // EPOLL: readable _, err = syscall.Write(wfd, send) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -128,7 +128,7 @@ func TestEpollWait(t *testing.T) { MustTrue(t, err == nil && string(recv) == string(send)) // EPOLL: read finished - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -136,7 +136,7 @@ func TestEpollWait(t *testing.T) { // EPOLL: close peer fd err = syscall.Close(wfd) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -149,7 +149,7 @@ func TestEpollWait(t *testing.T) { err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd2, event) err = syscall.Close(rfd2) MustNil(t, err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN != 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -174,7 +174,7 @@ func TestEpollETClose(t *testing.T) { // EPOLL: init state err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLOUT != 0) @@ -185,7 +185,7 @@ func TestEpollETClose(t *testing.T) { // nothing will happen err = syscall.Close(rfd) MustNil(t, err) - n, err := EpollWait(epollfd, events, 100) + n, err := epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 0, n) err = syscall.Close(wfd) @@ -197,7 +197,7 @@ func TestEpollETClose(t *testing.T) { err = EpollCtl(epollfd, unix.EPOLL_CTL_ADD, rfd, event) err = syscall.Close(wfd) MustNil(t, err) - n, err = EpollWait(epollfd, events, 100) + n, err = epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 1, n) Assert(t, events[0].events&syscall.EPOLLIN != 0) @@ -231,7 +231,7 @@ func TestEpollETDel(t *testing.T) { MustNil(t, err) _, err = syscall.Write(wfd, send) MustNil(t, err) - _, err = EpollWait(epollfd, events, 100) + _, err = epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLIN == 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -272,11 +272,11 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd1, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) - Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) - Assert(t, events[0].events&syscall.EPOLLERR == 0) + //Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) + //Assert(t, events[0].events&syscall.EPOLLERR == 0) // forget to del fd //err = EpollCtl(epollfd, unix.EPOLL_CTL_DEL, fd1, event1) //MustNil(t, err) @@ -293,7 +293,7 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd2, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -314,7 +314,7 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Connect(fd3, &addr) t.Log(err) - _, err = EpollWait(epollfd, events, -1) + _, err = epollWaitUntil(epollfd, events, -1) MustNil(t, err) Assert(t, events[0].events&syscall.EPOLLOUT != 0) Assert(t, events[0].events&syscall.EPOLLRDHUP == 0) @@ -324,7 +324,16 @@ func TestEpollConnectSameFD(t *testing.T) { MustNil(t, err) err = syscall.Close(fd3) // close fd3 MustNil(t, err) - n, err := EpollWait(epollfd, events, 100) + n, err := epollWaitUntil(epollfd, events, 100) MustNil(t, err) Assert(t, n == 0) } + +func epollWaitUntil(epfd int, events []epollevent, msec int) (n int, err error) { +WAIT: + n, err = EpollWait(epfd, events, msec) + if err == syscall.EINTR { + goto WAIT + } + return n, err +}