Skip to content

Commit

Permalink
Add support for watching processes with IOCP (#73)
Browse files Browse the repository at this point in the history
These changes add support for receiving events from JobObjects
(https://learn.microsoft.com/en-us/windows/win32/procthread/job-objects)
in the IOCP backend.

This feature is then used to implement an IOCP process watcher.

JobObjects are unique with the way they interact with the
`OVERLAPPED_ENTRY` result structure - they completely repurpose all the
fields to mean different things. This means I fill in the completion
result directly before `perform` is called. Let me know if this is an
issue, I could rework `perform` to accept the entry itself.

Remaining TODOs:
- [x] Figure out how to address the race condition that exists if the
process exits before `wait` is called - there will be no completion
events on the JobObject in this case
  • Loading branch information
mitchellh authored Oct 31, 2023
2 parents 5ecbc87 + 908155e commit 1b46c2d
Show file tree
Hide file tree
Showing 3 changed files with 406 additions and 14 deletions.
90 changes: 82 additions & 8 deletions src/backend/iocp.zig
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,30 @@ pub const Loop = struct {

// Go through the entries and perform completions callbacks.
for (entries[0..count]) |entry| {
// We retrieve the Completion from the OVERLAPPED pointer as we know it's a part of
// the Completion struct.
const overlapped_ptr: ?*windows.OVERLAPPED = @as(?*windows.OVERLAPPED, @ptrCast(entry.lpOverlapped));
if (overlapped_ptr == null) {
// Probably an async wakeup
continue;
}
var completion = @fieldParentPtr(Completion, "overlapped", overlapped_ptr.?);
const completion: *Completion = if (entry.lpCompletionKey == 0) completion: {
// We retrieve the Completion from the OVERLAPPED pointer as we know it's a part of
// the Completion struct.
const overlapped_ptr: ?*windows.OVERLAPPED = @as(?*windows.OVERLAPPED, @ptrCast(entry.lpOverlapped));
if (overlapped_ptr == null) {
// Probably an async wakeup
continue;
}

break :completion @fieldParentPtr(Completion, "overlapped", overlapped_ptr.?);
} else completion: {
// JobObjects are a special case where the OVERLAPPED_ENTRY fields are interpreted differently.
// When JOBOBJECT_ASSOCIATE_COMPLETION_PORT is used, lpOverlapped actually contains the message
// value, and not the address of the overlapped structure. The Completion pointer is passed
// as the completion key instead.
const completion: *Completion = @ptrFromInt(entry.lpCompletionKey);
completion.result = .{ .job_object = .{
.message = .{
.type = @enumFromInt(entry.dwNumberOfBytesTransferred),
.value = @intFromPtr(entry.lpOverlapped),
},
} };
break :completion completion;
};

wait_rem -|= 1;

Expand Down Expand Up @@ -699,6 +715,34 @@ pub const Loop = struct {
self.asyncs.push(completion);
break :action .{ .async_wait = {} };
},

.job_object => |*v| action: {
if (!v.associated) {
var port = windows.exp.JOBOBJECT_ASSOCIATE_COMPLETION_PORT{
.CompletionKey = @intFromPtr(completion),
.CompletionPort = self.iocp_handle,
};

windows.exp.SetInformationJobObject(
v.job,
.JobObjectAssociateCompletionPortInformation,
&port,
@sizeOf(windows.exp.JOBOBJECT_ASSOCIATE_COMPLETION_PORT),
) catch |err| break :action .{ .result = .{ .job_object = err } };

v.associated = true;
const action = completion.callback(completion.userdata, self, completion, .{ .job_object = .{ .associated = {} } });
switch (action) {
.disarm => {
completion.flags.state = .dead;
return;
},
.rearm => break :action .{ .submitted = {} },
}
}

break :action .{ .submitted = {} };
},
};

switch (action) {
Expand Down Expand Up @@ -1071,6 +1115,8 @@ pub const Completion = struct {
},

.async_wait => .{ .async_wait = {} },

.job_object => self.result.?,
};
}

Expand Down Expand Up @@ -1137,6 +1183,9 @@ pub const OperationType = enum {

/// Wait for an async event to be posted.
async_wait,

/// Receive a notification from a job object associated with a completion port
job_object,
};

/// All the supported operations of this event loop. These are always
Expand Down Expand Up @@ -1225,6 +1274,15 @@ pub const Operation = union(OperationType) {
async_wait: struct {
wakeup: std.atomic.Atomic(bool) = .{ .value = false },
},

job_object: struct {
job: windows.HANDLE,
userdata: ?*anyopaque,

/// Tracks if the job has been associated with the completion port.
/// Do not use this, it is used internally.
associated: bool = false,
},
};

/// The result type based on the operation type. For a callback, the
Expand All @@ -1246,6 +1304,7 @@ pub const Result = union(OperationType) {
timer: TimerError!TimerTrigger,
cancel: CancelError!void,
async_wait: AsyncError!void,
job_object: JobObjectError!JobObjectResult,
};

pub const CancelError = error{
Expand Down Expand Up @@ -1306,6 +1365,21 @@ pub const TimerTrigger = enum {
cancel,
};

pub const JobObjectError = error{
Unexpected,
};

pub const JobObjectResult = union(enum) {
// The job object was associated with the completion port
associated: void,

/// A message was recived on the completion port for this job object
message: struct {
type: windows.exp.JOB_OBJECT_MSG_TYPE,
value: usize,
},
};

/// ReadBuffer are the various options for reading.
pub const ReadBuffer = union(enum) {
/// Read into this slice.
Expand Down
188 changes: 182 additions & 6 deletions src/watcher/process.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ pub fn Process(comptime xev: type) type {

.kqueue => ProcessKqueue(xev),

.iocp => ProcessIocp(xev),

// Unsupported
.wasi_poll => struct {},
.iocp => struct {},
};
}

Expand Down Expand Up @@ -130,7 +131,7 @@ fn ProcessPidFd(comptime xev: type) type {
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self);
pub usingnamespace ProcessTests(xev, Self, &.{ "sh", "-c", "exit 0" }, &.{ "sh", "-c", "exit 42" });
};
}

Expand Down Expand Up @@ -200,17 +201,157 @@ fn ProcessKqueue(comptime xev: type) type {
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self);
pub usingnamespace ProcessTests(xev, Self, &.{ "sh", "-c", "exit 0" }, &.{ "sh", "-c", "exit 42" });
};
}

