diff --git a/changes/20231016114710.feature b/changes/20231016114710.feature new file mode 100644 index 0000000000..55d5a5a1ba --- /dev/null +++ b/changes/20231016114710.feature @@ -0,0 +1 @@ +:sparkles: `[parallelisation]` Run action with interrupt handling diff --git a/utils/parallelisation/parallelisation.go b/utils/parallelisation/parallelisation.go index 2baaf9e8db..9c4d4dd098 100644 --- a/utils/parallelisation/parallelisation.go +++ b/utils/parallelisation/parallelisation.go @@ -8,7 +8,11 @@ package parallelisation import ( "context" + "golang.org/x/sync/errgroup" + "os" + "os/signal" "reflect" + "syscall" "time" "go.uber.org/atomic" @@ -214,7 +218,7 @@ func RunActionWithTimeoutAndCancelStore(ctx context.Context, timeout time.Durati } // RunActionWithParallelCheck runs an action with a check in parallel -// The function performing the check should return true if the check was favorable; false otherwise. If the check did not have the expected result and the whole function would be cancelled. +// The function performing the check should return true if the check was favorable; false otherwise. If the check did not have the expected result, the whole function would be cancelled. func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Context) error, checkAction func(ctx context.Context) bool, checkPeriod time.Duration) error { err := DetermineContextError(ctx) if err != nil { @@ -246,3 +250,48 @@ func RunActionWithParallelCheck(ctx context.Context, action func(ctx context.Con } return err } + +// RunActionWithInterruptCancellation runs an action listening to interrupt signals such as SIGTERM or SIGINT +// On interrupt, any cancellation functions in store are called followed by actionOnInterrupt. These functions are not called if no interrupts were raised but action completed. +func RunActionWithInterruptCancellation(ctx context.Context, cancelStore *CancelFunctionStore, action func(ctx context.Context) error, actionOnInterrupt func(ctx context.Context) error) error { + err := DetermineContextError(ctx) + if err != nil { + return err + } + if cancelStore == nil { + cancelStore = NewCancelFunctionsStore() + } + defer cancelStore.Cancel() + // Listening to the following interrupt signals https://www.man7.org/linux/man-pages/man7/signal.7.html + interruptableCtx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGABRT) + cancelStore.RegisterCancelFunction(cancel) + g, groupCancellableCtx := errgroup.WithContext(ctx) + groupCancellableCtx, cancelOnSuccess := context.WithCancel(groupCancellableCtx) + g.Go(func() error { + select { + case <-interruptableCtx.Done(): + case <-groupCancellableCtx.Done(): + } + err = DetermineContextError(interruptableCtx) + if err != nil { + // An interrupt was raised. + cancelStore.Cancel() + return actionOnInterrupt(ctx) + } + return err + }) + g.Go(func() error { + err := action(interruptableCtx) + if err == nil { + cancelOnSuccess() + } + return err + }) + return g.Wait() +} + +// RunActionWithGracefulShutdown carries out an action until asked to gracefully shutdown on which the shutdownOnSignal is executed. +// if the action is completed before the shutdown request is performed, shutdownOnSignal will not be executed. +func RunActionWithGracefulShutdown(ctx context.Context, action func(ctx context.Context) error, shutdownOnSignal func(ctx context.Context) error) error { + return RunActionWithInterruptCancellation(ctx, NewCancelFunctionsStore(), action, shutdownOnSignal) +} diff --git a/utils/parallelisation/parallelisation_test.go b/utils/parallelisation/parallelisation_test.go index af4ca3d1a7..0bb08c4968 100644 --- a/utils/parallelisation/parallelisation_test.go +++ b/utils/parallelisation/parallelisation_test.go @@ -9,7 +9,9 @@ import ( "errors" "fmt" "math/rand" + "os" "reflect" + "syscall" "testing" "time" @@ -20,10 +22,11 @@ import ( "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/commonerrors/errortest" + "github.com/ARM-software/golang-utils/utils/platform" ) var ( - random = rand.New(rand.NewSource(time.Now().Unix())) //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for + random = rand.New(rand.NewSource(time.Now().Unix())) //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests ) func TestParallelisationWithResults(t *testing.T) { @@ -411,3 +414,132 @@ func runActionWithParallelCheckFailAtRandom(t *testing.T, ctx context.Context) { require.Error(t, err) errortest.AssertError(t, err, commonerrors.ErrCancelled) } + +func TestRunActionWithGracefulShutdown(t *testing.T) { + if platform.IsWindows() { + // Sending Interrupt on Windows is not implemented - https://golang.org/pkg/os/#Process.Signal + t.Skip("Skipping test on Windows as sending interrupt is not implemented on [this platform](https://golang.org/pkg/os/#Process.Signal)") + } + ctx := context.Background() + + defer goleak.VerifyNone(t) + tests := []struct { + name string + signal os.Signal + }{ + { + name: "SIGTERM", + signal: syscall.SIGTERM, + }, + { + name: "SIGINT", + signal: syscall.SIGINT, + }, + { + name: "SIGHUP", + signal: syscall.SIGHUP, + }, + { + name: "SIGQUIT", + signal: syscall.SIGQUIT, + }, + { + name: "SIGABRT", + signal: syscall.SIGABRT, + }, + { + name: "Interrupt", + signal: os.Interrupt, + }, + } + + process := os.Process{Pid: os.Getpid()} + longAction := func(ctx context.Context) error { + SleepWithContext(ctx, 150*time.Millisecond) + return ctx.Err() + } + shortAction := func(ctx context.Context) error { + return ctx.Err() + } + shortActionWithError := func(_ context.Context) error { + return commonerrors.ErrUnexpected + } + + t.Run("cancelled context", func(t *testing.T) { + defer goleak.VerifyNone(t) + cctx, cancel := context.WithCancel(ctx) + cancel() + err := RunActionWithGracefulShutdown(cctx, longAction, func(ctx context.Context) error { + return nil + }) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrTimeout, commonerrors.ErrCancelled) + }) + + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("interrupt [%v] before longAction completion", test.name), func(t *testing.T) { + defer goleak.VerifyNone(t) + called := atomic.NewBool(false) + shutdownAction := func(ctx2 context.Context) error { + err := DetermineContextError(ctx2) + if err == nil { + called.Store(true) + } + return err + } + require.False(t, called.Load()) + ScheduleAfter(ctx, time.Duration(random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests + if err := process.Signal(test.signal); err != nil { + t.Error("failed sending interrupt signal") + } + }) + err := RunActionWithGracefulShutdown(ctx, longAction, shutdownAction) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrTimeout, commonerrors.ErrCancelled) + require.True(t, called.Load()) + }) + t.Run(fmt.Sprintf("interrupt [%v] after shortAction completion", test.name), func(t *testing.T) { + defer goleak.VerifyNone(t) + called := atomic.NewBool(false) + shutdownAction := func(ctx2 context.Context) error { + err := DetermineContextError(ctx2) + if err == nil { + called.Store(true) + } + return err + } + require.False(t, called.Load()) + ScheduleAfter(ctx, time.Duration(50+random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests + if err := process.Signal(test.signal); err != nil { + t.Error("failed sending interrupt signal") + } + }) + err := RunActionWithGracefulShutdown(ctx, shortAction, shutdownAction) + require.NoError(t, err) + require.False(t, called.Load()) + }) + t.Run(fmt.Sprintf("interrupt [%v] after shortActionWithError completion", test.name), func(t *testing.T) { + defer goleak.VerifyNone(t) + called := atomic.NewBool(false) + shutdownAction := func(ctx2 context.Context) error { + err := DetermineContextError(ctx2) + if err == nil { + called.Store(true) + } + return err + } + require.False(t, called.Load()) + ScheduleAfter(ctx, time.Duration(50+random.Intn(100))*time.Millisecond, func(ti time.Time) { //nolint:gosec //causes G404: Use of weak random number generator (math/rand instead of crypto/rand) (gosec), So disable gosec as this is just for tests + if err := process.Signal(test.signal); err != nil { + t.Error("failed sending interrupt signal") + } + }) + err := RunActionWithGracefulShutdown(ctx, shortActionWithError, shutdownAction) + require.Error(t, err) + errortest.AssertError(t, err, commonerrors.ErrUnexpected) + require.False(t, called.Load()) + }) + } + +}