Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for watching processes with IOCP #73

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions src/backend/iocp.zig
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,30 @@ pub const Loop = struct {

// Go through the entries and perform completions callbacks.
for (entries[0..count]) |entry| {
// We retrieve the Completion from the OVERLAPPED pointer as we know it's a part of
// the Completion struct.
const overlapped_ptr: ?*windows.OVERLAPPED = @as(?*windows.OVERLAPPED, @ptrCast(entry.lpOverlapped));
if (overlapped_ptr == null) {
// Probably an async wakeup
continue;
}
var completion = @fieldParentPtr(Completion, "overlapped", overlapped_ptr.?);
const completion: *Completion = if (entry.lpCompletionKey == 0) completion: {
// We retrieve the Completion from the OVERLAPPED pointer as we know it's a part of
// the Completion struct.
const overlapped_ptr: ?*windows.OVERLAPPED = @as(?*windows.OVERLAPPED, @ptrCast(entry.lpOverlapped));
if (overlapped_ptr == null) {
// Probably an async wakeup
continue;
}

break :completion @fieldParentPtr(Completion, "overlapped", overlapped_ptr.?);
} else completion: {
// JobObjects are a special case where the OVERLAPPED_ENTRY fields are interpreted differently.
// When JOBOBJECT_ASSOCIATE_COMPLETION_PORT is used, lpOverlapped actually contains the message
// value, and not the address of the overlapped structure. The Completion pointer is passed
// as the completion key instead.
const completion: *Completion = @ptrFromInt(entry.lpCompletionKey);
completion.result = .{ .job_object = .{
.message = .{
.type = @enumFromInt(entry.dwNumberOfBytesTransferred),
.value = @intFromPtr(entry.lpOverlapped),
},
} };
break :completion completion;
};

wait_rem -|= 1;

Expand Down Expand Up @@ -699,6 +715,34 @@ pub const Loop = struct {
self.asyncs.push(completion);
break :action .{ .async_wait = {} };
},

.job_object => |*v| action: {
if (!v.associated) {
var port = windows.exp.JOBOBJECT_ASSOCIATE_COMPLETION_PORT{
.CompletionKey = @intFromPtr(completion),
.CompletionPort = self.iocp_handle,
};

windows.exp.SetInformationJobObject(
v.job,
.JobObjectAssociateCompletionPortInformation,
&port,
@sizeOf(windows.exp.JOBOBJECT_ASSOCIATE_COMPLETION_PORT),
) catch |err| break :action .{ .result = .{ .job_object = err } };

v.associated = true;
const action = completion.callback(completion.userdata, self, completion, .{ .job_object = .{ .associated = {} } });
switch (action) {
.disarm => {
completion.flags.state = .dead;
return;
},
.rearm => break :action .{ .submitted = {} },
}
}

break :action .{ .submitted = {} };
},
};

switch (action) {
Expand Down Expand Up @@ -1071,6 +1115,8 @@ pub const Completion = struct {
},

.async_wait => .{ .async_wait = {} },

.job_object => self.result.?,
};
}

Expand Down Expand Up @@ -1137,6 +1183,9 @@ pub const OperationType = enum {

/// Wait for an async event to be posted.
async_wait,

/// Receive a notification from a job object associated with a completion port
job_object,
};

/// All the supported operations of this event loop. These are always
Expand Down Expand Up @@ -1225,6 +1274,15 @@ pub const Operation = union(OperationType) {
async_wait: struct {
wakeup: std.atomic.Atomic(bool) = .{ .value = false },
},

job_object: struct {
job: windows.HANDLE,
userdata: ?*anyopaque,

/// Tracks if the job has been associated with the completion port.
/// Do not use this, it is used internally.
associated: bool = false,
},
};

/// The result type based on the operation type. For a callback, the
Expand All @@ -1246,6 +1304,7 @@ pub const Result = union(OperationType) {
timer: TimerError!TimerTrigger,
cancel: CancelError!void,
async_wait: AsyncError!void,
job_object: JobObjectError!JobObjectResult,
};

pub const CancelError = error{
Expand Down Expand Up @@ -1306,6 +1365,21 @@ pub const TimerTrigger = enum {
cancel,
};

pub const JobObjectError = error{
Unexpected,
};

pub const JobObjectResult = union(enum) {
// The job object was associated with the completion port
associated: void,

/// A message was recived on the completion port for this job object
message: struct {
type: windows.exp.JOB_OBJECT_MSG_TYPE,
value: usize,
},
};

/// ReadBuffer are the various options for reading.
pub const ReadBuffer = union(enum) {
/// Read into this slice.
Expand Down
188 changes: 182 additions & 6 deletions src/watcher/process.zig
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ pub fn Process(comptime xev: type) type {

.kqueue => ProcessKqueue(xev),

.iocp => ProcessIocp(xev),

// Unsupported
.wasi_poll => struct {},
.iocp => struct {},
};
}

Expand Down Expand Up @@ -130,7 +131,7 @@ fn ProcessPidFd(comptime xev: type) type {
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self);
pub usingnamespace ProcessTests(xev, Self, &.{ "sh", "-c", "exit 0" }, &.{ "sh", "-c", "exit 42" });
};
}

Expand Down Expand Up @@ -200,17 +201,157 @@ fn ProcessKqueue(comptime xev: type) type {
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self);
pub usingnamespace ProcessTests(xev, Self, &.{ "sh", "-c", "exit 0" }, &.{ "sh", "-c", "exit 42" });
};
}

