Skip to content

Commit

Permalink
refactor: simplify shard queue
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed May 30, 2024
1 parent 3e411b1 commit f6ae9e0
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 130 deletions.
196 changes: 91 additions & 105 deletions mux/shard_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,93 +17,89 @@ package mux
import (
"fmt"
"runtime"
"sync"
"sync/atomic"

"github.com/bytedance/gopkg/util/gopool"

"github.com/cloudwego/netpoll"
)

/* DOC:
* ShardQueue uses the netpoll's nocopy API to merge and send data.
* The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
* If there is an error in the data transmission, the connection will be closed.
*
* ShardQueue.Add: add the data to be sent.
* NewShardQueue: create a queue with netpoll.Connection.
* ShardSize: the recommended number of shards is 32.
*/
var ShardSize int

func init() {
ShardSize = runtime.GOMAXPROCS(0)
}
// ShardQueue uses the netpoll nocopy Writer API to merge multi packets and send them at once.
// The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
// If there is an error in the data transmission, the connection will be closed.

// NewShardQueue .
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
// NewShardQueue create a queue with netpoll.Connection
func NewShardQueue(shardsize int, conn netpoll.Connection) (queue *ShardQueue) {
queue = &ShardQueue{
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
conn: conn,
shardsize: uint32(shardsize),
shards: make([][]WriterGetter, shardsize),
locks: make([]int32, shardsize),
}
for i := range queue.getters {
queue.getters[i] = make([]WriterGetter, 0, 64)
for i := range queue.shards {
queue.shards[i] = make([]WriterGetter, 0, 64)
}
queue.list = make([]int32, size)
queue.shard = make([]WriterGetter, 0, 64)
return queue
}

// WriterGetter is used to get a netpoll.Writer.
type WriterGetter func() (buf netpoll.Writer, isNil bool)

// ShardQueue uses the netpoll's nocopy API to merge and send data.
// ShardQueue uses the netpoll s nocopy API to merge and send data.
// The Data Flush is passively triggered by ShardQueue.Add and does not require user operations.
// If there is an error in the data transmission, the connection will be closed.
// ShardQueue.Add: add the data to be sent.
type ShardQueue struct {
// state definition:
// active : only active state can allow user Add new task
// closing : ShardQueue.Close is called and try to close gracefully, cannot Add new data
// closed : Gracefully shutdown finished
state int32

conn netpoll.Connection
idx, size int32
getters [][]WriterGetter // len(getters) = size
swap []WriterGetter // use for swap
locks []int32 // len(locks) = size
queueTrigger
size uint32 // the size of all getters in all shards
shardsize uint32 // the size of shards
shards [][]WriterGetter // the shards of getters, len(shards) = shardsize
shard []WriterGetter // the shard is dealing, use shard to swap
locks []int32 // the locks of shards, len(locks) = shardsize
// trigger used to avoid triggering function re-enter twice.
// trigger == 0: nothing to do
// trigger == 1: we should start a new triggering()
// trigger >= 2: triggering() already started
trigger int32
}

const (
// queueTrigger state
// ShardQueue state
active = 0
closing = 1
closed = 2
)

// here for trigger
type queueTrigger struct {
trigger int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
listLock sync.Mutex // list total lock
}
var idgen uint32

// Add adds to q.getters[shard]
func (q *ShardQueue) Add(gts ...WriterGetter) {
if atomic.LoadInt32(&q.state) != active {
return
// Add adds gts to ShardQueue
func (q *ShardQueue) Add(gts ...WriterGetter) bool {
size := uint32(len(gts))
if size == 0 || atomic.LoadInt32(&q.state) != active {
return false
}
shard := atomic.AddInt32(&q.idx, 1) % q.size
q.lock(shard)
trigger := len(q.getters[shard]) == 0
q.getters[shard] = append(q.getters[shard], gts...)
q.unlock(shard)
if trigger {
q.triggering(shard)

// get current shard id
shardid := atomic.AddUint32(&idgen, 1) % q.shardsize
// add new shards into shard
q.lock(shardid)
q.shards[shardid] = append(q.shards[shardid], gts...)
// size update should happen in lock, because we should make sure when q.shards unlock, worker can get the correct size
_ = atomic.AddUint32(&q.size, size)
q.unlock(shardid)

if atomic.AddInt32(&q.trigger, 1) == 1 {
go q.triggering(shardid)
}
return true
}

// Close graceful shutdown the ShardQueue and will flush all data added first
func (q *ShardQueue) Close() error {
if !atomic.CompareAndSwapInt32(&q.state, active, closing) {
return fmt.Errorf("shardQueue has been closed")
Expand All @@ -120,73 +116,63 @@ func (q *ShardQueue) Close() error {
}

// triggering shard.
func (q *ShardQueue) triggering(shard int32) {
q.listLock.Lock()
q.w = (q.w + 1) % q.size
q.list[q.w] = shard
q.listLock.Unlock()

if atomic.AddInt32(&q.trigger, 1) > 1 {
return
func (q *ShardQueue) triggering(shardid uint32) {
WORKER:
for atomic.LoadUint32(&q.size) > 0 {
// lock & shard
q.lock(shardid)
shard := q.shards[shardid]
q.shards[shardid] = q.shard[:0]
q.shard = shard[:0] // reuse current shard's space for next round
q.unlock(shardid)

if len(shard) > 0 {
// flush shard
q.deal(shard)
// only decrease q.size when the shard dealt
atomic.AddUint32(&q.size, -uint32(len(shard)))
}
// if there have any new data, the next shard must not be empty
shardid = (shardid + 1) % q.shardsize
}
q.foreach()
}

// foreach swap r & w. It's not concurrency safe.
func (q *ShardQueue) foreach() {
if atomic.AddInt32(&q.runNum, 1) > 1 {
return
// flush connection
q.flush()

// [IMPORTANT] Atomic Double Check:
// ShardQueue.Add will ensure it will always update 'size' and 'trigger'.
// - If CAS(q.trigger, oldTrigger, 0) = true, it means there is no triggering() call during size check,
// so it's safe to exit triggering(). And any new Add() call will start triggering() successfully.
// - If CAS failed, there may have a failed triggering() call during Load(q.trigger) and CAS(q.trigger),
// so we should re-check q.size again from beginning.
oldTrigger := atomic.LoadInt32(&q.trigger)
if atomic.LoadUint32(&q.size) > 0 {
goto WORKER
}
if !atomic.CompareAndSwapInt32(&q.trigger, oldTrigger, 0) {
goto WORKER
}
gopool.CtxGo(nil, func() {
var negNum int32 // is negative number of triggerNum
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
q.r = (q.r + 1) % q.size
shared := q.list[q.r]

// lock & swap
q.lock(shared)
tmp := q.getters[shared]
q.getters[shared] = q.swap[:0]
q.swap = tmp
q.unlock(shared)

// deal
q.deal(q.swap)
negNum--
if triggerNum+negNum == 0 {
triggerNum = atomic.AddInt32(&q.trigger, negNum)
negNum = 0
}
}
q.flush()

// quit & check again
atomic.StoreInt32(&q.runNum, 0)
if atomic.LoadInt32(&q.trigger) > 0 {
q.foreach()
return
}
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
})
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
return
}

// deal is used to get deal of netpoll.Writer.
// deal append all getters into connection
func (q *ShardQueue) deal(gts []WriterGetter) {
writer := q.conn.Writer()
for _, gt := range gts {
buf, isNil := gt()
if !isNil {
err := writer.Append(buf)
if err != nil {
if err != nil { // never happen
q.conn.Close()
return
}
}
}
}

// flush is used to flush netpoll.Writer.
// flush the connection and send all appended data
func (q *ShardQueue) flush() {
err := q.conn.Writer().Flush()
if err != nil {
Expand All @@ -196,13 +182,13 @@ func (q *ShardQueue) flush() {
}

// lock shard.
func (q *ShardQueue) lock(shard int32) {
func (q *ShardQueue) lock(shard uint32) {
for !atomic.CompareAndSwapInt32(&q.locks[shard], 0, 1) {
runtime.Gosched()
}
}

// unlock shard.
func (q *ShardQueue) unlock(shard int32) {
func (q *ShardQueue) unlock(shard uint32) {
atomic.StoreInt32(&q.locks[shard], 0)
}
78 changes: 53 additions & 25 deletions mux/shard_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package mux

import (
"io"
"net"
"runtime"
"sync/atomic"
"testing"
"time"

Expand All @@ -27,18 +30,22 @@ import (

func TestShardQueue(t *testing.T) {
var svrConn net.Conn
var cliConn netpoll.Connection
accepted := make(chan struct{})
stopped := make(chan struct{})
streams, framesize := 128, 1024
totalsize := int32(streams * framesize)
var send, read int32

// create server connection
network, address := "tcp", ":18888"
ln, err := net.Listen("tcp", ":18888")
MustNil(t, err)
stop := make(chan int, 1)
defer close(stop)
go func() {
var err error
for {
select {
case <-stop:
case <-stopped:
err = ln.Close()
MustNil(t, err)
return
Expand All @@ -47,35 +54,56 @@ func TestShardQueue(t *testing.T) {
svrConn, err = ln.Accept()
MustNil(t, err)
accepted <- struct{}{}
go func() {
recv := make([]byte, 10240)
for {
n, err := svrConn.Read(recv)
atomic.AddInt32(&read, int32(n))
for i := 0; i < n; i++ {
MustTrue(t, recv[i] == 'a')
}
if err == io.EOF {
return
}
MustNil(t, err)
}
}()
}
}()

conn, err := netpoll.DialConnection(network, address, time.Second)
// create client connection
cliConn, err = netpoll.DialConnection(network, address, time.Second)
MustNil(t, err)
<-accepted
<-accepted // wait svrConn accepted

// test
queue := NewShardQueue(4, conn)
count, pkgsize := 16, 11
for i := 0; i < int(count); i++ {
var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) {
buf = netpoll.NewLinkBuffer(pkgsize)
buf.Malloc(pkgsize)
return buf, false
}
queue.Add(getter)
// cliConn flush packets to svrConn with ShardQueue
queue := NewShardQueue(4, cliConn)
for i := 0; i < streams; i++ {
go func() {
var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) {
buf = netpoll.NewLinkBuffer(framesize)
data, err := buf.Malloc(framesize)
MustNil(t, err)
for b := 0; b < framesize; b++ {
data[b] = 'a'
}
return buf, false
}
if queue.Add(getter) {
atomic.AddInt32(&send, int32(framesize))
}
}()
}

//cliConn graceful close, shardQueue should flush all data correctly
for atomic.LoadInt32(&send) < totalsize/2 {
t.Logf("waiting send all packets: send=%d", atomic.LoadInt32(&send))
runtime.Gosched()
}
err = queue.Close()
MustNil(t, err)
total := count * pkgsize
recv := make([]byte, total)
rn, err := svrConn.Read(recv)
MustNil(t, err)
Equal(t, rn, total)
}

// TODO: need mock flush
func BenchmarkShardQueue(b *testing.B) {
b.Skip()
for atomic.LoadInt32(&read) != atomic.LoadInt32(&send) {
t.Logf("waiting read all packets: read=%d", atomic.LoadInt32(&read))
runtime.Gosched()
}
}

0 comments on commit f6ae9e0

Please sign in to comment.