Skip to content

Commit

Permalink
first pass at supporting middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
karlseguin committed Aug 16, 2024
1 parent e1df408 commit 02819fc
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 136 deletions.
8 changes: 0 additions & 8 deletions src/config.zig
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ pub const Config = struct {
request: Request = .{},
response: Response = .{},
timeout: Timeout = .{},
cors: ?CORS = null,
thread_pool: ThreadPool = .{},
websocket: Websocket = .{},

Expand Down Expand Up @@ -52,13 +51,6 @@ pub const Config = struct {
request_count: ?u32 = null,
};

pub const CORS = struct {
origin: []const u8,
headers: ?[]const u8 = null,
methods: ?[]const u8 = null,
max_age: ?[]const u8 = null,
};

pub const Websocket = struct {
max_message_size: ?usize = null,
small_buffer_size: ?usize = null,
Expand Down
211 changes: 144 additions & 67 deletions src/httpz.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub const routing = @import("router.zig");
pub const request = @import("request.zig");
pub const response = @import("response.zig");
pub const key_value = @import("key_value.zig");
pub const middleware = @import("middleware/middleware.zig");

pub const Router = routing.Router;
pub const Request = request.Request;
Expand Down Expand Up @@ -168,6 +169,38 @@ pub fn DispatchableAction(comptime Handler: type, comptime ActionArg: type) type
handler: Handler,
action: ActionArg,
dispatcher: Dispatcher(Handler, ActionArg),
middlewares: []const Middleware(Handler) = &.{},
};
}

pub fn Middleware(comptime H: type) type {
return struct {
ptr: *anyopaque,
executeFn: *const fn (ptr: *anyopaque, req: *Request, res: *Response, executor: *Server(H).Executor) anyerror!void,

const Self = @This();

pub fn init(ptr: anytype) Self {
const T = @TypeOf(ptr);
const ptr_info = @typeInfo(T);

const gen = struct {
pub fn execute(pointer: *anyopaque, req: *Request, res: *Response, executor: *Server(H).Executor) anyerror!void {
const self: T = @ptrCast(@alignCast(pointer));
return ptr_info.Pointer.child.execute(self, req, res, executor);
}
};

return .{
.ptr = ptr,
.executeFn = gen.execute,
};
}

// This is the same as before
pub fn execute(self: Self, req: *Request, res: *Response, executor: *Server(H).Executor) !void {
return self.executeFn(self.ptr, req, res, executor);
}
};
}

