Skip to content

Commit

Permalink
Merge pull request #7 from Kyash/fix-unsafe-shutdown
Browse files Browse the repository at this point in the history
fix unsafe shutdown
  • Loading branch information
KeiichiHirobe authored Nov 22, 2022
2 parents 8ef9c04 + 52ea676 commit 64e2bcc
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 99 deletions.
43 changes: 25 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
# Async-retry
[![Go Reference](https://pkg.go.dev/badge/github.com/Kyash/async-retry.svg)](https://pkg.go.dev/github.com/Kyash/async-retry)

Async-retry controls asynchronous retries in Go, and can be shutdown gracefully.

Main features of Async-retry are
* Disable cancellation of context when pass to function as an argument
* Keep value of context when pass to function as an argument
* Disable cancellation of context passed to function as an argument
* Keep value of context passed to function as an argument
* Set timeout(default 10s) for each function call
* Recover from panic
* Control retry with delegating to https://github.com/avast/retry-go
* Shutdown safely anytime
* Gracefully shutdown

You can find other features or settings in options.go

# Example

```
package asyncretry_test
import (
"context"
"fmt"
Expand All @@ -33,21 +37,24 @@ func ExampleAsyncRetry() {
"/hello",
func(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
go func() {
err := asyncRetry.Do(
ctx,
func(ctx context.Context) error {
// do task
// ...
return nil
},
asyncretry.Attempts(5),
asyncretry.Timeout(8*time.Second),
)
if err != nil {
log.Println(err.Error())
}
}()
if err := asyncRetry.Do(
ctx,
func(ctx context.Context) error {
// do task
// ...
return nil
},
func(err error) {
if err != nil {
log.Println(err.Error())
}
},
asyncretry.Attempts(5),
asyncretry.Timeout(8*time.Second),
); err != nil {
// asyncRetry is in shutdown
log.Println(err.Error())
}
fmt.Fprintf(w, "Hello")
},
)
Expand Down
54 changes: 31 additions & 23 deletions async_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import (
)

type RetryableFunc func(ctx context.Context) error
type FinishFunc func(error)

type AsyncRetry interface {
// Do calls f and retry if necessary.
// In most cases, you should call Do in a new goroutine.
Do(ctx context.Context, f RetryableFunc, opts ...Option) error
// Shutdown shutdowns gracefully
// Do calls f in a new goroutine, and retry if necessary. When finished, `finish` is called regardless of success or failure exactly once.
// Non-nil error is always `ErrInShutdown` that will be returned when AsyncRetry is in shutdown.
Do(ctx context.Context, f RetryableFunc, finish FinishFunc, opts ...Option) error

// Shutdown gracefully shuts down AsyncRetry without interrupting any active `Do`.
// Shutdown works by first stopping to accept new `Do` request, and then waiting for all active `Do`'s background goroutines to be finished.
// Multiple call of Shutdown is OK.
Shutdown(ctx context.Context) error
}

Expand All @@ -33,30 +37,37 @@ func NewAsyncRetry() AsyncRetry {

var ErrInShutdown = fmt.Errorf("AsyncRetry is in shutdown")

func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (retErr error) {
func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, finish FinishFunc, opts ...Option) error {
config := DefaultConfig
for _, opt := range opts {
opt(&config)
}

a.mu.RLock()
select {
case <-a.shutdownChan:
a.mu.RUnlock()
return ErrInShutdown
default:
}
// notice that this line should be in lock so that shutdown would not go ahead
a.wg.Add(1)
a.wg.Add(1) // notice that this line should be in lock so that shutdown would not go ahead
a.mu.RUnlock()
defer a.wg.Done()

config := DefaultConfig
for _, opt := range opts {
opt(&config)
}

defer func() {
if err := recover(); err != nil {
retErr = fmt.Errorf("panicking while AsyncRetry err: %v", err)
}
go func() {
defer a.wg.Done() // Done should be called after `finish` returns
defer func() {
if recovered := recover(); recovered != nil {
var err = fmt.Errorf("panicking while AsyncRetry err: %v", recovered)
finish(err)
}
}()
var err = a.call(ctx, f, &config)
finish(err)
}()
return nil
}

func (a *asyncRetry) call(ctx context.Context, f RetryableFunc, config *Config) error {
ctx, cancel := context.WithCancel(WithoutCancel(ctx))
defer cancel()
noMoreRetryCtx, noMoreRetry := context.WithCancel(config.context)
Expand All @@ -76,8 +87,7 @@ func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (r
if config.cancelWhenConfigContextCanceled {
cancel()
}
// release resources
case <-done:
case <-done: // release resources
}
}()

Expand All @@ -102,10 +112,8 @@ func (a *asyncRetry) Do(ctx context.Context, f RetryableFunc, opts ...Option) (r
func (a *asyncRetry) Shutdown(ctx context.Context) error {
a.mu.Lock()
select {
case <-a.shutdownChan:
// Already closed.
default:
// Guarded by a.mu
case <-a.shutdownChan: // Already closed.
default: // Guarded by a.mu
close(a.shutdownChan)
}
a.mu.Unlock()
Expand Down
96 changes: 53 additions & 43 deletions async_retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,13 @@ func Test_asyncRetry_Do(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
counter = 0
a := NewAsyncRetry()
ch := make(chan error)
var err error
// Be careful not call Do synchronously when actually using
if err = a.Do(tt.args.ctx(), tt.args.f, tt.args.opts...); (err != nil) != tt.wantErr {
if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { ch <- err }, tt.args.opts...); err != nil {
t.Errorf("Do() failed %v", err)
}
err = <-ch
if (err != nil) != tt.wantErr {
t.Errorf("Do() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil {
Expand Down Expand Up @@ -324,9 +328,13 @@ func Test_asyncRetry_DoWithConfigContext(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
a := NewAsyncRetry()
ch := make(chan error)
var err error
// Be careful not call Do synchronously when actually using
if err = a.Do(tt.args.ctx(), tt.args.f, tt.args.opts()...); (err != nil) != tt.wantErr {
if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { ch <- err }, tt.args.opts()...); err != nil {
t.Errorf("Do() failed %v", err)
}
err = <-ch
if (err != nil) != tt.wantErr {
t.Errorf("Do() error = %v, wantErr %v", err, tt.wantErr)
}
if err != nil {
Expand Down Expand Up @@ -477,19 +485,18 @@ func Test_asyncRetry_DoAndShutdown(t *testing.T) {
counter = 0
a := NewAsyncRetry()

var doErr = make(chan error)
var doErr = make(chan error, 1)
var shutdownErr = make(chan error)
go func() {
doErr <- a.Do(
tt.args.ctx(), tt.args.f, tt.args.opts()...,
)
}()
var err error
if err = a.Do(tt.args.ctx(), tt.args.f, func(err error) { doErr <- err }, tt.args.opts()...); err != nil {
t.Errorf("Do() failed %v", err)
}

go func() {
<-ch
shutdownErr <- a.Shutdown(context.Background())
}()

var err error
select {
case err = <-shutdownErr:
case <-time.After(time.Second * 10):
Expand Down Expand Up @@ -542,21 +549,22 @@ func Test_ShutdownOrder(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < szDo; i++ {
wg.Add(1)
go func() {
err := a.Do(
context.Background(),
func(ctx context.Context) error {
wg.Done()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
results <- 1
return nil
},
Timeout(0),
)
if err != nil {
t.Errorf("Do() error = %v, wantErr %v", err, nil)
}
}()
err := a.Do(
context.Background(),
func(ctx context.Context) error {
wg.Done()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
return nil
},
func(error) {
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
results <- 1
},
Timeout(0),
)
if err != nil {
t.Errorf("Do() error = %v, wantErr %v", err, nil)
}
}
for i := 0; i < szShutdown; i++ {
go func() {
Expand Down Expand Up @@ -589,6 +597,7 @@ func Test_ShutdownOrder(t *testing.T) {
func(ctx context.Context) error {
return nil
},
func(error) {},
)
if err == nil || err.Error() != ErrInShutdown.Error() {
t.Errorf("call of Do after shudown must returns InShutdownErr")
Expand All @@ -606,19 +615,20 @@ func benchmarkDo(tasks int, concurrency int, b *testing.B) {
for c := 0; c < concurrency; c++ {
wg.Add(1)
go func() {
var dummy int
defer wg.Done()
for range ch {
for i := 0; i < 10000; i++ {
dummy /= dummy + 1
}
_ = a.Do(context.Background(), func(ctx context.Context) error {
var dummy int
for i := 0; i < 10000; i++ {
dummy /= dummy + 1
}
return nil
})
_ = a.Do(
context.Background(),
func(ctx context.Context) error {
var dummy int
for i := 0; i < 100; i++ {
dummy /= dummy + 1
}
return nil
},
func(err error) {
},
)
}
}()
}
Expand All @@ -630,9 +640,9 @@ func benchmarkDo(tasks int, concurrency int, b *testing.B) {
}
}

func BenchmarkDo10000With2(b *testing.B) { benchmarkDo(3000, 2, b) }
func BenchmarkDo10000With4(b *testing.B) { benchmarkDo(3000, 4, b) }
func BenchmarkDo10000With8(b *testing.B) { benchmarkDo(3000, 8, b) }
func BenchmarkDo10000With16(b *testing.B) { benchmarkDo(3000, 16, b) }
func BenchmarkDo10000With32(b *testing.B) { benchmarkDo(3000, 32, b) }
func BenchmarkDo10000With64(b *testing.B) { benchmarkDo(3000, 64, b) }
func BenchmarkDo10000With2(b *testing.B) { benchmarkDo(10000, 2, b) }
func BenchmarkDo10000With4(b *testing.B) { benchmarkDo(10000, 4, b) }
func BenchmarkDo10000With8(b *testing.B) { benchmarkDo(10000, 8, b) }
func BenchmarkDo10000With16(b *testing.B) { benchmarkDo(10000, 16, b) }
func BenchmarkDo10000With32(b *testing.B) { benchmarkDo(10000, 32, b) }
func BenchmarkDo10000With64(b *testing.B) { benchmarkDo(10000, 64, b) }
33 changes: 18 additions & 15 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,24 @@ func ExampleAsyncRetry() {
"/hello",
func(w http.ResponseWriter, r *http.Request) {
var ctx = r.Context()
go func() {
err := asyncRetry.Do(
ctx,
func(ctx context.Context) error {
// do task
// ...
return nil
},
asyncretry.Attempts(5),
asyncretry.Timeout(8*time.Second),
)
if err != nil {
log.Println(err.Error())
}
}()
if err := asyncRetry.Do(
ctx,
func(ctx context.Context) error {
// do task
// ...
return nil
},
func(err error) {
if err != nil {
log.Println(err.Error())
}
},
asyncretry.Attempts(5),
asyncretry.Timeout(8*time.Second),
); err != nil {
// asyncRetry is in shutdown
log.Println(err.Error())
}
fmt.Fprintf(w, "Hello")
},
)
Expand Down

0 comments on commit 64e2bcc

Please sign in to comment.