Skip to content

Commit

Permalink
fix: ctx race when disconnect callback run with connect callback (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
joway authored Feb 20, 2024
1 parent faa5263 commit b193834
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 13 deletions.
6 changes: 4 additions & 2 deletions connection_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ type connection struct {
outputBuffer *LinkBuffer
outputBarrier *barrier
supportZeroCopy bool
maxSize int // The maximum size of data between two Release().
bookSize int // The size of data that can be read at once.
maxSize int // The maximum size of data between two Release().
bookSize int // The size of data that can be read at once.
state int32 // 0: not connected, 1: connected, 2: disconnected. Connection state should be changed sequentially.
}

var (
Expand Down Expand Up @@ -323,6 +324,7 @@ func (c *connection) init(conn Conn, opts *options) (err error) {
c.bookSize, c.maxSize = pagesize, pagesize
c.inputBuffer, c.outputBuffer = NewLinkBuffer(pagesize), NewLinkBuffer()
c.outputBarrier = barrierPool.Get().(*barrier)
c.state = 0

c.initNetFD(conn) // conn must be *netFD{}
c.initFDOperator()
Expand Down
1 change: 1 addition & 0 deletions connection_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type key int32

const (
closing key = iota
connecting
processing
flushing
// total must be at the bottom.
Expand Down
46 changes: 41 additions & 5 deletions connection_onevent.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,36 @@ func (c *connection) onPrepare(opts *options) (err error) {
func (c *connection) onConnect() {
var onConnect, _ = c.onConnectCallback.Load().(OnConnect)
if onConnect == nil {
atomic.StoreInt32(&c.state, 1)
return
}
if !c.lock(connecting) {
// it never happens because onDisconnect will not lock connecting if c.connected == 0
return
}
var onRequest, _ = c.onRequestCallback.Load().(OnRequest)
var connected int32
c.onProcess(
// only process when conn active and have unread data
func(c *connection) bool {
// if onConnect not called
if atomic.LoadInt32(&connected) == 0 {
if atomic.LoadInt32(&c.state) == 0 {
return true
}
// check for onRequest
return onRequest != nil && c.Reader().Len() > 0
},
func(c *connection) {
if atomic.CompareAndSwapInt32(&connected, 0, 1) {
if atomic.CompareAndSwapInt32(&c.state, 0, 1) {
c.ctx = onConnect(c.ctx, c)

if !c.IsActive() && atomic.CompareAndSwapInt32(&c.state, 1, 2) {
// since we hold connecting lock, so we should help to call onDisconnect here
var onDisconnect, _ = c.onDisconnectCallback.Load().(OnDisconnect)
if onDisconnect != nil {
onDisconnect(c.ctx, c)
}
}
c.unlock(connecting)
return
}
if onRequest != nil {
Expand All @@ -160,12 +173,31 @@ func (c *connection) onConnect() {
)
}

// when onDisconnect called, c.IsActive() must return false
func (c *connection) onDisconnect() {
var onDisconnect, _ = c.onDisconnectCallback.Load().(OnDisconnect)
if onDisconnect == nil {
return
}
onDisconnect(c.ctx, c)
var onConnect, _ = c.onConnectCallback.Load().(OnConnect)
if onConnect == nil {
// no need lock if onConnect is nil
atomic.StoreInt32(&c.state, 2)
onDisconnect(c.ctx, c)
return
}
// check if OnConnect finished when onConnect != nil && onDisconnect != nil
if atomic.LoadInt32(&c.state) > 0 && c.lock(connecting) { // means OnConnect already finished
// protect onDisconnect run once
// if CAS return false, means OnConnect already helps to run onDisconnect
if atomic.CompareAndSwapInt32(&c.state, 1, 2) {
onDisconnect(c.ctx, c)
}
c.unlock(connecting)
return
}
// OnConnect is not finished yet, return and let onConnect helps to call onDisconnect
return
}

// onRequest is responsible for executing the closeCallbacks after the connection has been closed.
Expand All @@ -174,6 +206,11 @@ func (c *connection) onRequest() (needTrigger bool) {
if !ok {
return true
}
// wait onConnect finished first
if atomic.LoadInt32(&c.state) == 0 && c.onConnectCallback.Load() != nil {
// let onConnect to call onRequest
return
}
processed := c.onProcess(
// only process when conn active and have unread data
func(c *connection) bool {
Expand Down Expand Up @@ -259,7 +296,6 @@ func (c *connection) onProcess(isProcessable func(c *connection) bool, process f
panicked = false
return
}

runTask(c.ctx, task)
return true
}
Expand Down
7 changes: 5 additions & 2 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,10 @@ func TestParallelShortConnection(t *testing.T) {
var received int64
el, err := NewEventLoop(func(ctx context.Context, connection Connection) error {
data, err := connection.Reader().Next(connection.Reader().Len())
atomic.AddInt64(&received, int64(len(data)))
if err != nil {
return err
}
atomic.AddInt64(&received, int64(len(data)))
//t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive())
return nil
})
Expand Down Expand Up @@ -536,10 +536,13 @@ func TestParallelShortConnection(t *testing.T) {
}
wg.Wait()

for atomic.LoadInt64(&received) < int64(totalSize) {
count := 100
for count > 0 && atomic.LoadInt64(&received) < int64(totalSize) {
t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize)
time.Sleep(time.Millisecond * 100)
count--
}
Equal(t, atomic.LoadInt64(&received), int64(totalSize))
}

func TestConnectionServerClose(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions eventloop.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ type EventLoop interface {
| Read first byte | OnRequest | Conn is ready for read or write
| Peer closed but conn is active | OnDisconnect | Conn access will race with OnRequest function
| Self closed and conn is closed | CloseCallback | Conn is destroyed
Execution Order:
OnPrepare => OnConnect => OnRequest => CloseCallback
OnDisconnect
Note: only OnRequest and OnDisconnect will be executed in parallel
*/

// OnPrepare is used to inject custom preparation at connection initialization,
Expand Down
58 changes: 54 additions & 4 deletions netpoll_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ func TestOnConnectWrite(t *testing.T) {
}

func TestOnDisconnect(t *testing.T) {
var ctxKey = struct{}{}
type ctxKey struct{}
var network, address = "tcp", ":8888"
var canceled, closed int32
var conns int32 = 100
req := "ping"
var loop = newTestEventLoop(network, address,
func(ctx context.Context, connection Connection) error {
cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc)
cancelFunc, _ := ctx.Value(ctxKey{}).(context.CancelFunc)
MustTrue(t, cancelFunc != nil)
Assert(t, ctx.Done() != nil)

Expand All @@ -164,10 +164,10 @@ func TestOnDisconnect(t *testing.T) {
return nil
})
ctx, cancel := context.WithCancel(ctx)
return context.WithValue(ctx, ctxKey, cancel)
return context.WithValue(ctx, ctxKey{}, cancel)
}),
WithOnDisconnect(func(ctx context.Context, conn Connection) {
cancelFunc, _ := ctx.Value(ctxKey).(context.CancelFunc)
cancelFunc, _ := ctx.Value(ctxKey{}).(context.CancelFunc)
MustTrue(t, cancelFunc != nil)
cancelFunc()
}),
Expand Down Expand Up @@ -196,6 +196,56 @@ func TestOnDisconnect(t *testing.T) {
MustNil(t, err)
}

func TestOnDisconnectWhenOnConnect(t *testing.T) {
type ctxPrepareKey struct{}
type ctxConnectKey struct{}
var network, address = "tcp", ":8888"
var conns int32 = 100
var wg sync.WaitGroup
wg.Add(int(conns) * 3)
var loop = newTestEventLoop(network, address,
func(ctx context.Context, connection Connection) error {
_, _ = connection.Reader().Next(connection.Reader().Len())
return nil
},
WithOnPrepare(func(connection Connection) context.Context {
defer wg.Done()
var counter int32
return context.WithValue(context.Background(), ctxPrepareKey{}, &counter)
}),
WithOnConnect(func(ctx context.Context, conn Connection) context.Context {
defer wg.Done()
t.Logf("OnConnect: %v", conn.RemoteAddr())
time.Sleep(time.Millisecond * 10) // wait for closed called
counter := ctx.Value(ctxPrepareKey{}).(*int32)
ok := atomic.CompareAndSwapInt32(counter, 0, 1)
Assert(t, ok)
return context.WithValue(ctx, ctxConnectKey{}, "123")
}),
WithOnDisconnect(func(ctx context.Context, conn Connection) {
defer wg.Done()
t.Logf("OnDisconnect: %v", conn.RemoteAddr())
counter, _ := ctx.Value(ctxPrepareKey{}).(*int32)
ok := atomic.CompareAndSwapInt32(counter, 1, 2)
Assert(t, ok)
v := ctx.Value(ctxConnectKey{}).(string)
Equal(t, v, "123")
}),
)

for i := int32(0); i < conns; i++ {
var conn, err = DialConnection(network, address, time.Second)
MustNil(t, err)
err = conn.Close()
t.Logf("Close: %v", conn.LocalAddr())
MustNil(t, err)
}

wg.Wait()
err := loop.Shutdown(context.Background())
MustNil(t, err)
}

func TestGracefulExit(t *testing.T) {
var network, address = "tcp", ":8888"

Expand Down

0 comments on commit b193834

Please sign in to comment.