Expand All @@ -194,8 +227,8 @@ pub fn Server(comptime H: type) type {

handler: H,
config: Config,
arena: Allocator,
allocator: Allocator,
_cors_origin: ?[]const u8,
_router: Router(H, ActionArg),
_mut: Thread.Mutex,
_cond: Thread.Condition,
Expand All @@ -207,18 +240,27 @@ pub fn Server(comptime H: type) type {
const Self = @This();

pub fn init(allocator: Allocator, config: Config, handler: H) !Self {
var thread_pool = try TP.init(allocator, .{
// Be mindful about where we pass this arena. Most things are able to
// do dynamic allocation, and need to be able to free when they're
// done with their memory. Only use this for stuff that's created on
// startup and won't dynamically need to grow/shrink.
const arena = try allocator.create(std.heap.ArenaAllocator);
errdefer allocator.destroy(arena);
arena.* = std.heap.ArenaAllocator.init(allocator);
errdefer arena.deinit();

const thread_pool = try TP.init(arena.allocator(), .{
.count = config.threadPoolCount(),
.backlog = config.thread_pool.backlog orelse 500,
.buffer_size = config.thread_pool.buffer_size orelse 32_768,
});
errdefer thread_pool.deinit();

const signals = try allocator.alloc(posix.fd_t, config.workerCount());
errdefer allocator.free(signals);
const signals = try arena.allocator().alloc(posix.fd_t, config.workerCount());

const default_dispatcher = if (comptime Handler == void) defaultDispatcher else defaultDispatcherWithHandler;

// do not pass arena.allocator to WorkerState, it needs to be able to
// allocate and free at will.
var websocket_state = try ws.WorkerState.init(allocator, .{
.max_message_size = config.websocket.max_message_size,
.buffers = .{
Expand All @@ -238,25 +280,27 @@ pub fn Server(comptime H: type) type {
errdefer websocket_state.deinit();

return .{
.handler = handler,
.config = config,
.handler = handler,
.allocator = allocator,
.arena = arena.allocator(),
._mut = .{},
._cond = .{},
._signals = signals,
._thread_pool = thread_pool,
._websocket_state = websocket_state,
._router = try Router(H, ActionArg).init(allocator, default_dispatcher, handler),
._cors_origin = if (config.cors) |cors| cors.origin else null,
._router = try Router(H, ActionArg).init(arena.allocator(), default_dispatcher, handler),
._max_request_per_connection = config.timeout.request_count orelse MAX_REQUEST_COUNT,
};
}

pub fn deinit(self: *Self) void {
self.allocator.free(self._signals);
self._router.deinit();
self._thread_pool.deinit();
self._thread_pool.stop();
self._websocket_state.deinit();

const arena: *std.heap.ArenaAllocator = @ptrCast(@alignCast(self.arena.ptr));
arena.deinit();
self.allocator.destroy(arena);
}

pub fn listen(self: *Self) !void {
Expand Down Expand Up @@ -334,8 +378,8 @@ pub fn Server(comptime H: type) type {
const Worker = worker.NonBlocking(*Self, WebsocketHandler);
var signals = self._signals;
const worker_count = signals.len;
const workers = try allocator.alloc(Worker, worker_count);
const threads = try allocator.alloc(Thread, worker_count);
const workers = try self.arena.alloc(Worker, worker_count);
const threads = try self.arena.alloc(Thread, worker_count);

var started: usize = 0;
errdefer for (0..started) |i| {
Expand All @@ -347,8 +391,6 @@ pub fn Server(comptime H: type) type {
for (0..started) |i| {
workers[i].deinit();
}
allocator.free(workers);
allocator.free(threads);
}

for (0..workers.len) |i| {
Expand Down Expand Up @@ -468,42 +510,56 @@ pub fn Server(comptime H: type) type {
}

fn dispatch(self: *const Self, dispatchable_action: ?DispatchableAction(H, ActionArg), req: *Request, res: *Response) !void {
if (self._cors_origin) |origin| {
res.header("Access-Control-Allow-Origin", origin);
}
if (dispatchable_action) |da| {
if (Handler == void) {
return da.dispatcher(da.action, req, res);
const da = dispatchable_action orelse {
if (comptime std.meta.hasFn(Handler, "notFound")) {
return self.handler.notFound(req, res);
}
return da.dispatcher(da.handler, da.action, req, res);
}
res.status = 404;
res.body = "Not Found";
return;
};

if (req.method == .OPTIONS) {
if (self.config.cors) |config| {
if (req.header("sec-fetch-mode")) |mode| {
if (std.mem.eql(u8, mode, "cors")) {
if (config.headers) |headers| {
res.header("Access-Control-Allow-Headers", headers);
}
if (config.methods) |methods| {
res.header("Access-Control-Allow-Methods", methods);
}
if (config.max_age) |max_age| {
res.header("Access-Control-Max-Age", max_age);
}
res.status = 204;
return;
}
var executor = Executor{
.da = da,
.index = 0,
.req = req,
.res = res,
.middlewares = da.middlewares,
};
return executor.next();
}

pub fn middleware(self: *Self, comptime M: type, config: M.Config) !Middleware(H) {
const m = try self.arena.create(M);
errdefer self.arena.destroy(m);
m.* = M.init(config);
return Middleware(H).init(m);
}

pub const Executor = struct {
index: usize,
req: *Request,
res: *Response,
// pull this out of da since we'll access it a lot (not really, but w/e)
middlewares: []const Middleware(H),
da: DispatchableAction(H, ActionArg),

pub fn next(self: *Executor) !void {
const index = self.index;
const middlewares = self.middlewares;

if (index == middlewares.len) {
const da = self.da;
if (comptime H == void) {
return da.dispatcher(da.action, self.req, self.res);
}
return da.dispatcher(da.handler, da.action, self.req, self.res);
}
}

if (comptime std.meta.hasFn(Handler, "notFound")) {
return self.handler.notFound(req, res);
self.index = index + 1;
return middlewares[index].execute(self.req, self.res, self);
}
res.status = 404;
res.body = "Not Found";
}
};
};
}

Expand Down Expand Up @@ -630,18 +686,26 @@ test "tests:beforeAll" {
// this will leak since the server will run until the process exits. If we use
// our testing allocator, it'll report the leak.
const ga = global_test_allocator.allocator();

{
default_server = try Server(void).init(ga, .{
.port = 5992,
.cors = .{ .origin = "httpz.local", .headers = "content-type", .methods = "GET,POST", .max_age = "300" },
}, {});
default_server = try Server(void).init(ga, .{.port = 5992}, {});

var middlewares = try default_server.arena.alloc(Middleware(void), 1);
middlewares[0] = try default_server.middleware(middleware.Cors, .{
.max_age = "300",
.methods = "GET,POST",
.origin = "httpz.local",
.headers = "content-type",
});

var router = default_server.router();
// router.get("/test/ws", testWS);
router.get("/fail", TestDummyHandler.fail);
router.get("/test/json", TestDummyHandler.jsonRes);
router.get("/test/query", TestDummyHandler.reqQuery);
router.get("/test/stream", TestDummyHandler.eventStream);
router.get("/test/chunked", TestDummyHandler.chunked);
router.allC("/test/cors", TestDummyHandler.jsonRes, .{.middlewares = middlewares});
router.allC("/test/dispatcher", TestDummyHandler.dispatchedAction, .{ .dispatcher = TestDummyHandler.routeSpecificDispacthcer });
test_server_threads[0] = try default_server.listenInNewThread();
}
Expand Down Expand Up @@ -798,7 +862,7 @@ test "httpz: no route" {
try stream.writeAll("GET / HTTP/1.1\r\n\r\n");

var buf: [100]u8 = undefined;
try t.expectString("HTTP/1.1 404 \r\nAccess-Control-Allow-Origin: httpz.local\r\nContent-Length: 9\r\n\r\nNot Found", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 404 \r\nContent-Length: 9\r\n\r\nNot Found", testReadAll(stream, &buf));
}

test "httpz: no route with custom notFound handler" {
Expand All @@ -819,7 +883,7 @@ test "httpz: unhandled exception" {
try stream.writeAll("GET /fail HTTP/1.1\r\n\r\n");

var buf: [150]u8 = undefined;
try t.expectString("HTTP/1.1 500 \r\nAccess-Control-Allow-Origin: httpz.local\r\nContent-Length: 21\r\n\r\nInternal Server Error", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 500 \r\nContent-Length: 21\r\n\r\nInternal Server Error", testReadAll(stream, &buf));
}

test "httpz: unhandled exception with custom error handler" {
Expand Down Expand Up @@ -867,7 +931,7 @@ test "httpz: json response" {
try stream.writeAll("GET /test/json HTTP/1.1\r\nContent-Length: 0\r\n\r\n");

var buf: [200]u8 = undefined;
try t.expectString("HTTP/1.1 201 \r\nContent-Type: application/json\r\nAccess-Control-Allow-Origin: httpz.local\r\nContent-Length: 26\r\n\r\n{\"over\":9000,\"teg\":\"soup\"}", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 201 \r\nContent-Type: application/json\r\nContent-Length: 26\r\n\r\n{\"over\":9000,\"teg\":\"soup\"}", testReadAll(stream, &buf));
}

test "httpz: query" {
Expand All @@ -876,7 +940,7 @@ test "httpz: query" {
try stream.writeAll("GET /test/query?fav=keemun%20te%61%21 HTTP/1.1\r\nContent-Length: 0\r\n\r\n");

var buf: [200]u8 = undefined;
try t.expectString("HTTP/1.1 200 \r\nAccess-Control-Allow-Origin: httpz.local\r\nContent-Length: 11\r\n\r\nkeemun tea!", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\nkeemun tea!", testReadAll(stream, &buf));
}

test "httpz: chunked" {
Expand All @@ -885,7 +949,7 @@ test "httpz: chunked" {
try stream.writeAll("GET /test/chunked HTTP/1.1\r\nContent-Length: 0\r\n\r\n");

var buf: [1000]u8 = undefined;
try t.expectString("HTTP/1.1 200 \r\nAccess-Control-Allow-Origin: httpz.local\r\nOver: 9000!\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nChunk 1\r\n11\r\nand another chunk\r\n0\r\n\r\n", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 200 \r\nOver: 9000!\r\nTransfer-Encoding: chunked\r\n\r\n7\r\nChunk 1\r\n11\r\nand another chunk\r\n0\r\n\r\n", testReadAll(stream, &buf));
}

test "httpz: route-specific dispatcher" {
Expand All @@ -894,7 +958,7 @@ test "httpz: route-specific dispatcher" {
try stream.writeAll("HEAD /test/dispatcher HTTP/1.1\r\n\r\n");

var buf: [200]u8 = undefined;
try t.expectString("HTTP/1.1 200 \r\nAccess-Control-Allow-Origin: httpz.local\r\ndispatcher: test-dispatcher-1\r\nContent-Length: 6\r\n\r\naction", testReadAll(stream, &buf));
try t.expectString("HTTP/1.1 200 \r\ndispatcher: test-dispatcher-1\r\nContent-Length: 6\r\n\r\naction", testReadAll(stream, &buf));
}

test "httpz: CORS" {
Expand All @@ -905,34 +969,47 @@ test "httpz: CORS" {
try stream.writeAll("GET /echo HTTP/1.1\r\n\r\n");
var res = testReadParsed(stream);
defer res.deinit();
try t.expectEqual(true, res.headers.get("Access-Control-Max-Age") == null);
try t.expectEqual(true, res.headers.get("Access-Control-Allow-Methods") == null);
try t.expectEqual(true, res.headers.get("Access-Control-Allow-Headers") == null);
try t.expectString("httpz.local", res.headers.get("Access-Control-Allow-Origin").?);
try t.expectEqual(null, res.headers.get("Access-Control-Max-Age"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Methods"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Headers"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Origin"));
}

{
// non-cors options
try stream.writeAll("OPTIONS /echo HTTP/1.1\r\nSec-Fetch-Mode: navigate\r\n\r\n");
// cors endpoint but not cors options
try stream.writeAll("OPTIONS /test/cors HTTP/1.1\r\nSec-Fetch-Mode: navigate\r\n\r\n");
var res = testReadParsed(stream);
defer res.deinit();

try t.expectEqual(true, res.headers.get("Access-Control-Max-Age") == null);
try t.expectEqual(true, res.headers.get("Access-Control-Allow-Methods") == null);
try t.expectEqual(true, res.headers.get("Access-Control-Allow-Headers") == null);
try t.expectEqual(null, res.headers.get("Access-Control-Max-Age"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Methods"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Headers"));
try t.expectString("httpz.local", res.headers.get("Access-Control-Allow-Origin").?);
}

{
// cors request
try stream.writeAll("OPTIONS /no_route HTTP/1.1\r\nSec-Fetch-Mode: cors\r\n\r\n");
try stream.writeAll("OPTIONS /test/cors HTTP/1.1\r\nSec-Fetch-Mode: cors\r\n\r\n");
var res = testReadParsed(stream);
defer res.deinit();

try t.expectString("httpz.local", res.headers.get("Access-Control-Allow-Origin").?);
try t.expectString("300", res.headers.get("Access-Control-Max-Age").?);
try t.expectString("GET,POST", res.headers.get("Access-Control-Allow-Methods").?);
try t.expectString("content-type", res.headers.get("Access-Control-Allow-Headers").?);
try t.expectString("300", res.headers.get("Access-Control-Max-Age").?);
try t.expectString("httpz.local", res.headers.get("Access-Control-Allow-Origin").?);
}

{
// cors request, non-options
try stream.writeAll("GET /test/cors HTTP/1.1\r\nSec-Fetch-Mode: cors\r\n\r\n");
var res = testReadParsed(stream);
defer res.deinit();


try t.expectEqual(null, res.headers.get("Access-Control-Max-Age"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Methods"));
try t.expectEqual(null, res.headers.get("Access-Control-Allow-Headers"));
try t.expectString("httpz.local", res.headers.get("Access-Control-Allow-Origin").?);
}
}

Expand Down
Loading

0 comments on commit 02819fc

Please sign in to comment.