From 5f89ecb494e94de18d9c317cebf759bc8c65ef60 Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Thu, 16 May 2024 18:40:33 +0000 Subject: [PATCH] windows: fix inefficient gathering of task processes --- .changelog/20619.txt | 3 + drivers/shared/executor/executor_basic.go | 2 +- drivers/shared/executor/executor_linux.go | 2 +- .../executor/executor_universal_linux.go | 2 +- .../shared/executor/procstats/list_default.go | 4 +- .../shared/executor/procstats/list_windows.go | 73 +++++++++++++ .../executor/procstats/list_windows_test.go | 103 ++++++++++++++++++ .../shared/executor/procstats/procstats.go | 2 +- 8 files changed, 185 insertions(+), 6 deletions(-) create mode 100644 .changelog/20619.txt create mode 100644 drivers/shared/executor/procstats/list_windows.go create mode 100644 drivers/shared/executor/procstats/list_windows_test.go diff --git a/.changelog/20619.txt b/.changelog/20619.txt new file mode 100644 index 00000000000..a00cf6a9c53 --- /dev/null +++ b/.changelog/20619.txt @@ -0,0 +1,3 @@ +```release-note:bug +windows: Fixed a regression where scanning task processes was inefficient +``` diff --git a/drivers/shared/executor/executor_basic.go b/drivers/shared/executor/executor_basic.go index 2fee50d1c48..e7443167abf 100644 --- a/drivers/shared/executor/executor_basic.go +++ b/drivers/shared/executor/executor_basic.go @@ -36,7 +36,7 @@ func withNetworkIsolation(f func() error, _ *drivers.NetworkIsolationSpec) error func setCmdUser(*exec.Cmd, string) error { return nil } -func (e *UniversalExecutor) ListProcesses() *set.Set[int] { +func (e *UniversalExecutor) ListProcesses() set.Collection[int] { return procstats.List(e.childCmd.Process.Pid) } diff --git a/drivers/shared/executor/executor_linux.go b/drivers/shared/executor/executor_linux.go index ceb80f2c474..7e25237c2ff 100644 --- a/drivers/shared/executor/executor_linux.go +++ b/drivers/shared/executor/executor_linux.go @@ -120,7 +120,7 @@ func NewExecutorWithIsolation(logger hclog.Logger, compute cpustats.Compute) Exe return le } -func (l *LibcontainerExecutor) ListProcesses() *set.Set[int] { +func (l *LibcontainerExecutor) ListProcesses() set.Collection[int] { return procstats.List(l.command) } diff --git a/drivers/shared/executor/executor_universal_linux.go b/drivers/shared/executor/executor_universal_linux.go index f947ed8b24a..b42a4423547 100644 --- a/drivers/shared/executor/executor_universal_linux.go +++ b/drivers/shared/executor/executor_universal_linux.go @@ -102,7 +102,7 @@ func (e *UniversalExecutor) setSubCmdCgroup(cmd *exec.Cmd, cgroup string) (func( } } -func (e *UniversalExecutor) ListProcesses() *set.Set[procstats.ProcessID] { +func (e *UniversalExecutor) ListProcesses() set.Collection[procstats.ProcessID] { return procstats.List(e.command) } diff --git a/drivers/shared/executor/procstats/list_default.go b/drivers/shared/executor/procstats/list_default.go index 2b152934ded..04553e73a95 100644 --- a/drivers/shared/executor/procstats/list_default.go +++ b/drivers/shared/executor/procstats/list_default.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 -//go:build !linux +//go:build !linux && !windows package procstats @@ -15,7 +15,7 @@ import ( ) // List the process tree starting at the given executorPID -func List(executorPID int) *set.Set[ProcessID] { +func List(executorPID int) set.Collection[ProcessID] { result := set.New[ProcessID](10) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) diff --git a/drivers/shared/executor/procstats/list_windows.go b/drivers/shared/executor/procstats/list_windows.go new file mode 100644 index 00000000000..dca427ada3c --- /dev/null +++ b/drivers/shared/executor/procstats/list_windows.go @@ -0,0 +1,73 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows + +package procstats + +import ( + "github.com/hashicorp/go-set/v2" + "github.com/mitchellh/go-ps" +) + +func gather(procs map[int]ps.Process, family set.Collection[int], root int, candidate ps.Process) bool { + if candidate == nil { + return false + } + pid := candidate.Pid() + if pid == 0 || pid == 1 { + return false + } + if pid == root { + return true + } + parent := procs[candidate.PPid()] + result := gather(procs, family, root, parent) + if result { + family.Insert(pid) + } + return result +} + +func mapping(all []ps.Process) map[int]ps.Process { + result := make(map[int]ps.Process) + for _, process := range all { + result[process.Pid()] = process + } + return result +} + +func list(executorPID int, processes func() ([]ps.Process, error)) set.Collection[ProcessID] { + family := set.From([]int{executorPID}) + + all, err := processes() + if err != nil { + return set.New[ProcessID](0) + } + + m := mapping(all) + for _, candidate := range all { + gather(m, family, executorPID, candidate) + } + + return family +} + +// List will scan the process table and return a set of the process family +// tree starting with executorPID as the root. +// +// The implementation here specifically avoids using more than one system +// call. Unlike on Linux where we just read a cgroup, on Windows we must build +// the tree manually. We do so knowing only the child->parent relationships. +// +// So this turns into a fun leet code problem, where we invert the tree using +// only a bucket of edges pointing in the wrong direction. Basically we just +// iterate every process, recursively follow its parent, and determine whether +// executorPID is an ancestor. +// +// See https://github.com/hashicorp/nomad/issues/20042 as an example of what +// happens when you use syscalls to work your way from the root down to its +// descendants. +func List(executorPID int) set.Collection[ProcessID] { + return list(executorPID, ps.Processes) +} diff --git a/drivers/shared/executor/procstats/list_windows_test.go b/drivers/shared/executor/procstats/list_windows_test.go new file mode 100644 index 00000000000..cb6ae70223c --- /dev/null +++ b/drivers/shared/executor/procstats/list_windows_test.go @@ -0,0 +1,103 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build windows + +package procstats + +import ( + "testing" + + "github.com/mitchellh/go-ps" + "github.com/shoenig/test/must" +) + +type mockProcess struct { + pid int + ppid int +} + +func (p *mockProcess) Pid() int { + return p.pid +} + +func (p *mockProcess) PPid() int { + return p.ppid +} + +func (p *mockProcess) Executable() string { + return "" +} + +func mockProc(pid, ppid int) *mockProcess { + return &mockProcess{pid: pid, ppid: ppid} +} + +var ( + executorOnly = []ps.Process{ + mockProc(1, 1), + mockProc(42, 1), + } + + simpleLine = []ps.Process{ + mockProc(1, 1), + mockProc(50, 42), + mockProc(42, 1), + mockProc(51, 50), + mockProc(101, 100), + mockProc(60, 51), + mockProc(100, 1), + } + + bigTree = []ps.Process{ + mockProc(1, 1), + mockProc(25, 50), + mockProc(100, 1), + mockProc(75, 50), + mockProc(10, 25), + mockProc(80, 75), + mockProc(81, 75), + mockProc(51, 50), + mockProc(42, 1), + mockProc(101, 100), + mockProc(52, 51), + mockProc(50, 42), + } +) + +func Test_list(t *testing.T) { + cases := []struct { + name string + procs []ps.Process + exp []ProcessID + }{ + { + name: "executor only", + procs: executorOnly, + exp: []ProcessID{42}, + }, + { + name: "simple line", + procs: simpleLine, + exp: []ProcessID{42, 50, 51, 60}, + }, + { + name: "big tree", + procs: bigTree, + exp: []ProcessID{42, 50, 25, 75, 10, 80, 81, 51, 52}, + }, + } + + for _, tc := range cases { + const executorPID = 42 + t.Run(tc.name, func(t *testing.T) { + lister := func() ([]ps.Process, error) { + return tc.procs, nil + } + result := list(executorPID, lister) + must.SliceContainsAll(t, tc.exp, result.Slice(), + must.Sprintf("exp: %v; got: %v", tc.exp, result), + ) + }) + } +} diff --git a/drivers/shared/executor/procstats/procstats.go b/drivers/shared/executor/procstats/procstats.go index a5e3fd6e74d..04d157d4ee0 100644 --- a/drivers/shared/executor/procstats/procstats.go +++ b/drivers/shared/executor/procstats/procstats.go @@ -36,7 +36,7 @@ type ProcessStats interface { // A ProcessList is anything (i.e. a task driver) that implements ListProcesses // for gathering the list of process IDs associated with a task. type ProcessList interface { - ListProcesses() *set.Set[ProcessID] + ListProcesses() set.Collection[ProcessID] } // Aggregate combines a given ProcUsages with the Tracker for the Client.