const windows = @import("../windows.zig");
fn ProcessIocp(comptime xev: type) type {
return struct {
const Self = @This();

pub const WaitError = xev.Sys.JobObjectError;

job: windows.HANDLE,
process: windows.HANDLE,

pub fn init(process: os.pid_t) !Self {
const current_process = windows.kernel32.GetCurrentProcess();

// Duplicate the process handle so we don't rely on the caller keeping it alive
var dup_process: windows.HANDLE = undefined;
const dup_result = windows.kernel32.DuplicateHandle(
current_process,
process,
current_process,
&dup_process,
0,
windows.FALSE,
windows.DUPLICATE_SAME_ACCESS,
);
if (dup_result == 0) return windows.unexpectedError(windows.kernel32.GetLastError());

const job = try windows.exp.CreateJobObject(null, null);
errdefer _ = windows.kernel32.CloseHandle(job);

try windows.exp.AssignProcessToJobObject(job, dup_process);

return .{
.job = job,
.process = dup_process,
};
}

pub fn deinit(self: *Self) void {
_ = windows.kernel32.CloseHandle(self.job);
_ = windows.kernel32.CloseHandle(self.process);
}

pub fn wait(
self: Self,
loop: *xev.Loop,
c: *xev.Completion,
comptime Userdata: type,
userdata: ?*Userdata,
comptime cb: *const fn (
ud: ?*Userdata,
l: *xev.Loop,
c: *xev.Completion,
r: WaitError!u32,
) xev.CallbackAction,
) void {
c.* = .{
.op = .{
.job_object = .{
.job = self.job,
.userdata = self.process,
},
},
.userdata = userdata,
.callback = (struct {
fn callback(
ud: ?*anyopaque,
l_inner: *xev.Loop,
c_inner: *xev.Completion,
r: xev.Result,
) xev.CallbackAction {
if (r.job_object) |result| {
switch (result) {
.associated => {
// There was a period of time between when the job object was created
// and when it was associated with the completion port. We may have
// missed a notification, so check if it's still alive.

var exit_code: windows.DWORD = undefined;
const process: windows.HANDLE = @ptrCast(c_inner.op.job_object.userdata);
const has_code = windows.kernel32.GetExitCodeProcess(process, &exit_code) != 0;
if (!has_code) std.log.warn("unable to get exit code for process={}", .{windows.kernel32.GetLastError()});
if (exit_code == windows.exp.STILL_ACTIVE) return .rearm;

return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
exit_code,
});
},
.message => |message| {
const result_inner = switch (message.type) {
.JOB_OBJECT_MSG_EXIT_PROCESS,
.JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS,
=> b: {
const process: windows.HANDLE = @ptrCast(c_inner.op.job_object.userdata);
const pid = windows.exp.kernel32.GetProcessId(process);
if (pid == 0) break :b WaitError.Unexpected;
if (message.value != pid) return .rearm;

var exit_code: windows.DWORD = undefined;
const has_code = windows.kernel32.GetExitCodeProcess(process, &exit_code) != 0;
if (!has_code) std.log.warn("unable to get exit code for process={}", .{windows.kernel32.GetLastError()});
break :b if (has_code) exit_code else WaitError.Unexpected;
},
else => return .rearm,
};

return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
result_inner
});
},
}
} else |err| {
return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
err,
});
}
}
}).callback,
};
loop.add(c);
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self, &.{ "cmd.exe", "/C", "exit 0" }, &.{ "cmd.exe", "/C", "exit 42" });
};
}

