Skip to content

Commit

Permalink
Switch to nonblocking socket, falling back to blocking on write if ne…
Browse files Browse the repository at this point in the history
…cessary

Avoids a syscall
  • Loading branch information
karlseguin committed Oct 22, 2024
1 parent be9a673 commit 302943e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 53 deletions.
27 changes: 8 additions & 19 deletions src/response.zig
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ pub const Response = struct {
const conn = self.conn;
const stream = conn.stream;

// caller probably expects this to be in blocking mode
try conn.blockingMode();
const header_buf = try self.prepareHeader();
try stream.writeAll(header_buf);

Expand All @@ -123,11 +125,11 @@ pub const Response = struct {
}

pub fn chunk(self: *Response, data: []const u8) !void {
const stream = self.conn.stream;
const conn = self.conn;
if (!self.chunked) {
self.chunked = true;
const header_buf = try self.prepareHeader();
try stream.writeAll(header_buf);
try conn.writeAll(header_buf);
}

// enough for a 1TB chunk
Expand All @@ -143,7 +145,7 @@ pub const Response = struct {
.{ .len = len + 2, .base = &buf },
.{ .len = data.len, .base = data.ptr },
};
try writeAllIOVec(stream.handle, &vec);
try conn.writeAllIOVec(&vec);
}

pub fn clearWriter(self: *Response) void {
Expand All @@ -164,12 +166,12 @@ pub const Response = struct {
}
self.written = true;

const stream = self.conn.stream;
const conn = self.conn;
if (self.chunked) {
// If the response was chunked, then we've already written the header
// the connection is already in blocking mode, but the trailing chunk
// hasn't bene written yet. We'll write that now, and that's it.
return stream.writeAll("\r\n0\r\n\r\n");
return conn.writeAll("\r\n0\r\n\r\n");
}

const header_buf = try self.prepareHeader();
Expand All @@ -181,7 +183,7 @@ pub const Response = struct {
.{ .len = header_buf.len, .base = header_buf.ptr },
.{ .len = body.len, .base = body.ptr },
};
try writeAllIOVec(stream.handle, &vec);
return conn.writeAllIOVec(&vec);
}

fn prepareHeader(self: *Response) ![]const u8 {
Expand Down Expand Up @@ -405,19 +407,6 @@ pub const Response = struct {
};
};

fn writeAllIOVec(socket: std.posix.socket_t, vec: []std.posix.iovec_const) !void {
var i: usize = 0;
while (true) {
var n = try std.posix.writev(socket, vec[i..]);
while (n >= vec[i].len) {
n -= vec[i].len;
i += 1;
if (i >= vec.len) return;
}
vec[i].base += n;
vec[i].len -= n;
}
}

// All the upfront memory allocation that we can do. Gets re-used from request
// to request.
Expand Down
1 change: 1 addition & 0 deletions src/t.zig
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ pub const Context = struct {
.timeout = 0,
.request_count = 0,
.close = false,
.socket_flags = 0,
.ws_worker = undefined,
.conn_arena = ctx_arena,
.req_arena = std.heap.ArenaAllocator.init(aa),
Expand Down
142 changes: 108 additions & 34 deletions src/worker.zig
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type {
var address: net.Address = undefined;
var address_len: posix.socklen_t = @sizeOf(net.Address);

const socket = posix.accept(listener, &address.any, &address_len, posix.SOCK.CLOEXEC) catch |err| {
const socket = posix.accept(listener, &address.any, &address_len, posix.SOCK.CLOEXEC | posix.SOCK.NONBLOCK) catch |err| {
// On BSD, REUSEPORT_LB means that only 1 worker should get notified
// of a connetion. On Linux, however, we only have REUSEPORT, which will
// notify all workers. However, we monitor the listener using EPOLLEXCLUSIVE.
Expand All @@ -653,24 +653,18 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type {
errdefer posix.close(socket);
metrics.connection();

{
// socket is _probably_ in NONBLOCKING mode (it inherits
// the flag from the listening socket).
const flags = try posix.fcntl(socket, posix.F.GETFL, 0);
const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true }));
if (flags & nonblocking == nonblocking) {
// Yup, it's in nonblocking mode. Disable that flag to
// put it in blocking mode.
_ = try posix.fcntl(socket, posix.F.SETFL, flags & ~nonblocking);
}
}
const socket_flags = try posix.fcntl(socket, posix.F.GETFL, 0);
const nonblocking = @as(u32, @bitCast(posix.O{ .NONBLOCK = true }));
std.debug.assert(socket_flags & nonblocking == nonblocking);

const conn = try self.conn_mem_pool.create();
errdefer self.conn_mem_pool.destroy(conn);

const http_conn = try self.http_conn_pool.acquire();
http_conn.request_count = 1;
http_conn._state = .request;
http_conn.address = address;
http_conn.socket_flags = socket_flags;
http_conn.stream = .{ .handle = socket };
http_conn.timeout = now + self.timeout_request;

Expand Down Expand Up @@ -823,27 +817,24 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type {

// Enforces timeouts, and returns when the next timeout should be checked.
fn prepareToWait(self: *Self, now: u32) struct {bool, ?i32} {
const request_count, const request_timeout = self.enforceTimeout(&self.request_list, now);
const request_timed_out, const request_count, const request_timeout = collectTimedOut(&self.request_list, now);

// I'm pretty sure this is safe, and pretty neat. The only concurrent
// action that can happen on the keepalive_list is appending a node
// (at the tail). That means we can snapshot the keepalive list under
// a short-lived lock, and then safely iterate the snapshot.
const keepalive_snapshot: List(Conn(WSH)) = blk: {
const keepalive_timed_out, const keepalive_count, const keepalive_timeout = blk: {
const list = &self.keepalive_list;
list.mut.lock();
defer list.mut.unlock();
break :blk .{.head = list.inner.head, .tail = list.inner.tail};
break :blk collectTimedOut(&list.inner, now);
};
const keepalive_count, const keepalive_timeout = self.enforceTimeout(&keepalive_snapshot, now);

var closed = false;
if (request_count > 0) {
closed = true;
self.closeList(request_timed_out);
metrics.timeoutRequest(request_count);
}
if (keepalive_count > 0) {
closed = true;
self.closeList(keepalive_timed_out);
metrics.timeoutKeepalive(keepalive_count);
}

Expand All @@ -864,19 +855,30 @@ pub fn NonBlocking(comptime S: type, comptime WSH: type) type {
// lists are ordered from soonest to timeout to last, as soon as we find
// a connection that isn't timed out, we can break;
// This returns the next timeout.
fn enforceTimeout(self: *Self, list: *const List(Conn(WSH)), now: u32) struct { usize, ?u32 } {
fn collectTimedOut(list: *List(Conn(WSH)), now: u32) struct { List(Conn(WSH)), usize, ?u32 } {
var conn = list.head;
var count: usize = 0;
var timed_out: List(Conn(WSH)) = .{};

while (conn) |c| {
const timeout = c.protocol.http.timeout;
if (timeout > now) {
return .{ count, timeout };
return .{ timed_out, count, timeout };
}
count += 1;
conn = c.next;
list.remove(c);
timed_out.insert(c);
}
return .{ timed_out, count, null };
}

fn closeList(self: *Self, list: List(Conn(WSH))) void {
var conn = list.head;
while (conn) |c| {
conn = c.next;
self.closeConn(c);
}
return .{ count, null };
}

fn shutdownList(self: *Self, list: *List(Conn(WSH))) void {
Expand Down Expand Up @@ -1491,6 +1493,7 @@ pub const HTTPConn = struct {

stream: net.Stream,
address: net.Address,
socket_flags: usize,

// Data needed to parse a request. This contains pre-allocated memory, e.g.
// as a read buffer and to store parsed headers. It also contains the state
Expand Down Expand Up @@ -1540,6 +1543,7 @@ pub const HTTPConn = struct {
.stream = undefined,
.address = undefined,
.request_count = 0,
.socket_flags = 0,
.ws_worker = ws_worker,
.req_state = req_state,
.res_state = res_state,
Expand Down Expand Up @@ -1572,11 +1576,88 @@ pub const HTTPConn = struct {
// getting put back into the pool
pub fn reset(self: *HTTPConn) void {
self.close = false;
self._state = .request;
self.handover = .unknown;
self.stream = undefined;
self.address = undefined;
self.request_count = 0;
}

pub fn writeAll(self: *const HTTPConn, data: []const u8) !void {
const socket = self.stream.handle;

var i: usize = 0;
var blocking = false;

while (i < data.len) {
const n = posix.write(socket, data[i..]) catch |err| switch (err) {
error.WouldBlock => {
try self.blockingMode();
blocking = true;
continue;
},
else => return err,
};

// shouldn't be posssible on a correct posix implementation
// but let's assert to make sure
std.debug.assert(n != 0);
i += n;
}

// if write fails, and we're in blocking, it doesn't really matter
// we're going to be closing connction anyways
if (blocking) {
try self.nonblockingMode();
}
}

pub fn writeAllIOVec(self: *const HTTPConn, vec: []posix.iovec_const) !void {
const socket = self.stream.handle;

var i: usize = 0;
var blocking = false;

while (true) {
var n = posix.writev(socket, vec[i..]) catch |err| switch (err) {
error.WouldBlock => {
try self.blockingMode();
blocking = true;
continue;
},
else => return err,
};

while (n >= vec[i].len) {
n -= vec[i].len;
i += 1;
if (i >= vec.len) {
if (blocking) {
try self.nonblockingMode();
}
return;
}
}
vec[i].base += n;
vec[i].len -= n;
}

}

pub fn blockingMode(self: *const HTTPConn) !void {
if (comptime httpz.blockingMode() == true) {
// When httpz is in blocking mode, than we always keep the socket in
// blocking mode
return;
}
_ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags & ~@as(u32, @bitCast(posix.O{ .NONBLOCK = true })));
}

pub fn nonblockingMode(self: *const HTTPConn) !void {
if (comptime httpz.blockingMode() == true) {
// When httpz is in blocking mode, than we always keep the socket in
// blocking mode
return;
}
_ = try posix.fcntl(self.stream.handle, posix.F.SETFL, self.socket_flags);
}
};

Expand Down Expand Up @@ -1635,16 +1716,9 @@ fn requestError(conn: *HTTPConn, err: anyerror) !void {
fn writeError(socket: posix.socket_t, comptime status: u16, comptime msg: []const u8) !void {
const response = std.fmt.comptimePrint("HTTP/1.1 {d} \r\nConnection: Close\r\nContent-Length: {d}\r\n\r\n{s}", .{ status, msg.len, msg });

// Zig doesn't have the BSD/Darwin values for this.
const DONTWAIT = if (posix.MSG != void and @hasDecl(posix.MSG, "DONTWAIT")) posix.MSG.DONTWAIT else 0x00080;
var i: usize = 0;
while (i < response.len) {
const n = try posix.sendto(socket, response[i..], DONTWAIT, null, 0);
if (n == 0) {
return error.Closed;
}

i += n;
while (i < msg.len) {
i += try posix.write(socket, response[i..]);
}
}

Expand Down

0 comments on commit 302943e

Please sign in to comment.