From 302943e08c00aa6bd1307133b02304c0f2f0b172 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Tue, 22 Oct 2024 12:51:27 +0800 Subject: [PATCH] Switch to nonblocking socket, falling back to blocking on write if necessary Avoids a syscall --- src/response.zig | 27 +++------ src/t.zig | 1 + src/worker.zig | 142 +++++++++++++++++++++++++++++++++++------------ 3 files changed, 117 insertions(+), 53 deletions(-) diff --git a/src/response.zig b/src/response.zig index 110778d..9ff9dc3 100644 --- a/src/response.zig +++ b/src/response.zig @@ -113,6 +113,8 @@ pub const Response = struct { const conn = self.conn; const stream = conn.stream; + // caller probably expects this to be in blocking mode + try conn.blockingMode(); const header_buf = try self.prepareHeader(); try stream.writeAll(header_buf); @@ -123,11 +125,11 @@ pub const Response = struct { } pub fn chunk(self: *Response, data: []const u8) !void { - const stream = self.conn.stream; + const conn = self.conn; if (!self.chunked) { self.chunked = true; const header_buf = try self.prepareHeader(); - try stream.writeAll(header_buf); + try conn.writeAll(header_buf); } // enough for a 1TB chunk @@ -143,7 +145,7 @@ pub const Response = struct { .{ .len = len + 2, .base = &buf }, .{ .len = data.len, .base = data.ptr }, }; - try writeAllIOVec(stream.handle, &vec); + try conn.writeAllIOVec(&vec); } pub fn clearWriter(self: *Response) void { @@ -164,12 +166,12 @@ pub const Response = struct { } self.written = true; - const stream = self.conn.stream; + const conn = self.conn; if (self.chunked) { // If the response was chunked, then we've already written the header // the connection is already in blocking mode, but the trailing chunk // hasn't bene written yet. We'll write that now, and that's it. - return stream.writeAll("\r\n0\r\n\r\n"); + return conn.writeAll("\r\n0\r\n\r\n"); } const header_buf = try self.prepareHeader(); @@ -181,7 +183,7 @@ pub const Response = struct { .{ .len = header_buf.len, .base = header_buf.ptr }, .{ .len = body.len, .base = body.ptr }, }; - try writeAllIOVec(stream.handle, &vec); + return conn.writeAllIOVec(&vec); } fn prepareHeader(self: *Response) ![]const u8 { @@ -405,19 +407,6 @@ pub const Response = struct { }; }; -fn writeAllIOVec(socket: std.posix.socket_t, vec: []std.posix.iovec_const) !void { - var i: usize = 0; - while (true) { - var n = try std.posix.writev(socket, vec[i..]); - while (n >= vec[i].len) { - n -= vec[i].len; - i += 1; - if (i >= vec.len) return; - } - vec[i].base += n; - vec[i].len -= n; - } -} // All the upfront memory allocation that we can do. Gets re-used from request // to request. diff --git a/src/t.zig b/src/t.zig index 81d6500..7b7bb9f 100644 --- a/src/t.zig +++ b/src/t.zig @@ -107,6 +107,7 @@ pub const Context = struct { .timeout = 0, .request_count = 0, .close = false, + .socket_flags = 0, .ws_worker = undefined, .conn_arena = ctx_arena, .req_arena = std.heap.ArenaAllocator.init(aa), diff --git a/src/worker.zig b/src/worker.zig index 4009d76..82be2ae 100644 --- a/src/worker.zig +++ b/src/worker.zig @@ -639,7 +639,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { var address: net.Address = undefined; var address_len: posix.socklen_t = @sizeOf(net.Address); - const socket = posix.accept(listener, &address.any, &address_len, posix.SOCK.CLOEXEC) catch |err| { + const socket = posix.accept(listener, &address.any, &address_len, posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK) catch |err| { // On BSD, REUSEPORT_LB means that only 1 worker should get notified // of a connetion. On Linux, however, we only have REUSEPORT, which will // notify all workers. However, we monitor the listener using EPOLLEXCLUSIVE. @@ -653,17 +653,10 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { errdefer posix.close(socket); metrics.connection(); - { - // socket is _probably_ in NONBLOCKING mode (it inherits - // the flag from the listening socket). - const flags = try posix.fcntl(socket, posix.F.GETFL, 0); - const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true })); - if (flags & nonblocking == nonblocking) { - // Yup, it's in nonblocking mode. Disable that flag to - // put it in blocking mode. - _ = try posix.fcntl(socket, posix.F.SETFL, flags & ~nonblocking); - } - } + const socket_flags = try posix.fcntl(socket, posix.F.GETFL, 0); + const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true })); + std.debug.assert(socket_flags & nonblocking == nonblocking); + const conn = try self.conn_mem_pool.create(); errdefer self.conn_mem_pool.destroy(conn); @@ -671,6 +664,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { http_conn.request_count = 1; http_conn._state = .request; http_conn.address = address; + http_conn.socket_flags = socket_flags; http_conn.stream = .{ .handle = socket }; http_conn.timeout = now + self.timeout_request; @@ -823,27 +817,24 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { // Enforces timeouts, and returns when the next timeout should be checked. fn prepareToWait(self: *Self, now: u32) struct {bool, ?i32} { - const request_count, const request_timeout = self.enforceTimeout(&self.request_list, now); + const request_timed_out, const request_count, const request_timeout = collectTimedOut(&self.request_list, now); - // I'm pretty sure this is safe, and pretty neat. The only concurrent - // action that can happen on the keepalive_list is appending a node - // (at the tail). That means we can snapshot the keepalive list under - // a short-lived lock, and then safely iterate the snapshot. - const keepalive_snapshot: List(Conn(WSH)) = blk: { + const keepalive_timed_out, const keepalive_count, const keepalive_timeout = blk: { const list = &self.keepalive_list; list.mut.lock(); defer list.mut.unlock(); - break :blk .{.head = list.inner.head, .tail = list.inner.tail}; + break :blk collectTimedOut(&list.inner, now); }; - const keepalive_count, const keepalive_timeout = self.enforceTimeout(&keepalive_snapshot, now); var closed = false; if (request_count > 0) { closed = true; + self.closeList(request_timed_out); metrics.timeoutRequest(request_count); } if (keepalive_count > 0) { closed = true; + self.closeList(keepalive_timed_out); metrics.timeoutKeepalive(keepalive_count); } @@ -864,19 +855,30 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type { // lists are ordered from soonest to timeout to last, as soon as we find // a connection that isn't timed out, we can break; // This returns the next timeout. - fn enforceTimeout(self: *Self, list: *const List(Conn(WSH)), now: u32) struct { usize, ?u32 } { + fn collectTimedOut(list: *List(Conn(WSH)), now: u32) struct { List(Conn(WSH)), usize, ?u32 } { var conn = list.head; var count: usize = 0; + var timed_out: List(Conn(WSH)) = .{}; + while (conn) |c| { const timeout = c.protocol.http.timeout; if (timeout > now) { - return .{ count, timeout }; + return .{ timed_out, count, timeout }; } count += 1; + conn = c.next; + list.remove(c); + timed_out.insert(c); + } + return .{ timed_out, count, null }; + } + + fn closeList(self: *Self, list: List(Conn(WSH))) void { + var conn = list.head; + while (conn) |c| { conn = c.next; self.closeConn(c); } - return .{ count, null }; } fn shutdownList(self: *Self, list: *List(Conn(WSH))) void { @@ -1491,6 +1493,7 @@ pub const HTTPConn = struct { stream: net.Stream, address: net.Address, + socket_flags: usize, // Data needed to parse a request. This contains pre-allocated memory, e.g. // as a read buffer and to store parsed headers. It also contains the state @@ -1540,6 +1543,7 @@ pub const HTTPConn = struct { .stream = undefined, .address = undefined, .request_count = 0, + .socket_flags = 0, .ws_worker = ws_worker, .req_state = req_state, .res_state = res_state, @@ -1572,11 +1576,88 @@ pub const HTTPConn = struct { // getting put back into the pool pub fn reset(self: *HTTPConn) void { self.close = false; - self._state = .request; self.handover = .unknown; self.stream = undefined; self.address = undefined; - self.request_count = 0; + } + + pub fn writeAll(self: *const HTTPConn, data: []const u8) !void { + const socket = self.stream.handle; + + var i: usize = 0; + var blocking = false; + + while (i < data.len) { + const n = posix.write(socket, data[i..]) catch |err| switch (err) { + error.WouldBlock => { + try self.blockingMode(); + blocking = true; + continue; + }, + else => return err, + }; + + // shouldn't be posssible on a correct posix implementation + // but let's assert to make sure + std.debug.assert(n != 0); + i += n; + } + + // if write fails, and we're in blocking, it doesn't really matter + // we're going to be closing connction anyways + if (blocking) { + try self.nonblockingMode(); + } + } + + pub fn writeAllIOVec(self: *const HTTPConn, vec: []posix.iovec_const) !void { + const socket = self.stream.handle; + + var i: usize = 0; + var blocking = false; + + while (true) { + var n = posix.writev(socket, vec[i..]) catch |err| switch (err) { + error.WouldBlock => { + try self.blockingMode(); + blocking = true; + continue; + }, + else => return err, + }; + + while (n >= vec[i].len) { + n -= vec[i].len; + i += 1; + if (i >= vec.len) { + if (blocking) { + try self.nonblockingMode(); + } + return; + } + } + vec[i].base += n; + vec[i].len -= n; + } + + } + + pub fn blockingMode(self: *const HTTPConn) !void { + if (comptime httpz.blockingMode() == true) { + // When httpz is in blocking mode, than we always keep the socket in + // blocking mode + return; + } + _ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true }))); + } + + pub fn nonblockingMode(self: *const HTTPConn) !void { + if (comptime httpz.blockingMode() == true) { + // When httpz is in blocking mode, than we always keep the socket in + // blocking mode + return; + } + _ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags); } }; @@ -1635,16 +1716,9 @@ fn requestError(conn: *HTTPConn, err: anyerror) !void { fn writeError(socket: posix.socket_t, comptime status: u16, comptime msg: []const u8) !void { const response = std.fmt.comptimePrint("HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", .{ status, msg.len, msg }); - // Zig doesn't have the BSD/Darwin values for this. - const DONTWAIT = if (posix.MSG != void and @hasDecl(posix.MSG, "DONTWAIT")) posix.MSG.DONTWAIT else 0x00080; var i: usize = 0; - while (i < response.len) { - const n = try posix.sendto(socket, response[i..], DONTWAIT, null, 0); - if (n == 0) { - return error.Closed; - } - - i += n; + while (i < msg.len) { + i += try posix.write(socket, response[i..]); } }