Skip to content

Commit

Permalink
RSDK-9591 - Add KillGroup to ManagedProcess (#399)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Rewis <[email protected]>
  • Loading branch information
cheukt and benjirewis authored Jan 9, 2025
1 parent 8961a20 commit 95db6b6
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 5 deletions.
31 changes: 31 additions & 0 deletions pexec/managed_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ type ManagedProcess interface {
// there's any system level issue stopping the process.
Stop() error

// KillGroup will attempt to kill the process group and not wait for completion. Only use this if
// comfortable with leaking resources (in cases where exiting the program as quickly as possible is desired).
KillGroup()

// Status return nil when the process is both alive and owned.
// If err is non-nil, process may be a) alive but not owned or b) dead.
Status() error
Expand Down Expand Up @@ -432,3 +436,30 @@ func (p *managedProcess) Stop() error {
}
return errors.Errorf("non-successful exit code: %d", p.cmd.ProcessState.ExitCode())
}

// KillGroup kills the process group.
func (p *managedProcess) KillGroup() {
// Minimally hold a lock here so that we can signal the
// management goroutine to stop. We will attempt to kill the
// process even if p.stopped is true.
p.mu.Lock()
if !p.stopped {
close(p.killCh)
p.stopped = true
}

if p.cmd == nil {
p.mu.Unlock()
return
}
p.mu.Unlock()

// Since p.cmd is mutex guarded and we just signaled the manage
// goroutine to stop, no new Start can happen and therefore
// p.cmd can no longer be modified rendering it safe to read
// without a lock held.
// We are intentionally not checking the error here, we are already
// in a bad state.
//nolint:errcheck,gosec
p.forceKillGroup()
}
102 changes: 98 additions & 4 deletions pexec/managed_process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"

Expand Down Expand Up @@ -411,10 +412,10 @@ func TestManagedProcessStop(t *testing.T) {
bashScriptBuilder.WriteString("\n")
}
bashScriptBuilder.WriteString(fmt.Sprintf(`echo hello >> '%s'
while true
do echo hey
sleep 1
done`, tempFile.Name()))
while true
do echo hey
sleep 1
done`, tempFile.Name()))
bashScriptBuilder.WriteString("\n")

bashScript := bashScriptBuilder.String()
Expand Down Expand Up @@ -571,6 +572,97 @@ done`, tempFile.Name()))
})
}

func TestManagedProcessKillGroup(t *testing.T) {
t.Run("kill signaling with children", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("cannot test this on windows")
}
logger := golog.NewTestLogger(t)

watcher1, tempFile1 := testutils.WatchedFile(t)
watcher2, tempFile2 := testutils.WatchedFile(t)
watcher3, tempFile3 := testutils.WatchedFile(t)

// this script writes a string to the specified file every 100ms
script := `
while true
do echo hello >> '%s'
sleep 0.1
done
`

bashScript1 := fmt.Sprintf(script, tempFile1.Name())
bashScript2 := fmt.Sprintf(script, tempFile2.Name())
bashScriptParent := fmt.Sprintf(`
bash -c '%s' &
bash -c '%s' &
`+script,
bashScript1,
bashScript2,
tempFile3.Name(),
tempFile3.Name(),
)

proc := NewManagedProcess(ProcessConfig{
Name: "bash",
Args: []string{"-c", bashScriptParent},
}, logger)

// To confirm that the processes have died, confirm that the size of the file stopped increasing
getSize := func(file *os.File) int64 {
info, _ := file.Stat()
return info.Size()
}

file1SizeBeforeStart := getSize(tempFile1)
file2SizeBeforeStart := getSize(tempFile2)
file3SizeBeforeStart := getSize(tempFile3)

test.That(t, proc.Start(context.Background()), test.ShouldBeNil)

<-watcher1.Events
<-watcher2.Events
<-watcher3.Events

proc.KillGroup()

file1SizeAfterKill := getSize(tempFile1)
file2SizeAfterKill := getSize(tempFile2)
file3SizeAfterKill := getSize(tempFile3)

test.That(t, file1SizeAfterKill, test.ShouldBeGreaterThan, file1SizeBeforeStart)
test.That(t, file2SizeAfterKill, test.ShouldBeGreaterThan, file2SizeBeforeStart)
test.That(t, file3SizeAfterKill, test.ShouldBeGreaterThan, file3SizeBeforeStart)

// since KillGroup does not wait, we might have to check file size a few times as the kill
// might take a little to propagate. We want to make sure that the file size stops increasing.
testutils.WaitForAssertionWithSleep(t, 300*time.Millisecond, 50, func(tb testing.TB) {
tempSize1 := getSize(tempFile1)
tempSize2 := getSize(tempFile2)
tempSize3 := getSize(tempFile3)

test.That(t, tempSize1, test.ShouldEqual, file1SizeAfterKill)
test.That(t, tempSize2, test.ShouldEqual, file2SizeAfterKill)
test.That(t, tempSize3, test.ShouldEqual, file3SizeAfterKill)

file1SizeAfterKill = tempSize1
file2SizeAfterKill = tempSize1
file3SizeAfterKill = tempSize1
})

// in CI, we have to send another signal to make sure the cmd.Wait() in
// the manage goroutine actually returns.
// We do not care about the error if it is expected.
// maybe related to https://github.com/golang/go/issues/18874
if err := proc.(*managedProcess).cmd.Process.Signal(syscall.SIGTERM); err != nil {
test.That(t, errors.Is(err, os.ErrProcessDone), test.ShouldBeFalse)
}

// wait on the managingCh to close
<-proc.(*managedProcess).managingCh
})
}

