Skip to content

Commit

Permalink
windows: fix inefficient gathering of task processes
Browse files Browse the repository at this point in the history
  • Loading branch information
shoenig committed May 17, 2024
1 parent e9d6c39 commit 5f89ecb
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .changelog/20619.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
windows: Fixed a regression where scanning task processes was inefficient
```
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func withNetworkIsolation(f func() error, _ *drivers.NetworkIsolationSpec) error

func setCmdUser(*exec.Cmd, string) error { return nil }

func (e *UniversalExecutor) ListProcesses() *set.Set[int] {
func (e *UniversalExecutor) ListProcesses() set.Collection[int] {
return procstats.List(e.childCmd.Process.Pid)
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func NewExecutorWithIsolation(logger hclog.Logger, compute cpustats.Compute) Exe
return le
}

func (l *LibcontainerExecutor) ListProcesses() *set.Set[int] {
func (l *LibcontainerExecutor) ListProcesses() set.Collection[int] {
return procstats.List(l.command)
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_universal_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (e *UniversalExecutor) setSubCmdCgroup(cmd *exec.Cmd, cgroup string) (func(
}
}

func (e *UniversalExecutor) ListProcesses() *set.Set[procstats.ProcessID] {
func (e *UniversalExecutor) ListProcesses() set.Collection[procstats.ProcessID] {
return procstats.List(e.command)
}

Expand Down
4 changes: 2 additions & 2 deletions drivers/shared/executor/procstats/list_default.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !linux
//go:build !linux && !windows

package procstats

Expand All @@ -15,7 +15,7 @@ import (
)

// List the process tree starting at the given executorPID
func List(executorPID int) *set.Set[ProcessID] {
func List(executorPID int) set.Collection[ProcessID] {
result := set.New[ProcessID](10)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down
73 changes: 73 additions & 0 deletions drivers/shared/executor/procstats/list_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build windows

package procstats

import (
"github.com/hashicorp/go-set/v2"
"github.com/mitchellh/go-ps"
)

func gather(procs map[int]ps.Process, family set.Collection[int], root int, candidate ps.Process) bool {
if candidate == nil {
return false
}
pid := candidate.Pid()
if pid == 0 || pid == 1 {
return false
}
if pid == root {
return true
}
parent := procs[candidate.PPid()]
result := gather(procs, family, root, parent)
if result {
family.Insert(pid)
}
return result
}

func mapping(all []ps.Process) map[int]ps.Process {
result := make(map[int]ps.Process)
for _, process := range all {
result[process.Pid()] = process
}
return result
}

func list(executorPID int, processes func() ([]ps.Process, error)) set.Collection[ProcessID] {
family := set.From([]int{executorPID})

all, err := processes()
if err != nil {
return set.New[ProcessID](0)
}

m := mapping(all)
for _, candidate := range all {
gather(m, family, executorPID, candidate)
}

return family
}

// List will scan the process table and return a set of the process family
// tree starting with executorPID as the root.
//
// The implementation here specifically avoids using more than one system
// call. Unlike on Linux where we just read a cgroup, on Windows we must build
// the tree manually. We do so knowing only the child->parent relationships.
//
// So this turns into a fun leet code problem, where we invert the tree using
// only a bucket of edges pointing in the wrong direction. Basically we just
// iterate every process, recursively follow its parent, and determine whether
// executorPID is an ancestor.
//
// See https://github.com/hashicorp/nomad/issues/20042 as an example of what
// happens when you use syscalls to work your way from the root down to its
// descendants.
func List(executorPID int) set.Collection[ProcessID] {
return list(executorPID, ps.Processes)
}
103 changes: 103 additions & 0 deletions drivers/shared/executor/procstats/list_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build windows

package procstats

import (
"testing"

"github.com/mitchellh/go-ps"
"github.com/shoenig/test/must"
)

type mockProcess struct {
pid int
ppid int
}

func (p *mockProcess) Pid() int {
return p.pid
}

func (p *mockProcess) PPid() int {
return p.ppid
}

func (p *mockProcess) Executable() string {
return ""
}

func mockProc(pid, ppid int) *mockProcess {
return &mockProcess{pid: pid, ppid: ppid}
}

var (
executorOnly = []ps.Process{
mockProc(1, 1),
mockProc(42, 1),
}

simpleLine = []ps.Process{
mockProc(1, 1),
mockProc(50, 42),
mockProc(42, 1),
mockProc(51, 50),
mockProc(101, 100),
mockProc(60, 51),
mockProc(100, 1),
}

bigTree = []ps.Process{
mockProc(1, 1),
mockProc(25, 50),
mockProc(100, 1),
mockProc(75, 50),
mockProc(10, 25),
mockProc(80, 75),
mockProc(81, 75),
mockProc(51, 50),
mockProc(42, 1),
mockProc(101, 100),
mockProc(52, 51),
mockProc(50, 42),
}
)

func Test_list(t *testing.T) {
cases := []struct {
name string
procs []ps.Process
exp []ProcessID
}{
{
name: "executor only",
procs: executorOnly,
exp: []ProcessID{42},
},
{
name: "simple line",
procs: simpleLine,
exp: []ProcessID{42, 50, 51, 60},
},
{
name: "big tree",
procs: bigTree,
exp: []ProcessID{42, 50, 25, 75, 10, 80, 81, 51, 52},
},
}

for _, tc := range cases {
const executorPID = 42
t.Run(tc.name, func(t *testing.T) {
lister := func() ([]ps.Process, error) {
return tc.procs, nil
}
result := list(executorPID, lister)
must.SliceContainsAll(t, tc.exp, result.Slice(),
must.Sprintf("exp: %v; got: %v", tc.exp, result),
)
})
}
}
2 changes: 1 addition & 1 deletion drivers/shared/executor/procstats/procstats.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ProcessStats interface {
// A ProcessList is anything (i.e. a task driver) that implements ListProcesses
// for gathering the list of process IDs associated with a task.
type ProcessList interface {
ListProcesses() *set.Set[ProcessID]
ListProcesses() set.Collection[ProcessID]
}

// Aggregate combines a given ProcUsages with the Tracker for the Client.
Expand Down

0 comments on commit 5f89ecb

Please sign in to comment.