Skip to content

Commit

Permalink
feat: go: added middleware for elevated check (#489)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbarrosop authored Mar 22, 2024
1 parent 04b75ab commit 5472bed
Show file tree
Hide file tree
Showing 13 changed files with 580 additions and 48 deletions.
10 changes: 8 additions & 2 deletions go/api/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ paths:
tags:
- pat
security:
- BearerAuth: []
- BearerAuthElevated: []
requestBody:
content:
application/json:
Expand Down Expand Up @@ -170,7 +170,7 @@ paths:
- user
- email
security:
- BearerAuth: []
- BearerAuthElevated: []
requestBody:
content:
application/json:
Expand Down Expand Up @@ -253,6 +253,12 @@ components:
BearerAuth:
type: http
scheme: bearer
BearerAuthElevated:
type: http
scheme: bearer
description: >-
This endpoint may require elevated permissions, depending on server settings.
For details see https://docs.nhost.io/guides/auth/elevated-permissions
schemas:
CreatePATRequest:
Expand Down
83 changes: 42 additions & 41 deletions go/api/server.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/api/types.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion go/cmd/jwt_getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/urfave/cli/v2"
)

func getJWTGetter(cCtx *cli.Context) (*controller.JWTGetter, error) {
func getJWTGetter(cCtx *cli.Context, db controller.DBClient) (*controller.JWTGetter, error) {
var customClaimer controller.CustomClaimer
var err error
if cCtx.String(flagCustomClaims) != "" {
Expand All @@ -28,6 +28,8 @@ func getJWTGetter(cCtx *cli.Context) (*controller.JWTGetter, error) {
[]byte(cCtx.String(flagHasuraGraphqlJWTSecret)),
time.Duration(cCtx.Int(flagAccessTokensExpiresIn))*time.Second,
customClaimer,
cCtx.String(flagRequireElevatedClaim),
db,
)
if err != nil {
return nil, fmt.Errorf("error creating jwt getter: %w", err)
Expand Down
17 changes: 16 additions & 1 deletion go/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
flagAllowedEmailDomains = "allowed-email-domains"
flagAllowedEmails = "allowed-emails"
flagEmailPasswordlessEnabled = "email-passwordless-enabled"
flagRequireElevatedClaim = "require-elevated-claim"
)

func CommandServe() *cli.Command { //nolint:funlen,maintidx
Expand Down Expand Up @@ -376,6 +377,20 @@ func CommandServe() *cli.Command { //nolint:funlen,maintidx
Category: "signin",
EnvVars: []string{"AUTH_EMAIL_PASSWORDLESS_ENABLED"},
},
&cli.GenericFlag{ //nolint: exhaustruct
Name: flagRequireElevatedClaim,
Value: &EnumValue{ //nolint: exhaustruct
Enum: []string{
"disabled",
"recommended",
"required",
},
Default: "disabled",
},
Usage: "Require x-hasura-auth-elevated claim to perform certain actions: create PATs, change email and/or password, enable/disable MFA and add security keys. If set to `recommended` the claim check is only performed if the user has a security key attached. If set to `required` the only action that won't require the claim is setting a security key for the first time.",
Category: "security",
EnvVars: []string{"AUTH_REQUIRE_ELEVATED_CLAIM"},
},
},
Action: serve,
}
Expand Down Expand Up @@ -441,7 +456,7 @@ func getGoServer( //nolint:funlen
return nil, fmt.Errorf("problem creating config: %w", err)
}

jwtGetter, err := getJWTGetter(cCtx)
jwtGetter, err := getJWTGetter(cCtx, db)
if err != nil {
return nil, fmt.Errorf("problem creating jwt getter: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ type Emailer interface {
) error
}

type DBClient interface {
type DBClient interface { //nolint:interfacebloat
CountSecurityKeysUser(ctx context.Context, userID uuid.UUID) (int64, error)
GetUser(ctx context.Context, id uuid.UUID) (sql.AuthUser, error)
GetUserByEmail(ctx context.Context, email pgtype.Text) (sql.AuthUser, error)
GetUserByRefreshTokenHash(
Expand Down
3 changes: 3 additions & 0 deletions go/controller/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package controller

import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
Expand All @@ -17,6 +18,8 @@ func (e *APIError) Error() string {
return fmt.Sprintf("API error: %s", e.t)
}

var ErrElevatedClaimRequired = errors.New("elevated-claim-required")

var (
ErrUserEmailNotFound = &APIError{api.InvalidEmailPassword}
ErrEmailAlreadyInUse = &APIError{api.EmailAlreadyInUse}
Expand Down
61 changes: 61 additions & 0 deletions go/controller/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ type JWTGetter struct {
method jwt.SigningMethod
customClaimer CustomClaimer
accessTokenExpiresIn time.Duration
elevatedClaimMode string
db DBClient
}

func NewJWTGetter(
jwtSecretb []byte,
accessTokenExpiresIn time.Duration,
customClaimer CustomClaimer,
elevatedClaimMode string,
db DBClient,
) (*JWTGetter, error) {
jwtSecret, err := decodeJWTSecret(jwtSecretb)
if err != nil {
Expand All @@ -74,6 +78,8 @@ func NewJWTGetter(
method: method,
customClaimer: customClaimer,
accessTokenExpiresIn: accessTokenExpiresIn,
elevatedClaimMode: elevatedClaimMode,
db: db,
}, nil
}

Expand Down Expand Up @@ -182,13 +188,58 @@ func (j *JWTGetter) Validate(accessToken string) (*jwt.Token, error) {

func (j *JWTGetter) FromContext(ctx context.Context) (*jwt.Token, bool) {
token, ok := ctx.Value(jwtContextKey).(*jwt.Token)
if !ok { //nolint:nestif
c := ginmiddleware.GetGinContext(ctx)
if c != nil {
a, ok := c.Get(jwtContextKey)
if !ok {
return nil, false
}

token, ok = a.(*jwt.Token)
if !ok {
return nil, false
}
return token, true
}
}
return token, ok
}

func (j *JWTGetter) ToContext(ctx context.Context, jwtToken *jwt.Token) context.Context {
return context.WithValue(ctx, jwtContextKey, jwtToken) //nolint:revive,staticcheck
}

func (j *JWTGetter) verifyElevatedClaim(ctx context.Context, token *jwt.Token) (bool, error) {
if j.elevatedClaimMode == "disabled" {
return true, nil
}

u, err := token.Claims.GetSubject()
if err != nil {
return false, fmt.Errorf("error getting user id from subject: %w", err)
}

if j.elevatedClaimMode == "recommended" {
userID, err := uuid.Parse(u)
if err != nil {
return false, fmt.Errorf("error parsing user id: %w", err)
}
n, err := j.db.CountSecurityKeysUser(ctx, userID)
if err != nil {
return false, fmt.Errorf("error checking if user has security keys: %w", err)
}

if n == 0 {
return true, nil
}
}

elevatedClaim := j.GetCustomClaim(token, "x-hasura-auth-elevated")

return elevatedClaim == u, nil
}

func (j *JWTGetter) MiddlewareFunc(
ctx context.Context, input *openapi3filter.AuthenticationInput,
) error {
Expand All @@ -207,6 +258,16 @@ func (j *JWTGetter) MiddlewareFunc(
return fmt.Errorf("invalid token") //nolint:goerr113
}

if input.SecuritySchemeName == "BearerAuthElevated" {
found, err := j.verifyElevatedClaim(ctx, jwtToken)
if err != nil {
return fmt.Errorf("error verifying elevated claim: %w", err)
}
if !found {
return ErrElevatedClaimRequired
}
}

c := ginmiddleware.GetGinContext(ctx)
c.Set(jwtContextKey, jwtToken)

Expand Down
Loading

0 comments on commit 5472bed

Please sign in to comment.