diff --git a/apps/deployment/src/db/schema.ts b/apps/deployment/src/db/schema.ts index e7a3e4a..76481d5 100644 --- a/apps/deployment/src/db/schema.ts +++ b/apps/deployment/src/db/schema.ts @@ -10,7 +10,7 @@ import { } from "drizzle-orm/sqlite-core"; import { v7 } from "uuid"; -import type { ChannelMessageSchema } from "@sinkr/validators"; +import type { ChannelMessagesSendRequestSchema } from "@sinkr/validators"; export const apps = sqliteTable( "app", @@ -161,7 +161,7 @@ export const storedChannelMessages = sqliteTable( .notNull() .default(sql`(CURRENT_TIMESTAMP)`), data: blob({ mode: "json" }) - .$type>() + .$type["request"]>() .notNull(), }, (storedChannelMessage) => [ diff --git a/apps/deployment/src/hooks.ts b/apps/deployment/src/hooks.ts index 89fe2ef..52f3155 100644 --- a/apps/deployment/src/hooks.ts +++ b/apps/deployment/src/hooks.ts @@ -4,7 +4,7 @@ import { v7 } from "uuid"; import { ClientRequestStoredMessagesSchema, - ServerEndpointSchema, + ServerRequestSchema, } from "@sinkr/validators"; import { peers } from "./db/schema"; @@ -73,7 +73,7 @@ export const hooks = { data: unknown; id: string; }>(); - const parsed = ServerEndpointSchema.safeParse(body.data); + const parsed = ServerRequestSchema.safeParse(body.data); if (!parsed.success) { peer.send({ status: 400, @@ -84,9 +84,8 @@ export const hooks = { } const res = await handleSource(body.id, parsed.data, peerInfo.appId); peer.send({ - status: res.status, id: body.id, - data: res.data, + ...res, }); }, async close(peer) { diff --git a/apps/deployment/src/index.ts b/apps/deployment/src/index.ts index e374ec4..9837512 100644 --- a/apps/deployment/src/index.ts +++ b/apps/deployment/src/index.ts @@ -1,8 +1,12 @@ import { AsyncLocalStorage } from "async_hooks"; -import type { z } from "zod"; import { DurableObject } from "cloudflare:workers"; -import { ServerEndpointSchema } from "@sinkr/validators"; +import type { + RouteRequestSchema, + RouteResponseSchema, + ServerRoute, +} from "@sinkr/validators"; +import { ServerRequestSchema } from "@sinkr/validators"; import { getCoordinatorInstance, getPeers, handleSource, ws } from "./server"; import { getDB, init } from "./utils"; @@ -45,18 +49,17 @@ export class ObjectCoordinator extends DurableObject { return obj; } - async distribute({ + async distribute({ id, appId, data, }: { id: string; appId: string; - data: z.infer; - }) { + data: RouteRequestSchema; + }): Promise<({ id: string } & RouteResponseSchema["response"])[]> { const cursor = this.sql.exec<{ id: string }>("SELECT id FROM handler;"); - const promises: Promise<{ status: number; data?: unknown; id: string }>[] = - []; + const promises = []; for (const { id: handlerId } of cursor) { const handler = this.getBinding(handlerId); promises.push(handler.process({ id, appId, data })); @@ -143,9 +146,15 @@ export class ObjectCoordinator extends DurableObject { id: string; }; const { data: body, id } = bodyBuf; - const parsed = ServerEndpointSchema.safeParse(body); - if (!parsed.success) { - return new Response("Invalid request", { status: 400 }); + const parsed = ServerRequestSchema.safeParse(body); + if (!parsed.success || "response" in parsed.data) { + return new Response( + JSON.stringify({ + success: false, + error: "Invalid request", + }), + { status: 400 }, + ); } const data = parsed.data; const info = await handleSource(id, data, appId); @@ -155,7 +164,7 @@ export class ObjectCoordinator extends DurableObject { ...info, }), { - status: info.status, + status: info.success ? 200 : 400, }, ); } @@ -204,16 +213,16 @@ export class SocketHandler extends DurableObject { await this.coordinator.updateConnections(this.ctx.id.toString(), conns); } - async process({ + async process({ id, data, appId, }: { id: string; appId: string; - data: z.infer; - }) { - const info = await handleSource(id, data, appId); + data: RouteRequestSchema; + }): Promise["response"] & { id: string }> { + const info = await handleSource(id, data, appId); return { id, ...info, @@ -260,14 +269,20 @@ export class SocketHandler extends DurableObject { id: string; }; const { data: body, id } = bodyBuf; - const parsed = ServerEndpointSchema.safeParse(body); + const parsed = ServerRequestSchema.safeParse(body); if (!parsed.success) { - return new Response("Invalid request", { status: 400 }); + return new Response( + JSON.stringify({ + success: false, + error: "Invalid request", + }), + { status: 400 }, + ); } const data = parsed.data; const info = await this.process({ id, data, appId }); return new Response(JSON.stringify(info), { - status: info.status, + status: info.success ? 200 : 400, }); } diff --git a/apps/deployment/src/server.ts b/apps/deployment/src/server.ts index 4d4b070..e396e7e 100644 --- a/apps/deployment/src/server.ts +++ b/apps/deployment/src/server.ts @@ -6,7 +6,9 @@ import { and, eq, inArray, not } from "drizzle-orm"; import type { ClientReceiveSchema, - ServerEndpointSchema, + RouteRequestSchema, + RouteResponseSchema, + ServerRoute, } from "@sinkr/validators"; import { getCoordinator } from "."; @@ -19,8 +21,6 @@ import { import { hooks } from "./hooks"; import { getDB } from "./utils"; -type ServerMessage = z.infer; - type ClientReception = z.infer; export function getCoordinatorInstance(env: Env) { @@ -58,21 +58,20 @@ export function sendToPeer(peer: Peer, message: ClientReception) { peer.send(message); } -export async function handleSource( +export async function handleSource( id: string, - data: ServerMessage, + input: RouteRequestSchema, appId: string, -): Promise<{ - status: number; - data?: unknown; -}> { +): Promise["response"]> { const db = getDB(); - switch (data.route) { - case "authenticate": { + const { route, request: data } = input; + switch (route) { + case "user.authenticate": { const coordInst = getCoordinator(); if (!coordInst) { return { - status: 401, + success: false, + error: "Invalid connection", }; } const peer = await db.query.peers.findFirst({ @@ -81,7 +80,8 @@ export async function handleSource( }); if (!peer) { return { - status: 404, + success: false, + error: "Peer not found", }; } await db @@ -89,14 +89,15 @@ export async function handleSource( .set({ userInfo: data.userInfo, authenticatedUserId: data.id }) .where(eq(peers.id, peer.id)); return { - status: 200, + success: true, }; } - case "createChannel": { + case "channel.create": { const coordInst = getCoordinator(); if (!coordInst) { return { - status: 401, + success: false, + error: "Invalid connection", }; } const existing = await db.query.channels.findFirst({ @@ -114,12 +115,13 @@ export async function handleSource( .returning(); if (!updated) { return { - status: 500, + success: false, + error: "Unknown error", }; } return { - status: 200, - data: updated.id, + success: true, + channelId: updated.id, }; } else { const [inserted] = await db @@ -133,20 +135,22 @@ export async function handleSource( .returning(); if (!inserted) { return { - status: 500, + success: false, + error: "Unknown error", }; } return { - status: 200, - data: inserted.id, + success: true, + channelId: inserted.id, }; } } - case "deleteChannel": { + case "channel.delete": { const coordInst = getCoordinator(); if (!coordInst) { return { - status: 401, + success: false, + error: "Invalid connection", }; } const channel = await db.query.channels.findFirst({ @@ -155,21 +159,23 @@ export async function handleSource( }); if (!channel) { return { - status: 404, + success: false, + error: "Channel not found", }; } await db .delete(channels) .where(and(eq(channels.id, data.channelId), eq(channels.appId, appId))); return { - status: 200, + success: true, }; } - case "deleteMessages": { + case "channel.messages.delete": { const coordInst = getCoordinator(); if (!coordInst) { return { - status: 401, + success: false, + error: "Invalid connection", }; } if (data.messageIds?.length) { @@ -193,17 +199,28 @@ export async function handleSource( ); } return { - status: 200, + success: true, }; } - case "broadcast": { + case "global.messages.send": { const coordInst = getCoordinator(); if (coordInst) { - const res = await coordInst.distribute({ id, appId, data }); - const maxStatus = res.reduce((a, v) => Math.max(a, v.status), 0); - return { - status: maxStatus, - }; + const res = await coordInst.distribute({ + id, + appId, + data: input, + }); + const success = res.every((r) => r.success); + if (success) { + return { + success, + }; + } else { + return { + success, + error: "Invalid request", + }; + } } const peers = getPeers(); const dbPeers = await db.query.peers.findMany({ @@ -227,16 +244,20 @@ export async function handleSource( } }); return { - status: 200, + success: true, }; } - case "channel": { + case "channel.messages.send": { const coordInst = getCoordinator(); if (coordInst) { - const res = await coordInst.distribute({ id, appId, data }); - const maxStatus = res.reduce((a, v) => Math.max(a, v.status), 0); + const res = await coordInst.distribute({ id, appId, data: input }); + const success = res.every((r) => r.success); + if (success) { + return { success }; + } return { - status: maxStatus, + success, + error: "Invalid request", }; } const ch = await db.query.channels.findFirst({ @@ -245,7 +266,8 @@ export async function handleSource( }); if (!ch) { return { - status: 404, + success: false, + error: "Channel not found", }; } const subscriptions = await db.query.peerChannelSubscriptions.findMany({ @@ -279,16 +301,20 @@ export async function handleSource( } }); return { - status: 200, + success: true, }; } - case "direct": { + case "user.messages.send": { const coordInst = getCoordinator(); if (coordInst) { - const res = await coordInst.distribute({ id, appId, data }); - const maxStatus = res.reduce((a, v) => Math.max(a, v.status), 0); + const res = await coordInst.distribute({ id, appId, data: input }); + const success = res.some((r) => r.success); + if (success) { + return { success }; + } return { - status: maxStatus, + success, + error: "Peer not found", }; } const dbPeer = await db.query.peers.findFirst({ @@ -303,13 +329,15 @@ export async function handleSource( }); if (!dbPeer) { return { - status: 404, + success: false, + error: "Peer not found", }; } const peer = getPeerMap().get(dbPeer.id); if (!peer) { return { - status: 404, + success: false, + error: "Peer not found", }; } sendToPeer(peer, { @@ -324,10 +352,10 @@ export async function handleSource( }, }); return { - status: 200, + success: true, }; } - case "subscribe": { + case "channel.subscribers.add": { const coordInst = getCoordinator(); const ch = await db.query.channels.findFirst({ where: (c, ops) => @@ -344,7 +372,8 @@ export async function handleSource( }); if (!ch) { return { - status: 404, + success: false, + error: "Channel not found", }; } const dbPeer = await db.query.peers.findFirst({ @@ -359,13 +388,15 @@ export async function handleSource( }); if (!dbPeer) { return { - status: 404, + success: false, + error: "Peer not found", }; } if (ch.auth === "public") { if (!dbPeer.authenticatedUserId) { return { - status: 401, + success: false, + error: "Peer not authenticated", }; } } @@ -390,10 +421,14 @@ export async function handleSource( channelId: ch.id, }); - const res = await coordInst.distribute({ id, appId, data }); - const maxStatus = res.reduce((a, v) => Math.max(a, v.status), 0); + const res = await coordInst.distribute({ id, appId, data: input }); + const success = res.every((r) => r.success); + if (success) { + return { success }; + } return { - status: maxStatus, + success, + error: "Invalid request", }; } const peerMap = getPeerMap(); @@ -436,10 +471,10 @@ export async function handleSource( } }); return { - status: 200, + success: true, }; } - case "unsubscribe": { + case "channel.subscribers.remove": { const coordInst = getCoordinator(); const ch = await db.query.channels.findFirst({ where: (c, ops) => @@ -447,7 +482,8 @@ export async function handleSource( }); if (!ch) { return { - status: 404, + success: false, + error: "Channel not found", }; } const dbPeer = await db.query.peers.findFirst({ @@ -462,7 +498,8 @@ export async function handleSource( }); if (!dbPeer) { return { - status: 404, + success: false, + error: "Peer not found", }; } if (coordInst) { @@ -476,16 +513,21 @@ export async function handleSource( }); if (!isInChannel) { return { - status: 404, + success: false, + error: "Peer is not subscribed to channel", }; } await db .delete(peerChannelSubscriptions) .where(eq(peerChannelSubscriptions.id, isInChannel.id)); - const res = await coordInst.distribute({ id, appId, data }); - const maxStatus = res.reduce((a, v) => Math.max(a, v.status), 0); + const res = await coordInst.distribute({ id, appId, data: input }); + const success = res.every((r) => r.success); + if (success) { + return { success }; + } return { - status: maxStatus, + success, + error: "Invalid request", }; } const remainingSubs = await db @@ -531,7 +573,7 @@ export async function handleSource( } }); return { - status: 200, + success: true, }; } } diff --git a/packages/validators/jsr.json b/packages/validators/jsr.json index b351b8c..780db69 100644 --- a/packages/validators/jsr.json +++ b/packages/validators/jsr.json @@ -1,6 +1,6 @@ { "name": "@sinkr/validators", - "version": "0.5.1", + "version": "0.6.0", "license": "MIT", "exports": "./mod.ts" } \ No newline at end of file diff --git a/packages/validators/package.json b/packages/validators/package.json index 0f80e31..5f0523f 100644 --- a/packages/validators/package.json +++ b/packages/validators/package.json @@ -1,6 +1,6 @@ { "name": "@sinkr/validators", - "version": "0.5.1", + "version": "0.6.0", "type": "module", "main": "src/index.ts", "exports": { diff --git a/packages/validators/src/server.ts b/packages/validators/src/server.ts index 6931ca4..fa75c5c 100644 --- a/packages/validators/src/server.ts +++ b/packages/validators/src/server.ts @@ -2,71 +2,252 @@ import { z } from "zod"; import { MessageTypeSchema } from "./message"; -export const AuthenticateRouteSchema = z.object({ - route: z.literal("authenticate"), - peerId: z.string(), - id: z.string(), - userInfo: z.unknown(), -}); +export const STANDARD_ERRORS = [ + "Invalid connection", + "Invalid request", + "Unknown error", +] as const; -export const CreateChannelRouteSchema = z.object({ - route: z.literal("createChannel"), - name: z.string(), - authMode: z.enum(["public", "private", "presence"]), - storeMessages: z.boolean().default(false), -}); +function makeSchema< + TRoute extends string, + TReq extends z.ZodTypeAny, + TRes extends z.ZodTypeAny, +>( + route: TRoute, + { + request, + response, + }: { + request: TReq; + response: TRes; + }, +) { + const schema = { + request: z.object({ + route: z.literal(route), + request, + }), + response: z.object({ + route: z.literal(route), + response, + }), + } as const; + return { + [route]: schema, + } as Record; +} -export const DeleteChannelRouteSchema = z.object({ - route: z.literal("deleteChannel"), - channelId: z.string(), -}); +export const ALL_ROUTES = { + ...makeSchema("user.authenticate", { + request: z.object({ + peerId: z.string(), + id: z.string(), + userInfo: z.unknown(), + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([...STANDARD_ERRORS, "Peer not found"]), + }), + ]), + }), + ...makeSchema("channel.create", { + request: z.object({ + name: z.string(), + authMode: z.enum(["public", "private", "presence"]), + storeMessages: z.boolean().default(false), + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + channelId: z.string(), + }), + z.object({ + success: z.literal(false), + error: z.enum(STANDARD_ERRORS), + }), + ]), + }), + ...makeSchema("channel.delete", { + request: z.object({ + channelId: z.string(), + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([...STANDARD_ERRORS, "Channel not found"]), + }), + ]), + }), + ...makeSchema("channel.messages.delete", { + request: z.object({ + channelId: z.string(), + messageIds: z.array(z.string()).nullish(), + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum(STANDARD_ERRORS), + }), + ]), + }), + ...makeSchema("channel.subscribers.add", { + request: z.object({ subscriberId: z.string(), channelId: z.string() }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([ + ...STANDARD_ERRORS, + "Channel not found", + "Peer not found", + "Peer not authenticated", + ]), + }), + ]), + }), + ...makeSchema("channel.subscribers.remove", { + request: z.object({ subscriberId: z.string(), channelId: z.string() }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([ + ...STANDARD_ERRORS, + "Channel not found", + "Peer not found", + "Peer is not subscribed to channel", + ]), + }), + ]), + }), + ...makeSchema("channel.messages.send", { + request: z.object({ + channelId: z.string(), + event: z.string(), + message: MessageTypeSchema, + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([...STANDARD_ERRORS, "Channel not found"]), + }), + ]), + }), + ...makeSchema("user.messages.send", { + request: z.object({ + recipientId: z.string(), + event: z.string(), + message: MessageTypeSchema, + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum([...STANDARD_ERRORS, "Recipient not found"]), + }), + ]), + }), + ...makeSchema("global.messages.send", { + request: z.object({ + event: z.string(), + message: MessageTypeSchema, + }), + response: z.discriminatedUnion("success", [ + z.object({ + success: z.literal(true), + }), + z.object({ + success: z.literal(false), + error: z.enum(STANDARD_ERRORS), + }), + ]), + }), +} as const; -export const SubscribeRouteSchema = z.object({ - route: z.literal("subscribe"), - subscriberId: z.string(), - channelId: z.string(), -}); +export type ServerRoute = keyof typeof ALL_ROUTES; -export const UnsubscribeRouteSchema = z.object({ - route: z.literal("unsubscribe"), - subscriberId: z.string(), - channelId: z.string(), -}); +export type RouteRequestSchema = z.infer< + (typeof ALL_ROUTES)[TRoute]["request"] +>; -export const DeleteStoredMessagesSchema = z.object({ - route: z.literal("deleteMessages"), - channelId: z.string(), - messageIds: z.array(z.string()).optional(), -}); +export type RouteResponseSchema = z.infer< + (typeof ALL_ROUTES)[TRoute]["response"] +>; -export const ChannelMessageSchema = z.object({ - route: z.literal("channel"), - channelId: z.string(), - event: z.string(), - message: MessageTypeSchema, -}); +export const UserAuthenticateRequestSchema = + ALL_ROUTES["user.authenticate"].request; +export const UserAuthenticateResponseSchema = + ALL_ROUTES["user.authenticate"].response; +export const ChannelCreateRequestSchema = ALL_ROUTES["channel.create"].request; +export const ChannelCreateResponseSchema = + ALL_ROUTES["channel.create"].response; +export const ChannelDeleteRequestSchema = ALL_ROUTES["channel.delete"].request; +export const ChannelDeleteResponseSchema = + ALL_ROUTES["channel.delete"].response; +export const ChannelMessagesDeleteRequestSchema = + ALL_ROUTES["channel.messages.delete"].request; +export const ChannelMessagesDeleteResponseSchema = + ALL_ROUTES["channel.messages.delete"].response; +export const ChannelSubscribersAddRequestSchema = + ALL_ROUTES["channel.subscribers.add"].request; +export const ChannelSubscribersAddResponseSchema = + ALL_ROUTES["channel.subscribers.add"].response; +export const ChannelSubscribersRemoveRequestSchema = + ALL_ROUTES["channel.subscribers.remove"].request; +export const ChannelSubscribersRemoveResponseSchema = + ALL_ROUTES["channel.subscribers.remove"].response; +export const ChannelMessagesSendRequestSchema = + ALL_ROUTES["channel.messages.send"].request; +export const ChannelMessagesSendResponseSchema = + ALL_ROUTES["channel.messages.send"].response; +export const UserMessagesSendRequestSchema = + ALL_ROUTES["user.messages.send"].request; +export const UserMessagesSendResponseSchema = + ALL_ROUTES["user.messages.send"].response; +export const GlobalMessagesSendRequestSchema = + ALL_ROUTES["global.messages.send"].request; +export const GlobalMessagesSendResponseSchema = + ALL_ROUTES["global.messages.send"].response; -export const DirectMessageSchema = z.object({ - route: z.literal("direct"), - recipientId: z.string(), - event: z.string(), - message: MessageTypeSchema, -}); - -export const BroadcastMessageSchema = z.object({ - route: z.literal("broadcast"), - event: z.string(), - message: MessageTypeSchema, -}); +export const ServerRequestSchema = z.discriminatedUnion("route", [ + UserAuthenticateRequestSchema, + ChannelCreateRequestSchema, + ChannelDeleteRequestSchema, + ChannelMessagesDeleteRequestSchema, + ChannelSubscribersAddRequestSchema, + ChannelSubscribersRemoveRequestSchema, + ChannelMessagesSendRequestSchema, + UserMessagesSendRequestSchema, + GlobalMessagesSendRequestSchema, +]); -export const ServerEndpointSchema = z.discriminatedUnion("route", [ - AuthenticateRouteSchema, - SubscribeRouteSchema, - UnsubscribeRouteSchema, - ChannelMessageSchema, - DirectMessageSchema, - BroadcastMessageSchema, - DeleteStoredMessagesSchema, - CreateChannelRouteSchema, - DeleteChannelRouteSchema, +export const ServerResponseSchema = z.discriminatedUnion("route", [ + UserAuthenticateResponseSchema, + ChannelCreateResponseSchema, + ChannelDeleteResponseSchema, + ChannelMessagesDeleteResponseSchema, + ChannelSubscribersAddResponseSchema, + ChannelSubscribersRemoveResponseSchema, + ChannelMessagesSendResponseSchema, + UserMessagesSendResponseSchema, + GlobalMessagesSendResponseSchema, ]);