Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

windows: fix inefficient gathering of task processes #20619

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/20619.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
windows: Fixed a regression where scanning task processes was inefficient
```
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func withNetworkIsolation(f func() error, _ *drivers.NetworkIsolationSpec) error

func setCmdUser(*exec.Cmd, string) error { return nil }

func (e *UniversalExecutor) ListProcesses() *set.Set[int] {
func (e *UniversalExecutor) ListProcesses() set.Collection[int] {
return procstats.List(e.childCmd.Process.Pid)
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func NewExecutorWithIsolation(logger hclog.Logger, compute cpustats.Compute) Exe
return le
}

func (l *LibcontainerExecutor) ListProcesses() *set.Set[int] {
func (l *LibcontainerExecutor) ListProcesses() set.Collection[int] {
return procstats.List(l.command)
}

Expand Down
2 changes: 1 addition & 1 deletion drivers/shared/executor/executor_universal_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (e *UniversalExecutor) setSubCmdCgroup(cmd *exec.Cmd, cgroup string) (func(
}
}

func (e *UniversalExecutor) ListProcesses() *set.Set[procstats.ProcessID] {
func (e *UniversalExecutor) ListProcesses() set.Collection[procstats.ProcessID] {
return procstats.List(e.command)
}

Expand Down
4 changes: 2 additions & 2 deletions drivers/shared/executor/procstats/list_default.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !linux
//go:build !linux && !windows

package procstats

Expand All @@ -15,7 +15,7 @@ import (
)

// List the process tree starting at the given executorPID
func List(executorPID int) *set.Set[ProcessID] {
func List(executorPID int) set.Collection[ProcessID] {
result := set.New[ProcessID](10)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
Expand Down
73 changes: 73 additions & 0 deletions drivers/shared/executor/procstats/list_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build windows

package procstats

import (
"github.com/hashicorp/go-set/v2"
"github.com/mitchellh/go-ps"
)

func gather(procs map[int]ps.Process, family set.Collection[int], root int, candidate ps.Process) bool {
if candidate == nil {
return false
}
pid := candidate.Pid()
if pid == 0 || pid == 1 {
return false
}
if pid == root {
return true
}
parent := procs[candidate.PPid()]
result := gather(procs, family, root, parent)
if result {
family.Insert(pid)
}
return result
}

func mapping(all []ps.Process) map[int]ps.Process {
result := make(map[int]ps.Process)
for _, process := range all {
result[process.Pid()] = process
}
return result
}

func list(executorPID int, processes func() ([]ps.Process, error)) set.Collection[ProcessID] {
family := set.From([]int{executorPID})

all, err := processes()
if err != nil {
return set.New[ProcessID](0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we at least return the executor PID here? Or is the thinking that if this call fails we can't trust any of the results?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good idea, let me add that

}

m := mapping(all)
for _, candidate := range all {
gather(m, family, executorPID, candidate)
}

return family
}

// List will scan the process table and return a set of the process family
// tree starting with executorPID as the root.
//
// The implementation here specifically avoids using more than one system
// call. Unlike on Linux where we just read a cgroup, on Windows we must build
// the tree manually. We do so knowing only the child->parent relationships.
//
// So this turns into a fun leet code problem, where we invert the tree using
// only a bucket of edges pointing in the wrong direction. Basically we just
// iterate every process, recursively follow its parent, and determine whether
// executorPID is an ancestor.
//
// See https://github.com/hashicorp/nomad/issues/20042 as an example of what
// happens when you use syscalls to work your way from the root down to its
// descendants.
func List(executorPID int) set.Collection[ProcessID] {
return list(executorPID, ps.Processes)
}
103 changes: 103 additions & 0 deletions drivers/shared/executor/procstats/list_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build windows

package procstats

import (
"testing"

"github.com/mitchellh/go-ps"
"github.com/shoenig/test/must"
)

type mockProcess struct {
pid int
ppid int
}

func (p *mockProcess) Pid() int {
return p.pid
}

func (p *mockProcess) PPid() int {
return p.ppid
}

func (p *mockProcess) Executable() string {
return ""
}

func mockProc(pid, ppid int) *mockProcess {
return &mockProcess{pid: pid, ppid: ppid}
}

var (
executorOnly = []ps.Process{
mockProc(1, 1),
mockProc(42, 1),
}

simpleLine = []ps.Process{
mockProc(1, 1),
mockProc(50, 42),
mockProc(42, 1),
mockProc(51, 50),
mockProc(101, 100),
mockProc(60, 51),
mockProc(100, 1),
}

bigTree = []ps.Process{
mockProc(1, 1),
mockProc(25, 50),
mockProc(100, 1),
mockProc(75, 50),
mockProc(10, 25),
mockProc(80, 75),
mockProc(81, 75),
mockProc(51, 50),
mockProc(42, 1),
mockProc(101, 100),
mockProc(52, 51),
mockProc(50, 42),
}
)

func Test_list(t *testing.T) {
cases := []struct {
name string
procs []ps.Process
exp []ProcessID
}{
{
name: "executor only",
procs: executorOnly,
exp: []ProcessID{42},
},
{
name: "simple line",
procs: simpleLine,
exp: []ProcessID{42, 50, 51, 60},
},
{
name: "big tree",
procs: bigTree,
exp: []ProcessID{42, 50, 25, 75, 10, 80, 81, 51, 52},
},
}

for _, tc := range cases {
const executorPID = 42
t.Run(tc.name, func(t *testing.T) {
lister := func() ([]ps.Process, error) {
return tc.procs, nil
}
result := list(executorPID, lister)
must.SliceContainsAll(t, tc.exp, result.Slice(),
must.Sprintf("exp: %v; got: %v", tc.exp, result),
)
})
}
}
2 changes: 1 addition & 1 deletion drivers/shared/executor/procstats/procstats.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ProcessStats interface {
// A ProcessList is anything (i.e. a task driver) that implements ListProcesses
// for gathering the list of process IDs associated with a task.
type ProcessList interface {
ListProcesses() *set.Set[ProcessID]
ListProcesses() set.Collection[ProcessID]
}

// Aggregate combines a given ProcUsages with the Tracker for the Client.
Expand Down
Loading