Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: global signal handling to cancel ctx for graceful exits #4993

Merged
merged 1 commit into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cli/command/container/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o

if opts.Proxy && !c.Config.Tty {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
// since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
Benehiko marked this conversation as resolved.
Show resolved Hide resolved
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
8 changes: 8 additions & 0 deletions cli/command/container/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type fakeClient struct {
containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error
containerKillFunc func(ctx context.Context, containerID, signal string) error
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
containerAttachFunc func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error)
Version string
}

Expand Down Expand Up @@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A
}
return types.ContainersPruneReport{}, nil
}

func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
if f.containerAttachFunc != nil {
return f.containerAttachFunc(ctx, containerID, options)
}
return types.HijackedResponse{}, nil
}
7 changes: 6 additions & 1 deletion cli/command/container/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption
}
if runOpts.sigProxy {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc)
// since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
69 changes: 69 additions & 0 deletions cli/command/container/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@ import (
"errors"
"fmt"
"io"
"net"
"os/signal"
"syscall"
"testing"
"time"

"github.com/creack/pty"
"github.com/docker/cli/cli"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/internal/test"
"github.com/docker/cli/internal/test/notary"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
specs "github.com/opencontainers/image-spec/specs-go/v1"
Expand All @@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) {
assert.NilError(t, cmd.Execute())
}

func TestRunAttachTermination(t *testing.T) {
p, tty, err := pty.Open()
assert.NilError(t, err)

defer func() {
_ = tty.Close()
_ = p.Close()
}()

killCh := make(chan struct{})
attachCh := make(chan struct{})
fakeCLI := test.NewFakeCli(&fakeClient{
createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) {
return container.CreateResponse{
ID: "id",
}, nil
},
containerKillFunc: func(ctx context.Context, containerID, signal string) error {
killCh <- struct{}{}
return nil
},
containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
server, client := net.Pipe()
t.Cleanup(func() {
_ = server.Close()
})
attachCh <- struct{}{}
return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil
},
Version: "1.36",
}, func(fc *test.FakeCli) {
fc.SetOut(streams.NewOut(tty))
fc.SetIn(streams.NewIn(tty))
})
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer cancel()

assert.Equal(t, fakeCLI.In().IsTerminal(), true)
assert.Equal(t, fakeCLI.Out().IsTerminal(), true)

cmd := NewRunCommand(fakeCLI)
cmd.SetArgs([]string{"-it", "busybox"})
cmd.SilenceUsage = true
go func() {
assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled)
}()

select {
case <-time.After(5 * time.Second):
t.Fatal("containerAttachFunc was not called before the 5 second timeout")
case <-attachCh:
}

assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM))
select {
case <-time.After(5 * time.Second):
cancel()
t.Fatal("containerKillFunc was not called before the 5 second timeout")
case <-killCh:
}
}

func TestRunCommandWithContentTrustErrors(t *testing.T) {
testCases := []struct {
name string
Expand Down
3 changes: 2 additions & 1 deletion cli/command/container/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er
// We always use c.ID instead of container to maintain consistency during `docker start`
if !c.Config.Tty {
sigc := notifyAllSignals()
go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc)
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc)
defer signal.StopCatch(sigc)
}

Expand Down
10 changes: 1 addition & 9 deletions cli/command/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"

"github.com/docker/cli/cli/streams"
"github.com/docker/docker/api/types/filters"
Expand Down Expand Up @@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m

result := make(chan bool)

// Catch the termination signal and exit the prompt gracefully.
// The caller is responsible for properly handling the termination.
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer notifyCancel()

go func() {
var res bool
scanner := bufio.NewScanner(ins)
Expand All @@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
}()

select {
case <-notifyCtx.Done():
// print a newline on termination
case <-ctx.Done():
_, _ = fmt.Fprintln(outs, "")
return false, ErrPromptTerminated
case r := <-result:
Expand Down
6 changes: 5 additions & 1 deletion cli/command/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
Expand Down Expand Up @@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
}, promptResult{false, nil}},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
t.Cleanup(notifyCancel)

buf.Reset()
promptReader, promptWriter = io.Pipe()

Expand All @@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {

result := make(chan promptResult, 1)
go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
result <- promptResult{r, err}
}()

Expand Down
91 changes: 58 additions & 33 deletions cmd/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ import (
)

func main() {
ctx := context.Background()
statusCode := dockerMain()
if statusCode != 0 {
os.Exit(statusCode)
}
}

func dockerMain() int {
ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
defer cancelNotify()

dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
return 1
}
logrus.SetOutput(dockerCli.Err())
otel.SetErrorHandler(debug.OTELErrorHandler)
Expand All @@ -46,16 +54,17 @@ func main() {
// StatusError should only be used for errors, and all errors should
// have a non-zero exit status, so never exit with 0
if sterr.StatusCode == 0 {
os.Exit(1)
return 1
}
os.Exit(sterr.StatusCode)
return sterr.StatusCode
}
if errdefs.IsCancelled(err) {
os.Exit(0)
return 0
}
fmt.Fprintln(dockerCli.Err(), err)
os.Exit(1)
return 1
}
return 0
}

func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand {
Expand Down Expand Up @@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
})
}

func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
if err != nil {
return err
Expand All @@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,

// Background signal handling logic: block on the signals channel, and
// notify the plugin via the PluginServer (or signal) as appropriate.
const exitLimit = 3
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)
const exitLimit = 2
Benehiko marked this conversation as resolved.
Show resolved Hide resolved

tryTerminatePlugin := func(force bool) {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
return
}

// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()

// force the process to terminate if it hasn't already
if force {
_ = plugincmd.Process.Kill()
_, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n")
os.Exit(1)
}
}

go func() {
retries := 0
for range signals {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
continue
}
force := false
// catch the first signal through context cancellation
<-ctx.Done()
tryTerminatePlugin(force)

// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()
// register subsequent signals
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)

for range signals {
retries++
// If we're still running after 3 interruptions
// (SIGINT/SIGTERM), send a SIGKILL to the plugin as a
// final attempt to terminate, and exit.
retries++
if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)
_ = plugincmd.Process.Kill()
os.Exit(1)
force = true
}
tryTerminatePlugin(force)
}
}()

Expand Down Expand Up @@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
ccmd, _, err := cmd.Find(args)
subCommand = ccmd
if err != nil || pluginmanager.IsPluginCommand(ccmd) {
err := tryPluginRun(dockerCli, cmd, args[0], envs)
err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
if err == nil {
if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args)
Expand Down
8 changes: 5 additions & 3 deletions internal/test/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package test
import (
"context"
"os"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
assert.NilError(t, err)
cli.SetIn(streams.NewIn(r))

notifyCtx, notifyCancel := context.WithCancel(ctx)
t.Cleanup(notifyCancel)

go func() {
errChan <- cmd.ExecuteContext(ctx)
errChan <- cmd.ExecuteContext(notifyCtx)
}()

writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
Expand Down Expand Up @@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli

// sigint and sigterm are caught by the prompt
// this allows us to gracefully exit the prompt with a 0 exit code
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
notifyCancel()

select {
case <-errCtx.Done():
Expand Down
Loading