func TestManagedProcessEnvironmentVariables(t *testing.T) {
t.Run("set an environment variable on one-shot process", func(t *testing.T) {
logger := golog.NewTestLogger(t)
Expand Down Expand Up @@ -702,3 +794,5 @@ func (fp *fakeProcess) UnixPid() (int, error) {
in reality tests should just depend on the methods they rely on. UnixPid is not one
of those methods (for better or worse)`)
}

func (fp *fakeProcess) KillGroup() {}
10 changes: 10 additions & 0 deletions pexec/managed_process_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package pexec

import (
"os"
"os/exec"
"os/user"
"strconv"
"syscall"
Expand Down Expand Up @@ -126,6 +127,15 @@ func (p *managedProcess) kill() (bool, error) {
return forceKilled, nil
}

// forceKillGroup kills everything in the process group. This will not wait for completion and may result the
// kill becoming a zombie process.
func (p *managedProcess) forceKillGroup() error {
pgidStr := strconv.Itoa(-p.cmd.Process.Pid)
p.logger.Infof("killing entire process group %d", p.cmd.Process.Pid)
//nolint:gosec
return exec.Command("kill", "-9", pgidStr).Start()
}

func isWaitErrUnknown(err string, forceKilled bool) bool {
// This can easily happen if the process does not handle interrupts gracefully
// and it won't provide us any exit code info.
Expand Down
8 changes: 7 additions & 1 deletion pexec/managed_process_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ func parseSignal(sigStr, name string) (syscall.Signal, error) {
return 0, errors.New("signals not supported on Windows")
}


func (p *managedProcess) sysProcAttr() (*syscall.SysProcAttr, error) {
ret := &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
Expand Down Expand Up @@ -107,6 +106,13 @@ func (p *managedProcess) kill() (bool, error) {
return forceKilled, nil
}

// forceKillGroup kills everything in the process tree. This will not wait for completion and may result in a zombie process.
func (p *managedProcess) forceKillGroup() error {
pidStr := strconv.Itoa(p.cmd.Process.Pid)
p.logger.Infof("force killing entire process tree %d", p.cmd.Process.Pid)
return exec.Command("taskkill", "/t", "/f", "/pid", pidStr).Start()
}

func isWaitErrUnknown(err string, forceKilled bool) bool {
if !forceKilled {
return false
Expand Down

0 comments on commit 95db6b6

Please sign in to comment.