From 6023887cfec6efc862d7a2994bdcd13d6c67eb60 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 10 Jul 2024 10:47:13 +0000 Subject: [PATCH] fix --- api/uapi.go | 20 ++++++ .../migrate/migrations/migrations.go | 72 ++++++++++++++++++- routes/apps/endpoints/get_apps_list/route.go | 9 +++ .../endpoints/create_oauth2_login/route.go | 12 ++-- routes/auth/endpoints/create_session/route.go | 1 + .../delete_user_notifications/route.go | 34 +++++++-- .../endpoints/delete_user_reminders/route.go | 37 +++++----- .../endpoints/put_user_reminders/route.go | 25 ++++--- routes/vanity/endpoints/patch_vanity/route.go | 3 +- routes/vanity/router.go | 3 + types/auth.go | 6 +- validators/normalize_target_type.go | 4 +- 12 files changed, 177 insertions(+), 49 deletions(-) diff --git a/api/uapi.go b/api/uapi.go index 85b5c941..ee50f0e9 100644 --- a/api/uapi.go +++ b/api/uapi.go @@ -157,6 +157,10 @@ func Authorize(r uapi.Route, req *http.Request) (uapi.AuthData, uapi.HttpRespons }, false } + if len(permLimits) == 0 { + permLimits = []string{} + } + if authPrefix != "" && authPrefix != targetType { return uapi.AuthData{}, uapi.HttpResponse{ Status: http.StatusUnauthorized, @@ -320,6 +324,8 @@ func Authorize(r uapi.Route, req *http.Request) (uapi.AuthData, uapi.HttpRespons }, false } + state.Logger.Info("AuthData", zap.Any("authData", authData)) + pc, ok := r.ExtData[PERMISSION_CHECK_KEY] if !ok { @@ -332,6 +338,13 @@ func Authorize(r uapi.Route, req *http.Request) (uapi.AuthData, uapi.HttpRespons permCheck, ok := pc.(PermissionCheck) if ok { + if permCheck.NeededPermission == nil { + return uapi.AuthData{}, uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Internal error: NeededPermission function is nil"}, + }, false + } + neededPerm, err := permCheck.NeededPermission(r, req, authData) if err != nil { @@ -342,6 +355,13 @@ func Authorize(r uapi.Route, req *http.Request) (uapi.AuthData, uapi.HttpRespons } if neededPerm != nil { + if permCheck.GetTarget == nil { + return uapi.AuthData{}, uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Internal error: GetTarget function is nil"}, + }, false + } + targetTypeOfEntity, targetIdOfEntity := permCheck.GetTarget(r, req, authData) if targetTypeOfEntity == "" || targetIdOfEntity == "" { diff --git a/cmd/kitehelper/migrate/migrations/migrations.go b/cmd/kitehelper/migrate/migrations/migrations.go index 56d67538..c54fa481 100644 --- a/cmd/kitehelper/migrate/migrations/migrations.go +++ b/cmd/kitehelper/migrate/migrations/migrations.go @@ -945,7 +945,11 @@ var migs = []migrate.Migration{ return nil } - return errors.New("this can only be run with MIGRATE_BOTS_TO_TEAMS_TOKEN set to a valid token") + if os.Getenv("SKIP_MIGRATE_BOTS_TO_TEAMS") == "true" { + return errors.New("env skip_migrate_bots_to_teams is set, continuing") + } + + panic("this can only be run with MIGRATE_BOTS_TO_TEAMS_TOKEN set to a valid token") }, Function: func(pool *common.SandboxPool) { discordSess, err := common.NewDiscordSession(os.Getenv("MIGRATE_BOTS_TO_TEAMS_TOKEN")) @@ -1054,6 +1058,72 @@ var migs = []migrate.Migration{ err = tx.Commit(ctx) + if err != nil { + panic(err) + } + }, + }, + { + ID: "migrate_bots_to_sessions", + Name: "Migrate bot to sessions", + HasMigrated: func(pool *common.SandboxPool) error { + if os.Getenv("SKIP_MIGRATE_BOTS_TO_SESSIONS") == "true" { + return errors.New("env skip_migrate_bots_to_teams is set, continuing") + } + + return nil + }, + Function: func(pool *common.SandboxPool) { + tx, err := pool.Begin(ctx) + + if err != nil { + panic(err) + } + + defer tx.Rollback(ctx) + + rows, err := pool.Query(context.Background(), "SELECT bot_id, api_token FROM bots") + + if err != nil { + panic(err) + } + + defer rows.Close() + + var botTokenMap = map[string]string{} + for rows.Next() { + var botId string + var apiToken string + + err = rows.Scan(&botId, &apiToken) + + if err != nil { + panic(err) + } + + botTokenMap[botId] = apiToken + } + + rows.Close() + + for botId, apiToken := range botTokenMap { + // Delete all sessions existing for this bot currently + err = tx.Exec(context.Background(), "DELETE FROM api_sessions WHERE target_id = $1 AND target_type = 'bot'", botId) + + if err != nil { + panic(err) + } + + // Create new session with expiry set to 100 years from now and type being `api/automigrated` + err = tx.Exec(context.Background(), "INSERT INTO api_sessions (target_id, target_type, token, expiry, type) VALUES ($1, 'bot', $2, $3, 'api/automigrated')", botId, apiToken, time.Now().AddDate(100, 0, 0)) + + if err != nil { + panic(err) + } + } + + err = tx.Commit(ctx) + if err != nil { panic(err) } diff --git a/routes/apps/endpoints/get_apps_list/route.go b/routes/apps/endpoints/get_apps_list/route.go index 40289213..b3981ce7 100644 --- a/routes/apps/endpoints/get_apps_list/route.go +++ b/routes/apps/endpoints/get_apps_list/route.go @@ -1,6 +1,7 @@ package get_apps_list import ( + "errors" "net/http" "popplio/db" "popplio/state" @@ -46,6 +47,14 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { app, err := pgx.CollectRows(row, pgx.RowToStructByName[types.AppResponse]) + if errors.Is(err, pgx.ErrNoRows) { + return uapi.HttpResponse{ + Json: types.AppListResponse{ + Apps: []types.AppResponse{}, + }, + } + } + if err != nil { state.Logger.Error("Failed to fetch application list [collection]", zap.String("userId", d.Auth.ID), zap.Error(err)) return uapi.DefaultResponse(http.StatusNotFound) diff --git a/routes/auth/endpoints/create_oauth2_login/route.go b/routes/auth/endpoints/create_oauth2_login/route.go index 99ca62ea..ad53e929 100644 --- a/routes/auth/endpoints/create_oauth2_login/route.go +++ b/routes/auth/endpoints/create_oauth2_login/route.go @@ -31,7 +31,7 @@ func Docs() *docs.Doc { Summary: "Create OAuth2 Login", Description: "Takes in a ``code`` query parameter and returns a user ``token``. **Cannot be used outside of the site for security reasons but documented in case we wish to allow its use in the future.**", Req: types.AuthorizeRequest{}, - Resp: types.UserLogin{}, + Resp: types.CreateSessionResponse{}, } } @@ -434,7 +434,8 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { } var sessionToken = crypto.RandString(128) - _, err = state.Pool.Exec(d.Context, "INSERT INTO api_sessions (target_type, target_id, type, token, expiry) VALUES ('user', $1, 'login', $2, NOW() + INTERVAL '1 hour')", user.ID, sessionToken) + var sessionId string + err = state.Pool.QueryRow(d.Context, "INSERT INTO api_sessions (target_type, target_id, type, token, expiry) VALUES ('user', $1, 'login', $2, NOW() + INTERVAL '1 hour') RETURNING id", user.ID, sessionToken).Scan(&sessionId) if err != nil { state.Logger.Error("Failed to create session token", zap.Error(err), zap.String("userID", user.ID)) @@ -448,9 +449,10 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { } // Create authUser and send - var authUser = types.UserLogin{ - UserID: user.ID, - Token: sessionToken, + var authUser = types.CreateSessionResponse{ + TargetID: user.ID, + Token: sessionToken, + SessionID: sessionId, } go sendAuthLog(user, req, !exists) diff --git a/routes/auth/endpoints/create_session/route.go b/routes/auth/endpoints/create_session/route.go index b97cfced..5b37e055 100644 --- a/routes/auth/endpoints/create_session/route.go +++ b/routes/auth/endpoints/create_session/route.go @@ -188,6 +188,7 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { return uapi.HttpResponse{ Status: http.StatusCreated, Json: types.CreateSessionResponse{ + TargetID: targetId, Token: sessionToken, SessionID: sessionId, }, diff --git a/routes/notifications/endpoints/delete_user_notifications/route.go b/routes/notifications/endpoints/delete_user_notifications/route.go index 3ac37863..b529b37a 100644 --- a/routes/notifications/endpoints/delete_user_notifications/route.go +++ b/routes/notifications/endpoints/delete_user_notifications/route.go @@ -39,17 +39,43 @@ func Docs() *docs.Doc { func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { var id = chi.URLParam(r, "id") + notifId := r.URL.Query().Get("notif_id") // Check for notif_id - if r.URL.Query().Get("notif_id") == "" { - return uapi.DefaultResponse(http.StatusBadRequest) + if notifId == "" { + return uapi.HttpResponse{ + Status: http.StatusBadRequest, + Json: types.ApiError{Message: "`notif_id` is required in query params and must be set to the notification ID to delete"}, + } } - _, err := state.Pool.Exec(d.Context, "DELETE FROM user_notifications WHERE user_id = $1 AND notif_id = $2", id, r.URL.Query().Get("notif_id")) + // Check count of deleted rows + var count int64 + err := state.Pool.QueryRow(d.Context, "SELECT COUNT(*) FROM user_notifications WHERE user_id = $1 AND notif_id = $2", id, r.URL.Query().Get("notif_id")).Scan(&count) + + if err != nil { + state.Logger.Error("Error while checking user notification count", zap.Error(err), zap.String("userID", id), zap.String("notifID", r.URL.Query().Get("notif_id"))) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error while checking user notification count: " + err.Error()}, + } + } + + if count == 0 { + return uapi.HttpResponse{ + Status: http.StatusNotFound, + Json: types.ApiError{Message: "Notification not found"}, + } + } + + _, err = state.Pool.Exec(d.Context, "DELETE FROM user_notifications WHERE user_id = $1 AND notif_id = $2", id, r.URL.Query().Get("notif_id")) if err != nil { state.Logger.Error("Error while deleting user notification", zap.Error(err), zap.String("userID", id), zap.String("notifID", r.URL.Query().Get("notif_id"))) - return uapi.DefaultResponse(http.StatusInternalServerError) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error while deleting user notification: " + err.Error()}, + } } return uapi.DefaultResponse(http.StatusNoContent) diff --git a/routes/reminders/endpoints/delete_user_reminders/route.go b/routes/reminders/endpoints/delete_user_reminders/route.go index 306b3ee5..af1e15f1 100644 --- a/routes/reminders/endpoints/delete_user_reminders/route.go +++ b/routes/reminders/endpoints/delete_user_reminders/route.go @@ -53,38 +53,33 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { return uapi.DefaultResponse(http.StatusBadRequest) } - tx, err := state.Pool.Begin(d.Context) - - if err != nil { - state.Logger.Error("Error beginning transaction", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusInternalServerError) - } - - var count int - - err = tx.QueryRow(d.Context, "SELECT COUNT(*) FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType).Scan(&count) + // Check count of deleted rows + var count int64 + err := state.Pool.QueryRow(d.Context, "SELECT COUNT(*) FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType).Scan(&count) if err != nil { state.Logger.Error("Error querying reminders [db count]", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusInternalServerError) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error while checking user reminder count: " + err.Error()}, + } } if count == 0 { - return uapi.DefaultResponse(http.StatusNotFound) + return uapi.HttpResponse{ + Status: http.StatusNotFound, + Json: types.ApiError{Message: "Reminder not found"}, + } } - _, err = tx.Exec(d.Context, "DELETE FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType) + _, err = state.Pool.Exec(d.Context, "DELETE FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType) if err != nil { state.Logger.Error("Error deleting reminders", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusInternalServerError) - } - - err = tx.Commit(d.Context) - - if err != nil { - state.Logger.Error("Error committing transaction", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusInternalServerError) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error while deleting user reminder: " + err.Error()}, + } } return uapi.DefaultResponse(http.StatusNoContent) diff --git a/routes/reminders/endpoints/put_user_reminders/route.go b/routes/reminders/endpoints/put_user_reminders/route.go index 6fcb4136..3e3c26bc 100644 --- a/routes/reminders/endpoints/put_user_reminders/route.go +++ b/routes/reminders/endpoints/put_user_reminders/route.go @@ -60,29 +60,36 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { if err != nil { state.Logger.Error("Error getting entity info", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) return uapi.HttpResponse{ - Status: http.StatusBadRequest, + Status: http.StatusInternalServerError, Json: types.ApiError{Message: "Error: " + err.Error()}, } } - // Delete old - tx, err := state.Pool.Begin(d.Context) + // Get count of old + var count int64 + err = state.Pool.QueryRow(d.Context, "SELECT COUNT(*) FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType).Scan(&count) if err != nil { - state.Logger.Error("Error starting transaction", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusInternalServerError) + state.Logger.Error("Error selecting count of user_reminders", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error getting current user reminder count: " + err.Error()}, + } } - defer tx.Rollback(d.Context) - - tx.Exec(d.Context, "DELETE FROM user_reminders WHERE user_id = $1 AND target_id = $2 AND target_type = $3", d.Auth.ID, targetId, targetType) + if count > 0 { + return uapi.DefaultResponse(http.StatusNoContent) + } // Add new _, err = state.Pool.Exec(d.Context, "INSERT INTO user_reminders (user_id, target_id, target_type) VALUES ($1, $2, $3)", d.Auth.ID, targetId, targetType) if err != nil { state.Logger.Error("Error inserting new reminder", zap.Error(err), zap.String("target_id", targetId), zap.String("target_type", targetType)) - return uapi.DefaultResponse(http.StatusBadRequest) + return uapi.HttpResponse{ + Status: http.StatusInternalServerError, + Json: types.ApiError{Message: "Error adding new reminder: " + err.Error()}, + } } // Fan out notification diff --git a/routes/vanity/endpoints/patch_vanity/route.go b/routes/vanity/endpoints/patch_vanity/route.go index 880f402a..993a9391 100644 --- a/routes/vanity/endpoints/patch_vanity/route.go +++ b/routes/vanity/endpoints/patch_vanity/route.go @@ -46,6 +46,8 @@ func Docs() *docs.Doc { } func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { + state.Logger.Info("Patch Vanity", zap.String("userID", d.Auth.ID)) + targetId := chi.URLParam(r, "target_id") targetType := validators.NormalizeTargetType(chi.URLParam(r, "target_type")) @@ -107,7 +109,6 @@ func Route(d uapi.RouteData, r *http.Request) uapi.HttpResponse { Status: http.StatusBadRequest, Json: types.ApiError{Message: "Error while getting word blacklist systems: " + err.Error()}, } - } if slices.Contains(systems, "vanity.code") { diff --git a/routes/vanity/router.go b/routes/vanity/router.go index 27bda7fc..0b6b7824 100644 --- a/routes/vanity/router.go +++ b/routes/vanity/router.go @@ -55,6 +55,9 @@ func (b Router) Routes(r *chi.Mux) { Perm: teams.PermissionSetVanity, }, nil }, + GetTarget: func(d uapi.Route, r *http.Request, authData uapi.AuthData) (string, string) { + return validators.NormalizeTargetType(chi.URLParam(r, "target_type")), chi.URLParam(r, "target_id") + }, }, }, }.Route(r) diff --git a/types/auth.go b/types/auth.go index 2f755d14..192ee746 100644 --- a/types/auth.go +++ b/types/auth.go @@ -14,11 +14,6 @@ type AuthorizeRequest struct { Scope string `json:"scope" validate:"required,oneof=normal ban_exempt external_auth"` } -type UserLogin struct { - Token string `json:"token" description:"The users token"` - UserID string `json:"user_id" description:"The users ID"` -} - type OauthMeta struct { ClientID string `json:"client_id" description:"The client ID"` URL string `json:"url" description:"The URL to redirect the user to for discord oauth2"` @@ -57,6 +52,7 @@ type CreateSession struct { } type CreateSessionResponse struct { + TargetID string `json:"target_id" description:"The ID of the target"` Token string `json:"token" description:"The token of the session"` SessionID string `json:"session_id" description:"The ID of the session"` } diff --git a/validators/normalize_target_type.go b/validators/normalize_target_type.go index e8ec0d1e..0277a051 100644 --- a/validators/normalize_target_type.go +++ b/validators/normalize_target_type.go @@ -1,7 +1,5 @@ package validators -import "strings" - // This function normalizes the target type to its correct form. func NormalizeTargetType(targetType string) string { switch targetType { @@ -29,6 +27,6 @@ func NormalizeTargetType(targetType string) string { case "pack": return "pack" default: - return strings.TrimSuffix(targetType, "s") + return "" } }