Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Rate Limiting to API Endpoints #76

Merged
merged 8 commits into from
Feb 6, 2025
6 changes: 6 additions & 0 deletions backend/controllers/add_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ func AddTaskHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Due Date is required, and cannot be empty!", http.StatusBadRequest)
return
}

if priority != "" && priority != "H" && priority != "M" && priority != "L" {
http.Error(w, "Priority must be either 'H' (High), 'M' (Medium), or 'L' (Low)", http.StatusBadRequest)
return
}

job := Job{
Name: "Add Task",
Execute: func() error {
Expand Down
28 changes: 26 additions & 2 deletions backend/controllers/get_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,49 @@ import (
"encoding/json"
"net/http"
"os"
"sort"
)

// Priority value mapping
func getPriorityValue(priority string) int {
switch priority {
case "H":
return 3
case "M":
return 2
case "L":
return 1
default:
return 0
}
}

// helps to fetch tasks using '/tasks' route
func TasksHandler(w http.ResponseWriter, r *http.Request) {
email := r.URL.Query().Get("email")
encryptionSecret := r.URL.Query().Get("encryptionSecret")
UUID := r.URL.Query().Get("UUID")
sortBy := r.URL.Query().Get("sort") // New query parameter for sorting

origin := os.Getenv("CONTAINER_ORIGIN")
if email == "" || encryptionSecret == "" || UUID == "" {
http.Error(w, "Missing required parameters", http.StatusBadRequest)
return
}

if r.Method == http.MethodGet {
tasks, _ := tw.FetchTasksFromTaskwarrior(email, encryptionSecret, origin, UUID)
if tasks == nil {
tasks, err := tw.FetchTasksFromTaskwarrior(email, encryptionSecret, origin, UUID)
if err != nil || tasks == nil {
http.Error(w, "Failed to fetch tasks at backend", http.StatusInternalServerError)
return
}

if sortBy == "priority" {
sort.Slice(tasks, func(i, j int) bool {
return getPriorityValue(tasks[i].Priority) > getPriorityValue(tasks[j].Priority)
})
}

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tasks)
return
Expand Down
31 changes: 25 additions & 6 deletions backend/main.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package main

import (
"ccsync_backend/controllers"
"encoding/gob"
"log"
"net/http"
"os"
"time"

"github.com/gorilla/sessions"
"github.com/joho/godotenv"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"

"ccsync_backend/controllers"
"ccsync_backend/middleware"
)

func main() {
Expand Down Expand Up @@ -44,11 +47,27 @@ func main() {
app := controllers.App{Config: conf, SessionStore: store}
mux := http.NewServeMux()

// API endpoints
mux.HandleFunc("/auth/oauth", app.OAuthHandler)
mux.HandleFunc("/auth/callback", app.OAuthCallbackHandler)
mux.HandleFunc("/api/user", app.UserInfoHandler)
mux.HandleFunc("/auth/logout", app.LogoutHandler)
// Allow 50 requests per 30 seconds per IP for testing
rateLimitedHandler := middleware.RateLimitMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/auth/oauth":
app.OAuthHandler(w, r)
case "/auth/callback":
app.OAuthCallbackHandler(w, r)
case "/api/user":
app.UserInfoHandler(w, r)
case "/auth/logout":
app.LogoutHandler(w, r)
}
}), 30*time.Second, 50)

// API endpoints with rate limiting
mux.Handle("/auth/oauth", rateLimitedHandler)
mux.Handle("/auth/callback", rateLimitedHandler)
mux.Handle("/api/user", rateLimitedHandler)
mux.Handle("/auth/logout", rateLimitedHandler)

// API endpoints without rate limiting
its-me-abhishek marked this conversation as resolved.
Show resolved Hide resolved
mux.HandleFunc("/tasks", controllers.TasksHandler)
mux.HandleFunc("/add-task", controllers.AddTaskHandler)
mux.HandleFunc("/edit-task", controllers.EditTaskHandler)
Expand Down
98 changes: 98 additions & 0 deletions backend/middleware/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package middleware

import (
"net/http"
"strings"
"sync"
"time"
)

type RateLimiter struct {
sync.RWMutex
requests map[string][]time.Time
windowSize time.Duration
maxRequests int
}

func NewRateLimiter(windowSize time.Duration, maxRequests int) *RateLimiter {
limiter := &RateLimiter{
requests: make(map[string][]time.Time),
windowSize: windowSize,
maxRequests: maxRequests,
}

go func() {
for {
time.Sleep(windowSize)
limiter.cleanup()
}
}()

return limiter
}

func (rl *RateLimiter) cleanup() {
rl.Lock()
defer rl.Unlock()

now := time.Now()
for ip, times := range rl.requests {
var valid []time.Time
for _, t := range times {
if now.Sub(t) <= rl.windowSize {
valid = append(valid, t)
}
}
if len(valid) > 0 {
rl.requests[ip] = valid
} else {
delete(rl.requests, ip)
}
}
}

func (rl *RateLimiter) IsAllowed(ip string) bool {
rl.Lock()
defer rl.Unlock()

now := time.Now()
times := rl.requests[ip]

var valid []time.Time
for _, t := range times {
if now.Sub(t) <= rl.windowSize {
valid = append(valid, t)
}
}

if len(valid) >= rl.maxRequests {
rl.requests[ip] = valid
return false
}

valid = append(valid, now)
rl.requests[ip] = valid
return true
}

func RateLimitMiddleware(next http.Handler, windowSize time.Duration, maxRequests int) http.Handler {
limiter := NewRateLimiter(windowSize, maxRequests)

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := r.Header.Get("X-Forwarded-For")
if ip == "" {
ip = r.RemoteAddr
}
if idx := strings.Index(ip, ":"); idx != -1 {
ip = ip[:idx]
}

if !limiter.IsAllowed(ip) {
w.Header().Set("Retry-After", windowSize.String())
http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
return
}

next.ServeHTTP(w, r)
})
}