diff --git a/database/queries/user.sql b/database/queries/user.sql index 7349252..c394582 100644 --- a/database/queries/user.sql +++ b/database/queries/user.sql @@ -4,16 +4,31 @@ VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING *; -- name: GetUserByEmail :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT * FROM users WHERE email = $1; -- name: GetUserByUsername :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT * FROM users WHERE name = $1; -- name: GetUserById :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT * FROM users WHERE id = $1; +-- name: GetAllUsers :many +SELECT * +FROM users; +-- name: UpgradeUsersToRound :batchexec +UPDATE users +SET round_qualified = GREATEST(round_qualified, $2) +WHERE id = $1; +-- name: BanUser :exec +UPDATE users +SET is_banned = TRUE +WHERE id = $1; +-- name: UnbanUser :exec +UPDATE users +SET is_banned = FALSE +WHERE id = $1; \ No newline at end of file diff --git a/database/schema.sql b/database/schema.sql index c895fd9..cb1c96b 100644 --- a/database/schema.sql +++ b/database/schema.sql @@ -7,6 +7,7 @@ CREATE TABLE users ( round_qualified INTEGER NOT NULL DEFAULT 0, score INTEGER DEFAULT 0, name TEXT NOT NULL, + is_banned BOOLEAN NOT NULL DEFAULT false, PRIMARY KEY(id) ); @@ -79,11 +80,3 @@ ON UPDATE NO ACTION ON DELETE CASCADE; ALTER TABLE submissions ADD FOREIGN KEY(user_id) REFERENCES users(id) ON UPDATE NO ACTION ON DELETE CASCADE; - - - - - - - - diff --git a/go.mod b/go.mod index ec057ca..df45549 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,19 @@ go 1.22 require ( github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/jwtauth/v5 v5.3.1 + github.com/go-chi/render v1.0.3 github.com/go-playground/validator v9.31.0+incompatible github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.6.0 github.com/joho/godotenv v1.5.1 + github.com/lib/pq v1.10.9 github.com/redis/go-redis/v9 v9.6.1 go.uber.org/zap v1.27.0 ) require ( + github.com/ajg/form v1.5.1 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.2 // indirect diff --git a/go.sum b/go.sum index b0d6c94..c93e878 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= @@ -23,6 +25,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator v9.31.0+incompatible h1:UA72EPEogEnq76ehGdEDp4Mit+3FDh548oRqwVgNsHA= github.com/go-playground/validator v9.31.0+incompatible/go.mod h1:yrEkQXlcI+PugkyDjY2bRrL/UBU4f3rvrgkN3V8JEig= +github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= +github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= @@ -65,6 +69,8 @@ github.com/lestrrat-go/jwx/v2 v2.0.20 h1:sAgXuWS/t8ykxS9Bi2Qtn5Qhpakw1wrcjxChudj github.com/lestrrat-go/jwx/v2 v2.0.20/go.mod h1:UlCSmKqw+agm5BsOBfEAbTvKsEApaGNqHAEUTv5PJC4= github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= diff --git a/internal/controllers/admin_wishes.go b/internal/controllers/admin_wishes.go new file mode 100644 index 0000000..7a96b71 --- /dev/null +++ b/internal/controllers/admin_wishes.go @@ -0,0 +1,137 @@ +package controllers + +import ( + "net/http" + "log" + "github.com/google/uuid" + "fmt" + "github.com/CodeChefVIT/cookoff-backend/internal/db" + "github.com/CodeChefVIT/cookoff-backend/internal/helpers/database" + httphelpers "github.com/CodeChefVIT/cookoff-backend/internal/helpers/http" +) + +func GetAllUsers(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + users, err := database.Queries.GetAllUsers(ctx) + if err != nil { + httphelpers.WriteError(w, http.StatusInternalServerError, "Unable to fetch users") + return + } + httphelpers.WriteJSON(w, http.StatusOK, users) +} +func UpgradeUserToRound(w http.ResponseWriter, r *http.Request) { + var requestBody struct { + UserIDs []string `json:"user_ids"` + Round float64 `json:"round"` + } + + if err := httphelpers.ParseJSON(r, &requestBody); err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + if len(requestBody.UserIDs) == 0 { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid user_ids format") + return + } + + var upgradeParams []db.UpgradeUsersToRoundParams + for _, idStr := range requestBody.UserIDs { + id, err := uuid.Parse(idStr) + if err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid user_id") + return + } + + upgradeParams = append(upgradeParams, db.UpgradeUsersToRoundParams{ + ID: id, + RoundQualified: int32(requestBody.Round), + }) + } + + ctx := r.Context() + err := database.Queries.UpgradeUsersToRound(ctx, upgradeParams) + if err != nil { + httphelpers.WriteError(w, http.StatusInternalServerError, "Unable to upgrade users to round") + return + } + + httphelpers.WriteJSON(w, http.StatusOK, map[string]string{"message": "Users upgraded successfully"}) +} + +func BanUser(w http.ResponseWriter, r *http.Request) { + var requestBody map[string]string + if err := httphelpers.ParseJSON(r, &requestBody); err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + userIDStr, ok := requestBody["user_id"] + if !ok { + httphelpers.WriteError(w, http.StatusBadRequest, "user_id must be a string") + return + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid user_id") + return + } + + ctx := r.Context() + err = database.Queries.BanUser(ctx, userID) + if err != nil { + httphelpers.WriteError(w, http.StatusInternalServerError, "Unable to ban user") + return + } + + httphelpers.WriteJSON(w, http.StatusOK, map[string]string{"message": "User banned successfully"}) +} +func UnbanUser(w http.ResponseWriter, r *http.Request) { + var requestBody map[string]string + if err := httphelpers.ParseJSON(r, &requestBody); err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid request payload") + return + } + + userIDStr, ok := requestBody["user_id"] + if !ok { + httphelpers.WriteError(w, http.StatusBadRequest, "user_id must be a string") + return + } + + userID, err := uuid.Parse(userIDStr) + if err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid user_id") + return + } + + ctx := r.Context() + err = database.Queries.UnbanUser(ctx, userID) + if err != nil { + httphelpers.WriteError(w, http.StatusInternalServerError, "Unable to unban user") + return + } + + httphelpers.WriteJSON(w, http.StatusOK, map[string]string{"message": "User unbanned successfully"}) +} +type RoundRequest struct { + RoundID int `json:"round_id"` +} +func SetRoundStatus(w http.ResponseWriter, r *http.Request) { + var reqBody RoundRequest + if err := httphelpers.ParseJSON(r, &reqBody); err != nil { + httphelpers.WriteError(w, http.StatusBadRequest, "Invalid request payload") + return + } + ctx := r.Context() + redisKey := "round:enabled" + roundIDStr := fmt.Sprintf("%d", reqBody.RoundID) + err := database.RedisClient.Set(ctx, redisKey, roundIDStr, 0).Err() + if err != nil { + log.Printf("Failed to enable round: %v\n", err) + httphelpers.WriteError(w, http.StatusInternalServerError, "Failed to enable round") + return + } + httphelpers.WriteJSON(w, http.StatusOK, map[string]string{"message": "Round enabled successfully"}) +} \ No newline at end of file diff --git a/internal/db/batch.go b/internal/db/batch.go new file mode 100644 index 0000000..5d905b7 --- /dev/null +++ b/internal/db/batch.go @@ -0,0 +1,69 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.27.0 +// source: batch.go + +package db + +import ( + "context" + "errors" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" +) + +var ( + ErrBatchAlreadyClosed = errors.New("batch already closed") +) + +const upgradeUsersToRound = `-- name: UpgradeUsersToRound :batchexec +UPDATE users +SET round_qualified = GREATEST(round_qualified, $2) +WHERE id = $1 +` + +type UpgradeUsersToRoundBatchResults struct { + br pgx.BatchResults + tot int + closed bool +} + +type UpgradeUsersToRoundParams struct { + ID uuid.UUID + RoundQualified int32 +} + +func (q *Queries) UpgradeUsersToRound(ctx context.Context, arg []UpgradeUsersToRoundParams) *UpgradeUsersToRoundBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.ID, + a.RoundQualified, + } + batch.Queue(upgradeUsersToRound, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &UpgradeUsersToRoundBatchResults{br, len(arg), false} +} + +func (b *UpgradeUsersToRoundBatchResults) Exec(f func(int, error)) { + defer b.br.Close() + for t := 0; t < b.tot; t++ { + if b.closed { + if f != nil { + f(t, ErrBatchAlreadyClosed) + } + continue + } + _, err := b.br.Exec() + if f != nil { + f(t, err) + } + } +} + +func (b *UpgradeUsersToRoundBatchResults) Close() error { + b.closed = true + return b.br.Close() +} diff --git a/internal/db/db.go b/internal/db/db.go index b4a3b78..8c84b4d 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -15,6 +15,7 @@ type DBTX interface { Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) Query(context.Context, string, ...interface{}) (pgx.Rows, error) QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults } func New(db DBTX) *Queries { diff --git a/internal/db/models.go b/internal/db/models.go index 8d5aaae..0fb23e9 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -66,4 +66,5 @@ type User struct { RoundQualified int32 Score pgtype.Int4 Name string + IsBanned bool } diff --git a/internal/db/user.sql.go b/internal/db/user.sql.go index fd0055d..77e11ca 100644 --- a/internal/db/user.sql.go +++ b/internal/db/user.sql.go @@ -12,10 +12,21 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const banUser = `-- name: BanUser :exec +UPDATE users +SET is_banned = TRUE +WHERE id = $1 +` + +func (q *Queries) BanUser(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, banUser, id) + return err +} + const createUser = `-- name: CreateUser :one INSERT INTO users (id, email, reg_no, password, role, round_qualified, score, name) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) -RETURNING id, email, reg_no, password, role, round_qualified, score, name +RETURNING id, email, reg_no, password, role, round_qualified, score, name, is_banned ` type CreateUserParams struct { @@ -50,12 +61,48 @@ func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, e &i.RoundQualified, &i.Score, &i.Name, + &i.IsBanned, ) return i, err } +const getAllUsers = `-- name: GetAllUsers :many +SELECT id, email, reg_no, password, role, round_qualified, score, name, is_banned +FROM users +` + +func (q *Queries) GetAllUsers(ctx context.Context) ([]User, error) { + rows, err := q.db.Query(ctx, getAllUsers) + if err != nil { + return nil, err + } + defer rows.Close() + var items []User + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.Email, + &i.RegNo, + &i.Password, + &i.Role, + &i.RoundQualified, + &i.Score, + &i.Name, + &i.IsBanned, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getUserByEmail = `-- name: GetUserByEmail :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT id, email, reg_no, password, role, round_qualified, score, name, is_banned FROM users WHERE email = $1 ` @@ -72,12 +119,13 @@ func (q *Queries) GetUserByEmail(ctx context.Context, email string) (User, error &i.RoundQualified, &i.Score, &i.Name, + &i.IsBanned, ) return i, err } const getUserById = `-- name: GetUserById :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT id, email, reg_no, password, role, round_qualified, score, name, is_banned FROM users WHERE id = $1 ` @@ -94,12 +142,13 @@ func (q *Queries) GetUserById(ctx context.Context, id uuid.UUID) (User, error) { &i.RoundQualified, &i.Score, &i.Name, + &i.IsBanned, ) return i, err } const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, email, reg_no, password, role, round_qualified, score, name +SELECT id, email, reg_no, password, role, round_qualified, score, name, is_banned FROM users WHERE name = $1 ` @@ -116,6 +165,18 @@ func (q *Queries) GetUserByUsername(ctx context.Context, name string) (User, err &i.RoundQualified, &i.Score, &i.Name, + &i.IsBanned, ) return i, err } + +const unbanUser = `-- name: UnbanUser :exec +UPDATE users +SET is_banned = FALSE +WHERE id = $1 +` + +func (q *Queries) UnbanUser(ctx context.Context, id uuid.UUID) error { + _, err := q.db.Exec(ctx, unbanUser, id) + return err +} diff --git a/internal/middlewares/ban_check.go b/internal/middlewares/ban_check.go new file mode 100644 index 0000000..d4ff940 --- /dev/null +++ b/internal/middlewares/ban_check.go @@ -0,0 +1,34 @@ +package middlewares + +import ( + "net/http" + "github.com/CodeChefVIT/cookoff-backend/internal/helpers/database" + "github.com/google/uuid" +) + +func BanCheckMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userIDStr, ok := r.Context().Value("user_id").(string) + if !ok { + http.Error(w, "user_id not found in context", http.StatusUnauthorized) + return + } + userID, err := uuid.Parse(userIDStr) + if err != nil { + http.Error(w, "Invalid user ID format", http.StatusBadRequest) + return + } + ctx := r.Context() + user, err := database.Queries.GetUserById(ctx, userID) + if err != nil { + http.Error(w, "Unable to fetch user data", http.StatusInternalServerError) + return + } + + if user.IsBanned { + http.Error(w, "Your account is banned", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 3144728..35d704f 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -33,18 +33,23 @@ func (s *Server) RegisterRoutes(taskClient *asynq.Client) http.Handler { protected.Get("/me", controllers.MeHandler) protected.Get("/protected", controllers.ProtectedHandler) - protected.Post("/submit", controllers.SubmitCode) - protected.Post("/runcode", controllers.RunCode) - protected.Get("/result/{submission_id}", controllers.GetResult) - protected.Get("/question/round", controllers.GetQuestionsByRound) + banCheckRoutes := protected.With(middlewares.BanCheckMiddleware) + banCheckRoutes.Post("/submit", controllers.SubmitCode) + banCheckRoutes.Post("/runcode", controllers.RunCode) + banCheckRoutes.Get("/question/round", controllers.GetQuestionsByRound) adminRoutes := protected.With(middlewares.RoleAuthorizationMiddleware("admin")) - adminRoutes.Post("/question/create", controllers.CreateQuestion) adminRoutes.Get("/questions", controllers.GetAllQuestion) adminRoutes.Get("/question/{question_id}", controllers.GetQuestionById) adminRoutes.Delete("/question/{question_id}", controllers.DeleteQuestion) adminRoutes.Patch("/question", controllers.UpdateQuestion) + adminRoutes.Post("/upgrade", controllers.UpgradeUserToRound) + adminRoutes.Post("/roast", controllers.BanUser) + adminRoutes.Post("/unroast", controllers.UnbanUser) + adminRoutes.Post("/round/", controllers.SetRoundStatus) + adminRoutes.Get("/users", controllers.GetAllUsers) + }) adminRoutes.Post("/testcase", controllers.CreateTestCaseHandler) adminRoutes.Put("/testcase/{testcase_id}", controllers.UpdateTestCaseHandler)