diff --git a/go/api/openapi.yaml b/go/api/openapi.yaml index ac0181acc..11578a045 100644 --- a/go/api/openapi.yaml +++ b/go/api/openapi.yaml @@ -41,7 +41,7 @@ paths: tags: - pat security: - - BearerAuth: [] + - BearerAuthElevated: [] requestBody: content: application/json: @@ -170,7 +170,7 @@ paths: - user - email security: - - BearerAuth: [] + - BearerAuthElevated: [] requestBody: content: application/json: @@ -253,6 +253,12 @@ components: BearerAuth: type: http scheme: bearer + BearerAuthElevated: + type: http + scheme: bearer + description: >- + This endpoint may require elevated permissions, depending on server settings. + For details see https://docs.nhost.io/guides/auth/elevated-permissions schemas: CreatePATRequest: diff --git a/go/api/server.gen.go b/go/api/server.gen.go index e17a07fe2..496ab38a5 100644 --- a/go/api/server.gen.go +++ b/go/api/server.gen.go @@ -95,7 +95,7 @@ func (siw *ServerInterfaceWrapper) HeadHealthz(c *gin.Context) { // PostPat operation middleware func (siw *ServerInterfaceWrapper) PostPat(c *gin.Context) { - c.Set(BearerAuthScopes, []string{}) + c.Set(BearerAuthElevatedScopes, []string{}) for _, middleware := range siw.HandlerMiddlewares { middleware(c) @@ -162,7 +162,7 @@ func (siw *ServerInterfaceWrapper) PostSignupEmailPassword(c *gin.Context) { // PostUserEmailChange operation middleware func (siw *ServerInterfaceWrapper) PostUserEmailChange(c *gin.Context) { - c.Set(BearerAuthScopes, []string{}) + c.Set(BearerAuthElevatedScopes, []string{}) for _, middleware := range siw.HandlerMiddlewares { middleware(c) @@ -847,45 +847,46 @@ func (sh *strictHandler) GetVersion(ctx *gin.Context) { // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+RaW3fbuBH+KzjY9qQ9h5QU293Geqo2dTfei60Ty9sHRw8QMSKRkAAXAKWorv57Dy6k", - "SJGSZdd2nOybhMtgZr65YcBbHIksFxy4Vnh4i1WUQEbsz7cSiIbxaPIefi9AaTNGKGWaCU7SsRQ5SM1A", - "4eGcpAoCnNeGbjF8zpkENbL7KKhIstxsxUN8ZqaI+YMo0YDEHOkE0Hg0wQGeC5kRjYfYTIWaZYADrFc5", - "4CFWWjIe43WAM9CEEk12M6VlAQGGzyTLUzDLOMkMjWwV5kTjABcKaDhbuSGS52GUMkPanyVmHyHSeL0O", - "sITfCyaB4uFNTaxpa2lQ15nKBVdwT6Ux2tbW+T+bCqpEwkfR8d9m38+Pw+hkdhqevIHj8PTvb0hIT+hg", - "/pqeHMHRCQ5wTrQGaUh9+DC7GYSnJJxPb9+sP3yYhdXfk/XO3/Vdr4/Mti5EcpDKyDiKIlBqIj4Bb8vy", - "giXYwplR3C1TF+xnUgr5QMjB7O3wETOMIkEB6YRoxChwzeYMlDUFkucpi5wPOQoBBl5khnUKc1KkOpQi", - "hTArlA5nEDIekjQVS6B2XOEAU6bILAUaAqe5YFzXxwoFlmZGWBqSVAKhK0OkUNAaXoA0nFHnvTNGKfCQ", - "cMFXmSjMSYwb+EgaKpALkGHJMeMLkjIaOnI5UWopJK1NSB96ApyKiKQQcqFLOaxduB2hFiJUiZC6Psh4", - "mLBZHpo4MSOWbwmUSYgmYouQVVVzSLGYF3lYKsQEDF4KWmqnZNMElGlnlFKKxNBG912REY7mkgGn6coh", - "iMrVHYSUJrpQHXQmkzFyk56IsZgNBaP4GGTLuj29DYeBt8Mu6/71X6O3CUlT4DGMySoVhN7TxjWLPoFN", - "BPudzq/rYuLy54P9q/SDy587Qbm0ylPvnSnoibinMLKxcRPLEq1zNez3XTbpRSLrR0RHSVhuMJB1hZ2W", - "rFeglMX3XnyRZuBtyV2bP3NJ7NwurBIu4/r7kw7bMRjNJahkR1B/72aRNtOIFuZARAqdmIjlY5SQaJkA", - "R56SWWHC2E//fsEZzTr58Bb/ScIcD/F3/U2p1Pd1Uv9adXhXHYkdet/S6XS3FTzM49TGhPaxX1patxmy", - "mJ/zMxObxz6mPrAUNCTadjNCNnoiN103g48i4T2VMZ38gydC6R4T9cqw3NCuQcoM0nFWOWeKqYxxlhUZ", - "OkZRQiSJNEjVYOBKywGPrdTfmVRyevLfP5toyfgvwGOd4OHxXQVEyWTF0/RQFT+ojMjm5C6wu+K4yS2P", - "ZioPvit8Y7XjoWWj15oHPgWlrCW8cCcTLn/eaS8s5te5T7Y7vGOXUq7zrynqPEwhX2m02khwz+rEFdbv", - "7d2jXjfdYHvDt8l2GmCmIVOd9YsfIFKSlfnvrzmGYrMQ88V5iwBlKk/J6sL2AeobfhIJR1cG+K5t7ubR", - "BZJeinADCfIL68jYAiAjn0scjhqoHD1OY2POpNJOKiuKuS2RasTJ1RW0n76QvfYV1H3sZEE0kdcy7bSB", - "yPZXqGsqHdYpei4zebZwYyd+K6/cGy3NhEiBcLOks41EyzZSebV/mVU3U6Oqc9Ap3Dfrj3kiOFwU2cw5", - "Tbu63czvh18+VpDdvtdUvln3xKaLNf0nqMy4abWB66/Vsa6Arem8qZNuDZTiTnfEH1tJvE0Ij+FhdQSH", - "5dkLKyXaLYxtqCqm96rlCjh1inTX9G+o/LxbRbBXP7XrGOg/ukrsLTEqJNOrK0PMyfcDEAlyVJjQ6V+O", - "bCCywxs+TTmB14YG43NbakSCaxLpmo6wKvJcSF2X27/XXJiRVwpduRUmipnioKpSqh0219cVfQVywSJA", - "WqDRphtlAkzKIvBXbH/KKCdRAuioN2gdsFwue8RO94SM+36v6v9y/vbs4uosPOoNeonOUhs+QWbqcu5P", - "9kSG/b5akjgG2WOib5f0jXqYTisBLYc4wAuQ7jKOX/cGvYFDFzjJGR7iYztkE3FiEegnQFKd/Mf8jl2D", - "1Rig9eVziof4R9Dv/BKDtOss2K1Hg0EJBXBn2Zsnhf5H5RoCznTuNKxNb9bi3A0DU8ix6xKLKrKMyBUe", - "YschihKIPhm9kFgZY3SL8XQdmJ+0Ldw7IHS/dI/MyDrA/ZxYZeVCdah7LJQe28dF/3Dxg6CrR1Nz60F2", - "3fRfU4isnxDm9uNmF9qFbXrMizRdIV8kIILGviWCXE8EuaZIPbDg4U0zpNxM19M6Ou74XbTQX8ajyV9r", - "qNk3GQuZYjFnvL/1yLQXxCu7pdGEeCJQ9zRZnxnefb3Iu4A2KgaKGO+hiyJNke8pogwIV2hyORmjqGw9", - "Gv/jABTolvsZBhDjaMl04rIjIpyi2rNgia1DdPMOyekG1wbmea231q/y8V24tzpyT4r9zv7fM+O/P4r/", - "SmIWoZTxT0gB1yapmtvEK+WRUjVz2Idrtp9OD53P7QCiAhR/pRF8ZkoHiGm0ZGmKZlBGlR6aJIB87YNc", - "qWJsy40QZzwR4WZLoYCaoxTo6gr8SrknKWcqMSpyRBCHpZ3soXNLjMVcSGPZm6sz8s/ejjPV67JL93hc", - "66cZcLdNUx9mi/ppre+LpZOt962DQ8w+6zosNVQwNXNEkd87RxT5c+WIHS3xl42ZhJgpDdI4a0decJFj", - "UbuAGpfLgVN/sTkZHD8a681vdLo4t3iiDKKEcKYyw0v19Ydl5vT5mDFXUGfSMVsAL9NhI/J0OEKRH5g9", - "bXDamT1NmHO+0I9s52S/J2y1WZ7IC3Y0c9beBb5IVrT8IKcj5IU2xj4qEbNokKaN2+SXEIVmYNKPz4Im", - "vZjsQyiVJlncrzZ2HNj0VLUJPNz1L7na+CrgNKyzFx5QJu1vID01/nu7Vi+qaDprBzhfLRnQ95VMwGlX", - "eDwU29Kf+xIU6LvBbHS7nhC/zq7aF/XgkiNkNbXx4S1APK+IbF5m7YZDXL2sW+ue7gvQ0tlbiG4F5Kot", - "tLvP85tf8n9q8h79zRpTmx7moDfoDUIKizu/kii3d7QbWyh54coXLOX7a02MfgSNFpUWSoVWxxgz+18A", - "AAD//3cUeO3vLgAA", + "H4sIAAAAAAAC/+RaW3fbuBH+KzjY9qQ9h5S0trtN9FRt6t14L4lOrGwfEj9AxIhEQgJcAJSiuvrvPQOA", + "EiVSsuzajjd9s3HjzHxzwwdd00QVpZIgraHDa2qSDArm/nypgVkYjyZv4fcKjMUxxrmwQkmWj7UqQVsB", + "hg5nLDcQ0bIxdE3hcyk0mJHbx8EkWpS4lQ7pOU4x/IdwZoGoGbEZkPFoQiM6U7pglg4pTsVWFEAjapcl", + "0CE1VguZ0lVEC7CMM8v2C2V1BRGFz6woc8BlkhV4RrGMS2ZpRCsDPJ4u/RAryzjJBR4dvqWmHyGxdLWK", + "qIbfK6GB0+H7hlpXraVR02amVNLALY0meNtaF//cNtBaJXqSnP5t+t3sNE7Opi/is+dwGr/4+3MW8zM+", + "mH3Lz07g5IxGtGTWgsajPnyYvh/EL1g8u7p+vvrwYRqv/z1b7f27uevbE9zWhUgJ2qCOoyQBYybqE8i2", + "Lk9Ygx2cBafdOnXBfq610neEHHBvR4zgMEkUB2IzZongIK2YCTDOFVhZ5iLxMeRPiCjIqkDROcxYldtY", + "qxziojI2nkIsZMzyXC2Au3FDI8qFYdMceAySl0pI2xyrDLgzCybymOUaGF/iIZWB1vAcNErGffROBecg", + "YyaVXBaqwi8JifCxPDag56DjWmIh5ywXPPbHlcyYhdK8MaFD6olorhKWQyyVrfVwfuF3xFap2GRK2+ag", + "kHEmpmWMeWLKnNwauNCQTNTOQc5U20NGpLIq49ogmDBkrWhtnVpMTChXnVnKGJZCG91XVcEkmWkBkudL", + "jyCpV3ccZCyzlek4ZzIZEz8ZDkGP2ZyAhk9Bt7w7nLeRMAp+2OXdv/4wepmxPAeZwpgtc8X4LX3ciuQT", + "uEJwOOjCui4h3vx8dHzVcfDm505Q3jjjmbfeFexE3VIZvbVxk8sya0sz7Pd9NeklqugnzCZZXG9AyLrS", + "TkvXSzDG4Xsrudh24m3p3Zg/90Xswi1cF1wh7XdnHb6DGM00mGxPUn/rZ4nFacIr/CBhlc0wY4UcpTRZ", + "ZCBJOAlXYBr76V9PuKK5IB9e0z9pmNEh/aa/aZX6oU/qvzMd0dVEYo/dd2x6td8L7hZxZuNCh8SvPa3b", + "DUUqL+Q55uZxyKl3bAXxiLbfjIjLnsRPN93go8pkzxTCZv+QmTK2J1SzM6w3tHuQuoJ0fKuew2aqEFIU", + "VUFOSZIxzRIL2mwJcGn1QKZO62+wlLw4+8+fMVsK+QvI1GZ0eHpTA1ELuZbp6lgT36mNKGbsJrC78jjW", + "lntzlTvfFb6y3vHYtjFYLQCfgzHOE554kClfP2/0F5HKd2UotnuiY59R3pV/pKxzN4P8QbPVRoNbdie+", + "sX7r7h7Nvuk9dTd8V2yvIiosFKazfwkDTGu2xP/DNQdP3G7EQnPeOoALU+Zs+drxAM0NP6lMkksEvmub", + "v3l0gWQXKt5AQsLCJjKuASjY5xqHky1UTu6H2JgJbazXyqmCtyW2HvF6dSXth29k34UO6jZ+MmeW6Xc6", + "7/SBxPEr3JNKxzFFj+Umj5Zu3MRv9ZV7Y6WpUjkwiUs6aSRe00j11f5pdt3CjNbMQadyX208lpmS8Loq", + "pj5o2t3tZv4w/Pq+kuzuvWYdm81I3A6x7fiJ1m687bWR59eaWK+Bbdh82ybdFqjVvdqTf1wn8TJjMoW7", + "9RESFudPrJVoUxi7UK2FPmiWS5DcG9Jf07+i9vNmE8FB+zSuY2D/303ibolJpYVdXuJhXr/vgWnQowpT", + "Z3g5conIDW/kxHYCpdwsP89hjsmjbZtJJgyp6WhSsCUJ4hEIe0gJuhDuKmoiwqEEyYVMiZLEk8vEgLVC", + "pqZHflCacLBM5IYYAFI3NlwlpldbuJ9WgoPps8pm/forceMrNLpJN7SPkDPXRiVKWpbYBv7UVGWptG1i", + "Gt6iXuPIM0Mu/QrM0Nj4rDuw9Q7XxzQNdQl6LhIgVpHRhmnD5JmLBAJ9EL4yKlmSATnpDVofWCwWPeam", + "e0qn/bDX9H+5eHn++vI8PukNepktclcaQBfmzSx8ORwy7PfNgqUpaDSlW9JH8wibrxV0EtKIzkF7ooF+", + "2xv0Bt5zQbJS0CE9dUOuycicd/UzYLnN/o1/p548xuByeeqC0yH9EeyrsAS92LMmbuvJYFBDAdJH7ea5", + "pP/ReLLDh8WNQbPhnR3O3TAIQ7y4vmiaqiiYXtIh9RKSJIPkE9qFpQYDzS+mV6sI/+Rt5V4B44e1u2dB", + "VhHtl8wZq1Smw9xjZezYPZyGR5nvFV/em5lbj82r7dyETdbqAWFuP9x2oV05QmdW5fmShAaIMDIOdA/x", + "fA/xhE8zadLh++vO/Pf+anXVRMmLse9M8pfxaPLXBnru3clBZ0QqhezvPKQdBPPSbdkiWh4I3ANE8iPD", + "fIhvvQlwNDFwImSPvK7ynATelBTApCGTN5MxSWp6FeNQAnDgO2GIAhAhyULYzHcAhElOGk+fNbYe0c1b", + "q+QbXLcwLxv8YX/dc9yEe4t1fFDs93Kcj4z/4Wz+K0tFQnIhPxED0mJxxRvTMxOQMg13OIRrcficHrmY", + "uQHCFRj5zBL4LIyNiLBkIfKcTKHOLj0yyYCE/o74dgx9y48w7zwJk7ilMsDxUwbs+pr/zPhnN+8qKalK", + "woiEhZvskQt3mEil0ujZG3qAhKd9L5npdfmlfyBvcIYI7q5r2uN80T6s932xsrLzhnd0ijnkXceVhjVM", + "2zWiKm9dI6rysWrEHtr/aWOmIRXGgsZg7agLPnPMG5dsDLlwa8EW+Gxwem+ib/8OqUtyhycpIMmYFKZA", + "Wda/cHHCvHg8YfCa7V06FXOQdTncyjwdgVCVR1ZPl5z2Vk9Mcz4W+oljhw5Hwg6V9EBRsIewWoUQ+CJV", + "0clDvI1IUBqdfVQj5tBg2z7uil/GDJkClp9QBbG8YPVhnGssFnfrkb0krkytKZEAe/NXa22cDUgeN8WM", + "j2iXDpNlD+0HBxm6J9U8nbcTXeiaEPxDrRNI3pUmj8W2juu+BgP2ZjC3mL0HxK+TQfyikVxLRJylNrG8", + "A0iQlbDNK7TbcEzI1/1rM+JDI1oHfQvRncS8pon28z6/hSX/oyVvweU2hNrwtYPeoDeIOcxv/EVIvb2D", + "Wm2hFJSrX+tM4Nu2MfoRLJmvrVAbdP0ZdLP/BgAA///4O3pv2y8AAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/go/api/types.gen.go b/go/api/types.gen.go index a6faeaf8f..c0c3a0a54 100644 --- a/go/api/types.gen.go +++ b/go/api/types.gen.go @@ -10,7 +10,7 @@ import ( ) const ( - BearerAuthScopes = "BearerAuth.Scopes" + BearerAuthElevatedScopes = "BearerAuthElevated.Scopes" ) // Defines values for ErrorResponseError. diff --git a/go/cmd/jwt_getter.go b/go/cmd/jwt_getter.go index 113593e01..1b1d2afe2 100644 --- a/go/cmd/jwt_getter.go +++ b/go/cmd/jwt_getter.go @@ -9,7 +9,7 @@ import ( "github.com/urfave/cli/v2" ) -func getJWTGetter(cCtx *cli.Context) (*controller.JWTGetter, error) { +func getJWTGetter(cCtx *cli.Context, db controller.DBClient) (*controller.JWTGetter, error) { var customClaimer controller.CustomClaimer var err error if cCtx.String(flagCustomClaims) != "" { @@ -28,6 +28,8 @@ func getJWTGetter(cCtx *cli.Context) (*controller.JWTGetter, error) { []byte(cCtx.String(flagHasuraGraphqlJWTSecret)), time.Duration(cCtx.Int(flagAccessTokensExpiresIn))*time.Second, customClaimer, + cCtx.String(flagRequireElevatedClaim), + db, ) if err != nil { return nil, fmt.Errorf("error creating jwt getter: %w", err) diff --git a/go/cmd/serve.go b/go/cmd/serve.go index 118cc320e..db22bb334 100644 --- a/go/cmd/serve.go +++ b/go/cmd/serve.go @@ -67,6 +67,7 @@ const ( flagAllowedEmailDomains = "allowed-email-domains" flagAllowedEmails = "allowed-emails" flagEmailPasswordlessEnabled = "email-passwordless-enabled" + flagRequireElevatedClaim = "require-elevated-claim" ) func CommandServe() *cli.Command { //nolint:funlen,maintidx @@ -376,6 +377,20 @@ func CommandServe() *cli.Command { //nolint:funlen,maintidx Category: "signin", EnvVars: []string{"AUTH_EMAIL_PASSWORDLESS_ENABLED"}, }, + &cli.GenericFlag{ //nolint: exhaustruct + Name: flagRequireElevatedClaim, + Value: &EnumValue{ //nolint: exhaustruct + Enum: []string{ + "disabled", + "recommended", + "required", + }, + Default: "disabled", + }, + Usage: "Require x-hasura-auth-elevated claim to perform certain actions: create PATs, change email and/or password, enable/disable MFA and add security keys. If set to `recommended` the claim check is only performed if the user has a security key attached. If set to `required` the only action that won't require the claim is setting a security key for the first time.", + Category: "security", + EnvVars: []string{"AUTH_REQUIRE_ELEVATED_CLAIM"}, + }, }, Action: serve, } @@ -441,7 +456,7 @@ func getGoServer( //nolint:funlen return nil, fmt.Errorf("problem creating config: %w", err) } - jwtGetter, err := getJWTGetter(cCtx) + jwtGetter, err := getJWTGetter(cCtx, db) if err != nil { return nil, fmt.Errorf("problem creating jwt getter: %w", err) } diff --git a/go/controller/controller.go b/go/controller/controller.go index 8ae42e7a5..f4ceadb43 100644 --- a/go/controller/controller.go +++ b/go/controller/controller.go @@ -36,7 +36,8 @@ type Emailer interface { ) error } -type DBClient interface { +type DBClient interface { //nolint:interfacebloat + CountSecurityKeysUser(ctx context.Context, userID uuid.UUID) (int64, error) GetUser(ctx context.Context, id uuid.UUID) (sql.AuthUser, error) GetUserByEmail(ctx context.Context, email pgtype.Text) (sql.AuthUser, error) GetUserByRefreshTokenHash( diff --git a/go/controller/errors.go b/go/controller/errors.go index 72d3e3e4f..a05dfbdb1 100644 --- a/go/controller/errors.go +++ b/go/controller/errors.go @@ -2,6 +2,7 @@ package controller import ( "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -17,6 +18,8 @@ func (e *APIError) Error() string { return fmt.Sprintf("API error: %s", e.t) } +var ErrElevatedClaimRequired = errors.New("elevated-claim-required") + var ( ErrUserEmailNotFound = &APIError{api.InvalidEmailPassword} ErrEmailAlreadyInUse = &APIError{api.EmailAlreadyInUse} diff --git a/go/controller/jwt.go b/go/controller/jwt.go index fa8853aeb..70fa14e20 100644 --- a/go/controller/jwt.go +++ b/go/controller/jwt.go @@ -53,12 +53,16 @@ type JWTGetter struct { method jwt.SigningMethod customClaimer CustomClaimer accessTokenExpiresIn time.Duration + elevatedClaimMode string + db DBClient } func NewJWTGetter( jwtSecretb []byte, accessTokenExpiresIn time.Duration, customClaimer CustomClaimer, + elevatedClaimMode string, + db DBClient, ) (*JWTGetter, error) { jwtSecret, err := decodeJWTSecret(jwtSecretb) if err != nil { @@ -74,6 +78,8 @@ func NewJWTGetter( method: method, customClaimer: customClaimer, accessTokenExpiresIn: accessTokenExpiresIn, + elevatedClaimMode: elevatedClaimMode, + db: db, }, nil } @@ -182,6 +188,21 @@ func (j *JWTGetter) Validate(accessToken string) (*jwt.Token, error) { func (j *JWTGetter) FromContext(ctx context.Context) (*jwt.Token, bool) { token, ok := ctx.Value(jwtContextKey).(*jwt.Token) + if !ok { //nolint:nestif + c := ginmiddleware.GetGinContext(ctx) + if c != nil { + a, ok := c.Get(jwtContextKey) + if !ok { + return nil, false + } + + token, ok = a.(*jwt.Token) + if !ok { + return nil, false + } + return token, true + } + } return token, ok } @@ -189,6 +210,36 @@ func (j *JWTGetter) ToContext(ctx context.Context, jwtToken *jwt.Token) context. return context.WithValue(ctx, jwtContextKey, jwtToken) //nolint:revive,staticcheck } +func (j *JWTGetter) verifyElevatedClaim(ctx context.Context, token *jwt.Token) (bool, error) { + if j.elevatedClaimMode == "disabled" { + return true, nil + } + + u, err := token.Claims.GetSubject() + if err != nil { + return false, fmt.Errorf("error getting user id from subject: %w", err) + } + + if j.elevatedClaimMode == "recommended" { + userID, err := uuid.Parse(u) + if err != nil { + return false, fmt.Errorf("error parsing user id: %w", err) + } + n, err := j.db.CountSecurityKeysUser(ctx, userID) + if err != nil { + return false, fmt.Errorf("error checking if user has security keys: %w", err) + } + + if n == 0 { + return true, nil + } + } + + elevatedClaim := j.GetCustomClaim(token, "x-hasura-auth-elevated") + + return elevatedClaim == u, nil +} + func (j *JWTGetter) MiddlewareFunc( ctx context.Context, input *openapi3filter.AuthenticationInput, ) error { @@ -207,6 +258,16 @@ func (j *JWTGetter) MiddlewareFunc( return fmt.Errorf("invalid token") //nolint:goerr113 } + if input.SecuritySchemeName == "BearerAuthElevated" { + found, err := j.verifyElevatedClaim(ctx, jwtToken) + if err != nil { + return fmt.Errorf("error verifying elevated claim: %w", err) + } + if !found { + return ErrElevatedClaimRequired + } + } + c := ginmiddleware.GetGinContext(ctx) c.Set(jwtContextKey, jwtToken) diff --git a/go/controller/jwt_test.go b/go/controller/jwt_test.go index c26f57556..cfd04d7d3 100644 --- a/go/controller/jwt_test.go +++ b/go/controller/jwt_test.go @@ -3,16 +3,21 @@ package controller_test import ( "context" "crypto" + "errors" "log/slog" + "net/http" "testing" "time" + "github.com/getkin/kin-openapi/openapi3filter" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/nhost/hasura-auth/go/controller" "github.com/nhost/hasura-auth/go/controller/mock" + ginmiddleware "github.com/oapi-codegen/gin-middleware" "go.uber.org/mock/gomock" ) @@ -224,7 +229,7 @@ func TestGetJWTFunc(t *testing.T) { if tc.customClaimer != nil { customClaimer = tc.customClaimer(ctrl) } - jwtGetter, err := controller.NewJWTGetter(tc.key, tc.expiresIn, customClaimer) + jwtGetter, err := controller.NewJWTGetter(tc.key, tc.expiresIn, customClaimer, "", nil) if err != nil { t.Fatalf("GetJWTFunc() err = %v; want nil", err) } @@ -254,3 +259,408 @@ func TestGetJWTFunc(t *testing.T) { }) } } + +//nolint:dupl,goconst +func TestMiddlewareFunc(t *testing.T) { //nolint:maintidx + t.Parallel() + + userID := uuid.MustParse("8b3107c8-8b1c-4f14-a403-c1c446e36ec3") + + //nolint:lll,gosec + nonElevatedToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjEwNzExMTEyMzA4LCJodHRwczovL2hhc3VyYS5pby9qd3QvY2xhaW1zIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsibWUiLCJ1c2VyIiwiZWRpdG9yIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6InVzZXIiLCJ4LWhhc3VyYS11c2VyLWlkIjoiOGIzMTA3YzgtOGIxYy00ZjE0LWE0MDMtYzFjNDQ2ZTM2ZWMzIiwieC1oYXN1cmEtdXNlci1pc0Fub255bW91cyI6ImZhbHNlIn0sImlhdCI6MTcxMTExMjMwOCwiaXNzIjoiaGFzdXJhLWF1dGgiLCJzdWIiOiI4YjMxMDdjOC04YjFjLTRmMTQtYTQwMy1jMWM0NDZlMzZlYzMifQ.vryKygEgosBsRZDQDxpAdbpU_HEA4E8p6Rg0KOtrLV4` + //nolint:lll + elevatedToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjEwNzExMTEyMzk1LCJodHRwczovL2hhc3VyYS5pby9qd3QvY2xhaW1zIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsibWUiLCJ1c2VyIiwiZWRpdG9yIl0sIngtaGFzdXJhLWF1dGgtZWxldmF0ZWQiOiI4YjMxMDdjOC04YjFjLTRmMTQtYTQwMy1jMWM0NDZlMzZlYzMiLCJ4LWhhc3VyYS1kZWZhdWx0LXJvbGUiOiJ1c2VyIiwieC1oYXN1cmEtdXNlci1pZCI6IjhiMzEwN2M4LThiMWMtNGYxNC1hNDAzLWMxYzQ0NmUzNmVjMyIsIngtaGFzdXJhLXVzZXItaXNBbm9ueW1vdXMiOiJmYWxzZSJ9LCJpYXQiOjE3MTExMTIzOTUsImlzcyI6Imhhc3VyYS1hdXRoIiwic3ViIjoiOGIzMTA3YzgtOGIxYy00ZjE0LWE0MDMtYzFjNDQ2ZTM2ZWMzIn0.ySwnKlt5_7R112OMkJrzUi5v9jE3nbaAbZTmLILKCYE` //nolint:gosec + + cases := []struct { + name string + elevatedMode string + db func(ctrl *gomock.Controller) *mock.MockDBClient + request *openapi3filter.AuthenticationInput + expected *jwt.Token + expectedErr error + }{ + { + name: "BearerAuth: elevated disabled", + elevatedMode: "disabled", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuth", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: nonElevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112308e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112308e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuth: elevated recommended, no security keys, claim not present", + elevatedMode: "recommended", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuth", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: nonElevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112308e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112308e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuth: elevated required, no security keys, claim not present", + elevatedMode: "required", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuth", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: nonElevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112308e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112308e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuthElevated: elevated disabled", + elevatedMode: "disabled", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: nonElevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112308e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112308e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuthElevated: elevated recommended, no security keys, claim not present", + elevatedMode: "recommended", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + mock.EXPECT().CountSecurityKeysUser(gomock.Any(), userID).Return(int64(0), nil) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: nonElevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112308e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112308e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuthElevated: elevated recommended, security keys, claim not present", + elevatedMode: "recommended", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + mock.EXPECT().CountSecurityKeysUser(gomock.Any(), userID).Return(int64(1), nil) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: nil, + expectedErr: controller.ErrElevatedClaimRequired, + }, + + { + name: "BearerAuthElevated: elevated required, no security keys, claim not present", + elevatedMode: "required", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + nonElevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: nil, + expectedErr: controller.ErrElevatedClaimRequired, + }, + + { + name: "BearerAuthElevated: elevated recommended, security keys, claim present", + elevatedMode: "recommended", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + mock.EXPECT().CountSecurityKeysUser(gomock.Any(), userID).Return(int64(1), nil) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + elevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: elevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112395e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-auth-elevated": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112395e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + + { + name: "BearerAuthElevated: elevated required, security keys, claim present", + elevatedMode: "required", + db: func(ctrl *gomock.Controller) *mock.MockDBClient { + mock := mock.NewMockDBClient(ctrl) + return mock + }, + //nolint:exhaustruct + request: &openapi3filter.AuthenticationInput{ + RequestValidationInput: &openapi3filter.RequestValidationInput{ + Request: &http.Request{ + Header: http.Header{ + "Authorization": []string{"Bearer " + elevatedToken}, + }, + }, + }, + SecuritySchemeName: "BearerAuthElevated", + SecurityScheme: nil, + Scopes: []string{}, + }, + expected: &jwt.Token{ + Raw: elevatedToken, + Method: jwt.SigningMethodHS256, + Header: map[string]any{"alg": string("HS256"), "typ": string("JWT")}, + Claims: jwt.MapClaims{ + "exp": float64(1.0711112395e+10), + "https://hasura.io/jwt/claims": map[string]any{ + "x-hasura-allowed-roles": []any{"me", "user", "editor"}, + "x-hasura-auth-elevated": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-default-role": string("user"), + "x-hasura-user-id": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + "x-hasura-user-isAnonymous": string("false"), + }, + "iat": float64(1.711112395e+09), + "iss": string("hasura-auth"), + "sub": string("8b3107c8-8b1c-4f14-a403-c1c446e36ec3"), + }, + Signature: []uint8{}, + Valid: true, + }, + expectedErr: nil, + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc := tc + + ctrl := gomock.NewController(t) + + jwtGetter, err := controller.NewJWTGetter( + jwtSecret, time.Hour, nil, tc.elevatedMode, tc.db(ctrl), + ) + if err != nil { + t.Fatalf("GetJWTFunc() err = %v; want nil", err) + } + + //nolint + ctx := context.WithValue( + context.Background(), + ginmiddleware.GinContextKey, + &gin.Context{}, + ) + err = jwtGetter.MiddlewareFunc(ctx, tc.request) + if !errors.Is(err, tc.expectedErr) { + t.Errorf("err = %v; want %v", err, tc.expectedErr) + } + + got, _ := jwtGetter.FromContext(ctx) + + cmpopts := []cmp.Option{ + cmpopts.IgnoreFields(jwt.Token{}, "Signature"), //nolint:exhaustruct + } + if diff := cmp.Diff(got, tc.expected, cmpopts...); diff != "" { + t.Errorf("got mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/go/controller/main_test.go b/go/controller/main_test.go index 55932c1a9..ca88553e5 100644 --- a/go/controller/main_test.go +++ b/go/controller/main_test.go @@ -178,6 +178,8 @@ func getController( jwtSecret, time.Second*time.Duration(config().AccessTokenExpiresIn), cc, + "", + nil, ) if err != nil { t.Fatalf("failed to create jwt getter: %v", err) diff --git a/go/controller/mock/controller.go b/go/controller/mock/controller.go index 769808c41..5bfed5513 100644 --- a/go/controller/mock/controller.go +++ b/go/controller/mock/controller.go @@ -80,6 +80,21 @@ func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { return m.recorder } +// CountSecurityKeysUser mocks base method. +func (m *MockDBClient) CountSecurityKeysUser(ctx context.Context, userID uuid.UUID) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountSecurityKeysUser", ctx, userID) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountSecurityKeysUser indicates an expected call of CountSecurityKeysUser. +func (mr *MockDBClientMockRecorder) CountSecurityKeysUser(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountSecurityKeysUser", reflect.TypeOf((*MockDBClient)(nil).CountSecurityKeysUser), ctx, userID) +} + // GetUser mocks base method. func (m *MockDBClient) GetUser(ctx context.Context, id uuid.UUID) (sql.AuthUser, error) { m.ctrl.T.Helper() diff --git a/go/sql/query.sql b/go/sql/query.sql index 898e7040f..d7c0aa971 100644 --- a/go/sql/query.sql +++ b/go/sql/query.sql @@ -95,3 +95,7 @@ UPDATE auth.users SET (ticket, ticket_expires_at, new_email) = ($2, $3, $4) WHERE id = $1 RETURNING *; + +-- name: CountSecurityKeysUser :one +SELECT COUNT(*) FROM auth.user_security_keys +WHERE user_id = $1; diff --git a/go/sql/query.sql.go b/go/sql/query.sql.go index 19e78b46d..bef379f1a 100644 --- a/go/sql/query.sql.go +++ b/go/sql/query.sql.go @@ -12,6 +12,18 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) +const countSecurityKeysUser = `-- name: CountSecurityKeysUser :one +SELECT COUNT(*) FROM auth.user_security_keys +WHERE user_id = $1 +` + +func (q *Queries) CountSecurityKeysUser(ctx context.Context, userID uuid.UUID) (int64, error) { + row := q.db.QueryRow(ctx, countSecurityKeysUser, userID) + var count int64 + err := row.Scan(&count) + return count, err +} + const getUser = `-- name: GetUser :one SELECT id, created_at, updated_at, last_seen, disabled, display_name, avatar_url, locale, email, phone_number, password_hash, email_verified, phone_number_verified, new_email, otp_method_last_used, otp_hash, otp_hash_expires_at, default_role, is_anonymous, totp_secret, active_mfa_type, ticket, ticket_expires_at, metadata, webauthn_current_challenge FROM auth.users WHERE id = $1 LIMIT 1