Skip to content

Commit

Permalink
add rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
dunkbing authored Mar 1, 2024
1 parent ac1234d commit 78eb9b3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ require (
github.com/foobaz/lossypng v0.0.0-20200814224715-48fa8819852a
github.com/google/uuid v1.6.0
)

require golang.org/x/time v0.5.0
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ github.com/foobaz/lossypng v0.0.0-20200814224715-48fa8819852a h1:0TYY/syyvt/+y5P
github.com/foobaz/lossypng v0.0.0-20200814224715-48fa8819852a/go.mod h1:wRxTcIExb9GZAgOr1wrQuOZBkyoZNQi7znUmeyKTciA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
88 changes: 79 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ package main
import (
"encoding/json"
"fmt"
"github.com/dunkbing/tinyimg/converter/config"
"github.com/dunkbing/tinyimg/converter/image"
"io"
"log"
"log/slog"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/dunkbing/tinyimg/converter/config"
"github.com/dunkbing/tinyimg/converter/image"

"github.com/google/uuid"
"golang.org/x/time/rate"
)

type RequestBody struct {
Expand Down Expand Up @@ -217,11 +221,77 @@ func isImage(mimeType string) bool {
return strings.HasPrefix(mimeType, "image/")
}

// Create a custom visitor struct which holds the rate limiter for each
// visitor and the last time that the visitor was seen.
type visitor struct {
limiter *rate.Limiter
lastSeen time.Time
}

// Change the map to hold values of the type visitor.
var visitors = make(map[string]*visitor)
var mu sync.Mutex

// Run a background goroutine to remove old entries from the visitors map.
func init() {
go cleanupVisitors()
}

func getVisitor(ip string) *rate.Limiter {
mu.Lock()
defer mu.Unlock()

v, exists := visitors[ip]
fmt.Println(ip)
if !exists {
limiter := rate.NewLimiter(30, 120)
// Include the current time when creating a new visitor.
visitors[ip] = &visitor{limiter, time.Now()}
return limiter
}

// Update the last seen time for the visitor.
v.lastSeen = time.Now()
return v.limiter
}

// Every minute check the map for visitors that haven't been seen for
// more than 3 minutes and delete the entries.
func cleanupVisitors() {
for {
time.Sleep(time.Minute)

mu.Lock()
for ip, v := range visitors {
if time.Since(v.lastSeen) > 3*time.Minute {
delete(visitors, ip)
}
}
mu.Unlock()
}
}

func limit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
log.Print(err.Error())
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

limiter := getVisitor(ip)
if !limiter.Allow() {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}

next.ServeHTTP(w, r)
})
}

func main() {
mux := http.NewServeMux()
uploadHandlerWithCors := enableCors(http.HandlerFunc(uploadHandler))
downloadZipHandlerWithCors := enableCors(http.HandlerFunc(downloadZipHandler))
serveImageHandlerWithCors := enableCors(http.HandlerFunc(serveImgHandler))
mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// send hello message
w.Header().Set("Content-Type", "application/json")
Expand All @@ -230,14 +300,14 @@ func main() {
"message": "Hello",
})
}))
mux.Handle("/upload", uploadHandlerWithCors)
mux.Handle("/download-all", downloadZipHandlerWithCors)
mux.Handle("/image", serveImageHandlerWithCors)
mux.HandleFunc("/upload", uploadHandler)
mux.HandleFunc("/download-all", downloadZipHandler)
mux.HandleFunc("/image", serveImgHandler)
fs := http.FileServer(http.Dir("./output"))
http.Handle("/static/", http.StripPrefix("/static/", fs))

log.Println("Server started on port 8080")
err := http.ListenAndServe(":8080", mux)
err := http.ListenAndServe(":8080", enableCors((mux)))
if err != nil {
log.Fatal("Server failed to start:", err)
}
Expand Down

0 comments on commit 78eb9b3

Please sign in to comment.