diff --git a/src/rpc/server/basic.zig b/src/rpc/server/basic.zig new file mode 100644 index 000000000..02e85bff5 --- /dev/null +++ b/src/rpc/server/basic.zig @@ -0,0 +1,46 @@ +const std = @import("std"); +const sig = @import("../../sig.zig"); + +const connection = @import("connection.zig"); +const requests = @import("requests.zig"); + +const ServerCtx = sig.rpc.server.Context; + +pub const AcceptAndServeConnectionError = + connection.AcceptHandledError || + connection.SetSocketSync || + std.mem.Allocator.Error || + std.http.Server.ReceiveHeadError || + requests.HandleRequestError; + +pub fn acceptAndServeConnection(server_ctx: *ServerCtx) !void { + const conn = connection.acceptHandled( + server_ctx.tcp, + .blocking, + ) catch |err| switch (err) { + error.WouldBlock => return, + else => |e| return e, + }; + defer conn.stream.close(); + + if (!connection.have_accept4) { + // make sure the accepted socket is in blocking mode + try connection.setSocketSync(conn.stream.handle, .blocking); + } + + server_ctx.wait_group.start(); + defer server_ctx.wait_group.finish(); + + const buffer = try server_ctx.allocator.alloc(u8, server_ctx.read_buffer_size); + defer server_ctx.allocator.free(buffer); + + var http_server = std.http.Server.init(conn, buffer); + var request = try http_server.receiveHead(); + + try requests.handleRequest( + server_ctx.logger, + &request, + server_ctx.snapshot_dir, + server_ctx.latest_snapshot_gen_info, + ); +} diff --git a/src/rpc/server/lib.zig b/src/rpc/server/lib.zig index 6e83fb0b4..bd14c5b40 100644 --- a/src/rpc/server/lib.zig +++ b/src/rpc/server/lib.zig @@ -7,8 +7,17 @@ const requests = @import("requests.zig"); const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; +pub const basic = @import("basic.zig"); pub const LinuxIoUring = @import("linux_io_uring.zig").LinuxIoUring; +pub const WorkPool = union(enum) { + basic, + linux_io_uring: switch (LinuxIoUring.can_use) { + .yes, .check => *LinuxIoUring, + .no => noreturn, + }, +}; + pub const Context = struct { allocator: std.mem.Allocator, logger: ScopedLogger, @@ -75,83 +84,36 @@ pub const Context = struct { /// Spawn the serve loop as a separate thread. pub fn serveSpawn( self: *Context, - exit: *std.atomic.Value(bool), /// The pool to dispatch work to. work_pool: WorkPool, + exit: *std.atomic.Value(bool), ) std.Thread.SpawnError!std.Thread { - return try std.Thread.spawn(.{}, serve, .{ self, exit, work_pool }); + return try std.Thread.spawn(.{}, serve, .{ self, work_pool, exit }); } /// Calls `acceptAndServeConnection` in a loop until `exit.load(.acquire)`. pub fn serve( self: *Context, - exit: *std.atomic.Value(bool), /// The pool to dispatch work to. work_pool: WorkPool, - ) WorkPool.AcceptAndServeConnectionError!void { + exit: *std.atomic.Value(bool), + ) AcceptAndServeConnectionError!void { while (!exit.load(.acquire)) { - try work_pool.acceptAndServeConnection(self); + try self.acceptAndServeConnection(work_pool); } } -}; - -pub const WorkPool = union(enum) { - basic, - linux_io_uring: switch (LinuxIoUring.can_use) { - .yes, .check => *LinuxIoUring, - .no => noreturn, - }, - - const BasicAASCError = - connection.AcceptHandledError || - connection.SetSocketSync || - std.mem.Allocator.Error || - std.http.Server.ReceiveHeadError || - requests.HandleRequestError; - const IoUringAASCError = - LinuxIoUring.AcceptAndServeConnectionsError; pub const AcceptAndServeConnectionError = - BasicAASCError || - IoUringAASCError; + basic.AcceptAndServeConnectionError || + LinuxIoUring.AcceptAndServeConnectionsError; pub fn acceptAndServeConnection( - self: WorkPool, - server: *Context, + self: *Context, + work_pool: WorkPool, ) AcceptAndServeConnectionError!void { - switch (self) { - .basic => { - const conn = connection.acceptHandled( - server.tcp, - .blocking, - ) catch |err| switch (err) { - error.WouldBlock => return, - else => |e| return e, - }; - defer conn.stream.close(); - - if (!connection.have_accept4) { - // make sure the accepted socket is in blocking mode - try connection.setSocketSync(conn.stream.handle, .blocking); - } - - server.wait_group.start(); - defer server.wait_group.finish(); - - const buffer = try server.allocator.alloc(u8, server.read_buffer_size); - defer server.allocator.free(buffer); - - var http_server = std.http.Server.init(conn, buffer); - var request = try http_server.receiveHead(); - - try requests.handleRequest( - server.logger, - &request, - server.snapshot_dir, - server.latest_snapshot_gen_info, - ); - }, - .linux_io_uring => |linux| try linux.acceptAndServeConnections(server), + switch (work_pool) { + .basic => try basic.acceptAndServeConnection(self), + .linux_io_uring => |linux| try linux.acceptAndServeConnections(self), } } }; @@ -238,7 +200,7 @@ test Context { defer rpc_server_ctx.joinDeinit(); var exit = std.atomic.Value(bool).init(false); - const serve_thread = try rpc_server_ctx.serveSpawn(&exit, work_pool); + const serve_thread = try rpc_server_ctx.serveSpawn(work_pool, &exit); defer serve_thread.join(); defer exit.store(true, .release); diff --git a/src/rpc/server/linux_io_uring.zig b/src/rpc/server/linux_io_uring.zig index 15baef701..9ec43f5af 100644 --- a/src/rpc/server/linux_io_uring.zig +++ b/src/rpc/server/linux_io_uring.zig @@ -5,8 +5,8 @@ const sig = @import("../../sig.zig"); const connection = @import("connection.zig"); const requests = @import("requests.zig"); -const IoUring = std.os.linux.IoUring; const ServerCtx = sig.rpc.server.Context; +const IoUring = std.os.linux.IoUring; pub const LinuxIoUring = struct { io_uring: IoUring,