fn ProcessTests(comptime xev: type, comptime Impl: type) type {
fn ProcessTests(
comptime xev: type,
comptime Impl: type,
comptime argv_0: []const []const u8,
comptime argv_42: []const []const u8,
) type {
return struct {
test "process wait" {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(&.{ "sh", "-c", "exit 0" }, alloc);
var child = std.ChildProcess.init(argv_0, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
Expand Down Expand Up @@ -243,7 +384,7 @@ fn ProcessTests(comptime xev: type, comptime Impl: type) type {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(&.{ "sh", "-c", "exit 42" }, alloc);
var child = std.ChildProcess.init(argv_42, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
Expand Down Expand Up @@ -271,5 +412,40 @@ fn ProcessTests(comptime xev: type, comptime Impl: type) type {
try loop.run(.until_done);
try testing.expectEqual(@as(u32, 42), code.?);
}

test "process wait on a process that already exited" {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(argv_0, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
defer loop.deinit();

var p = try Impl.init(child.id);
defer p.deinit();

_ = try child.wait();

// Wait
var code: ?u32 = null;
var c_wait: xev.Completion = undefined;
p.wait(&loop, &c_wait, ?u32, &code, (struct {
fn callback(
ud: ?*?u32,
_: *xev.Loop,
_: *xev.Completion,
r: Impl.WaitError!u32,
) xev.CallbackAction {
ud.?.* = r catch unreachable;
return .disarm;
}
}).callback);

// Wait for wake
try loop.run(.until_done);
try testing.expectEqual(@as(u32, 0), code.?);
}
};
}
Loading

0 comments on commit 1b46c2d

Please sign in to comment.