Skip to content

Commit

Permalink
BugFix & Upgrade: TimeWheel
Browse files Browse the repository at this point in the history
  • Loading branch information
shenghui0779 committed Dec 22, 2023
1 parent f54dbce commit eba922d
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 105 deletions.
8 changes: 4 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ require (
github.com/jackc/pgx/v5 v5.5.1
github.com/jmoiron/sqlx v1.3.5
github.com/joho/godotenv v1.5.1
github.com/mattn/go-sqlite3 v1.14.18
github.com/mattn/go-sqlite3 v1.14.19
github.com/nsqio/go-nsq v1.1.0
github.com/redis/go-redis/v9 v9.3.0
github.com/redis/go-redis/v9 v9.3.1
github.com/stretchr/testify v1.8.4
go.uber.org/zap v1.26.0
golang.org/x/crypto v0.16.0
golang.org/x/crypto v0.17.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
)

Expand All @@ -33,7 +33,7 @@ require (
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sync v0.5.0 // indirect
Expand Down
16 changes: 8 additions & 8 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNa
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/nsqio/go-nsq v1.1.0 h1:PQg+xxiUjA7V+TLdXw7nVrJ5Jbl3sN86EhGCQj4+FYE=
github.com/nsqio/go-nsq v1.1.0/go.mod h1:vKq36oyeVXgsS5Q8YEO7WghqidAVXQlcFxzQbQTuDEY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0=
github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
Expand All @@ -72,8 +72,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo=
go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
Expand Down
120 changes: 54 additions & 66 deletions timewheel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,46 @@ package yiigo

import (
"context"
"fmt"
"strconv"
"sync"
"time"
)

type ctxTWKey int
type ctxKey int

// CtxTaskAddedAt Context存储任务入队时间的Key
const CtxTaskAddedAt ctxTWKey = 0
// ctxTaskAddedAt Context存储任务入队时间的Key
const ctxTaskAddedAt ctxKey = 0

// TWTask 时间轮任务
type TWTask struct {
// TaskAddedAt 返回任务的添加时间
func TaskAddedAt(ctx context.Context) time.Time {
v := ctx.Value(ctxTaskAddedAt)
if v == nil {
return time.Time{}
}

t, ok := v.(time.Time)
if !ok {
return time.Time{}
}

return t
}

// Task 时间轮任务
type Task struct {
ctx context.Context
uniqID string // 任务唯一标识
round int // 延迟执行的轮数
attempts uint16 // 当前尝试的次数
maxAttempts uint16 // 最大尝试次数
remainder time.Duration // 任务执行前的剩余延迟(小于时间轮精度)
cumulative int64 // 多次重试的累计时长(单位:ns)
deferFn func(attempts uint16) time.Duration // 返回任务下一次延迟执行的时间
callback func(ctx context.Context, taskID string) error // 任务回调函数
}

// TimeWheel 单时间轮
type TimeWheel interface {
// AddTask 添加一个任务,到期被执行,默认仅执行一次;若指定了重试次数,则在发生错误后重试
// 注意:任务是异步执行的,故Context应该是一个克隆的且不带超时时间的
// AddTask 添加一个任务,到期被执行,默认仅执行一次;若指定了重试次数,则在发生错误后重试
// 注意:任务是异步执行的,`ctx`一旦被取消,则任务也随之取消,故需考虑是否应该克隆一个不带「取消」的`ctx`
AddTask(ctx context.Context, taskID string, handler func(ctx context.Context, taskID string) error, options ...TaskOption)
// Run 运行时间轮
Run()
Expand All @@ -43,12 +55,11 @@ type timewheel struct {
size int
bucket []sync.Map
stop chan struct{}
log func(ctx context.Context, v ...any)
}

func (tw *timewheel) AddTask(ctx context.Context, taskID string, handler func(ctx context.Context, taskID string) error, options ...TaskOption) {
task := &TWTask{
ctx: context.WithValue(ctx, CtxTaskAddedAt, time.Now()),
task := &Task{
ctx: ctx,
uniqID: taskID,
callback: handler,
maxAttempts: 1,
Expand All @@ -70,53 +81,49 @@ func (tw *timewheel) Run() {

func (tw *timewheel) Stop() {
select {
case <-tw.stop:
tw.log(context.Background(), "timingwheel stopped")
case <-tw.stop: // 时间轮已停止
return
default:
}

close(tw.stop)

tw.log(context.Background(), "timewheel stopped", "at="+time.Now().String())
}

func (tw *timewheel) requeue(task *TWTask) {
if task.attempts >= task.maxAttempts {
return
}

func (tw *timewheel) requeue(task *Task) {
select {
case <-tw.stop:
tw.log(task.ctx, "task requeue failed because of timewheel has stopped", "task_id="+task.uniqID, "attempts="+strconv.Itoa(int(task.attempts+1)))
case <-tw.stop: // 时间轮已停止
return
default:
}

task.ctx = context.WithValue(task.ctx, CtxTaskAddedAt, time.Now())
// 任务已达到最大尝试次数
if task.attempts >= task.maxAttempts {
return
}

task.attempts++
task.ctx = context.WithValue(task.ctx, ctxTaskAddedAt, time.Now())

tick := tw.tick.Nanoseconds()
delay := task.deferFn(task.attempts)
duration := delay.Nanoseconds()

task.cumulative += duration
// 圈数
task.round = int(duration / (tick * int64(tw.size)))

slot := int(task.cumulative/tick) % tw.size
// 槽位
slot := (int(duration/tick)%tw.size + tw.slot) % tw.size
if slot == tw.slot {
if task.round == 0 {
task.remainder = delay
go tw.run(task)
go tw.do(task)

return
}

task.round--
}

task.remainder = time.Duration(task.cumulative % tick)
// 剩余延迟
task.remainder = time.Duration(duration % tick)

tw.bucket[slot].Store(task.uniqID, task)
}
Expand All @@ -127,7 +134,7 @@ func (tw *timewheel) scheduler() {

for {
select {
case <-tw.stop:
case <-tw.stop: // 时间轮已停止
return
case <-ticker.C:
tw.slot = (tw.slot + 1) % tw.size
Expand All @@ -139,30 +146,33 @@ func (tw *timewheel) scheduler() {
func (tw *timewheel) process(slot int) {
tw.bucket[slot].Range(func(key, value any) bool {
select {
case <-tw.stop:
case <-tw.stop: // 时间轮已停止
return false
default:
}

task := value.(*TWTask)

task := value.(*Task)
if task.round > 0 {
task.round--
return true
}

go tw.run(task)
select {
case <-task.ctx.Done(): // 任务被取消
default:
go tw.do(task)
}

tw.bucket[slot].Delete(key)

return true
})
}

func (tw *timewheel) run(task *TWTask) {
func (tw *timewheel) do(task *Task) {
defer func() {
if v := recover(); v != nil {
tw.log(task.ctx, "task do panic", "task_id="+task.uniqID, fmt.Sprintf("error=%v", v))
if recover() != nil {
tw.requeue(task)
}
}()

Expand All @@ -171,31 +181,16 @@ func (tw *timewheel) run(task *TWTask) {
}

if err := task.callback(task.ctx, task.uniqID); err != nil {
tw.log(task.ctx, "task do error", "task_id="+task.uniqID, "error="+err.Error())
tw.requeue(task)

return
}
}

// TWOption 时间轮选项
type TWOption func(tw *timewheel)

// WithTWErrLog 设置时间轮错误日志
func WithTWErrLog(fn func(ctx context.Context, v ...any)) TWOption {
return func(tw *timewheel) {
if fn != nil {
tw.log = fn
}
}
}

// TaskOption 时间轮任务选项
type TaskOption func(t *TWTask)
type TaskOption func(t *Task)

// WithTaskAttempts 指定任务重试次数;默认:1
func WithTaskAttempts(attempts uint16) TaskOption {
return func(t *TWTask) {
return func(t *Task) {
if attempts > 0 {
t.maxAttempts = attempts
}
Expand All @@ -204,26 +199,19 @@ func WithTaskAttempts(attempts uint16) TaskOption {

// WithTaskDefer 指定任务延迟执行时间;默认:立即执行
func WithTaskDefer(fn func(attempts uint16) time.Duration) TaskOption {
return func(t *TWTask) {
return func(t *Task) {
if fn != nil {
t.deferFn = fn
}
}
}

// NewTimeWheel 返回一个时间轮实例
func NewTimeWheel(tick time.Duration, size int, options ...TWOption) TimeWheel {
tw := &timewheel{
func NewTimeWheel(tick time.Duration, size int) TimeWheel {
return &timewheel{
tick: tick,
size: size,
bucket: make([]sync.Map, size),
stop: make(chan struct{}),
log: func(ctx context.Context, v ...any) {},
}

for _, f := range options {
f(tw)
}

return tw
}
50 changes: 23 additions & 27 deletions timewheel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@ func TestTimeWheel_1(t *testing.T) {
ch := make(chan string)
defer close(ch)

tw := NewTimeWheel(time.Second, 60)
tw := NewTimeWheel(time.Second, 7)

for i := 0; i < 10; i++ {
n := i + 1
now := time.Now()

tw.AddTask(context.Background(), "task#"+strconv.Itoa(n), func(ctx context.Context, taskID string) error {
ch <- fmt.Sprintf("%s - %ds", taskID, int64(time.Since(now).Seconds()))

ch <- fmt.Sprintf("%s run after %ds", taskID, int64(time.Since(TaskAddedAt(ctx)).Seconds()))
return nil
}, WithTaskDefer(func(attempts uint16) time.Duration {
return time.Second * time.Duration(n+i)
Expand All @@ -42,16 +40,16 @@ func TestTimeWheel_1(t *testing.T) {
}

assert.Equal(t, []string{
"task#1 - 1s",
"task#2 - 3s",
"task#3 - 5s",
"task#4 - 7s",
"task#5 - 9s",
"task#6 - 11s",
"task#7 - 13s",
"task#8 - 15s",
"task#9 - 17s",
"task#10 - 19s",
"task#1 run after 1s",
"task#2 run after 3s",
"task#3 run after 5s",
"task#4 run after 7s",
"task#5 run after 9s",
"task#6 run after 11s",
"task#7 run after 13s",
"task#8 run after 15s",
"task#9 run after 17s",
"task#10 run after 19s",
}, ret)
}

Expand All @@ -65,11 +63,9 @@ func TestTimeWheel_2(t *testing.T) {

for i := 0; i < 10; i++ {
n := i + 1
now := time.Now()

tw.AddTask(context.Background(), "task#"+strconv.Itoa(n), func(ctx context.Context, taskID string) error {
ch <- fmt.Sprintf("%s - %ds", taskID, int64(time.Since(now).Seconds()))

ch <- fmt.Sprintf("%s run after %ds", taskID, int64(time.Since(TaskAddedAt(ctx)).Seconds()))
return nil
}, WithTaskDefer(func(attempts uint16) time.Duration {
return time.Second * time.Duration(n+i)
Expand All @@ -86,15 +82,15 @@ func TestTimeWheel_2(t *testing.T) {
}

assert.Equal(t, []string{
"task#1 - 1s",
"task#2 - 3s",
"task#3 - 5s",
"task#4 - 7s",
"task#5 - 9s",
"task#6 - 11s",
"task#7 - 13s",
"task#8 - 15s",
"task#9 - 17s",
"task#10 - 19s",
"task#1 run after 1s",
"task#2 run after 3s",
"task#3 run after 5s",
"task#4 run after 7s",
"task#5 run after 9s",
"task#6 run after 11s",
"task#7 run after 13s",
"task#8 run after 15s",
"task#9 run after 17s",
"task#10 run after 19s",
}, ret)
}

0 comments on commit eba922d

Please sign in to comment.