diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..6313b56c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto eol=lf diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f27e6822 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# This file is for zig-specific build artifacts. +# If you have OS-specific or editor-specific files to ignore, +# such as *.swp or .DS_Store, put those in your global +# ~/.gitignore and put this in your ~/.gitconfig: +# +# [core] +# excludesfile = ~/.gitignore +# +# Cheers! +# -andrewrk + +.zig-cache/ +zig-cache/ +zig-out/ +/release/ +/debug/ +/build/ +/build-*/ +/docgen_tmp/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..e69de29b diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..f0962b15 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (Expat) + +Copyright (c) Vetoniemi Jari Juhani + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/build.zig b/build.zig new file mode 100644 index 00000000..6dd313c7 --- /dev/null +++ b/build.zig @@ -0,0 +1,53 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const aio = b.addModule("aio", .{ + .root_source_file = b.path("src/aio.zig"), + .target = target, + .optimize = optimize, + }); + + const coro = b.addModule("coro", .{ + .root_source_file = b.path("src/coro.zig"), + .target = target, + .optimize = optimize, + }); + coro.addImport("aio", aio); + + const run_all = b.step("run", "Run all examples"); + inline for (.{ + .aio_dynamic, + .aio_static, + .coro, + }) |example| { + const exe = b.addExecutable(.{ + .name = @tagName(example), + .root_source_file = b.path("examples/" ++ @tagName(example) ++ ".zig"), + .target = target, + .optimize = optimize, + }); + exe.root_module.addImport("aio", aio); + exe.root_module.addImport("coro", coro); + const cmd = b.addRunArtifact(exe); + cmd.step.dependOn(b.getInstallStep()); + if (b.args) |args| cmd.addArgs(args); + const run = b.step(@tagName(example), "Run " ++ @tagName(example) ++ " example"); + run.dependOn(&cmd.step); + run_all.dependOn(&cmd.step); + } + + const test_step = b.step("test", "Run unit tests"); + inline for (.{ .aio, .coro }) |mod| { + const tst = b.addTest(.{ + .root_source_file = b.path("src/" ++ @tagName(mod) ++ ".zig"), + .target = target, + .optimize = optimize, + }); + if (mod == .coro) tst.root_module.addImport("aio", aio); + const run = b.addRunArtifact(tst); + test_step.dependOn(&run.step); + } +} diff --git a/build.zig.zon b/build.zig.zon new file mode 100644 index 00000000..fdf1922e --- /dev/null +++ b/build.zig.zon @@ -0,0 +1,13 @@ +.{ + .name = "zig-aio", + .version = "0.0.0", + .dependencies = .{}, + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + "examples", + "LICENSE", + "README.md", + }, +} diff --git a/examples/aio_dynamic.zig b/examples/aio_dynamic.zig new file mode 100644 index 00000000..5973a2e7 --- /dev/null +++ b/examples/aio_dynamic.zig @@ -0,0 +1,38 @@ +const std = @import("std"); +const aio = @import("aio"); +const log = std.log.scoped(.aio_dynamic); + +pub fn main() !void { + var f = try std.fs.cwd().openFile("flake.nix", .{}); + defer f.close(); + var buf: [4096]u8 = undefined; + var len: usize = 0; + + var f2 = try std.fs.cwd().openFile("build.zig.zon", .{}); + defer f2.close(); + var buf2: [4096]u8 = undefined; + var len2: usize = 0; + + const allocator = std.heap.page_allocator; + var work = try aio.Dynamic.init(allocator, 16); + defer work.deinit(allocator); + + try work.queue(.{ + aio.Read{ + .file = f, + .buffer = &buf, + .out_read = &len, + }, + aio.Read{ + .file = f2, + .buffer = &buf2, + .out_read = &len2, + }, + }); + + const ret = try work.complete(.blocking); + + log.info("{s}", .{buf[0..len]}); + log.info("{s}", .{buf2[0..len2]}); + log.info("{}", .{ret}); +} diff --git a/examples/aio_static.zig b/examples/aio_static.zig new file mode 100644 index 00000000..17b9b1bd --- /dev/null +++ b/examples/aio_static.zig @@ -0,0 +1,32 @@ +const std = @import("std"); +const aio = @import("aio"); +const log = std.log.scoped(.aio_static); + +pub fn main() !void { + var f = try std.fs.cwd().openFile("flake.nix", .{}); + defer f.close(); + var buf: [4096]u8 = undefined; + var len: usize = 0; + + var f2 = try std.fs.cwd().openFile("build.zig.zon", .{}); + defer f2.close(); + var buf2: [4096]u8 = undefined; + var len2: usize = 0; + + const ret = try aio.batch(.{ + aio.Read{ + .file = f, + .buffer = &buf, + .out_read = &len, + }, + aio.Read{ + .file = f2, + .buffer = &buf2, + .out_read = &len2, + }, + }); + + log.info("{s}", .{buf[0..len]}); + log.info("{s}", .{buf2[0..len2]}); + log.info("{}", .{ret}); +} diff --git a/examples/coro.zig b/examples/coro.zig new file mode 100644 index 00000000..e9577323 --- /dev/null +++ b/examples/coro.zig @@ -0,0 +1,89 @@ +const std = @import("std"); +const aio = @import("aio"); +const coro = @import("coro"); +const log = std.log.scoped(.coro_aio); + +pub const aio_coro_options: coro.Options = .{ + .debug = false, // set to true to enable debug logs +}; + +fn server() !void { + var socket: std.posix.socket_t = undefined; + try coro.io.single(aio.Socket{ + .domain = std.posix.AF.INET, + .flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC, + .protocol = std.posix.IPPROTO.TCP, + .out_socket = &socket, + }); + + const address = std.net.Address.initIp4(.{ 0, 0, 0, 0 }, 1327); + try std.posix.setsockopt(socket, std.posix.SOL.SOCKET, std.posix.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + try std.posix.setsockopt(socket, std.posix.SOL.SOCKET, std.posix.SO.REUSEPORT, &std.mem.toBytes(@as(c_int, 1))); + try std.posix.bind(socket, &address.any, address.getOsSockLen()); + try std.posix.listen(socket, 128); + + var client_sock: std.posix.socket_t = undefined; + try coro.io.single(aio.Accept{ .socket = socket, .out_socket = &client_sock }); + + var buf: [1024]u8 = undefined; + var len: usize = 0; + try coro.io.multi(.{ + aio.Send{ .socket = client_sock, .buffer = "hey ", .link_next = true }, + aio.Send{ .socket = client_sock, .buffer = "I'm doing multiple IO ops at once ", .link_next = true }, + aio.Send{ .socket = client_sock, .buffer = "how cool is that? ", .link_next = true }, + aio.Recv{ .socket = client_sock, .buffer = &buf, .out_read = &len }, + }); + + log.warn("got reply from client: {s}", .{buf[0..len]}); + try coro.io.multi(.{ + aio.Send{ .socket = client_sock, .buffer = "ok bye", .link_next = true }, + aio.CloseSocket{ .socket = client_sock, .link_next = true }, + aio.CloseSocket{ .socket = socket }, + }); +} + +fn client() !void { + log.info("waiting 2 secs, to give time for the server to spin up", .{}); + try coro.io.single(aio.Timeout{ .ts = .{ .sec = 2, .nsec = 0 } }); + + var socket: std.posix.socket_t = undefined; + try coro.io.single(aio.Socket{ + .domain = std.posix.AF.INET, + .flags = std.posix.SOCK.STREAM | std.posix.SOCK.CLOEXEC, + .protocol = std.posix.IPPROTO.TCP, + .out_socket = &socket, + }); + + const address = std.net.Address.initIp4(.{ 127, 0, 0, 1 }, 1327); + try coro.io.single(aio.Connect{ + .socket = socket, + .addr = &address.any, + .addrlen = address.getOsSockLen(), + .link_next = true, + }); + + while (true) { + var buf: [1024]u8 = undefined; + var len: usize = 0; + try coro.io.single(aio.Recv{ .socket = socket, .buffer = &buf, .out_read = &len }); + log.warn("got reply from server: {s}", .{buf[0..len]}); + if (std.mem.indexOf(u8, buf[0..len], "how cool is that?")) |_| break; + } + + try coro.io.single(aio.Send{ .socket = socket, .buffer = "dude, I don't care" }); + + var buf: [1024]u8 = undefined; + var len: usize = 0; + try coro.io.single(aio.Recv{ .socket = socket, .buffer = &buf, .out_read = &len }); + log.warn("got final words from server: {s}", .{buf[0..len]}); +} + +pub fn main() !void { + var gpa: std.heap.GeneralPurposeAllocator(.{}) = .{}; + defer _ = gpa.deinit(); + var scheduler = try coro.Scheduler.init(gpa.allocator(), .{}); + defer scheduler.deinit(); + _ = try scheduler.spawn(server, .{}, .{}); + _ = try scheduler.spawn(client, .{}, .{}); + try scheduler.run(); +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..57ae407a --- /dev/null +++ b/flake.lock @@ -0,0 +1,78 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1717822988, + "narHash": "sha256-HO0j+jyYv5BKsF+QlxLW223Aj4GkRp11jG4XLvAO/zI=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "9e4e5d96f16f7539ae8020b4b1ea71ddba9b6f3d", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "zig2nix": "zig2nix" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "zig2nix": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1718327908, + "narHash": "sha256-E3/3urkZTEIaV2afNDniC9Z7Ksxyn6AaVRnrrxxMPZE=", + "owner": "Cloudef", + "repo": "zig2nix", + "rev": "d846577330acb4571e837d5c631beeb555988dd0", + "type": "github" + }, + "original": { + "owner": "Cloudef", + "repo": "zig2nix", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..0855d197 --- /dev/null +++ b/flake.nix @@ -0,0 +1,44 @@ +{ + description = "Zig project flake"; + + inputs = { + zig2nix.url = "github:Cloudef/zig2nix"; + }; + + outputs = { zig2nix, ... }: let + flake-utils = zig2nix.inputs.flake-utils; + in (flake-utils.lib.eachDefaultSystem (system: let + # Zig flake helper + # Check the flake.nix in zig2nix project for more options: + # + env = zig2nix.outputs.zig-env.${system} { zig = zig2nix.outputs.packages.${system}.zig.master.bin; }; + system-triple = env.lib.zigTripleFromString system; + in with builtins; with env.lib; with env.pkgs.lib; rec { + # nix run . + apps.default = env.app [] "zig build run -- \"$@\""; + + # nix run .#build + apps.build = env.app [] "zig build \"$@\""; + + # nix run .#test + apps.test = env.app [] "zig build test -- \"$@\""; + + # nix run .#docs + apps.docs = env.app [] "zig build docs -- \"$@\""; + + # nix run .#deps + apps.deps = env.showExternalDeps; + + # nix run .#zon2json + apps.zon2json = env.app [env.zon2json] "zon2json \"$@\""; + + # nix run .#zon2json-lock + apps.zon2json-lock = env.app [env.zon2json-lock] "zon2json-lock \"$@\""; + + # nix run .#zon2nix + apps.zon2nix = env.app [env.zon2nix] "zon2nix \"$@\""; + + # nix develop + devShells.default = env.mkShell {}; + })); +} diff --git a/src/aio.zig b/src/aio.zig new file mode 100644 index 00000000..b05090b8 --- /dev/null +++ b/src/aio.zig @@ -0,0 +1,124 @@ +//! Basic io-uring -like asynchronous IO API +//! It is possible to both dynamically and statically queue IO work to be executed in a asynchronous fashion +//! On linux this is a very shim wrapper around `io_uring`, on other systems there might be more overhead + +const std = @import("std"); + +pub const InitError = error{ + Overflow, + OutOfMemory, + PermissionDenied, + ProcessQuotaExceeded, + SystemQuotaExceeded, + SystemResources, + SystemOutdated, + Unexpected, +}; + +pub const QueueError = error{ + Overflow, + SubmissionQueueFull, +}; + +pub const CompletionError = error{ + CompletionQueueOvercommitted, + SubmissionQueueEntryInvalid, + SystemResources, + Unexpected, +}; + +pub const OperationError = ops.ErrorUnion; + +pub const ImmediateError = InitError || QueueError || CompletionError; + +pub const CompletionResult = struct { + num_completed: u16 = 0, + num_errors: u16 = 0, +}; + +/// Queue operations dynamically and complete them on demand +pub const Dynamic = struct { + io: IO, + + pub inline fn init(allocator: std.mem.Allocator, n: u16) InitError!@This() { + return .{ .io = try IO.init(allocator, n) }; + } + + pub inline fn deinit(self: *@This(), allocator: std.mem.Allocator) void { + self.io.deinit(allocator); + self.* = undefined; + } + + /// Queue operations for future completion + /// The call is atomic, if any of the operations fail to queue, then the given operations are reverted + pub inline fn queue(self: *@This(), operations: anytype) QueueError!void { + const ti = @typeInfo(@TypeOf(operations)); + if (comptime ti == .Struct and ti.Struct.is_tuple) { + return self.io.queue(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations }); + } else { + return self.io.queue(1, &struct { ops: @TypeOf(.{operations}) }{ .ops = .{operations} }); + } + } + + pub const CompletionMode = enum { + /// Call to `complete` will block until at least one operation completes + blocking, + /// Call to `complete` will only complete the currently ready operations if any + nonblocking, + }; + + /// Complete operations + /// Returns the number of completed operations, `0` if no operations were completed + pub inline fn complete(self: *@This(), mode: CompletionMode) CompletionError!CompletionResult { + return self.io.complete(mode); + } +}; + +/// Completes a list of operations immediately, blocks until complete +/// For error handling you must check the `out_error` field in the operation +pub inline fn batch(operations: anytype) ImmediateError!CompletionResult { + return IO.immediate(operations.len, &struct { ops: @TypeOf(operations) }{ .ops = operations }); +} + +/// Completes a list of operations immediately, blocks until complete +/// Returns `error.SomeOperationFailed` if any operation failed +pub inline fn multi(operations: anytype) (ImmediateError || error{SomeOperationFailed})!void { + const res = try batch(operations); + if (res.num_errors > 0) return error.SomeOperationFailed; +} + +/// Completes a single operation immediately, blocks until complete +pub inline fn single(operation: anytype) (ImmediateError || OperationError)!void { + var op: @TypeOf(operation) = operation; + var err: @TypeOf(op.out_error.?.*) = error.Success; + op.out_error = &err; + _ = try batch(.{op}); + if (err != error.Success) return err; +} + +const ops = @import("aio/ops.zig"); +pub const Id = ops.Id; +pub const Sync = ops.Sync; +pub const Read = ops.Read; +pub const Write = ops.Write; +pub const Accept = ops.Accept; +pub const Connect = ops.Connect; +pub const Recv = ops.Recv; +pub const Send = ops.Send; +pub const OpenAt = ops.OpenAt; +pub const Close = ops.Close; +pub const Timeout = ops.Timeout; +pub const TimeoutRemove = ops.TimeoutRemove; +pub const LinkTimeout = ops.LinkTimeout; +pub const Cancel = ops.Cancel; +pub const RenameAt = ops.RenameAt; +pub const UnlinkAt = ops.UnlinkAt; +pub const MkDirAt = ops.MkDirAt; +pub const SymlinkAt = ops.SymlinkAt; +pub const Socket = ops.Socket; +pub const CloseSocket = ops.CloseSocket; + +const IO = switch (@import("builtin").target.os.tag) { + .linux => @import("aio/linux.zig"), + else => @compileError("unsupported os"), +}; diff --git a/src/aio/linux.zig b/src/aio/linux.zig new file mode 100644 index 00000000..7e8bf4cf --- /dev/null +++ b/src/aio/linux.zig @@ -0,0 +1,573 @@ +const std = @import("std"); +const aio = @import("../aio.zig"); +const Operation = @import("ops.zig").Operation; +const ErrorUnion = @import("ops.zig").ErrorUnion; + +io: std.os.linux.IoUring, +ops: Pool(Operation, u16), + +pub fn init(allocator: std.mem.Allocator, n: u16) aio.InitError!@This() { + const n2 = try std.math.ceilPowerOfTwo(u16, n); + var io = try uring_init(n2); + errdefer io.deinit(); + const ops = try Pool(Operation, u16).init(allocator, n2); + errdefer ops.deinit(allocator); + return .{ .io = io, .ops = ops }; +} + +pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void { + self.io.deinit(); + self.ops.deinit(allocator); + self.* = undefined; +} + +inline fn queueOperation(self: *@This(), op: anytype) aio.QueueError!u16 { + const n = self.ops.next() orelse return error.Overflow; + try uring_queue(&self.io, op, n); + const tag = @tagName(comptime Operation.tagFromPayloadType(@TypeOf(op.*))); + return self.ops.add(@unionInit(Operation, tag, op.*)) catch unreachable; +} + +pub fn queue(self: *@This(), comptime len: u16, work: anytype) aio.QueueError!void { + if (comptime len == 1) { + _ = try self.queueOperation(&work.ops[0]); + } else { + var ids: std.BoundedArray(u16, len) = .{}; + errdefer for (ids.constSlice()) |id| self.ops.remove(id); + inline for (&work.ops) |*op| ids.append(try self.queueOperation(op)) catch unreachable; + } +} + +const NOP = std.math.maxInt(usize); + +pub fn complete(self: *@This(), mode: aio.Dynamic.CompletionMode) aio.CompletionError!aio.CompletionResult { + if ((!self.ops.empty() or self.io.sq_ready() > 0) and mode == .nonblocking) { + _ = self.io.nop(NOP) catch |err| return switch (err) { + error.SubmissionQueueFull => .{}, + }; + } + if (try uring_submit(&self.io) == 0) return .{}; + var result: aio.CompletionResult = .{}; + var cqes: [64]std.os.linux.io_uring_cqe = undefined; + const n = try uring_copy_cqes(&self.io, &cqes, 1); + for (cqes[0..n]) |*cqe| { + if (cqe.user_data == NOP) continue; + switch (self.ops.get(@intCast(cqe.user_data)).*) { + inline else => |op| uring_handle_completion(&op, cqe) catch { + result.num_errors += 1; + }, + } + } + result.num_completed = n; + return result; +} + +pub fn immediate(comptime len: u16, work: anytype) aio.ImmediateError!aio.CompletionResult { + var io = try uring_init(len); + defer io.deinit(); + inline for (&work.ops) |*op| try uring_queue(&io, op, @intFromPtr(op)); + var num = try uring_submit(&io); + const submitted = num; + var result: aio.CompletionResult = .{}; + var cqes: [len]std.os.linux.io_uring_cqe = undefined; + while (num > 0) { + const n = try uring_copy_cqes(&io, &cqes, num); + for (cqes[0..n]) |*cqe| { + const op = blk: { + inline for (&work.ops) |*op| if (@intFromPtr(op) == cqe.user_data) break :blk op; + unreachable; + }; + uring_handle_completion(op, cqe) catch { + result.num_errors += 1; + }; + } + num -= n; + } + result.num_completed = submitted - num; + return result; +} + +inline fn uring_init(n: u16) aio.InitError!std.os.linux.IoUring { + return std.os.linux.IoUring.init(n, 0) catch |err| switch (err) { + error.PermissionDenied, error.SystemResources, error.SystemOutdated => |e| e, + error.ProcessFdQuotaExceeded => error.ProcessQuotaExceeded, + error.SystemFdQuotaExceeded => error.SystemQuotaExceeded, + else => error.Unexpected, + }; +} + +fn convertOpenFlags(flags: std.fs.File.OpenFlags) std.posix.O { + var os_flags: std.posix.O = .{ + .ACCMODE = switch (flags.mode) { + .read_only => .RDONLY, + .write_only => .WRONLY, + .read_write => .RDWR, + }, + }; + if (@hasField(std.posix.O, "CLOEXEC")) os_flags.CLOEXEC = true; + if (@hasField(std.posix.O, "LARGEFILE")) os_flags.LARGEFILE = true; + if (@hasField(std.posix.O, "NOCTTY")) os_flags.NOCTTY = !flags.allow_ctty; + + // Use the O locking flags if the os supports them to acquire the lock + // atomically. + const has_flock_open_flags = @hasField(std.posix.O, "EXLOCK"); + if (has_flock_open_flags) { + // Note that the NONBLOCK flag is removed after the openat() call + // is successful. + switch (flags.lock) { + .none => {}, + .shared => { + os_flags.SHLOCK = true; + os_flags.NONBLOCK = flags.lock_nonblocking; + }, + .exclusive => { + os_flags.EXLOCK = true; + os_flags.NONBLOCK = flags.lock_nonblocking; + }, + } + } + return os_flags; +} + +inline fn uring_queue(io: *std.os.linux.IoUring, op: anytype, user_data: u64) aio.QueueError!void { + var sqe = switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) { + .fsync => try io.fsync(user_data, op.file.handle, 0), + .read => try io.read(user_data, op.file.handle, .{ .buffer = op.buffer }, op.offset), + .write => try io.write(user_data, op.file.handle, op.buffer, op.offset), + .accept => try io.accept(user_data, op.socket, @ptrCast(@alignCast(op.addr)), op.inout_addrlen, 0), + .connect => try io.connect(user_data, op.socket, @ptrCast(@alignCast(op.addr)), op.addrlen), + .recv => try io.recv(user_data, op.socket, .{ .buffer = op.buffer }, 0), + .send => try io.send(user_data, op.socket, op.buffer, 0), + .open_at => try io.openat(user_data, op.dir.handle, op.path, convertOpenFlags(op.flags)), + .close => try io.close(user_data, op.file.handle), + .timeout => try io.timeout(user_data, @ptrCast(&op.ts), 0, 0), + .timeout_remove => try io.timeout_remove(user_data, @intFromEnum(op.id), 0), + .link_timeout => try io.link_timeout(user_data, @ptrCast(&op.ts), 0), + .cancel => try io.cancel(user_data, @intFromEnum(op.id), 0), + .rename_at => try io.renameat(user_data, op.old_dir.handle, op.old_path, op.new_dir.handle, op.new_path, 0), + .unlink_at => try io.unlinkat(user_data, op.dir.handle, op.path, 0), + .mkdir_at => try io.mkdirat(user_data, op.dir.handle, op.path, op.mode), + .symlink_at => try io.symlinkat(user_data, op.target, op.dir.handle, op.link_path), + // .waitid => try io.waitid(user_data, .PID, op.child, &op._, 0, 0), + .socket => try io.socket(user_data, op.domain, op.flags, op.protocol, 0), + .close_socket => try io.close(user_data, op.socket), + }; + if (op.link_next) sqe.flags |= std.os.linux.IOSQE_IO_LINK; + if (op.out_id) |id| id.* = @enumFromInt(user_data); +} + +inline fn uring_submit(io: *std.os.linux.IoUring) aio.CompletionError!u16 { + while (true) { + const n = io.submit() catch |err| switch (err) { + error.FileDescriptorInvalid => unreachable, + error.FileDescriptorInBadState => unreachable, + error.BufferInvalid => unreachable, + error.OpcodeNotSupported => unreachable, + error.RingShuttingDown => unreachable, + error.SignalInterrupt => continue, + error.CompletionQueueOvercommitted, error.SubmissionQueueEntryInvalid, error.Unexpected, error.SystemResources => |e| return e, + }; + return @intCast(n); + } +} + +inline fn uring_copy_cqes(io: *std.os.linux.IoUring, cqes: []std.os.linux.io_uring_cqe, len: u16) aio.CompletionError!u16 { + while (true) { + const n = io.copy_cqes(cqes, len) catch |err| switch (err) { + error.FileDescriptorInvalid => unreachable, + error.FileDescriptorInBadState => unreachable, + error.BufferInvalid => unreachable, + error.OpcodeNotSupported => unreachable, + error.RingShuttingDown => unreachable, + error.SignalInterrupt => continue, + error.CompletionQueueOvercommitted, error.SubmissionQueueEntryInvalid, error.Unexpected, error.SystemResources => |e| return e, + }; + return @intCast(n); + } + unreachable; +} + +fn statusToTerm(status: u32) std.process.Child.Term { + return if (std.posix.W.IFEXITED(status)) + .{ .Exited = std.posix.W.EXITSTATUS(status) } + else if (std.posix.W.IFSIGNALED(status)) + .{ .Signal = std.posix.W.TERMSIG(status) } + else if (std.posix.W.IFSTOPPED(status)) + .{ .Stopped = std.posix.W.STOPSIG(status) } + else + .{ .Unknown = status }; +} + +inline fn uring_handle_completion(op: anytype, cqe: *std.os.linux.io_uring_cqe) !void { + switch (op.counter) { + .dec => |c| c.* -= 1, + .inc => |c| c.* += 1, + .nop => {}, + } + + const err = cqe.err(); + if (err != .SUCCESS) { + var skip_err = false; + if (op.out_error) |out_error| { + out_error.* = switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) { + .fsync => switch (err) { + .SUCCESS, .INTR, .INVAL, .FAULT, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + else => std.posix.unexpectedErrno(err), + }, + .read => switch (err) { + .SUCCESS, .INTR, .INVAL, .FAULT, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .BADF => error.NotOpenForReading, + .IO => error.InputOutput, + .PERM => error.AccessDenied, + .PIPE => error.BrokenPipe, + .ISDIR => error.IsDir, + .NOBUFS => error.SystemResources, + .NOMEM => error.SystemResources, + .NOTCONN => error.SocketNotConnected, + .CONNRESET => error.ConnectionResetByPeer, + .TIMEDOUT => error.ConnectionTimedOut, + else => std.posix.unexpectedErrno(err), + }, + .write => switch (err) { + .SUCCESS, .INTR, .INVAL, .FAULT, .AGAIN, .DESTADDRREQ => unreachable, + .CANCELED => error.OperationCanceled, + .DQUOT => error.DiskQuota, + .FBIG => error.FileTooBig, + .BADF => error.NotOpenForWriting, + .IO => error.InputOutput, + .NOSPC => error.NoSpaceLeft, + .PERM => error.AccessDenied, + .PIPE => error.BrokenPipe, + .NOBUFS => error.SystemResources, + .NOMEM => error.SystemResources, + .CONNRESET => error.ConnectionResetByPeer, + else => std.posix.unexpectedErrno(err), + }, + .accept => switch (err) { + .SUCCESS, .INTR, .FAULT, .AGAIN, .DESTADDRREQ => unreachable, + .CANCELED => error.OperationCanceled, + .BADF => unreachable, // always a race condition + .CONNABORTED => error.ConnectionAborted, + .INVAL => error.SocketNotListening, + .NOTSOCK => unreachable, + .MFILE => error.ProcessFdQuotaExceeded, + .NFILE => error.SystemFdQuotaExceeded, + .NOBUFS => error.SystemResources, + .NOMEM => error.SystemResources, + .OPNOTSUPP => unreachable, + .PROTO => error.ProtocolFailure, + .PERM => error.BlockedByFirewall, + else => std.posix.unexpectedErrno(err), + }, + .connect => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN, .DESTADDRREQ, .INPROGRESS => unreachable, + .CANCELED => error.OperationCanceled, + .ACCES => error.PermissionDenied, + .PERM => error.PermissionDenied, + .ADDRINUSE => error.AddressInUse, + .ADDRNOTAVAIL => error.AddressNotAvailable, + .AFNOSUPPORT => error.AddressFamilyNotSupported, + .ALREADY => error.ConnectionPending, + .BADF => unreachable, // sockfd is not a valid open file descriptor. + .CONNREFUSED => error.ConnectionRefused, + .CONNRESET => error.ConnectionResetByPeer, + .FAULT => unreachable, // The socket structure address is outside the user's address space. + .ISCONN => unreachable, // The socket is already connected. + .HOSTUNREACH => error.NetworkUnreachable, + .NETUNREACH => error.NetworkUnreachable, + .NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket. + .PROTOTYPE => unreachable, // The socket type does not support the requested communications protocol. + .TIMEDOUT => error.ConnectionTimedOut, + .NOENT => error.FileNotFound, // Returned when socket is AF.UNIX and the given path does not exist. + .CONNABORTED => unreachable, // Tried to reuse socket that previously received error.ConnectionRefused. + else => std.posix.unexpectedErrno(err), + }, + .recv => switch (err) { + .SUCCESS, .INTR, .INVAL, .FAULT, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .BADF => unreachable, // always a race condition + .NOTCONN => error.SocketNotConnected, + .NOTSOCK => unreachable, + .NOMEM => error.SystemResources, + .CONNREFUSED => error.ConnectionRefused, + .CONNRESET => error.ConnectionResetByPeer, + .TIMEDOUT => error.ConnectionTimedOut, + else => std.posix.unexpectedErrno(err), + }, + .send => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .ACCES => error.AccessDenied, + .ALREADY => error.FastOpenAlreadyInProgress, + .BADF => unreachable, // always a race condition + .CONNRESET => error.ConnectionResetByPeer, + .DESTADDRREQ => unreachable, // The socket is not connection-mode, and no peer address is set. + .FAULT => unreachable, // An invalid user space address was specified for an argument. + .ISCONN => unreachable, // connection-mode socket was connected already but a recipient was specified + .MSGSIZE => error.MessageTooBig, + .NOBUFS => error.SystemResources, + .NOMEM => error.SystemResources, + .NOTSOCK => unreachable, // The file descriptor sockfd does not refer to a socket. + .OPNOTSUPP => unreachable, // Some bit in the flags argument is inappropriate for the socket type. + .PIPE => error.BrokenPipe, + .HOSTUNREACH => error.NetworkUnreachable, + .NETUNREACH => error.NetworkUnreachable, + .NETDOWN => error.NetworkSubsystemFailed, + else => std.posix.unexpectedErrno(err), + }, + .open_at => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .ACCES => error.AccessDenied, + .FBIG => error.FileTooBig, + .OVERFLOW => error.FileTooBig, + .ISDIR => error.IsDir, + .LOOP => error.SymLinkLoop, + .MFILE => error.ProcessFdQuotaExceeded, + .NAMETOOLONG => error.NameTooLong, + .NFILE => error.SystemFdQuotaExceeded, + .NODEV => error.NoDevice, + .NOENT => error.FileNotFound, + .NOMEM => error.SystemResources, + .NOSPC => error.NoSpaceLeft, + .NOTDIR => error.NotDir, + .PERM => error.AccessDenied, + .EXIST => error.PathAlreadyExists, + .BUSY => error.DeviceBusy, + .ILSEQ => error.InvalidUtf8, + else => std.posix.unexpectedErrno(err), + }, + .close, .close_socket => unreachable, + .timeout => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .TIME => blk: { + skip_err = true; + break :blk error.Success; + }, + .CANCELED => error.OperationCanceled, + else => unreachable, + }, + .timeout_remove => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .BUSY => error.InProgress, + .NOENT => error.NotFound, + else => unreachable, + }, + .link_timeout => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .TIME => blk: { + skip_err = true; + break :blk error.Success; + }, + .CANCELED => error.OperationCanceled, + .ALREADY => error.InProgress, + else => unreachable, + }, + .cancel => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .ALREADY => error.InProgress, + .NOENT => error.NotFound, + else => unreachable, + }, + .rename_at => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .ACCES => error.AccessDenied, + .PERM => error.AccessDenied, + .BUSY => error.FileBusy, + .DQUOT => error.DiskQuota, + .FAULT => unreachable, + .ISDIR => error.IsDir, + .LOOP => error.SymLinkLoop, + .MLINK => error.LinkQuotaExceeded, + .NAMETOOLONG => error.NameTooLong, + .NOENT => error.FileNotFound, + .NOTDIR => error.NotDir, + .NOMEM => error.SystemResources, + .NOSPC => error.NoSpaceLeft, + .EXIST => error.PathAlreadyExists, + .NOTEMPTY => error.PathAlreadyExists, + .ROFS => error.ReadOnlyFileSystem, + .XDEV => error.RenameAcrossMountPoints, + else => std.posix.unexpectedErrno(err), + }, + .unlink_at => switch (err) { + .SUCCESS, .INTR, .AGAIN => unreachable, + .CANCELED => error.OperationCanceled, + .ACCES => error.AccessDenied, + .PERM => error.AccessDenied, + .BUSY => error.FileBusy, + .FAULT => unreachable, + .IO => error.FileSystem, + .ISDIR => error.IsDir, + .LOOP => error.SymLinkLoop, + .NAMETOOLONG => error.NameTooLong, + .NOENT => error.FileNotFound, + .NOTDIR => error.NotDir, + .NOMEM => error.SystemResources, + .ROFS => error.ReadOnlyFileSystem, + .EXIST => error.DirNotEmpty, + .NOTEMPTY => error.DirNotEmpty, + .INVAL => unreachable, // invalid flags, or pathname has . as last component + .BADF => unreachable, // always a race condition + else => std.posix.unexpectedErrno(err), + }, + .mkdir_at => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN => unreachable, + .ACCES => error.AccessDenied, + .BADF => unreachable, + .PERM => error.AccessDenied, + .DQUOT => error.DiskQuota, + .EXIST => error.PathAlreadyExists, + .FAULT => unreachable, + .LOOP => error.SymLinkLoop, + .MLINK => error.LinkQuotaExceeded, + .NAMETOOLONG => error.NameTooLong, + .NOENT => error.FileNotFound, + .NOMEM => error.SystemResources, + .NOSPC => error.NoSpaceLeft, + .NOTDIR => error.NotDir, + .ROFS => error.ReadOnlyFileSystem, + // dragonfly: when dir_fd is unlinked from filesystem + .NOTCONN => error.FileNotFound, + else => std.posix.unexpectedErrno(err), + }, + .symlink_at => switch (err) { + .SUCCESS, .INTR, .INVAL, .AGAIN, .FAULT => unreachable, + .ACCES => error.AccessDenied, + .PERM => error.AccessDenied, + .DQUOT => error.DiskQuota, + .EXIST => error.PathAlreadyExists, + .IO => error.FileSystem, + .LOOP => error.SymLinkLoop, + .NAMETOOLONG => error.NameTooLong, + .NOENT => error.FileNotFound, + .NOTDIR => error.NotDir, + .NOMEM => error.SystemResources, + .NOSPC => error.NoSpaceLeft, + .ROFS => error.ReadOnlyFileSystem, + else => std.posix.unexpectedErrno(err), + }, + // .waitid => unreachable, + .socket => switch (err) { + .SUCCESS, .INTR, .AGAIN, .FAULT => unreachable, + .ACCES => error.PermissionDenied, + .AFNOSUPPORT => error.AddressFamilyNotSupported, + .INVAL => error.ProtocolFamilyNotAvailable, + .MFILE => error.ProcessFdQuotaExceeded, + .NFILE => error.SystemFdQuotaExceeded, + .NOBUFS => error.SystemResources, + .NOMEM => error.SystemResources, + .PROTONOSUPPORT => error.ProtocolNotSupported, + .PROTOTYPE => error.SocketTypeNotSupported, + else => std.posix.unexpectedErrno(err), + }, + }; + } + if (!skip_err) return error.OperationFailed; + } + + if (op.out_error) |out_error| out_error.* = error.Success; + + switch (comptime Operation.tagFromPayloadType(@TypeOf(op.*))) { + .fsync => {}, + .read => op.out_read.* = @intCast(cqe.res), + .write => if (op.out_written) |w| { + w.* = @intCast(cqe.res); + }, + .accept => op.out_socket.* = cqe.res, + .connect => {}, + .recv => op.out_read.* = @intCast(cqe.res), + .send => if (op.out_written) |w| { + w.* = @intCast(cqe.res); + }, + .open_at => op.out_file.handle = cqe.res, + .close, .close_socket => {}, + .timeout, .timeout_remove, .link_timeout => {}, + .cancel => {}, + .rename_at, .unlink_at, .mkdir_at, .symlink_at => {}, + // .waitid => op.out_term.* = statusToTerm(@intCast(op._.fields.common.second.sigchld.status)), + .socket => op.out_socket.* = cqe.res, + } +} + +pub fn Pool(T: type, SZ: type) type { + return struct { + pub const Node = union(enum) { free: ?SZ, used: T }; + nodes: []Node, + free: ?SZ = null, + num_free: SZ = 0, + num_used: SZ = 0, + + pub const Error = error{ + OutOfMemory, + }; + + pub fn init(allocator: std.mem.Allocator, n: SZ) Error!@This() { + return .{ .nodes = try allocator.alloc(Node, n) }; + } + + pub fn deinit(self: *@This(), allocator: std.mem.Allocator) void { + allocator.free(self.nodes); + self.* = undefined; + } + + pub fn empty(self: *@This()) bool { + return self.num_used == self.num_free; + } + + pub fn next(self: *@This()) ?SZ { + if (self.free) |fslot| return fslot; + if (self.num_used >= self.nodes.len) return null; + return self.num_used; + } + + pub fn add(self: *@This(), item: T) Error!SZ { + if (self.free) |fslot| { + self.free = self.nodes[fslot].free; + self.nodes[fslot] = .{ .used = item }; + self.num_free -= 1; + return fslot; + } + if (self.num_used >= self.nodes.len) return error.OutOfMemory; + self.nodes[self.num_used] = .{ .used = item }; + defer self.num_used += 1; + return self.num_used; + } + + pub fn remove(self: *@This(), slot: SZ) void { + if (self.free) |fslot| { + self.nodes[slot] = .{ .free = fslot }; + } else { + self.nodes[slot] = .{ .free = null }; + } + self.free = slot; + self.num_free += 1; + } + + pub fn get(self: *@This(), slot: SZ) *T { + return &self.nodes[slot].used; + } + + pub const Iterator = struct { + items: []Node, + index: SZ = 0, + + pub fn next(self: *@This()) *T { + while (self.index < self.items.len) { + defer self.index += 1; + if (self.items[self.index] == .used) { + return &self.items[self.index].used; + } + } + return null; + } + }; + + pub fn iterator(self: *@This()) Iterator { + return .{ .items = self.nodes[0..self.num_used] }; + } + }; +} diff --git a/src/aio/ops.zig b/src/aio/ops.zig new file mode 100644 index 00000000..ec82a591 --- /dev/null +++ b/src/aio/ops.zig @@ -0,0 +1,236 @@ +const std = @import("std"); +const builtin = @import("builtin"); + +// Virtual linked actions are possible with `nop` under io_uring :thinking: + +pub const Id = enum(usize) { _ }; + +/// std.fs.File.sync +pub const Fsync = Define(struct { + file: std.fs.File, +}, std.fs.File.SyncError); + +/// std.fs.File.read +pub const Read = Define(struct { + file: std.fs.File, + buffer: []u8, + offset: u64 = 0, + out_read: *usize, +}, std.fs.File.ReadError); + +/// std.fs.File.write +pub const Write = Define(struct { + file: std.fs.File, + buffer: []const u8, + offset: u64 = 0, + out_written: ?*usize = null, +}, std.fs.File.WriteError); + +// For whatever reason the posix.sockaddr crashes the compiler, so use this +const sockaddr = anyopaque; + +/// std.posix.accept +pub const Accept = Define(struct { + socket: std.posix.socket_t, + addr: ?*sockaddr = null, + inout_addrlen: ?*std.posix.socklen_t = null, + out_socket: *std.posix.socket_t, +}, std.posix.AcceptError); + +/// std.posix.connect +pub const Connect = Define(struct { + socket: std.posix.socket_t, + addr: *const sockaddr, + addrlen: std.posix.socklen_t, +}, std.posix.ConnectError); + +/// std.posix.recv +pub const Recv = Define(struct { + socket: std.posix.socket_t, + buffer: []u8, + out_read: *usize, +}, std.posix.RecvFromError); + +/// std.posix.send +pub const Send = Define(struct { + socket: std.posix.socket_t, + buffer: []const u8, + out_written: ?*usize = null, +}, std.posix.SendError); + +// TODO: recvmsg, sendmsg + +/// std.fs.Dir.openFile +pub const OpenAt = Define(struct { + dir: std.fs.Dir, + path: [*:0]const u8, + flags: std.fs.File.OpenFlags, + out_file: *std.fs.File, +}, std.fs.File.OpenError); + +/// std.fs.File.close +pub const Close = Define(struct { + file: std.fs.File, +}, error{}); + +/// std.time.Timer.start +pub const Timeout = Define(struct { + ts: struct { sec: i64 = 0, nsec: i64 = 0 }, +}, error{}); + +/// std.time.Timer.cancel (if it existed) +/// XXX: Overlap with `Cancel`, is this even needed? (io_uring) +pub const TimeoutRemove = Define(struct { + id: Id, +}, error{ InProgress, NotFound }); + +/// Timeout linked to a operation +/// This must be linked last and the operation before must have set `link_next` to `true` +/// If the operation finishes before the timeout the timeout will be canceled +pub const LinkTimeout = Define(struct { + ts: struct { sec: i64 = 0, nsec: i64 = 0 }, + out_expired: ?*bool = null, +}, error{InProgress}); + +/// Cancel a operation +pub const Cancel = Define(struct { + id: Id, +}, error{ InProgress, NotFound }); + +/// std.fs.rename +pub const RenameAt = Define(struct { + old_dir: std.fs.Dir, + old_path: [*:0]const u8, + new_dir: std.fs.Dir, + new_path: [*:0]const u8, +}, std.fs.Dir.RenameError); + +/// std.fs.Dir.deleteFile +pub const UnlinkAt = Define(struct { + dir: std.fs.Dir, + path: [*:0]const u8, +}, std.posix.UnlinkatError); + +/// std.fs.Dir.makeDir +pub const MkDirAt = Define(struct { + dir: std.fs.Dir, + path: [*:0]const u8, + mode: u32 = std.fs.Dir.default_mode, +}, std.fs.Dir.MakeError); + +/// std.fs.Dir.symlink +pub const SymlinkAt = Define(struct { + dir: std.fs.Dir, + target: [*:0]const u8, + link_path: [*:0]const u8, +}, std.posix.SymLinkError); + +// TODO: linkat + +/// std.process.Child.wait +/// TODO: Crashes compiler, doesn't like the std.process fields wut? +pub const WaitId = Define(struct { + child: std.process.Child.Id, + out_term: *std.process.Child.Term, + _: switch (builtin.target.os.tag) { + .linux => std.os.linux.siginfo_t, + else => @compileError("unsupported os"), + }, +}, error{}); + +/// std.posix.socket +pub const Socket = Define(struct { + /// std.posix.AF + domain: u32, + /// std.posix.SOCK + flags: u32, + /// std.posix.IPPROTO + protocol: u32, + out_socket: *std.posix.socket_t, +}, std.posix.SocketError); + +/// std.posix.close +pub const CloseSocket = Define(struct { + socket: std.posix.socket_t, +}, error{}); + +pub const Operation = union(enum) { + fsync: Fsync, + read: Read, + write: Write, + accept: Accept, + connect: Connect, + recv: Recv, + send: Send, + open_at: OpenAt, + close: Close, + timeout: Timeout, + timeout_remove: TimeoutRemove, + link_timeout: LinkTimeout, + cancel: Cancel, + rename_at: RenameAt, + unlink_at: UnlinkAt, + mkdir_at: MkDirAt, + symlink_at: SymlinkAt, + // waitid: WaitId, + socket: Socket, + close_socket: CloseSocket, + + pub fn tagFromPayloadType(Op: type) std.meta.Tag(Operation) { + inline for (std.meta.fields(Operation)) |field| { + if (Op == field.type) { + @setEvalBranchQuota(1_000_0); + return std.meta.stringToEnum(std.meta.Tag(Operation), field.name) orelse unreachable; + } + } + unreachable; + } +}; + +const SharedError = error{ + Success, + OperationCanceled, +}; + +pub const ErrorUnion = SharedError || + std.fs.File.SyncError || + std.fs.File.ReadError || + std.fs.File.WriteError || + std.posix.AcceptError || + std.posix.ConnectError || + std.posix.RecvFromError || + std.posix.SendError || + std.fs.File.OpenError || + error{ InProgress, NotFound } || + std.fs.Dir.RenameError || + std.posix.UnlinkatError || + std.fs.Dir.MakeError || + std.posix.SymLinkError || + std.process.Child.SpawnError || + std.posix.SocketError; + +fn Define(T: type, E: type) type { + // Counter that either increases or decreases a value in a given address + // Reserved when using the coroutines API + const Counter = union(enum) { + inc: *u16, + dec: *u16, + nop: void, + }; + + const Super = struct { + out_id: ?*Id = null, + out_error: ?*(E || SharedError) = null, + counter: Counter = .nop, + link_next: bool = false, + }; + + return @Type(.{ + .Struct = .{ + .layout = .auto, + .fields = std.meta.fields(T) ++ std.meta.fields(Super), + .decls = &.{}, + .is_tuple = false, + }, + }); +} diff --git a/src/coro.zig b/src/coro.zig new file mode 100644 index 00000000..be32b3f5 --- /dev/null +++ b/src/coro.zig @@ -0,0 +1,311 @@ +//! Coroutines API +//! This combines the basic aio API with coroutines +//! Coroutines will yield when IO is being performed and get waken up when the IO is complete +//! This allows you to write asynchronous IO tasks with ease + +const std = @import("std"); +const aio = @import("aio"); +const Fiber = @import("coro/zefi.zig"); + +const root = @import("root"); +pub const options: Options = if (@hasDecl(root, "aio_coro_options")) root.aio_coro_options else .{}; + +pub const Options = struct { + /// Enable coroutine debug logs and tracing + debug: bool = false, + /// Default io queue entries + io_queue_entries: u16 = 4096, + /// Default stack size for coroutines + stack_size: usize = 1.049e+6, // 1 MiB + /// Default handler for errors for !void coroutines + error_handler: fn (err: anyerror) void = defaultErrorHandler, +}; + +fn defaultErrorHandler(err: anyerror) void { + std.debug.print("error: {s}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } +} + +fn debug(comptime fmt: []const u8, args: anytype) void { + if (comptime !options.debug) return; + const log = std.log.scoped(.coro); + log.debug(fmt, args); +} + +pub const io = struct { + /// Completes a list of operations immediately, blocks the coroutine until complete + /// The IO operations can be cancelled by calling `wakeup` + /// For error handling you must check the `out_error` field in the operation + /// Returns the number of errors occured, 0 if there were no errors + pub inline fn batch(operations: anytype) aio.QueueError!u16 { + if (Fiber.current()) |fiber| { + var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*); + + const State = struct { old_err: ?*anyerror, old_id: ?*aio.Id, id: aio.Id, err: anyerror }; + var state: [operations.len]State = undefined; + var work = struct { ops: @TypeOf(operations) }{ .ops = operations }; + inline for (&work.ops, &state) |*op, *s| { + op.counter = .{ .dec = &task.io_counter }; + s.old_id = op.out_id; + op.out_id = &s.id; + s.old_err = op.out_error; + op.out_error = @ptrCast(&s.err); + } + + try task.io.queue(work.ops); + task.io_counter = operations.len; + task.status = .doing_io; + debug("yielding for io: {}", .{task}); + Fiber.yield(); + + if (task.io_counter > 0) { + // wakeup() was called, try cancel the io + inline for (&state) |*s| try task.io.queue(aio.Cancel{ .id = s.id }); + task.status = .cancelling_io; + debug("yielding for io cancellation: {}", .{task}); + Fiber.yield(); + } + + var num_errors: u16 = 0; + inline for (&state) |*s| { + num_errors += @intFromBool(s.err != error.Success); + if (s.old_err) |p| p.* = s.err; + if (s.old_id) |p| p.* = s.id; + } + return num_errors; + } else { + unreachable; // this io function is only meant to be used in coroutines! + } + } + + /// Completes a list of operations immediately, blocks until complete + /// The IO operations can be cancelled by calling `wakeup` + /// Returns `error.SomeOperationFailed` if any operation failed + pub inline fn multi(operations: anytype) (aio.QueueError || error{SomeOperationFailed})!void { + if (try batch(operations) > 0) return error.SomeOperationFailed; + } + + /// Completes a single operation immediately, blocks the coroutine until complete + /// The IO operation can be cancelled by calling `wakeup` + pub fn single(operation: anytype) (aio.QueueError || aio.OperationError)!void { + if (Fiber.current()) |fiber| { + var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*); + + var op: @TypeOf(operation) = operation; + var err: @TypeOf(op.out_error.?.*) = error.Success; + var id: aio.Id = undefined; + op.counter = .{ .dec = &task.io_counter }; + op.out_id = &id; + op.out_error = &err; + try task.io.queue(op); + task.io_counter = 1; + task.status = .doing_io; + debug("yielding for io: {}", .{task}); + Fiber.yield(); + + if (task.io_counter > 0) { + // wakeup() was called, try cancel the io + try task.io.queue(aio.Cancel{ .id = id }); + task.status = .cancelling_io; + debug("yielding for io cancellation: {}", .{task}); + Fiber.yield(); + } + + if (err != error.Success) return err; + } else { + unreachable; // this io function is only meant to be used in coroutines! + } + } +}; + +/// Yields current task, can only be called from inside a task +pub inline fn yield() void { + if (Fiber.current()) |fiber| { + var task: *Scheduler.TaskState = @ptrFromInt(fiber.getUserDataPtr().*); + if (task.status == .dead) unreachable; // race condition + if (task.status == .yield) unreachable; // race condition + task.status = .yield; + debug("yielding: {}", .{task}); + Fiber.yield(); + } else { + unreachable; // yield is only meant to be used in coroutines! + } +} + +/// Wakeups a task by either cancelling the io its doing or switching back to it from yielded state +pub inline fn wakeup(task: Scheduler.Task) void { + const node: *Scheduler.TaskNode = @ptrCast(task); + if (node.data.status == .dead) unreachable; // race condition + if (node.data.status == .running) return; // already awake + if (node.data.status == .cancelling_io) return; // can't wake up when cancelling + debug("waking up from yield: {}", .{node.data}); + node.data.status = .running; + node.data.fiber.switchTo(); +} + +/// Runtime for asynchronous IO tasks +pub const Scheduler = struct { + allocator: std.mem.Allocator, + io: aio.Dynamic, + tasks: std.DoublyLinkedList(TaskState) = .{}, + num_dead: usize = 0, + + const TaskState = struct { + fiber: *Fiber, + status: enum { + running, + doing_io, + cancelling_io, + yield, + dead, + } = .running, + stack: ?Fiber.Stack = null, + io: *aio.Dynamic, + io_counter: u16 = 0, + + pub fn format(self: @This(), comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { + if (self.status == .doing_io) { + try writer.print("{x}: {s}, {} ops left", .{ @intFromPtr(self.fiber), @tagName(self.status), self.io_counter }); + } else { + try writer.print("{x}: {s}", .{ @intFromPtr(self.fiber), @tagName(self.status) }); + } + } + + fn deinit(self: *@This(), allocator: std.mem.Allocator) void { + if (Fiber.current()) |_| unreachable; // do not call deinit from a task + if (self.stack) |stack| allocator.free(stack); + self.* = undefined; + } + }; + + const TaskNode = std.DoublyLinkedList(TaskState).Node; + pub const Task = *align(@alignOf(TaskNode)) anyopaque; + + pub const InitOptions = struct { + /// This is a hint, the implementation makes the final call + io_queue_entries: u16 = options.io_queue_entries, + }; + + pub fn init(allocator: std.mem.Allocator, opts: InitOptions) !@This() { + if (Fiber.current()) |_| unreachable; // do not call init from a task + return .{ + .allocator = allocator, + .io = try aio.Dynamic.init(allocator, opts.io_queue_entries), + }; + } + + pub fn reapAll(self: *@This()) void { + if (Fiber.current()) |_| unreachable; // do not call reapAll from a task + while (self.tasks.pop()) |node| { + debug("reaping: {}", .{node.data}); + node.data.deinit(self.allocator); + self.allocator.destroy(node); + } + } + + pub fn reap(self: *@This(), task: Task) void { + if (Fiber.current()) |_| unreachable; // do not call reap from a task + const node: *TaskNode = @ptrCast(task); + debug("reaping: {}", .{node.data}); + self.tasks.remove(node); + node.data.deinit(self.allocator); + self.allocator.destroy(node); + } + + pub fn deinit(self: *@This()) void { + if (Fiber.current()) |_| unreachable; // do not call deinit from a task + self.reapAll(); + self.io.deinit(self.allocator); + self.* = undefined; + } + + fn entrypoint(self: *@This(), comptime func: anytype, args: anytype) void { + if (@typeInfo(@typeInfo(@TypeOf(func)).Fn.return_type.?) == .ErrorUnion) { + @call(.auto, func, args) catch |err| options.error_handler(err); + } else { + @call(.auto, func, args); + } + var task: *Scheduler.TaskState = @ptrFromInt(Fiber.current().?.getUserDataPtr().*); + task.status = .dead; + self.num_dead += 1; + debug("finished: {}", .{task}); + } + + pub const SpawnError = error{OutOfMemory} || Fiber.Error; + + pub const SpawnOptions = struct { + stack: union(enum) { + unmanaged: Fiber.Stack, + managed: usize, + } = .{ .managed = options.stack_size }, + }; + + /// Spawns a new task, the task may do local IO operations which will not block the whole process using the `io` namespace functions + pub fn spawn(self: *@This(), comptime func: anytype, args: anytype, opts: SpawnOptions) SpawnError!Task { + if (Fiber.current()) |_| unreachable; // do not call spawn from a task + const stack = switch (opts.stack) { + .unmanaged => |buf| buf, + .managed => |sz| try self.allocator.alignedAlloc(u8, Fiber.stack_alignment, sz), + }; + errdefer if (opts.stack == .managed) self.allocator.free(stack); + var fiber = try Fiber.init(stack, 0, entrypoint, .{ self, func, args }); + const node = try self.allocator.create(TaskNode); + errdefer self.allocator.destroy(node); + node.* = .{ .data = .{ + .fiber = fiber, + .stack = if (opts.stack == .managed) stack else null, + .io = &self.io, + } }; + fiber.getUserDataPtr().* = @intFromPtr(&node.data); + self.tasks.append(node); + errdefer self.tasks.remove(node); + debug("spawned: {}", .{node.data}); + fiber.switchTo(); + return node; + } + + /// Processes pending IO and reaps dead tasks + pub fn tick(self: *@This(), mode: aio.Dynamic.CompletionMode) !void { + if (Fiber.current()) |_| unreachable; // do not call tick from a task + const res = try self.io.complete(mode); + if (res.num_completed > 0) { + var maybe_node = self.tasks.first; + while (maybe_node) |node| { + const next = node.next; + switch (node.data.status) { + .running, .dead, .yield => {}, + .doing_io, .cancelling_io => if (node.data.io_counter == 0) { + debug("waking up from io: {}", .{node.data}); + node.data.status = .running; + node.data.fiber.switchTo(); + }, + } + maybe_node = next; + } + } + while (self.num_dead > 0) { + var maybe_node = self.tasks.first; + while (maybe_node) |node| { + const next = node.next; + switch (node.data.status) { + .running, .doing_io, .cancelling_io, .yield => {}, + .dead => { + debug("reaping: {}", .{node.data}); + node.data.deinit(self.allocator); + self.tasks.remove(node); + self.allocator.destroy(node); + self.num_dead -= 1; + }, + } + maybe_node = next; + } + } + } + + /// Run until all tasks are dead + pub fn run(self: *@This()) !void { + while (self.tasks.len > 0) try self.tick(.blocking); + } +}; diff --git a/src/coro/zefi.zig b/src/coro/zefi.zig new file mode 100644 index 00000000..3c02f4af --- /dev/null +++ b/src/coro/zefi.zig @@ -0,0 +1,311 @@ +// MIT License +// +// Copyright (c) 2023 kprotty +// Modified by Cloudef (2024) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +const std = @import("std"); +const builtin = @import("builtin"); +const assert = std.debug.assert; +const Fiber = @This(); + +pub const stack_alignment = 16; +pub const Stack = []align(stack_alignment) u8; + +pub const Error = error{ + /// The stack space provided to the fiber is not large enough to contain required metadata. + StackTooSmall, + /// The stack space provided to the fiber is too big to be tracked by getStack(). + StackTooLarge, +}; + +/// Intrusively allocates a Fiber object (and auxiliary data) inside (specifically, and the end of) the given stack memory. +/// Then, execution of the fiber is setup to invoke the given function with the args on the next call to `switchTo`.'' +pub fn init(stack: Stack, user_data: usize, comptime func: anytype, args: anytype) Error!*Fiber { + const Args = @TypeOf(args); + const state = try State.init(stack, user_data, @sizeOf(Args), struct { + fn entry() callconv(.C) noreturn { + const state = tls_state orelse unreachable; + + // Call the functions with the args. + const args_ptr: *align(1) Args = @ptrFromInt(@intFromPtr(state) - @sizeOf(Args)); + @call(.auto, func, args_ptr.*); + + // Mark the fiber as completed and do one last + zefi_stack_swap(&state.stack_context, &state.caller_context); + unreachable; + } + }.entry); + + const magic_number: usize = 0x5E574D6D; + const magic_number_ptr: *usize = @ptrCast(stack.ptr); + magic_number_ptr.* = magic_number; + + const args_ptr: *align(1) Args = @ptrFromInt(@intFromPtr(state) - @sizeOf(Args)); + args_ptr.* = args; + + return @ptrCast(state); +} + +threadlocal var tls_state: ?*State = null; + +/// Get the currently running fiber of the caller, if any. +pub inline fn current() ?*Fiber { + return @ptrCast(tls_state); +} + +/// Given a fiber, return the stack memory used to initialize it. +/// Calling getStack() on a fiber which has completed is unspecified behavior. +pub fn getStack(fiber: *Fiber) Stack { + const state: *State = @ptrCast(@alignCast(fiber)); + const state_offset: usize = @intCast(state.offset); + const stack_end = @intFromPtr(state) + state_offset; + const stack_base = stack_end - (state.offset >> @bitSizeOf(u8)); + const base: [*]align(stack_alignment) u8 = @ptrFromInt(stack_base); + return base[0..(stack_end - stack_base)]; +} + +/// Given a fiber, return the user_data used to initialize it. +/// A pointer to the user_data is returned to give the caller the ability to modify it on the Fiber. +/// Calling getUserDataPtr() on a fiber which has completed is unspecified behavior. +pub fn getUserDataPtr(fiber: *Fiber) *usize { + const state: *State = @ptrCast(@alignCast(fiber)); + return &state.user_data; +} + +/// Switches the current thread's execution state from the caller's to the fiber's. +/// The fiber will return back to this caller either through yield or completing its init function. +/// The fiber must either be newly initialized or previously yielded. +/// +/// Switching to a fiber that is currently executing is undefined behavior. +/// Switching to a fiber that has completed is illegal behavior. +pub fn switchTo(fiber: *Fiber) void { + const state: *State = @ptrCast(@alignCast(fiber)); + + // Temporarily set the current fiber to the one passed in for the duration of the stack swap. + const old_state = tls_state; + assert(old_state != state); + tls_state = state; + defer tls_state = old_state; + + zefi_stack_swap(&state.caller_context, &state.stack_context); +} + +/// Switches the current thread's execution back to the most recent switchTo() called on the currently running fiber. +/// Calling yield from outside a fiber context (`current() == null`) is illegal behavior. +/// Once execution is yielded back, switchTo() on the (now previous) current fiber can be called again +/// to continue the fiber from this yield point. +pub fn yield() void { + const state = tls_state orelse unreachable; + zefi_stack_swap(&state.stack_context, &state.caller_context); +} + +const State = extern struct { + caller_context: *anyopaque, + stack_context: *anyopaque, + user_data: usize, + offset: usize, + + fn init(stack: Stack, user_data: usize, args_size: usize, entry_point: *const fn () callconv(.C) noreturn) Error!*State { + const stack_base = @intFromPtr(stack.ptr); + const stack_end = @intFromPtr(stack.ptr + stack.len); + if (stack.len > (std.math.maxInt(usize) >> @bitSizeOf(u8))) return error.StackTooLarge; + + // Push the State onto the state. + var stack_ptr = std.mem.alignBackward(usize, stack_end - @sizeOf(State), stack_alignment); + if (stack_ptr < stack_base) return error.StackTooSmall; + + const state: *State = @ptrFromInt(stack_ptr); + const end_offset = stack_end - stack_ptr; + + // Push enough bytes for the args onto the stack. + stack_ptr = std.mem.alignBackward(usize, stack_ptr - args_size, stack_alignment); + if (stack_ptr < stack_base) return error.StackTooSmall; + + // Reserve data for the StackContext. + stack_ptr = std.mem.alignBackward(usize, stack_ptr - @sizeOf(usize) * StackContext.word_count, stack_alignment); + assert(std.mem.isAligned(stack_ptr, stack_alignment)); + if (stack_ptr < stack_base) return error.StackTooSmall; + + // Write the entry point into the StackContext. + var entry: [*]@TypeOf(entry_point) = @ptrFromInt(stack_ptr); + entry[StackContext.entry_offset] = entry_point; + + state.* = .{ + .caller_context = undefined, + .stack_context = @ptrFromInt(stack_ptr), + .user_data = user_data, + .offset = (stack.len << @bitSizeOf(u8)) | end_offset, + }; + + return state; + } +}; + +extern fn zefi_stack_swap( + noalias current_context_ptr: **anyopaque, + noalias new_context_ptr: **anyopaque, +) void; + +const StackContext = switch (builtin.cpu.arch) { + .x86_64 => switch (builtin.os.tag) { + .windows => Intel_Microsoft, + else => Intel_SysV, + }, + .aarch64 => Arm_64, + else => @compileError("platform not currently supported"), +}; + +const Intel_Microsoft = struct { + pub const word_count = 31; + + pub const entry_offset = word_count - 1; + + comptime { + asm ( + \\.global zefi_stack_swap + \\zefi_stack_swap: + \\ pushq %gs:0x10 + \\ pushq %gs:0x08 + \\ + \\ pushq %rbx + \\ pushq %rbp + \\ pushq %rdi + \\ pushq %rsi + \\ pushq %r12 + \\ pushq %r13 + \\ pushq %r14 + \\ pushq %r15 + \\ + \\ subq $160, %rsp + \\ movups %xmm6, 0x00(%rsp) + \\ movups %xmm7, 0x10(%rsp) + \\ movups %xmm8, 0x20(%rsp) + \\ movups %xmm9, 0x30(%rsp) + \\ movups %xmm10, 0x40(%rsp) + \\ movups %xmm11, 0x50(%rsp) + \\ movups %xmm12, 0x60(%rsp) + \\ movups %xmm13, 0x70(%rsp) + \\ movups %xmm14, 0x80(%rsp) + \\ movups %xmm15, 0x90(%rsp) + \\ + \\ movq %rsp, (%rcx) + \\ movq (%rdx), %rsp + \\ + \\ movups 0x00(%rsp), %xmm6 + \\ movups 0x10(%rsp), %xmm7 + \\ movups 0x20(%rsp), %xmm8 + \\ movups 0x30(%rsp), %xmm9 + \\ movups 0x40(%rsp), %xmm10 + \\ movups 0x50(%rsp), %xmm11 + \\ movups 0x60(%rsp), %xmm12 + \\ movups 0x70(%rsp), %xmm13 + \\ movups 0x80(%rsp), %xmm14 + \\ movups 0x90(%rsp), %xmm15 + \\ addq $160, %rsp + \\ + \\ popq %r15 + \\ popq %r14 + \\ popq %r13 + \\ popq %r12 + \\ popq %rsi + \\ popq %rdi + \\ popq %rbp + \\ popq %rbx + \\ + \\ popq %gs:0x08 + \\ popq %gs:0x10 + \\ + \\ retq + ); + } +}; + +const Intel_SysV = struct { + pub const word_count = 7; + + pub const entry_offset = word_count - 1; + + comptime { + asm ( + \\.global zefi_stack_swap + \\zefi_stack_swap: + \\ pushq %rbx + \\ pushq %rbp + \\ pushq %r12 + \\ pushq %r13 + \\ pushq %r14 + \\ pushq %r15 + \\ + \\ movq %rsp, (%rdi) + \\ movq (%rsi), %rsp + \\ + \\ popq %r15 + \\ popq %r14 + \\ popq %r13 + \\ popq %r12 + \\ popq %rbp + \\ popq %rbx + \\ + \\ retq + ); + } +}; + +const Arm_64 = struct { + pub const word_count = 20; + + pub const entry_offset = 0; + + comptime { + asm ( + \\.global _zefi_stack_swap + \\_zefi_stack_swap: + \\ stp lr, fp, [sp, #-20*8]! + \\ stp d8, d9, [sp, #2*8] + \\ stp d10, d11, [sp, #4*8] + \\ stp d12, d13, [sp, #6*8] + \\ stp d14, d15, [sp, #8*8] + \\ stp x19, x20, [sp, #10*8] + \\ stp x21, x22, [sp, #12*8] + \\ stp x23, x24, [sp, #14*8] + \\ stp x25, x26, [sp, #16*8] + \\ stp x27, x28, [sp, #18*8] + \\ + \\ mov x9, sp + \\ str x9, [x0] + \\ ldr x9, [x1] + \\ mov sp, x9 + \\ + \\ ldp x27, x28, [sp, #18*8] + \\ ldp x25, x26, [sp, #16*8] + \\ ldp x23, x24, [sp, #14*8] + \\ ldp x21, x22, [sp, #12*8] + \\ ldp x19, x20, [sp, #10*8] + \\ ldp d14, d15, [sp, #8*8] + \\ ldp d12, d13, [sp, #6*8] + \\ ldp d10, d11, [sp, #4*8] + \\ ldp d8, d9, [sp, #2*8] + \\ ldp lr, fp, [sp], #20*8 + \\ + \\ ret + ); + } +};