From 6f1c8f88b3172d1afcd500d6a7579a4cc1ffe677 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 5 Feb 2024 04:21:05 +0000 Subject: [PATCH] security: consider entities before vanities to avoid hijacking --- routes/vanity/assets/resolve.go | 50 ++++++++++++++++----------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/routes/vanity/assets/resolve.go b/routes/vanity/assets/resolve.go index c4cb4d01..dccbb0ac 100644 --- a/routes/vanity/assets/resolve.go +++ b/routes/vanity/assets/resolve.go @@ -37,45 +37,45 @@ func resolveImpl(ctx context.Context, code string, src string) (*types.Vanity, e } func ResolveVanity(ctx context.Context, code string) (*types.Vanity, error) { - var v *types.Vanity - var err error - for _, src := range []string{"code", "target_id"} { - v, err = resolveImpl(ctx, code, src) + // First check bot_id and client_id to avoid vanity stealing + var botId string - if err != nil { - return nil, err - } + err := state.Pool.QueryRow(ctx, "SELECT bot_id FROM bots WHERE client_id = $1", code).Scan(&botId) - if v == nil { - continue - } - - break + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return nil, err } - // If all fails, try checking client_id of bots - if v == nil { - var count int64 + if botId != "" { + return resolveImpl(ctx, botId, "target_id") + } - err = state.Pool.QueryRow(ctx, "SELECT COUNT(*) FROM bots WHERE client_id = $1", code).Scan(&count) + // Then check server id + var serverId string - if err != nil { - return nil, err - } + err = state.Pool.QueryRow(ctx, "SELECT server_id FROM servers WHERE server_id = $1", code).Scan(&serverId) - if count == 0 { - return nil, nil - } + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return nil, err + } - var botId string + if serverId != "" { + return resolveImpl(ctx, serverId, "target_id") + } - err = state.Pool.QueryRow(ctx, "SELECT bot_id FROM bots WHERE client_id = $1", code).Scan(&botId) + var v *types.Vanity + for _, src := range []string{"code", "target_id"} { + v, err = resolveImpl(ctx, code, src) if err != nil { return nil, err } - return resolveImpl(ctx, botId, "target_id") + if v == nil { + continue + } + + break } return v, nil