const windows = @import("../windows.zig");
fn ProcessIocp(comptime xev: type) type {
return struct {
const Self = @This();

pub const WaitError = xev.Sys.JobObjectError;

job: windows.HANDLE,
process: windows.HANDLE,

pub fn init(process: os.pid_t) !Self {
const current_process = windows.kernel32.GetCurrentProcess();

// Duplicate the process handle so we don't rely on the caller keeping it alive
var dup_process: windows.HANDLE = undefined;
const dup_result = windows.kernel32.DuplicateHandle(
current_process,
process,
current_process,
&dup_process,
0,
windows.FALSE,
windows.DUPLICATE_SAME_ACCESS,
);
if (dup_result == 0) return windows.unexpectedError(windows.kernel32.GetLastError());

const job = try windows.exp.CreateJobObject(null, null);
errdefer _ = windows.kernel32.CloseHandle(job);

try windows.exp.AssignProcessToJobObject(job, dup_process);

return .{
.job = job,
.process = dup_process,
};
}

pub fn deinit(self: *Self) void {
_ = windows.kernel32.CloseHandle(self.job);
_ = windows.kernel32.CloseHandle(self.process);
}

pub fn wait(
self: Self,
loop: *xev.Loop,
c: *xev.Completion,
comptime Userdata: type,
userdata: ?*Userdata,
comptime cb: *const fn (
ud: ?*Userdata,
l: *xev.Loop,
c: *xev.Completion,
r: WaitError!u32,
) xev.CallbackAction,
) void {
c.* = .{
.op = .{
.job_object = .{
.job = self.job,
.userdata = self.process,
},
},
.userdata = userdata,
.callback = (struct {
fn callback(
ud: ?*anyopaque,
l_inner: *xev.Loop,
c_inner: *xev.Completion,
r: xev.Result,
) xev.CallbackAction {
if (r.job_object) |result| {
switch (result) {
.associated => {
// There was a period of time between when the job object was created
// and when it was associated with the completion port. We may have
// missed a notification, so check if it's still alive.

var exit_code: windows.DWORD = undefined;
const process: windows.HANDLE = @ptrCast(c_inner.op.job_object.userdata);
const has_code = windows.kernel32.GetExitCodeProcess(process, &exit_code) != 0;
if (!has_code) std.log.warn("unable to get exit code for process={}", .{windows.kernel32.GetLastError()});
if (exit_code == windows.exp.STILL_ACTIVE) return .rearm;

return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
exit_code,
});
},
.message => |message| {
const result_inner = switch (message.type) {
.JOB_OBJECT_MSG_EXIT_PROCESS,
.JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS,
=> b: {
const process: windows.HANDLE = @ptrCast(c_inner.op.job_object.userdata);
const pid = windows.exp.kernel32.GetProcessId(process);
if (pid == 0) break :b WaitError.Unexpected;
if (message.value != pid) return .rearm;

var exit_code: windows.DWORD = undefined;
const has_code = windows.kernel32.GetExitCodeProcess(process, &exit_code) != 0;
if (!has_code) std.log.warn("unable to get exit code for process={}", .{windows.kernel32.GetLastError()});
break :b if (has_code) exit_code else WaitError.Unexpected;
},
else => return .rearm,
};

return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
result_inner
});
},
}
} else |err| {
return @call(.always_inline, cb, .{
common.userdataValue(Userdata, ud),
l_inner,
c_inner,
err,
});
}
}
}).callback,
};
loop.add(c);
}

/// Common tests
pub usingnamespace ProcessTests(xev, Self, &.{ "cmd.exe", "/C", "exit 0" }, &.{ "cmd.exe", "/C", "exit 42" });
};
}

fn ProcessTests(comptime xev: type, comptime Impl: type) type {
fn ProcessTests(
comptime xev: type,
comptime Impl: type,
comptime argv_0: []const []const u8,
comptime argv_42: []const []const u8,
) type {
return struct {
test "process wait" {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(&.{ "sh", "-c", "exit 0" }, alloc);
var child = std.ChildProcess.init(argv_0, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
Expand Down Expand Up @@ -243,7 +384,7 @@ fn ProcessTests(comptime xev: type, comptime Impl: type) type {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(&.{ "sh", "-c", "exit 42" }, alloc);
var child = std.ChildProcess.init(argv_42, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
Expand Down Expand Up @@ -271,5 +412,40 @@ fn ProcessTests(comptime xev: type, comptime Impl: type) type {
try loop.run(.until_done);
try testing.expectEqual(@as(u32, 42), code.?);
}

test "process wait on a process that already exited" {
const testing = std.testing;
const alloc = testing.allocator;

var child = std.ChildProcess.init(argv_0, alloc);
try child.spawn();

var loop = try xev.Loop.init(.{});
defer loop.deinit();

var p = try Impl.init(child.id);
defer p.deinit();

_ = try child.wait();

// Wait
var code: ?u32 = null;
var c_wait: xev.Completion = undefined;
p.wait(&loop, &c_wait, ?u32, &code, (struct {
fn callback(
ud: ?*?u32,
_: *xev.Loop,
_: *xev.Completion,
r: Impl.WaitError!u32,
) xev.CallbackAction {
ud.?.* = r catch unreachable;
return .disarm;
}
}).callback);

// Wait for wake
try loop.run(.until_done);
try testing.expectEqual(@as(u32, 0), code.?);
}
};
}
Loading