diff --git a/README.md b/README.md index 452160b..e005111 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,23 @@ # Async-retry +[![Go Reference](https://pkg.go.dev/badge/github.com/Kyash/async-retry.svg)](https://pkg.go.dev/github.com/Kyash/async-retry) + Async-retry controls asynchronous retries in Go, and can be shutdown gracefully. Main features of Async-retry are -* Disable cancellation of context when pass to function as an argument -* Keep value of context when pass to function as an argument +* Disable cancellation of context passed to function as an argument +* Keep value of context passed to function as an argument * Set timeout(default 10s) for each function call * Recover from panic * Control retry with delegating to https://github.com/avast/retry-go -* Shutdown safely anytime +* Gracefully shutdown You can find other features or settings in options.go # Example ``` +package asyncretry_test + import ( "context" "fmt" @@ -33,21 +37,24 @@ func ExampleAsyncRetry() { "/hello", func(w http.ResponseWriter, r *http.Request) { var ctx = r.Context() - go func() { - err := asyncRetry.Do( - ctx, - func(ctx context.Context) error { - // do task - // ... - return nil - }, - asyncretry.Attempts(5), - asyncretry.Timeout(8*time.Second), - ) - if err != nil { - log.Println(err.Error()) - } - }() + if err := asyncRetry.Do( + ctx, + func(ctx context.Context) error { + // do task + // ... + return nil + }, + func(err error) { + if err != nil { + log.Println(err.Error()) + } + }, + asyncretry.Attempts(5), + asyncretry.Timeout(8*time.Second), + ); err != nil { + // asyncRetry is in shutdown + log.Println(err.Error()) + } fmt.Fprintf(w, "Hello") }, ) diff --git a/async_retry.go b/async_retry.go index 4f71467..a20f1eb 100644 --- a/async_retry.go +++ b/async_retry.go @@ -9,12 +9,16 @@ import ( ) type RetryableFunc func(ctx context.Context) error +type FinishFunc func(error) type AsyncRetry interface { - // Do calls f and retry if necessary. - // In most cases, you should call Do in a new goroutine. - Do(ctx context.Context, f RetryableFunc, opts ...Option) error - // Shutdown shutdowns gracefully + // Do calls f in a new goroutine, and retry if necessary. When finished, `finish` is called regardless of success or failure exactly once. + // Non-nil error is always `ErrInShutdown` that will be returned when AsyncRetry is in shutdown. + Do(ctx context.Context, f RetryableFunc, finish FinishFunc, opts ...Option) error + + // Shutdown gracefully shuts down AsyncRetry without interrupting any active `Do`. + // Shutdown works by first stopping to accept new `Do` request, and then waiting for all active `Do`'s background goroutines to be finished. + // Multiple call of Shutdown is OK. Shutdown(ctx context.Context) error } @@ -33,7 +37,12 @@ func NewAsyncRetry() AsyncRetry { var ErrInShutdown = fmt.Errorf("AsyncRetry is in shutdown") -func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (retErr error) { +func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, finish FinishFunc, opts ...Option) error { + config := DefaultConfig + for _, opt := range opts { + opt(&config) + } + a.mu.RLock() select { case <-a.shutdownChan: @@ -41,22 +50,24 @@ func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (r return ErrInShutdown default: } - // notice that this line should be in lock so that shutdown would not go ahead - a.wg.Add(1) + a.wg.Add(1) // notice that this line should be in lock so that shutdown would not go ahead a.mu.RUnlock() - defer a.wg.Done() - config := DefaultConfig - for _, opt := range opts { - opt(&config) - } - - defer func() { - if err := recover(); err != nil { - retErr = fmt.Errorf("panicking while AsyncRetry err: %v", err) - } + go func() { + defer a.wg.Done() // Done should be called after `finish` returns + defer func() { + if recovered := recover(); recovered != nil { + var err = fmt.Errorf("panicking while AsyncRetry err: %v", recovered) + finish(err) + } + }() + var err = a.call(ctx, f, &config) + finish(err) }() + return nil +} +func (a *asyncRetry) call(ctx context.Context, f RetryableFunc, config *Config) error { ctx, cancel := context.WithCancel(WithoutCancel(ctx)) defer cancel() noMoreRetryCtx, noMoreRetry := context.WithCancel(config.context) @@ -76,8 +87,7 @@ func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (r if config.cancelWhenConfigContextCanceled { cancel() } - // release resources - case <-done: + case <-done: // release resources } }() @@ -102,10 +112,8 @@ func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (r func (a *asyncRetry) Shutdown(ctx context.Context) error { a.mu.Lock() select { - case <-a.shutdownChan: - // Already closed. - default: - // Guarded by a.mu + case <-a.shutdownChan: // Already closed. + default: // Guarded by a.mu close(a.shutdownChan) } a.mu.Unlock() diff --git a/async_retry_test.go b/async_retry_test.go index cdb6e72..70ef6c0 100644 --- a/async_retry_test.go +++ b/async_retry_test.go @@ -167,9 +167,13 @@ func Test_asyncRetry_Do(t *testing.T) { t.Run(tt.name, func(t *testing.T) { counter = 0 a := NewAsyncRetry() + ch := make(chan error) var err error - // Be careful not call Do synchronously when actually using - if err = a.Do(tt.args.ctx(), tt.args.f, tt.args.opts...); (err != nil) != tt.wantErr { + if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { ch <- err }, tt.args.opts...); err != nil { + t.Errorf("Do() failed %v", err) + } + err = <-ch + if (err != nil) != tt.wantErr { t.Errorf("Do() error = %v, wantErr %v", err, tt.wantErr) } if err != nil { @@ -324,9 +328,13 @@ func Test_asyncRetry_DoWithConfigContext(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) defer cancel() a := NewAsyncRetry() + ch := make(chan error) var err error - // Be careful not call Do synchronously when actually using - if err = a.Do(tt.args.ctx(), tt.args.f, tt.args.opts()...); (err != nil) != tt.wantErr { + if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { ch <- err }, tt.args.opts()...); err != nil { + t.Errorf("Do() failed %v", err) + } + err = <-ch + if (err != nil) != tt.wantErr { t.Errorf("Do() error = %v, wantErr %v", err, tt.wantErr) } if err != nil { @@ -477,19 +485,18 @@ func Test_asyncRetry_DoAndShutdown(t *testing.T) { counter = 0 a := NewAsyncRetry() - var doErr = make(chan error) + var doErr = make(chan error, 1) var shutdownErr = make(chan error) - go func() { - doErr <- a.Do( - tt.args.ctx(), tt.args.f, tt.args.opts()..., - ) - }() + var err error + if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { doErr <- err }, tt.args.opts()...); err != nil { + t.Errorf("Do() failed %v", err) + } + go func() { <-ch shutdownErr <- a.Shutdown(context.Background()) }() - var err error select { case err = <-shutdownErr: case <-time.After(time.Second * 10): @@ -542,21 +549,22 @@ func Test_ShutdownOrder(t *testing.T) { var wg sync.WaitGroup for i := 0; i < szDo; i++ { wg.Add(1) - go func() { - err := a.Do( - context.Background(), - func(ctx context.Context) error { - wg.Done() - time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) - results <- 1 - return nil - }, - Timeout(0), - ) - if err != nil { - t.Errorf("Do() error = %v, wantErr %v", err, nil) - } - }() + err := a.Do( + context.Background(), + func(ctx context.Context) error { + wg.Done() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) + return nil + }, + func(error) { + time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) + results <- 1 + }, + Timeout(0), + ) + if err != nil { + t.Errorf("Do() error = %v, wantErr %v", err, nil) + } } for i := 0; i < szShutdown; i++ { go func() { @@ -589,6 +597,7 @@ func Test_ShutdownOrder(t *testing.T) { func(ctx context.Context) error { return nil }, + func(error) {}, ) if err == nil || err.Error() != ErrInShutdown.Error() { t.Errorf("call of Do after shudown must returns InShutdownErr") @@ -606,19 +615,20 @@ func benchmarkDo(tasks int, concurrency int, b *testing.B) { for c := 0; c < concurrency; c++ { wg.Add(1) go func() { - var dummy int defer wg.Done() for range ch { - for i := 0; i < 10000; i++ { - dummy /= dummy + 1 - } - _ = a.Do(context.Background(), func(ctx context.Context) error { - var dummy int - for i := 0; i < 10000; i++ { - dummy /= dummy + 1 - } - return nil - }) + _ = a.Do( + context.Background(), + func(ctx context.Context) error { + var dummy int + for i := 0; i < 100; i++ { + dummy /= dummy + 1 + } + return nil + }, + func(err error) { + }, + ) } }() } @@ -630,9 +640,9 @@ func benchmarkDo(tasks int, concurrency int, b *testing.B) { } } -func BenchmarkDo10000With2(b *testing.B) { benchmarkDo(3000, 2, b) } -func BenchmarkDo10000With4(b *testing.B) { benchmarkDo(3000, 4, b) } -func BenchmarkDo10000With8(b *testing.B) { benchmarkDo(3000, 8, b) } -func BenchmarkDo10000With16(b *testing.B) { benchmarkDo(3000, 16, b) } -func BenchmarkDo10000With32(b *testing.B) { benchmarkDo(3000, 32, b) } -func BenchmarkDo10000With64(b *testing.B) { benchmarkDo(3000, 64, b) } +func BenchmarkDo10000With2(b *testing.B) { benchmarkDo(10000, 2, b) } +func BenchmarkDo10000With4(b *testing.B) { benchmarkDo(10000, 4, b) } +func BenchmarkDo10000With8(b *testing.B) { benchmarkDo(10000, 8, b) } +func BenchmarkDo10000With16(b *testing.B) { benchmarkDo(10000, 16, b) } +func BenchmarkDo10000With32(b *testing.B) { benchmarkDo(10000, 32, b) } +func BenchmarkDo10000With64(b *testing.B) { benchmarkDo(10000, 64, b) } diff --git a/example_test.go b/example_test.go index 20e9119..32fffc3 100644 --- a/example_test.go +++ b/example_test.go @@ -19,21 +19,24 @@ func ExampleAsyncRetry() { "/hello", func(w http.ResponseWriter, r *http.Request) { var ctx = r.Context() - go func() { - err := asyncRetry.Do( - ctx, - func(ctx context.Context) error { - // do task - // ... - return nil - }, - asyncretry.Attempts(5), - asyncretry.Timeout(8*time.Second), - ) - if err != nil { - log.Println(err.Error()) - } - }() + if err := asyncRetry.Do( + ctx, + func(ctx context.Context) error { + // do task + // ... + return nil + }, + func(err error) { + if err != nil { + log.Println(err.Error()) + } + }, + asyncretry.Attempts(5), + asyncretry.Timeout(8*time.Second), + ); err != nil { + // asyncRetry is in shutdown + log.Println(err.Error()) + } fmt.Fprintf(w, "Hello") }, )