diff --git a/drivers/shared/executor/procstats/list_test.go b/drivers/shared/executor/procstats/list_test.go new file mode 100644 index 00000000000..9e9588d347b --- /dev/null +++ b/drivers/shared/executor/procstats/list_test.go @@ -0,0 +1,110 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package procstats + +import ( + "math/rand" + "testing" + + "github.com/mitchellh/go-ps" + "github.com/shoenig/test/must" +) + +type mockProcess struct { + pid int + ppid int +} + +func (p *mockProcess) Pid() int { + return p.pid +} + +func (p *mockProcess) PPid() int { + return p.ppid +} + +func (p *mockProcess) Executable() string { + return "" +} + +func mockProc(pid, ppid int) *mockProcess { + return &mockProcess{pid: pid, ppid: ppid} +} + +func genMockProcs(needles, haystack int) ([]ps.Process, []ProcessID) { + + procs := []ps.Process{mockProc(1, 1), mockProc(42, 1)} + expect := []ProcessID{42} + + // TODO: make this into a tree structure, not just a linear tree + for i := 0; i < needles; i++ { + parent := 42 + i + pid := parent + 1 + procs = append(procs, mockProc(pid, parent)) + expect = append(expect, pid) + } + + for i := 0; i < haystack; i++ { + parent := 200 + i + pid := parent + 1 + procs = append(procs, mockProc(pid, parent)) + } + + rand.Shuffle(len(procs), func(i, j int) { + procs[i], procs[j] = procs[j], procs[i] + }) + + return procs, expect +} + +func Test_list(t *testing.T) { + cases := []struct { + name string + needles int + haystack int + expect int + }{ + { + name: "minimal", + needles: 2, + haystack: 10, + expect: 16, + }, + { + name: "small needles small haystack", + needles: 5, + haystack: 200, + expect: 212, + }, + { + name: "small needles large haystack", + needles: 10, + haystack: 1000, + expect: 1022, + }, + { + name: "moderate needles giant haystack", + needles: 20, + haystack: 2000, + expect: 2042, + }, + } + + for _, tc := range cases { + const executorPID = 42 + t.Run(tc.name, func(t *testing.T) { + + procs, expect := genMockProcs(tc.needles, tc.haystack) + lister := func() ([]ps.Process, error) { + return procs, nil + } + + result, examined := list(executorPID, lister) + must.SliceContainsAll(t, expect, result.Slice(), + must.Sprintf("exp: %v; got: %v", expect, result), + ) + must.Eq(t, tc.expect, examined) + }) + } +} diff --git a/drivers/shared/executor/procstats/list_windows.go b/drivers/shared/executor/procstats/list_windows.go index c6affafab33..9e4cd96c456 100644 --- a/drivers/shared/executor/procstats/list_windows.go +++ b/drivers/shared/executor/procstats/list_windows.go @@ -10,49 +10,6 @@ import ( "github.com/mitchellh/go-ps" ) -func gather(procs map[int]ps.Process, family set.Collection[int], root int, candidate ps.Process) bool { - if candidate == nil { - return false - } - pid := candidate.Pid() - if pid == 0 || pid == 1 { - return false - } - if pid == root { - return true - } - parent := procs[candidate.PPid()] - result := gather(procs, family, root, parent) - if result { - family.Insert(pid) - } - return result -} - -func mapping(all []ps.Process) map[int]ps.Process { - result := make(map[int]ps.Process) - for _, process := range all { - result[process.Pid()] = process - } - return result -} - -func list(executorPID int, processes func() ([]ps.Process, error)) set.Collection[ProcessID] { - family := set.From([]int{executorPID}) - - all, err := processes() - if err != nil { - return family - } - - m := mapping(all) - for _, candidate := range all { - gather(m, family, executorPID, candidate) - } - - return family -} - // List will scan the process table and return a set of the process family // tree starting with executorPID as the root. // @@ -69,5 +26,6 @@ func list(executorPID int, processes func() ([]ps.Process, error)) set.Collectio // happens when you use syscalls to work your way from the root down to its // descendants. func List(executorPID int) set.Collection[ProcessID] { - return list(executorPID, ps.Processes) + procs, _ := list(executorPID, ps.Processes) + return procs } diff --git a/drivers/shared/executor/procstats/list_windows_test.go b/drivers/shared/executor/procstats/list_windows_test.go deleted file mode 100644 index cb6ae70223c..00000000000 --- a/drivers/shared/executor/procstats/list_windows_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright (c) HashiCorp, Inc. -// SPDX-License-Identifier: MPL-2.0 - -//go:build windows - -package procstats - -import ( - "testing" - - "github.com/mitchellh/go-ps" - "github.com/shoenig/test/must" -) - -type mockProcess struct { - pid int - ppid int -} - -func (p *mockProcess) Pid() int { - return p.pid -} - -func (p *mockProcess) PPid() int { - return p.ppid -} - -func (p *mockProcess) Executable() string { - return "" -} - -func mockProc(pid, ppid int) *mockProcess { - return &mockProcess{pid: pid, ppid: ppid} -} - -var ( - executorOnly = []ps.Process{ - mockProc(1, 1), - mockProc(42, 1), - } - - simpleLine = []ps.Process{ - mockProc(1, 1), - mockProc(50, 42), - mockProc(42, 1), - mockProc(51, 50), - mockProc(101, 100), - mockProc(60, 51), - mockProc(100, 1), - } - - bigTree = []ps.Process{ - mockProc(1, 1), - mockProc(25, 50), - mockProc(100, 1), - mockProc(75, 50), - mockProc(10, 25), - mockProc(80, 75), - mockProc(81, 75), - mockProc(51, 50), - mockProc(42, 1), - mockProc(101, 100), - mockProc(52, 51), - mockProc(50, 42), - } -) - -func Test_list(t *testing.T) { - cases := []struct { - name string - procs []ps.Process - exp []ProcessID - }{ - { - name: "executor only", - procs: executorOnly, - exp: []ProcessID{42}, - }, - { - name: "simple line", - procs: simpleLine, - exp: []ProcessID{42, 50, 51, 60}, - }, - { - name: "big tree", - procs: bigTree, - exp: []ProcessID{42, 50, 25, 75, 10, 80, 81, 51, 52}, - }, - } - - for _, tc := range cases { - const executorPID = 42 - t.Run(tc.name, func(t *testing.T) { - lister := func() ([]ps.Process, error) { - return tc.procs, nil - } - result := list(executorPID, lister) - must.SliceContainsAll(t, tc.exp, result.Slice(), - must.Sprintf("exp: %v; got: %v", tc.exp, result), - ) - }) - } -} diff --git a/drivers/shared/executor/procstats/procstats.go b/drivers/shared/executor/procstats/procstats.go index 1f7e87ea0ce..a21e97cefa6 100644 --- a/drivers/shared/executor/procstats/procstats.go +++ b/drivers/shared/executor/procstats/procstats.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/go-set/v3" "github.com/hashicorp/nomad/client/lib/cpustats" "github.com/hashicorp/nomad/plugins/drivers" + "github.com/mitchellh/go-ps" ) var ( @@ -80,3 +81,52 @@ func Aggregate(systemStats *cpustats.Tracker, procStats ProcUsages) *drivers.Tas Pids: procStats, } } + +func list(executorPID int, processes func() ([]ps.Process, error)) (set.Collection[ProcessID], int) { + family := set.From([]int{executorPID}) + + all, err := processes() + if err != nil { + return family, 0 + } + + parents, examined := mapping(all) + examined += gather(family, parents, executorPID) + + return family, examined +} + +func gather(family set.Collection[int], parents map[int]set.Collection[int], parent int) int { + examined := 0 + candidates, ok := parents[parent] + if !ok { + return examined + } + for _, candidate := range candidates.Slice() { + examined++ + family.Insert(candidate) + examined += gather(family, parents, candidate) + } + + return examined +} + +// mapping builds a reverse map of parent to children +func mapping(all []ps.Process) (map[int]set.Collection[int], int) { + + parents := map[int]set.Collection[int]{} + examined := 0 + + for _, candidate := range all { + if candidate != nil { + examined++ + if children, ok := parents[candidate.PPid()]; ok { + children.Insert(candidate.Pid()) + } else { + parents[candidate.PPid()] = set.From([]int{candidate.Pid()}) + } + } + } + + return parents, examined +}