-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Arnaud (Arhuman) ASSAD
committed
Feb 22, 2024
1 parent
159d6a1
commit 9fc2e6c
Showing
4 changed files
with
302 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
// Package token handles JWT tokens manipulation | ||
package token | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
|
||
"github.com/gin-gonic/gin" | ||
jwt "github.com/golang-jwt/jwt/v5" | ||
"go.uber.org/zap" | ||
) | ||
|
||
// Authenticater is the interface that wraps the Authenticate method | ||
type Authenticater interface { | ||
Authenticate(login, pass string) (CustomClaims, error) | ||
} | ||
|
||
// CustomClaims is the struct that represents the claims of a JWT token in EPFL context | ||
type CustomClaims struct { | ||
Sciper string `json:"sciper"` | ||
jwt.RegisteredClaims | ||
} | ||
|
||
// Validate validates the claims of a JWT token | ||
func (m CustomClaims) Validate() error { | ||
if m.Sciper == "" { | ||
return errors.New("sciper must be set") | ||
} | ||
return nil | ||
} | ||
|
||
// Token is the struct that represents a JWT token | ||
type Token struct { | ||
JWT *jwt.Token | ||
} | ||
|
||
// New creates a new JWT token | ||
func New(claims CustomClaims) *Token { | ||
jwt := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | ||
|
||
return &Token{JWT: jwt} | ||
} | ||
|
||
// Parse parses a JWT token | ||
func Parse(tokenString string, secret []byte) (*Token, error) { | ||
t, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | ||
// Don't forget to validate the alg is what you expect: | ||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) | ||
} | ||
|
||
return secret, nil | ||
}) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &Token{t}, nil | ||
} | ||
|
||
// Sign signs a JWT token | ||
func (t *Token) Sign(secret []byte) (string, error) { | ||
return t.JWT.SignedString([]byte(secret)) | ||
} | ||
|
||
// Claims returns the claims of a JWT token | ||
func (t *Token) Claims() jwt.MapClaims { | ||
return t.JWT.Claims.(jwt.MapClaims) | ||
} | ||
|
||
// Set sets a claim in a JWT token | ||
func (t *Token) Set(key string, value interface{}) { | ||
t.Claims()[key] = value | ||
} | ||
|
||
// Get gets a claim from a JWT token | ||
func (t *Token) Get(key string) interface{} { | ||
return t.Claims()[key] | ||
} | ||
|
||
// GetString gets a claim from a JWT token as a string | ||
func (t *Token) GetString(key string) string { | ||
return t.Claims()[key].(string) | ||
} | ||
|
||
// ToJSON converts a JWT token to JSON | ||
func (t *Token) ToJSON() (string, error) { | ||
return t.JWT.Raw, nil | ||
} | ||
|
||
// PostLoginHandler is the handler that checks the login and password and returns a JWT token | ||
func PostLoginHandler(log *zap.Logger, auth Authenticater, secret []byte) gin.HandlerFunc { | ||
log.Info("Creating login handler") | ||
return func(c *gin.Context) { | ||
login := c.PostForm("login") | ||
pass := c.PostForm("pass") | ||
|
||
log.Info("Login attempt", zap.String("login", login)) | ||
|
||
claims, err := auth.Authenticate(login, pass) | ||
if err != nil { | ||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
|
||
t := New(claims) | ||
encoded, err := t.Sign(secret) | ||
if err != nil { | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
|
||
c.JSON(http.StatusOK, gin.H{"access_token": encoded}) | ||
} | ||
} | ||
|
||
// GinMiddleware is the middleware that checks the JWT token | ||
func GinMiddleware(secret []byte) gin.HandlerFunc { | ||
return func(c *gin.Context) { | ||
authorizationHeaderString := c.GetHeader("Authorization") | ||
if authorizationHeaderString == "" { | ||
c.JSON(http.StatusUnauthorized, gin.H{"error": "No token provided"}) | ||
c.Abort() | ||
return | ||
} | ||
|
||
// Check that the authorization header starts with "Bearer" | ||
if len(authorizationHeaderString) < 7 || authorizationHeaderString[:7] != "Bearer " { | ||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"}) | ||
c.Abort() | ||
return | ||
} | ||
|
||
// Extract the token from the authorization header | ||
tokenString := authorizationHeaderString[7:] | ||
|
||
t, err := Parse(tokenString, secret) | ||
if err != nil { | ||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) | ||
c.Abort() | ||
return | ||
} | ||
|
||
c.Set("token", t) | ||
c.Next() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package token_test | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/epfl-si/go-toolkit/token" | ||
) | ||
|
||
func Example_creating() { | ||
t := token.New(token.CustomClaims{Sciper: "321014"}) | ||
encoded, err := t.Sign([]byte("secret")) | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
|
||
fmt.Println(encoded) | ||
// Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY2lwZXIiOiIzMjEwMTQifQ.7Nf7BUmLmN2RGXwf2nr-cOwkcsCkWO2i6YgLZdItrek | ||
} | ||
|
||
func Example_decoding() { | ||
tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY2lwZXIiOiIzMjEwMTQifQ.7Nf7BUmLmN2RGXwf2nr-cOwkcsCkWO2i6YgLZdItrek" | ||
decoded, err := token.Parse(tokenString, []byte("secret")) | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
|
||
fmt.Println(decoded.JWT.Raw) | ||
fmt.Println(decoded.JWT.Header["alg"]) | ||
fmt.Println(decoded.JWT.Header["typ"]) | ||
// fmt.Printf("%+v", decoded.JWT) | ||
// Output: | ||
// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY2lwZXIiOiIzMjEwMTQifQ.7Nf7BUmLmN2RGXwf2nr-cOwkcsCkWO2i6YgLZdItrek | ||
// HS256 | ||
// JWT | ||
} |