From 376fa78a94322c15b3a2dbeaa1c393163e6b273c Mon Sep 17 00:00:00 2001 From: Joway Date: Mon, 18 Sep 2023 14:15:32 +0800 Subject: [PATCH] fix: protect operator detach twice (#283) --- .github/workflows/pr-check.yml | 4 +- connection_onevent.go | 23 ++++++-- connection_reactor.go | 9 ++-- connection_test.go | 98 ++++++++++++++++++++++++++++++++++ fd_operator.go | 7 +++ nocopy_linkbuffer.go | 8 +-- nocopy_linkbuffer_race.go | 8 +-- 7 files changed, 138 insertions(+), 19 deletions(-) diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index d61d604b..a923b418 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -1,12 +1,12 @@ name: Push and Pull Request Check -on: [ push ] +on: [ push, pull_request ] jobs: compatibility-test: strategy: matrix: - go: [ 1.15, "1.20" ] + go: [ 1.15, "1.21" ] os: [ X64, ARM64 ] runs-on: ${{ matrix.os }} steps: diff --git a/connection_onevent.go b/connection_onevent.go index 6ddbf31e..9b87f01b 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -195,13 +195,26 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f if isProcessable(c) { process(c) } - for !c.isCloseBy(user) && isProcessable(c) { + // `process` must either eventually read all the input data or actively Close the connection, + // otherwise the goroutine will fall into a dead loop. + var closedBy who + for { + closedBy = c.status(closing) + // close by user or no processable + if closedBy == user || !isProcessable(c) { + break + } process(c) } // Handling callback if connection has been closed. - if !c.IsActive() { - // connection if closed by user when processing, so it needs detach - c.closeCallback(false, true) + if closedBy != none { + // if closed by user when processing, it "may" needs detach + needDetach := closedBy == user + // Here is a conor case that operator will be detached twice: + // If server closed the connection(client OnHup will detach op first and closeBy=poller), + // and then client's OnRequest function also closed the connection(closeBy=user). + // But operator already prevent that detach twice will not cause any problem + c.closeCallback(false, needDetach) panicked = false return } @@ -229,7 +242,7 @@ func (c *connection) closeCallback(needLock bool, needDetach bool) (err error) { if needDetach && c.operator.poll != nil { // If Close is called during OnPrepare, poll is not registered. // PollDetach only happen when user call conn.Close() or poller detect error if err := c.operator.Control(PollDetach); err != nil { - logger.Printf("NETPOLL: onClose detach operator failed: %v", err) + logger.Printf("NETPOLL: closeCallback[%v,%v] detach operator failed: %v", needLock, needDetach, err) } } var latest = c.closeCallbacks.Load() diff --git a/connection_reactor.go b/connection_reactor.go index 2acd45ce..fa485be1 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -31,11 +31,12 @@ func (c *connection) onHup(p Poll) error { c.triggerRead(Exception(ErrEOF, "peer close")) c.triggerWrite(Exception(ErrConnClosed, "peer close")) // It depends on closing by user if OnConnect and OnRequest is nil, otherwise it needs to be released actively. - // It can be confirmed that the OnRequest goroutine has been exited before closecallback executing, + // It can be confirmed that the OnRequest goroutine has been exited before closeCallback executing, // and it is safe to close the buffer at this time. - var onConnect, _ = c.onConnectCallback.Load().(OnConnect) - var onRequest, _ = c.onRequestCallback.Load().(OnRequest) - if onConnect != nil || onRequest != nil { + var onConnect = c.onConnectCallback.Load() + var onRequest = c.onRequestCallback.Load() + var needCloseByUser = onConnect == nil && onRequest == nil + if !needCloseByUser { // already PollDetach when call OnHup c.closeCallback(true, false) } diff --git a/connection_test.go b/connection_test.go index ec72a0c9..6de6f017 100644 --- a/connection_test.go +++ b/connection_test.go @@ -540,3 +540,101 @@ func TestParallelShortConnection(t *testing.T) { time.Sleep(time.Millisecond * 100) } } + +func TestConnectionServerClose(t *testing.T) { + ln, err := createTestListener("tcp", ":12345") + MustNil(t, err) + defer ln.Close() + + /* + Client Server + - Client --- connect --> Server + - Client <-- [ping] --- Server + - Client --- [pong] --> Server + - Client <-- close --- Server + - Client --- close --> Server + */ + const PING, PONG = "ping", "pong" + var wg sync.WaitGroup + el, err := NewEventLoop( + func(ctx context.Context, connection Connection) error { + t.Logf("server.OnRequest: addr=%s", connection.RemoteAddr()) + defer wg.Done() + buf, err := connection.Reader().Next(len(PONG)) // pong + Equal(t, string(buf), PONG) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + err = connection.Close() + MustNil(t, err) + return err + }, + WithOnConnect(func(ctx context.Context, connection Connection) context.Context { + t.Logf("server.OnConnect: addr=%s", connection.RemoteAddr()) + defer wg.Done() + // check OnPrepare + v := ctx.Value("prepare").(string) + Equal(t, v, "true") + + _, err := connection.Writer().WriteBinary([]byte(PING)) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + connection.AddCloseCallback(func(connection Connection) error { + t.Logf("server.CloseCallback: addr=%s", connection.RemoteAddr()) + wg.Done() + return nil + }) + return ctx + }), + WithOnPrepare(func(connection Connection) context.Context { + t.Logf("server.OnPrepare: addr=%s", connection.RemoteAddr()) + defer wg.Done() + return context.WithValue(context.Background(), "prepare", "true") + }), + ) + defer el.Shutdown(context.Background()) + go func() { + err := el.Serve(ln) + if err != nil { + t.Logf("servce end with error: %v", err) + } + }() + + var clientOnRequest OnRequest = func(ctx context.Context, connection Connection) error { + t.Logf("client.OnRequest: addr=%s", connection.LocalAddr()) + defer wg.Done() + buf, err := connection.Reader().Next(len(PING)) + MustNil(t, err) + Equal(t, string(buf), PING) + + _, err = connection.Writer().WriteBinary([]byte(PONG)) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + + _, err = connection.Reader().Next(1) // server will not send any data, just wait for server close + MustTrue(t, errors.Is(err, ErrEOF)) // should get EOF when server close + + return connection.Close() + } + conns := 100 + // server: OnPrepare, OnConnect, OnRequest, CloseCallback + // client: OnRequest, CloseCallback + wg.Add(conns * 6) + for i := 0; i < conns; i++ { + go func() { + conn, err := DialConnection("tcp", ":12345", time.Second) + MustNil(t, err) + err = conn.SetOnRequest(clientOnRequest) + MustNil(t, err) + conn.AddCloseCallback(func(connection Connection) error { + t.Logf("client.CloseCallback: addr=%s", connection.LocalAddr()) + defer wg.Done() + return nil + }) + }() + } + //time.Sleep(time.Second) + wg.Wait() +} diff --git a/fd_operator.go b/fd_operator.go index 4132fe9c..1ac843a9 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -42,6 +42,9 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll + // protect only detach once + detached int32 + // private, used by operatorCache next *FDOperator state int32 // CAS: 0(unused) 1(inuse) 2(do-done) @@ -49,6 +52,9 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { + if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { + return nil + } return op.poll.Control(op, event) } @@ -92,4 +98,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil + op.detached = 0 } diff --git a/nocopy_linkbuffer.go b/nocopy_linkbuffer.go index 59cc6530..555ba5ce 100644 --- a/nocopy_linkbuffer.go +++ b/nocopy_linkbuffer.go @@ -461,9 +461,8 @@ func (b *LinkBuffer) WriteBinary(p []byte) (n int, err error) { } // here will copy b.growth(n) - malloc := b.write.malloc - b.write.malloc += n - return copy(b.write.buf[malloc:b.write.malloc], p), nil + buf := b.write.Malloc(n) + return copy(buf, p), nil } // WriteDirect cannot be mixed with WriteString or WriteBinary functions. @@ -578,7 +577,8 @@ func (b *LinkBuffer) GetBytes(p [][]byte) (vs [][]byte) { // // bookSize: The size of data that can be read at once. // maxSize: The maximum size of data between two Release(). In some cases, this can -// guarantee all data allocated in one node to reduce copy. +// +// guarantee all data allocated in one node to reduce copy. func (b *LinkBuffer) book(bookSize, maxSize int) (p []byte) { l := cap(b.write.buf) - b.write.malloc // grow linkBuffer diff --git a/nocopy_linkbuffer_race.go b/nocopy_linkbuffer_race.go index a785aa15..4b3635d0 100644 --- a/nocopy_linkbuffer_race.go +++ b/nocopy_linkbuffer_race.go @@ -497,9 +497,8 @@ func (b *LinkBuffer) WriteBinary(p []byte) (n int, err error) { } // here will copy b.growth(n) - malloc := b.write.malloc - b.write.malloc += n - return copy(b.write.buf[malloc:b.write.malloc], p), nil + buf := b.write.Malloc(n) + return copy(buf, p), nil } // WriteDirect cannot be mixed with WriteString or WriteBinary functions. @@ -622,7 +621,8 @@ func (b *LinkBuffer) GetBytes(p [][]byte) (vs [][]byte) { // // bookSize: The size of data that can be read at once. // maxSize: The maximum size of data between two Release(). In some cases, this can -// guarantee all data allocated in one node to reduce copy. +// +// guarantee all data allocated in one node to reduce copy. func (b *LinkBuffer) book(bookSize, maxSize int) (p []byte) { b.Lock() defer b.Unlock()