From 22d7d63bf6d5ed59dc3d38cd891061b227ec8f91 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sat, 4 Jan 2020 14:39:47 +0000 Subject: [PATCH 01/23] Allow providers to validate and refresh sessions This change allows authentication providers to implement two new types of behavior: expiring a session and refreshing a session. Both changes were made with the intent of supporting upcoming changes in the OAuth2 providers, but might find use in other types of authentication providers as well. The use case for expiration is that OAuth2 authentication results in an access token that often has an expiry date. Using the access token to make requests beyond that date will result in an error response. The auth middleware has the opportunity to detect an invalid token early and see the user replaces it, pre-empting a failed request later and so resulting in a better user experience. The use case for refreshing a session is that OAuth2 also provides a mechanism for obtaining a new access token without involving the user, by using a refresh token. When we obbtain a new access token this way we do need to update the session to store it, and this commit makes that possible. To maintain backwards compatibility we add a default behavior to all existing providers that makes them behave like before this change. The default behavior will never cause sessions to be expired or refreshed. --- src/Network/Wai/Middleware/Auth.hs | 28 ++++++++++++++++++--- src/Network/Wai/Middleware/Auth/Provider.hs | 14 ++++++++++- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth.hs b/src/Network/Wai/Middleware/Auth.hs index 5fcfaf0..b018217 100644 --- a/src/Network/Wai/Middleware/Auth.hs +++ b/src/Network/Wai/Middleware/Auth.hs @@ -41,7 +41,8 @@ import GHC.Generics (Generic) import Network.HTTP.Types (Header, status200, status303, status404, status501) -import Network.Wai (Middleware, Request, +import Network.Wai (mapResponseHeaders, + Middleware, Request, pathInfo, rawPathInfo, rawQueryString, responseBuilder, @@ -273,8 +274,29 @@ mkAuthMiddleware AuthSettings {..} = do authState <- loadCookieValue secretKey asStateKey req case authState of Just (AuthLoggedIn user) -> - let req' = req {vault = Vault.insert userKey user $ vault req} - in app req' respond + let providerName = decodeUtf8With lenientDecode (authProviderName user) + in case HM.lookup providerName asProviders of + Nothing -> + -- We can no longer find the provider the user originally + -- authenticated with, and as a result have no way to check if the + -- session is still valid. For backwards compatibility with older + -- versions of this library we'll assume the session remains valid. + let req' = req {vault = Vault.insert userKey user $ vault req} + in app req' respond + Just provider -> do + refreshResult <- refreshLoginState provider user + case refreshResult of + Nothing -> + -- The session has expired, the user needs to re-authenticate. + enforceLogin "/" req respond + Just user' -> + let req' = req {vault = Vault.insert userKey user' $ vault req} + respond' response + | user' == user = respond response + | otherwise = do + cookieHeader <- saveAuthState (AuthLoggedIn user') + respond $ mapResponseHeaders (cookieHeader :) response + in app req' respond' Just (AuthNeedRedirect url) -> enforceLogin url req respond Nothing -> enforceLogin "/" req respond diff --git a/src/Network/Wai/Middleware/Auth/Provider.hs b/src/Network/Wai/Middleware/Auth/Provider.hs index 26cb249..a74aad5 100644 --- a/src/Network/Wai/Middleware/Auth/Provider.hs +++ b/src/Network/Wai/Middleware/Auth/Provider.hs @@ -105,6 +105,18 @@ class AuthProvider ap where -> (Status -> S.ByteString -> IO Response) -> IO Response + -- | Check if the login state in a session is still valid, and have the + -- opportunity to update it. Return `Nothing` to indicate a session has + -- expired, and the user will be directed to re-authenticate. + -- + -- The default implementation never invalidates a session once set. + -- + -- @since X.Y.Z + refreshLoginState + :: ap + -> AuthUser + -> IO (Maybe AuthUser) + refreshLoginState _ loginState = pure (Just loginState) -- | Generic authentication provider wrapper. data Provider where @@ -153,7 +165,7 @@ data AuthUser = AuthUser { authLoginState :: !UserIdentity , authProviderName :: !S.ByteString , authLoginTime :: !Int64 - } deriving (Generic, Show) + } deriving (Eq, Generic, Show) instance Binary AuthUser From 0dfa610785315b1a2c73ba9ea212622f2f4a54ab Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 5 Jan 2020 14:11:48 +0000 Subject: [PATCH 02/23] OAuth2 middleware supports refresh tokens --- src/Network/Wai/Middleware/Auth/OAuth2.hs | 48 ++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index 38b94ff..ecb6fdd 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -21,6 +21,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8, decodeUtf8With) import Data.Text.Encoding.Error (lenientDecode) +import Foreign.C.Types (CTime (..)) import Network.HTTP.Client.TLS (getGlobalManager) import Network.HTTP.Types (status303, status403, status404, status501) @@ -31,6 +32,7 @@ import Network.Wai.Auth.Internal (encodeToken, decodeToken) import Network.Wai.Auth.Tools (toLowerUnderscore) import qualified Network.Wai.Middleware.Auth as MA import Network.Wai.Middleware.Auth.Provider +import System.PosixCompat.Time (epochTime) import qualified URI.ByteString as U import URI.ByteString (URI) @@ -142,7 +144,51 @@ instance AuthProvider OAuth2 where status404 "Page not found. Please continue with login." _ -> onFailure status404 "Page not found. Please continue with login." - + refreshLoginState OAuth2 {..} user = + let loginState = authLoginState user + in case decodeToken loginState of + Left _ -> pure Nothing + Right tokens -> do + CTime now <- epochTime + if tokenExpired user now tokens then + case OA2.refreshToken tokens of + Nothing -> pure Nothing + Just refreshToken -> do + authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint + accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint + let oauth2 = + OA2.OAuth2 + { oauthClientId = getClientId oa2ClientId + , oauthClientSecret = getClientSecret oa2ClientSecret + , oauthOAuthorizeEndpoint = authEndpointURI + , oauthAccessTokenEndpoint = accessTokenEndpointURI + -- Setting callback endpoint to `Nothing` below is a lie. + -- We do have a callback endpoint but in this context + -- don't have access to the function that can render it. + -- We get away with this because the callback endpoint is + -- not needed for obtaining a refresh token, the only + -- way we use the config here constructed. + , oauthCallback = Nothing + } + man <- getGlobalManager + rRes <- OA2.refreshAccessToken man oauth2 refreshToken + case rRes of + Left _ -> pure Nothing + Right tokens' -> + let user' = + user { + authLoginState = encodeToken tokens', + authLoginTime = fromIntegral now + } + in pure (Just user') + else + pure (Just user) + +tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool +tokenExpired user now tokens = + case OA2.expiresIn tokens of + Nothing -> False + Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2) From 51eed3c5312a3682a7cff0c45787778e4969ef48 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 5 Jan 2020 15:03:05 +0000 Subject: [PATCH 03/23] Extract implementation of OAuth2 provider This is preparation for creation of a new Open ID Connect provider, which will now be able to reuse most of the logic for the OAuth2 provider. --- src/Network/Wai/Auth/Internal.hs | 112 +++++++++++++++++ src/Network/Wai/Middleware/Auth/OAuth2.hs | 140 +++++----------------- 2 files changed, 145 insertions(+), 107 deletions(-) diff --git a/src/Network/Wai/Auth/Internal.hs b/src/Network/Wai/Auth/Internal.hs index 896a8ef..60568f9 100644 --- a/src/Network/Wai/Auth/Internal.hs +++ b/src/Network/Wai/Auth/Internal.hs @@ -1,15 +1,37 @@ {-# OPTIONS_HADDOCK hide, not-home #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} module Network.Wai.Auth.Internal ( OAuth2TokenBinary(..) , encodeToken , decodeToken + , oauth2Login + , oauth2RefreshLogin ) where import Data.Binary (Binary(get, put), encode, decodeOrFail) import qualified Data.ByteString as S +import qualified Data.ByteString.Char8 as S8 (pack) import qualified Data.ByteString.Lazy as SL +import Data.Int +import qualified Data.Text as T +import Data.Text.Encoding (encodeUtf8, + decodeUtf8With) +import Data.Text.Encoding.Error (lenientDecode) +import Foreign.C.Types (CTime (..)) +import Network.HTTP.Client (Manager) +import Network.HTTP.Types (Status, status303, + status403, status404, + status501) import qualified Network.OAuth.OAuth2 as OA2 +import Network.Wai (Request, Response, + queryString, responseLBS) +import Network.Wai.Middleware.Auth.Provider +import System.PosixCompat.Time (epochTime) +import qualified URI.ByteString as U +import URI.ByteString (URI) decodeToken :: S.ByteString -> Either String OA2.OAuth2Token decodeToken bs = @@ -39,3 +61,93 @@ instance Binary OAuth2TokenBinary where idToken <- fmap OA2.IdToken <$> get pure $ OAuth2TokenBinary $ OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken + +oauth2Login + :: OA2.OAuth2 + -> Manager + -> Maybe [T.Text] + -> T.Text + -> Request + -> [T.Text] + -> (AuthLoginState -> IO Response) + -> (Status -> S.ByteString -> IO Response) + -> IO Response +oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure = + case suffix of + [] -> do + let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope + let redirectUrl = + getRedirectURI $ + appendQueryParams + (OA2.authorizationUrl oauth2) + (maybe [] ((: []) . ("scope", )) scope) + return $ + responseLBS + status303 + [("Location", redirectUrl)] + "Redirect to OAuth2 Authentication server" + ["complete"] -> + let params = queryString req + in case lookup "code" params of + Just (Just code) -> do + eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code + case eRes of + Left err -> onFailure status501 $ S8.pack $ show err + Right token -> onSuccess $ encodeToken token + _ -> + case lookup "error" params of + (Just (Just "access_denied")) -> + onFailure + status403 + "User rejected access to the application." + (Just (Just error_code)) -> + onFailure status501 $ "Received an error: " <> error_code + (Just Nothing) -> + onFailure status501 $ + "Unknown error connecting to " <> + encodeUtf8 providerName + Nothing -> + onFailure + status404 + "Page not found. Please continue with login." + _ -> onFailure status404 "Page not found. Please continue with login." + +oauth2RefreshLogin :: OA2.OAuth2 -> Manager -> AuthUser -> IO (Maybe AuthUser) +oauth2RefreshLogin oauth2 man user = + let loginState = authLoginState user + in case decodeToken loginState of + Left _ -> pure Nothing + Right tokens -> do + CTime now <- epochTime + if tokenExpired user now tokens then + case OA2.refreshToken tokens of + Nothing -> pure Nothing + Just refreshToken -> do + rRes <- OA2.refreshAccessToken man oauth2 refreshToken + case rRes of + Left _ -> pure Nothing + Right tokens' -> + let user' = + user { + authLoginState = encodeToken tokens', + authLoginTime = fromIntegral now + } + in pure (Just user') + else + pure (Just user) + +tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool +tokenExpired user now tokens = + case OA2.expiresIn tokens of + Nothing -> False + Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now + +getExchangeToken :: S.ByteString -> OA2.ExchangeToken +getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode + +appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI +appendQueryParams uri params = + OA2.appendQueryParams params uri + +getRedirectURI :: U.URIRef a -> S.ByteString +getRedirectURI = U.serializeURIRef' diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index ecb6fdd..53dee86 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -14,27 +14,18 @@ import Control.Monad.Catch import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) -import qualified Data.ByteString as S -import qualified Data.ByteString.Char8 as S8 (pack) import Data.Proxy (Proxy (..)) import qualified Data.Text as T -import Data.Text.Encoding (encodeUtf8, - decodeUtf8With) -import Data.Text.Encoding.Error (lenientDecode) -import Foreign.C.Types (CTime (..)) +import Data.Text.Encoding (encodeUtf8) import Network.HTTP.Client.TLS (getGlobalManager) -import Network.HTTP.Types (status303, status403, - status404, status501) import qualified Network.OAuth.OAuth2 as OA2 -import Network.Wai (Request, queryString, - responseLBS) -import Network.Wai.Auth.Internal (encodeToken, decodeToken) +import Network.Wai (Request) +import Network.Wai.Auth.Internal (decodeToken, oauth2Login, + oauth2RefreshLogin) import Network.Wai.Auth.Tools (toLowerUnderscore) import qualified Network.Wai.Middleware.Auth as MA import Network.Wai.Middleware.Auth.Provider -import System.PosixCompat.Time (epochTime) import qualified URI.ByteString as U -import URI.ByteString (URI) -- | General OAuth2 authentication `Provider`. data OAuth2 = OAuth2 @@ -67,22 +58,12 @@ parseAbsoluteURI urlTxt = do parseAbsoluteURI' :: MonadThrow m => T.Text -> m U.URI parseAbsoluteURI' = parseAbsoluteURI -getExchangeToken :: S.ByteString -> OA2.ExchangeToken -getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode - -appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI -appendQueryParams uri params = - OA2.appendQueryParams params uri - getClientId :: T.Text -> T.Text getClientId = id getClientSecret :: T.Text -> T.Text getClientSecret = id -getRedirectURI :: U.URIRef a -> S.ByteString -getRedirectURI = U.serializeURIRef' - -- | Aeson parser for `OAuth2` provider. -- -- @since 0.1.0 @@ -105,90 +86,35 @@ instance AuthProvider OAuth2 where , oauthAccessTokenEndpoint = accessTokenEndpointURI , oauthCallback = Just callbackURI } - case suffix of - [] -> do - let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope - let redirectUrl = - getRedirectURI $ - appendQueryParams - (OA2.authorizationUrl oauth2) - (maybe [] ((: []) . ("scope", )) scope) - return $ - responseLBS - status303 - [("Location", redirectUrl)] - "Redirect to OAuth2 Authentication server" - ["complete"] -> - let params = queryString req - in case lookup "code" params of - Just (Just code) -> do - man <- getGlobalManager - eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code - case eRes of - Left err -> onFailure status501 $ S8.pack $ show err - Right token -> onSuccess $ encodeToken token - _ -> - case lookup "error" params of - (Just (Just "access_denied")) -> - onFailure - status403 - "User rejected access to the application." - (Just (Just error_code)) -> - onFailure status501 $ "Received an error: " <> error_code - (Just Nothing) -> - onFailure status501 $ - "Unknown error connecting to " <> - encodeUtf8 (getProviderName oa2) - Nothing -> - onFailure - status404 - "Page not found. Please continue with login." - _ -> onFailure status404 "Page not found. Please continue with login." - refreshLoginState OAuth2 {..} user = - let loginState = authLoginState user - in case decodeToken loginState of - Left _ -> pure Nothing - Right tokens -> do - CTime now <- epochTime - if tokenExpired user now tokens then - case OA2.refreshToken tokens of - Nothing -> pure Nothing - Just refreshToken -> do - authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint - accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint - let oauth2 = - OA2.OAuth2 - { oauthClientId = getClientId oa2ClientId - , oauthClientSecret = getClientSecret oa2ClientSecret - , oauthOAuthorizeEndpoint = authEndpointURI - , oauthAccessTokenEndpoint = accessTokenEndpointURI - -- Setting callback endpoint to `Nothing` below is a lie. - -- We do have a callback endpoint but in this context - -- don't have access to the function that can render it. - -- We get away with this because the callback endpoint is - -- not needed for obtaining a refresh token, the only - -- way we use the config here constructed. - , oauthCallback = Nothing - } - man <- getGlobalManager - rRes <- OA2.refreshAccessToken man oauth2 refreshToken - case rRes of - Left _ -> pure Nothing - Right tokens' -> - let user' = - user { - authLoginState = encodeToken tokens', - authLoginTime = fromIntegral now - } - in pure (Just user') - else - pure (Just user) - -tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool -tokenExpired user now tokens = - case OA2.expiresIn tokens of - Nothing -> False - Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now + man <- getGlobalManager + oauth2Login + oauth2 + man + oa2Scope + (getProviderName oa2) + req + suffix + onSuccess + onFailure + refreshLoginState OAuth2 {..} user = do + authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint + accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint + let oauth2 = + OA2.OAuth2 + { oauthClientId = getClientId oa2ClientId + , oauthClientSecret = getClientSecret oa2ClientSecret + , oauthOAuthorizeEndpoint = authEndpointURI + , oauthAccessTokenEndpoint = accessTokenEndpointURI + -- Setting callback endpoint to `Nothing` below is a lie. + -- We do have a callback endpoint but in this context + -- don't have access to the function that can render it. + -- We get away with this because the callback endpoint is + -- not needed for obtaining a refresh token, the only + -- way we use the config here constructed. + , oauthCallback = Nothing + } + man <- getGlobalManager + oauth2RefreshLogin oauth2 man user $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2) From 30c6b6bb61259199b845f34169a30a0914ef3142 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 12 Jan 2020 13:49:41 +0000 Subject: [PATCH 04/23] Add tasty test runner This is preparation for adding additional tests. Our existing tests are hedgehog property tests, the next ones aren't going to be. The tasty wrapper will allow us to run these different types of tests as a single suite. --- test/Main.hs | 56 ++++---------------------- test/Spec/Network/Wai/Auth/Internal.hs | 55 +++++++++++++++++++++++++ wai-middleware-auth.cabal | 4 ++ 3 files changed, 67 insertions(+), 48 deletions(-) create mode 100644 test/Spec/Network/Wai/Auth/Internal.hs diff --git a/test/Main.hs b/test/Main.hs index c9edb08..21a9f17 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -2,53 +2,13 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} module Main (main) where -import Data.Binary (encode, decodeOrFail) -import qualified Data.ByteString.Lazy.Char8 as BSL8 -import qualified Data.Text as T -import Hedgehog -import Hedgehog.Gen as Gen -import Hedgehog.Range as Range -import Network.Wai.Auth.Internal -import qualified Network.OAuth.OAuth2.Internal as OA2 +import Test.Tasty +import qualified Spec.Network.Wai.Auth.Internal -main :: IO Bool -main = - checkParallel $ Group "Main" [ - ("oAuth2TokenBinaryDuality", oAuth2TokenBinaryDuality) - ] - -oAuth2TokenBinaryDuality :: Property -oAuth2TokenBinaryDuality = property $ do - token <- forAll oauth2TokenBinary - let checkUnconsumed ("", _, roundTripToken) = roundTripToken - checkUnconsumed (unconsumed, _, _) = - error $ "Unexpected unconsumed in bytes: " <> BSL8.unpack unconsumed - tripping token encode (fmap checkUnconsumed . decodeOrFail) - tripping token (encodeToken . unOAuth2TokenBinary) (fmap OAuth2TokenBinary . decodeToken) - -oauth2TokenBinary :: Gen OAuth2TokenBinary -oauth2TokenBinary = do - accessToken <- OA2.AccessToken <$> anyText - refreshToken <- Gen.maybe $ OA2.RefreshToken <$> anyText - expiresIn <- Gen.maybe $ Gen.int (Range.linear 0 1000) - tokenType <- Gen.maybe anyText - idToken <- Gen.maybe $ OA2.IdToken <$> anyText - pure $ - OAuth2TokenBinary $ - OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken +main :: IO () +main = defaultMain tests -anyText :: Gen T.Text -anyText = Gen.text (Range.linear 0 100) Gen.unicodeAll - --- The `OAuth2Token` type from the `hoauth2` library does not have a `Eq` --- instance, and it's constituent parts don't have a `Generic` instance. Hence --- this orphan instance here. -instance Eq OAuth2TokenBinary where - (OAuth2TokenBinary t1) == (OAuth2TokenBinary t2) = - and - [ OA2.atoken (OA2.accessToken t1) == OA2.atoken (OA2.accessToken t2) - , (OA2.rtoken <$> OA2.refreshToken t1) == (OA2.rtoken <$> OA2.refreshToken t2) - , OA2.expiresIn t1 == OA2.expiresIn t2 - , OA2.tokenType t1 == OA2.tokenType t2 - , (OA2.idtoken <$> OA2.idToken t1) == (OA2.idtoken <$> OA2.idToken t2) - ] +tests :: TestTree +tests = testGroup "wai-middleware-auth" + [ Spec.Network.Wai.Auth.Internal.tests + ] diff --git a/test/Spec/Network/Wai/Auth/Internal.hs b/test/Spec/Network/Wai/Auth/Internal.hs new file mode 100644 index 0000000..afa0e93 --- /dev/null +++ b/test/Spec/Network/Wai/Auth/Internal.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} +module Spec.Network.Wai.Auth.Internal (tests) where + +import Data.Binary (encode, decodeOrFail) +import qualified Data.ByteString.Lazy.Char8 as BSL8 +import qualified Data.Text as T +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.Hedgehog (testProperty) +import Hedgehog +import Hedgehog.Gen as Gen +import Hedgehog.Range as Range +import Network.Wai.Auth.Internal +import qualified Network.OAuth.OAuth2.Internal as OA2 + +tests :: TestTree +tests = testGroup "Network.Wai.Auth.Internal" + [ testProperty "oAuth2TokenBinaryDuality" oAuth2TokenBinaryDuality + ] + +oAuth2TokenBinaryDuality :: Property +oAuth2TokenBinaryDuality = property $ do + token <- forAll oauth2TokenBinary + let checkUnconsumed ("", _, roundTripToken) = roundTripToken + checkUnconsumed (unconsumed, _, _) = + error $ "Unexpected unconsumed in bytes: " <> BSL8.unpack unconsumed + tripping token encode (fmap checkUnconsumed . decodeOrFail) + tripping token (encodeToken . unOAuth2TokenBinary) (fmap OAuth2TokenBinary . decodeToken) + +oauth2TokenBinary :: Gen OAuth2TokenBinary +oauth2TokenBinary = do + accessToken <- OA2.AccessToken <$> anyText + refreshToken <- Gen.maybe $ OA2.RefreshToken <$> anyText + expiresIn <- Gen.maybe $ Gen.int (Range.linear 0 1000) + tokenType <- Gen.maybe anyText + idToken <- Gen.maybe $ OA2.IdToken <$> anyText + pure $ + OAuth2TokenBinary $ + OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken + +anyText :: Gen T.Text +anyText = Gen.text (Range.linear 0 100) Gen.unicodeAll + +-- The `OAuth2Token` type from the `hoauth2` library does not have a `Eq` +-- instance, and it's constituent parts don't have a `Generic` instance. Hence +-- this orphan instance here. +instance Eq OAuth2TokenBinary where + (OAuth2TokenBinary t1) == (OAuth2TokenBinary t2) = + and + [ OA2.atoken (OA2.accessToken t1) == OA2.atoken (OA2.accessToken t2) + , (OA2.rtoken <$> OA2.refreshToken t1) == (OA2.rtoken <$> OA2.refreshToken t2) + , OA2.expiresIn t1 == OA2.expiresIn t2 + , OA2.tokenType t1 == OA2.tokenType t2 + , (OA2.idtoken <$> OA2.idToken t1) == (OA2.idtoken <$> OA2.idToken t2) + ] diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index ad78177..871020f 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -78,11 +78,15 @@ test-suite spec type: exitcode-stdio-1.0 main-is: Main.hs hs-source-dirs: test + other-modules: Spec.Network.Wai.Auth.Internal build-depends: base , binary , bytestring , hedgehog , hoauth2 + , tasty + , tasty-hedgehog + , tasty-hunit , text , wai-middleware-auth ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N From 852626f19568a1f709656c725f9fc58305e1dcc7 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Mon, 13 Jan 2020 08:31:18 +0000 Subject: [PATCH 05/23] Add tests for OAuth2 middleware --- src/Network/Wai/Middleware/Auth/Provider.hs | 1 + test/Main.hs | 2 + .../Network/Wai/Middleware/Auth/OAuth2.hs | 151 ++++++++++++++++++ wai-middleware-auth.cabal | 8 + 4 files changed, 162 insertions(+) create mode 100644 test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs diff --git a/src/Network/Wai/Middleware/Auth/Provider.hs b/src/Network/Wai/Middleware/Auth/Provider.hs index a74aad5..c5155a8 100644 --- a/src/Network/Wai/Middleware/Auth/Provider.hs +++ b/src/Network/Wai/Middleware/Auth/Provider.hs @@ -131,6 +131,7 @@ instance AuthProvider Provider where handleLogin (Provider p) = handleLogin p + refreshLoginState (Provider p) loginState = refreshLoginState p loginState -- | Collection of supported providers. type Providers = HM.HashMap T.Text Provider diff --git a/test/Main.hs b/test/Main.hs index 21a9f17..ce1b291 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -4,6 +4,7 @@ module Main (main) where import Test.Tasty import qualified Spec.Network.Wai.Auth.Internal +import qualified Spec.Network.Wai.Middleware.Auth.OAuth2 main :: IO () main = defaultMain tests @@ -11,4 +12,5 @@ main = defaultMain tests tests :: TestTree tests = testGroup "wai-middleware-auth" [ Spec.Network.Wai.Auth.Internal.tests + , Spec.Network.Wai.Middleware.Auth.OAuth2.tests ] diff --git a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs new file mode 100644 index 0000000..f3b0959 --- /dev/null +++ b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs @@ -0,0 +1,151 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Spec.Network.Wai.Middleware.Auth.OAuth2 (tests) where + +import Control.Monad (void) +import qualified Data.Aeson as Aeson +import Data.ByteString (ByteString) +import Data.Function ((&)) +import qualified Data.Text as T +import GHC.Exts (fromList) +import qualified Network.HTTP.Types.Status as Status +import qualified Network.OAuth.OAuth2 as OA2 +import qualified Network.Wai as Wai +import qualified Network.Wai.Handler.Warp as Warp +import qualified Network.Wai.Middleware.Auth as Auth +import Network.Wai.Middleware.Auth.OAuth2 (OAuth2(..), + getAccessToken) +import Network.Wai.Middleware.Auth.Provider (Provider(..), + ProviderInfo(..)) +import Network.Wai.Test (Session, SResponse, + assertHeader, + assertStatus, + defaultRequest, + request, runSession, + setClientCookie, + setPath) +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (testCase) +import qualified Web.Cookie as Cookie + +tests :: TestTree +tests = testGroup "Network.Wai.Auth.OAuth2" + [ testCase "when a request without a session is made then the response redirects to the oauth2 authorize endpoint" $ do + middleware <- Auth.mkAuthMiddleware $ authSettings "http://oauth2provider.com" + let app = middleware const200 + flip runSession app $ do + redirect1 <- get "/hi" + assertStatus 303 redirect1 + assertHeader "Location" "/prefix" redirect1 + redirect2 <- get "/prefix" + assertStatus 303 redirect2 + assertHeader "location" "/prefix/oauth2" redirect2 + redirect3 <- get "/prefix/oauth2" + assertStatus 303 redirect3 + assertHeader "location" "http://oauth2provider.com/authorize?scope=scope1%2Cscope2&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foauth2%2Fcomplete" redirect3 + + , testCase "when a request with an expired session is made then the response redirects to the oauth2 authorize endpoint" $ do + Warp.testWithApplication (pure (fakeProvider (-3600))) $ \port -> do + middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when a request with a valid session is made then the middleware passes the request through" $ do + Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do + middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when a request with an invalid session is made then the response redirects to the oauth2 authorize endpoint" $ do + Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do + middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) + let app = middleware const200 + flip runSession app $ do + -- First create a known valid session, so we can see that it's the act + -- of corrupting it that makes the test fail. + createSession + setClientCookie + Cookie.defaultSetCookie + { Cookie.setCookieName = "auth-cookie" + , Cookie.setCookieValue = "garbage" + } + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when a request is made to the oauth2 complete endpoint then the middleware fatches an access token and sets a user sesion" $ + Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do + middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) + let app = middleware const200 + flip runSession app $ do + response <- get "/prefix/oauth2/complete?code=1234" + assertStatus 303 response + assertHeader "location" "/" response + + , testCase "when a request with a valid session is made then the application can access the session payload" $ do + Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do + middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) + let app = middleware $ \req respond -> + case getAccessToken req of + Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" + Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" + flip runSession app $ do + createSession + response <- get "/prefix/oauth2/complete?code=1234" + assertStatus 200 response + ] + +get :: ByteString -> Session SResponse +get = request . setPath defaultRequest + +createSession :: Session () +createSession = void $ get "/prefix/oauth2/complete?code=1234" + +authSettings :: T.Text -> Auth.AuthSettings +authSettings host = + Auth.defaultAuthSettings + & Auth.setAuthProviders (fromList [("oauth2", provider host)]) + & Auth.setAuthPrefix "prefix" + & Auth.setAuthCookieName "auth-cookie" + +provider :: T.Text -> Provider +provider host = + Provider + OAuth2 + { oa2ClientId = "client-id" + , oa2ClientSecret = "client-secret" + , oa2AuthorizeEndpoint = host <> "/authorize" + , oa2AccessTokenEndpoint = host <> "/token" + , oa2Scope = Just ["scope1", "scope2"] + , oa2ProviderInfo = + ProviderInfo + { providerTitle = "" + , providerLogoUrl = "" + , providerDescr = "" + } + } + +const200 :: Wai.Application +const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" + +fakeProvider :: Int -> Wai.Application +fakeProvider expiresIn req respond = + case Wai.pathInfo req of + ["token"] -> + respond $ Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] body + where + body = + Aeson.encode OA2.OAuth2Token + { OA2.accessToken = OA2.AccessToken "access-granted", + OA2.refreshToken = Nothing, + OA2.expiresIn = Just expiresIn, + OA2.tokenType = Nothing, + OA2.idToken = Nothing + } + _ -> + respond $ Wai.responseLBS Status.notFound404 [] "" diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 871020f..20b2e98 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -42,6 +42,7 @@ library , http-conduit , http-reverse-proxy , http-types + , jose , regex-posix , safe-exceptions , shakespeare @@ -79,16 +80,23 @@ test-suite spec main-is: Main.hs hs-source-dirs: test other-modules: Spec.Network.Wai.Auth.Internal + , Spec.Network.Wai.Middleware.Auth.OAuth2 build-depends: base + , aeson , binary , bytestring + , cookie , hedgehog , hoauth2 + , http-types , tasty , tasty-hedgehog , tasty-hunit , text + , wai + , wai-extra , wai-middleware-auth + , warp ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N source-repository head From 391eff28742ddd7ba20e6a87c9510d4235535857 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Mon, 27 Jan 2020 22:00:37 +0000 Subject: [PATCH 06/23] Add OpenID Connect provider --- src/Network/Wai/Auth/Internal.hs | 43 +--- src/Network/Wai/Middleware/Auth.hs | 8 +- src/Network/Wai/Middleware/Auth/OAuth2.hs | 48 +++- .../Wai/Middleware/Auth/OpenIDConnect.hs | 232 ++++++++++++++++++ src/Network/Wai/Middleware/Auth/Provider.hs | 7 +- wai-middleware-auth.cabal | 3 + 6 files changed, 289 insertions(+), 52 deletions(-) create mode 100644 src/Network/Wai/Middleware/Auth/OpenIDConnect.hs diff --git a/src/Network/Wai/Auth/Internal.hs b/src/Network/Wai/Auth/Internal.hs index 60568f9..0462505 100644 --- a/src/Network/Wai/Auth/Internal.hs +++ b/src/Network/Wai/Auth/Internal.hs @@ -7,7 +7,7 @@ module Network.Wai.Auth.Internal , encodeToken , decodeToken , oauth2Login - , oauth2RefreshLogin + , refreshTokens ) where import Data.Binary (Binary(get, put), encode, @@ -15,12 +15,10 @@ import Data.Binary (Binary(get, put), encode, import qualified Data.ByteString as S import qualified Data.ByteString.Char8 as S8 (pack) import qualified Data.ByteString.Lazy as SL -import Data.Int import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8, decodeUtf8With) import Data.Text.Encoding.Error (lenientDecode) -import Foreign.C.Types (CTime (..)) import Network.HTTP.Client (Manager) import Network.HTTP.Types (Status, status303, status403, status404, @@ -29,7 +27,6 @@ import qualified Network.OAuth.OAuth2 as OA2 import Network.Wai (Request, Response, queryString, responseLBS) import Network.Wai.Middleware.Auth.Provider -import System.PosixCompat.Time (epochTime) import qualified URI.ByteString as U import URI.ByteString (URI) @@ -112,35 +109,15 @@ oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure = "Page not found. Please continue with login." _ -> onFailure status404 "Page not found. Please continue with login." -oauth2RefreshLogin :: OA2.OAuth2 -> Manager -> AuthUser -> IO (Maybe AuthUser) -oauth2RefreshLogin oauth2 man user = - let loginState = authLoginState user - in case decodeToken loginState of - Left _ -> pure Nothing - Right tokens -> do - CTime now <- epochTime - if tokenExpired user now tokens then - case OA2.refreshToken tokens of - Nothing -> pure Nothing - Just refreshToken -> do - rRes <- OA2.refreshAccessToken man oauth2 refreshToken - case rRes of - Left _ -> pure Nothing - Right tokens' -> - let user' = - user { - authLoginState = encodeToken tokens', - authLoginTime = fromIntegral now - } - in pure (Just user') - else - pure (Just user) - -tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool -tokenExpired user now tokens = - case OA2.expiresIn tokens of - Nothing -> False - Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now +refreshTokens :: OA2.OAuth2Token -> Manager -> OA2.OAuth2 -> IO (Maybe OA2.OAuth2Token) +refreshTokens tokens manager oauth2 = + case OA2.refreshToken tokens of + Nothing -> pure Nothing + Just refreshToken -> do + res <- OA2.refreshAccessToken manager oauth2 refreshToken + case res of + Left _ -> pure Nothing + Right newTokens -> pure (Just newTokens) getExchangeToken :: S.ByteString -> OA2.ExchangeToken getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode diff --git a/src/Network/Wai/Middleware/Auth.hs b/src/Network/Wai/Middleware/Auth.hs index b018217..8249dec 100644 --- a/src/Network/Wai/Middleware/Auth.hs +++ b/src/Network/Wai/Middleware/Auth.hs @@ -284,19 +284,19 @@ mkAuthMiddleware AuthSettings {..} = do let req' = req {vault = Vault.insert userKey user $ vault req} in app req' respond Just provider -> do - refreshResult <- refreshLoginState provider user + refreshResult <- refreshLoginState provider req user case refreshResult of Nothing -> -- The session has expired, the user needs to re-authenticate. enforceLogin "/" req respond - Just user' -> - let req' = req {vault = Vault.insert userKey user' $ vault req} + Just (req', user') -> + let req'' = req' {vault = Vault.insert userKey user' $ vault req'} respond' response | user' == user = respond response | otherwise = do cookieHeader <- saveAuthState (AuthLoggedIn user') respond $ mapResponseHeaders (cookieHeader :) response - in app req' respond' + in app req'' respond' Just (AuthNeedRedirect url) -> enforceLogin url req respond Nothing -> enforceLogin "/" req respond diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index 53dee86..ef62b7f 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -14,17 +14,21 @@ import Control.Monad.Catch import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) +import Data.Int (Int64) import Data.Proxy (Proxy (..)) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) +import Foreign.C.Types (CTime (..)) import Network.HTTP.Client.TLS (getGlobalManager) import qualified Network.OAuth.OAuth2 as OA2 import Network.Wai (Request) -import Network.Wai.Auth.Internal (decodeToken, oauth2Login, - oauth2RefreshLogin) +import Network.Wai.Auth.Internal (decodeToken, encodeToken, + oauth2Login, + refreshTokens) import Network.Wai.Auth.Tools (toLowerUnderscore) import qualified Network.Wai.Middleware.Auth as MA import Network.Wai.Middleware.Auth.Provider +import System.PosixCompat.Time (epochTime) import qualified URI.ByteString as U -- | General OAuth2 authentication `Provider`. @@ -55,9 +59,6 @@ parseAbsoluteURI urlTxt = do Left err -> throwM $ URIParseException err Right url -> return url -parseAbsoluteURI' :: MonadThrow m => T.Text -> m U.URI -parseAbsoluteURI' = parseAbsoluteURI - getClientId :: T.Text -> T.Text getClientId = id @@ -75,9 +76,9 @@ instance AuthProvider OAuth2 where getProviderName _ = "oauth2" getProviderInfo = oa2ProviderInfo handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do - authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint - accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint - callbackURI <- parseAbsoluteURI' $ renderUrl (ProviderUrl ["complete"]) [] + authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint + accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint + callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) [] let oauth2 = OA2.OAuth2 { oauthClientId = getClientId oa2ClientId @@ -96,9 +97,9 @@ instance AuthProvider OAuth2 where suffix onSuccess onFailure - refreshLoginState OAuth2 {..} user = do - authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint - accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint + refreshLoginState OAuth2 {..} req user = do + authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint + accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint let oauth2 = OA2.OAuth2 { oauthClientId = getClientId oa2ClientId @@ -114,7 +115,30 @@ instance AuthProvider OAuth2 where , oauthCallback = Nothing } man <- getGlobalManager - oauth2RefreshLogin oauth2 man user + let loginState = authLoginState user + case decodeToken loginState of + Left _ -> pure Nothing + Right tokens -> do + CTime now <- epochTime + if tokenExpired user now tokens then do + rRes <- refreshTokens tokens man oauth2 + case rRes of + Nothing -> pure Nothing + Just newTokens -> + let user' = + user { + authLoginState = encodeToken newTokens, + authLoginTime = fromIntegral now + } + in pure (Just (req, user')) + else + pure (Just (req, user)) + +tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool +tokenExpired user now tokens = + case OA2.expiresIn tokens of + Nothing -> False + Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now $(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2) diff --git a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs new file mode 100644 index 0000000..e406936 --- /dev/null +++ b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -0,0 +1,232 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE OverloadedStrings #-} +module Network.Wai.Middleware.Auth.OpenIDConnect + ( OpenIDConnect + , discover + , getAccessToken + , getIdToken + -- * Customizing an OpenIDConnect provider + , oidcClientId + , oidcClientSecret + , oidcProviderInfo + , oidcManager + , oidcScopes + , oidcAllowedSkew + ) where + +import Control.Applicative ((<|>)) +import qualified Crypto.JOSE as JOSE +import qualified Crypto.JWT as JWT +import Control.Monad.Except (runExceptT) +import qualified Data.Aeson as Aeson +import qualified Data.ByteString.Char8 as S8 +import Data.Function ((&)) +import qualified Data.Time.Clock as Clock +import Data.Traversable (for) +import qualified Data.Text as T +import qualified Data.Text.Lazy as TL +import qualified Data.Text.Lazy.Encoding as TLE +import qualified Data.Vault.Lazy as Vault +import Foreign.C.Types (CTime (..)) +import GHC.Generics (Generic) +import qualified Lens.Micro as Lens +import qualified Lens.Micro.Extras as Lens.Extras +import Network.HTTP.Simple (httpJSON, + getResponseBody, + parseRequestThrow) +import Network.Wai.Middleware.Auth.OAuth2 (parseAbsoluteURI, + getAccessToken) +import qualified Network.OAuth.OAuth2 as OA2 +import Network.HTTP.Client (Manager) +import Network.HTTP.Client.TLS (getGlobalManager) +import Network.Wai (Request, vault) +import Network.Wai.Auth.Internal (decodeToken, encodeToken, + oauth2Login, + refreshTokens) +import Network.Wai.Middleware.Auth.Provider +import System.IO.Unsafe (unsafePerformIO) +import System.PosixCompat.Time (epochTime) +import qualified Text.Hamlet +import qualified URI.ByteString as U + +-- | An Open ID Connect provider. +-- +-- @since X.Y.Z +data OpenIDConnect + = OpenIDConnect + { oidcMetadata :: Metadata + , oidcJwkSet :: JOSE.JWKSet + -- | The client id this application is registered with at the Open ID + -- Connect provider. The default is an empty string, you will need to + -- overwrite this. + , oidcClientId :: T.Text + -- | The client secret of this application. The default is an empty + -- string, you will need to overwrite this. + , oidcClientSecret :: T.Text + -- | The information for this provider. The default contains some + -- placeholder texts. If you're using the provider screen you'll want to + -- overwrite this. + , oidcProviderInfo :: ProviderInfo + -- | The HTTP manager to use. Defaults to the global manager. + , oidcManager :: Manager + -- | The scopes to set. Defaults to only the "openid" scope. + , oidcScopes :: [T.Text] + -- | The amount of clock skew to allow when validating id tokens. Defaults + -- to 0. + , oidcAllowedSkew :: Clock.NominalDiffTime + } + +data Metadata + = Metadata + { issuer :: T.Text + , authorizationEndpoint :: U.URI + , tokenEndpoint :: U.URI + , userinfoEndpoint :: Maybe T.Text + , revocationEndpoint :: Maybe T.Text + , jwksUri :: T.Text + , responseTypesSupported :: [T.Text] + , subjectTypesSupported :: [T.Text] + , idTokenSigningAlgValuesSupported :: [T.Text] + , scopesSupported :: Maybe [T.Text] + , tokenEndpointAuthMethodsSupported :: Maybe [T.Text] + , claimsSupported :: Maybe [T.Text] + } + deriving (Generic) + +instance Aeson.FromJSON Metadata + +instance AuthProvider OpenIDConnect where + getProviderName _ = "oidc" + getProviderInfo = oidcProviderInfo + handleLogin oidc@OpenIDConnect {.. } req suffix renderUrl onSuccess onFailure = do + oauth2 <- mkOauth2 oidc (Just renderUrl) + oauth2Login + oauth2 + oidcManager + (Just oidcScopes) + (getProviderName oidc) + req + suffix + onSuccess + onFailure + refreshLoginState oidc req user = + let loginState = authLoginState user + in case decodeToken loginState of + Left _ -> pure Nothing + Right tokens -> do + vRes <- validateIdToken' oidc tokens + case vRes of + Nothing -> do + oauth2 <- mkOauth2 oidc Nothing + rRes <- refreshTokens tokens (oidcManager oidc) oauth2 + case rRes of + Nothing -> pure Nothing + Just newTokens -> do + v2Res <- validateIdToken' oidc newTokens + case v2Res of + Nothing -> pure Nothing + Just claims -> do + CTime now <- epochTime + let newUser = + user { + authLoginState = encodeToken newTokens, + authLoginTime = fromIntegral now + } + pure (Just (storeClaims claims req, newUser)) + Just claims -> + pure (Just (storeClaims claims req, user)) + +-- | Obtain configuration of an OpenID Connect from its discovery endpoint. +discover :: U.URI -> IO OpenIDConnect +discover base = do + let uri = base { U.uriPath = "/.well-known/openid-configuration" } + metadata <- fetchMetadata uri + jwkset <- fetchJWKSet (jwksUri metadata) + manager <- getGlobalManager + pure OpenIDConnect + { oidcClientId = "" + , oidcClientSecret = "" + , oidcMetadata = metadata + , oidcJwkSet = jwkset + , oidcProviderInfo = ProviderInfo "OpenID Connect Provider" "" "" + , oidcManager = manager + , oidcScopes = ["openid"] + , oidcAllowedSkew = 0 + } + +fetchMetadata :: U.URI -> IO Metadata +fetchMetadata metadataEndpoint = do + req <- parseRequestThrow (S8.unpack $ U.serializeURIRef' metadataEndpoint) + getResponseBody <$> httpJSON req + +fetchJWKSet :: T.Text -> IO JOSE.JWKSet +fetchJWKSet jwkSetEndpoint = do + req <- parseRequestThrow (T.unpack jwkSetEndpoint) + getResponseBody <$> httpJSON req + +mkOauth2 :: OpenIDConnect -> Maybe (Text.Hamlet.Render ProviderUrl) -> IO OA2.OAuth2 +mkOauth2 OpenIDConnect {..} renderUrl = do + callbackURI <- for renderUrl $ \render -> parseAbsoluteURI $ render (ProviderUrl ["complete"]) [] + pure OA2.OAuth2 + { oauthClientId = oidcClientId + , oauthClientSecret = oidcClientSecret + , oauthOAuthorizeEndpoint = authorizationEndpoint oidcMetadata + , oauthAccessTokenEndpoint = tokenEndpoint oidcMetadata + , oauthCallback = callbackURI + } + +validateIdToken :: OpenIDConnect -> OA2.IdToken -> IO (Either JWT.JWTError JWT.ClaimsSet) +validateIdToken oidc (OA2.IdToken idToken) = runExceptT $ do + signedJwt <- JOSE.decodeCompact (TLE.encodeUtf8 $ TL.fromStrict idToken) + JWT.verifyClaims (validationSettings oidc) (oidcJwkSet oidc) signedJwt + +validateIdToken' :: OpenIDConnect -> OA2.OAuth2Token -> IO (Maybe JWT.ClaimsSet) +validateIdToken' oidc tokens = + case OA2.idToken tokens of + Nothing -> pure Nothing + Just idToken -> + either (const Nothing) Just <$> validateIdToken oidc idToken + +validationSettings :: OpenIDConnect -> JWT.JWTValidationSettings +validationSettings oidc = + JWT.defaultJWTValidationSettings (validateAudience oidc) + & Lens.set JWT.jwtValidationSettingsCheckIssuedAt True + & Lens.set JWT.jwtValidationSettingsIssuerPredicate (validateIssuer oidc) + & Lens.set JWT.jwtValidationSettingsAllowedSkew (oidcAllowedSkew oidc) + +validateAudience :: OpenIDConnect -> JWT.StringOrURI -> Bool +validateAudience oidc audClaim = + audienceFromJWT == Just correctClientId + where + correctClientId = oidcClientId oidc + audienceFromJWT = fromStringOrURI audClaim + +validateIssuer :: OpenIDConnect -> JWT.StringOrURI -> Bool +validateIssuer oidc issClaim = + issuerFromJWT == Just correctIssuer + where + correctIssuer = issuer (oidcMetadata oidc) + issuerFromJWT = fromStringOrURI issClaim + +fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text +fromStringOrURI stringOrURI = + Lens.Extras.preview JWT.string stringOrURI + <|> fmap (T.pack . show) (Lens.Extras.preview JWT.uri stringOrURI) + +storeClaims :: JWT.ClaimsSet -> Request -> Request +storeClaims claims req = + req { vault = Vault.insert idTokenKey claims (vault req) } + +-- | Get the @IdToken@ for the current user. +-- +-- If called on a @Request@ behind the middleware, should almost return a +-- @Just@ value. +-- +-- @since X.Y.Z +getIdToken :: Request -> Maybe JWT.ClaimsSet +getIdToken req = Vault.lookup idTokenKey (vault req) + +idTokenKey :: Vault.Key JWT.ClaimsSet +idTokenKey = unsafePerformIO Vault.newKey +{-# NOINLINE idTokenKey #-} diff --git a/src/Network/Wai/Middleware/Auth/Provider.hs b/src/Network/Wai/Middleware/Auth/Provider.hs index c5155a8..ef8bb43 100644 --- a/src/Network/Wai/Middleware/Auth/Provider.hs +++ b/src/Network/Wai/Middleware/Auth/Provider.hs @@ -114,9 +114,10 @@ class AuthProvider ap where -- @since X.Y.Z refreshLoginState :: ap + -> Request -> AuthUser - -> IO (Maybe AuthUser) - refreshLoginState _ loginState = pure (Just loginState) + -> IO (Maybe (Request, AuthUser)) + refreshLoginState _ req loginState = pure (Just (req, loginState)) -- | Generic authentication provider wrapper. data Provider where @@ -131,7 +132,7 @@ instance AuthProvider Provider where handleLogin (Provider p) = handleLogin p - refreshLoginState (Provider p) loginState = refreshLoginState p loginState + refreshLoginState (Provider p) = refreshLoginState p -- | Collection of supported providers. type Providers = HM.HashMap T.Text Provider diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 20b2e98..46df93f 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -16,6 +16,7 @@ library Network.Wai.Middleware.Auth.OAuth2 Network.Wai.Middleware.Auth.OAuth2.Github Network.Wai.Middleware.Auth.OAuth2.Google + Network.Wai.Middleware.Auth.OpenIDConnect Network.Wai.Middleware.Auth.Provider Network.Wai.Auth.Executable Network.Wai.Auth.Internal @@ -43,6 +44,8 @@ library , http-reverse-proxy , http-types , jose + , microlens + , mtl , regex-posix , safe-exceptions , shakespeare From cb45b17d61e7644f7c09ada53d1a3275183a3c0f Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 31 Jan 2020 18:46:17 +0000 Subject: [PATCH 07/23] Add tests for OpenIDConnect provider --- .../Wai/Middleware/Auth/OpenIDConnect.hs | 3 + test/Main.hs | 2 + .../Wai/Middleware/Auth/OpenIDConnect.hs | 301 ++++++++++++++++++ wai-middleware-auth.cabal | 6 + 4 files changed, 312 insertions(+) create mode 100644 test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs diff --git a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs index e406936..b758001 100644 --- a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -3,6 +3,7 @@ {-# LANGUAGE OverloadedStrings #-} module Network.Wai.Middleware.Auth.OpenIDConnect ( OpenIDConnect + , Metadata(..) , discover , getAccessToken , getIdToken @@ -96,6 +97,8 @@ data Metadata instance Aeson.FromJSON Metadata +instance Aeson.ToJSON Metadata + instance AuthProvider OpenIDConnect where getProviderName _ = "oidc" getProviderInfo = oidcProviderInfo diff --git a/test/Main.hs b/test/Main.hs index ce1b291..a606390 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -5,6 +5,7 @@ module Main (main) where import Test.Tasty import qualified Spec.Network.Wai.Auth.Internal import qualified Spec.Network.Wai.Middleware.Auth.OAuth2 +import qualified Spec.Network.Wai.Middleware.Auth.OpenIDConnect main :: IO () main = defaultMain tests @@ -13,4 +14,5 @@ tests :: TestTree tests = testGroup "wai-middleware-auth" [ Spec.Network.Wai.Auth.Internal.tests , Spec.Network.Wai.Middleware.Auth.OAuth2.tests + , Spec.Network.Wai.Middleware.Auth.OpenIDConnect.tests ] diff --git a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs new file mode 100644 index 0000000..7611ccb --- /dev/null +++ b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -0,0 +1,301 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Spec.Network.Wai.Middleware.Auth.OpenIDConnect (tests) where + +import Control.Monad (void) +import Control.Monad.IO.Class (liftIO) +import Data.ByteString (ByteString) +import qualified Data.IORef as IORef +import qualified Crypto.JOSE as JOSE +import qualified Crypto.JWT as JWT +import qualified Control.Monad.Except +import qualified Data.Aeson as Aeson +import Data.Function ((&)) +import qualified Data.Text as T +import qualified Data.Text.Encoding +import qualified Data.Text.Lazy +import qualified Data.Text.Lazy.Encoding +import qualified Data.Time.Clock as Clock +import GHC.Exts (fromList, fromString) +import qualified Network.HTTP.Types.Status as Status +import qualified Network.OAuth.OAuth2 as OA2 +import qualified Network.Wai as Wai +import qualified Network.Wai.Handler.Warp as Warp +import qualified Network.Wai.Middleware.Auth as Auth +import Network.Wai.Middleware.Auth.OpenIDConnect +import Network.Wai.Middleware.Auth.Provider (Provider(..)) +import Network.Wai.Test (Session, SResponse, + assertHeader, + assertStatus, + defaultRequest, + request, runSession, + setClientCookie, + setPath) +import qualified Lens.Micro as Lens +import Test.Tasty (TestTree, testGroup) +import Test.Tasty.HUnit (assertBool, testCase) +import qualified URI.ByteString as U +import qualified Web.Cookie as Cookie + +tests :: TestTree +tests = testGroup "Network.Wai.Auth.OpenIDConnect" + [ testCase "when a request without a session is made then redirect to re-authorize" $ do + (provider, _) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + redirect1 <- get "/hi" + assertStatus 303 redirect1 + assertHeader "Location" "/prefix" redirect1 + redirect2 <- get "/prefix" + assertStatus 303 redirect2 + assertHeader "location" "/prefix/oidc" redirect2 + redirect3 <- get "/prefix/oidc" + assertStatus 303 redirect3 + assertHeader + "location" + (U.serializeURIRef' host <> "/authorize?scope=openid%2Cscope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") + redirect3 + + , testCase "when a request is made with a valid session then pass the request through" $ do + (provider, _) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when an ID token expired and no refresh token is available then redirect to re-authorize" $ do + (provider, changeConfig) <- fakeProvider + changeConfig (\c -> c { jwtExpiresIn = -600 }) + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token expired then use a refresh" $ do + (provider, changeConfig) <- fakeProvider + changeConfig (\c -> c { jwtExpiresIn = -600, returnRefreshToken = True }) + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + createSession + liftIO $ changeConfig (\c -> c { jwtExpiresIn = 600 }) + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when a request is made with an invalid session redirect to re-authorize" $ do + (provider, _) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + -- First create a known valid session, so we can see that it's the act + -- of corrupting it that makes the test fail. + createSession + setClientCookie + Cookie.defaultSetCookie + { Cookie.setCookieName = "auth-cookie" + , Cookie.setCookieValue = "garbage" + } + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid audience then redirect to re-authorize" $ do + (provider, changeConfig) <- fakeProvider + changeConfig (\c -> c { jwtAudience = fromString "wrong-audience" }) + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid issuer then redirect to re-authorize" $ do + (provider, changeConfig) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + liftIO $ changeConfig (\c -> c { jwtIssuer = "wrong-issuer" }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when a session does not contain an ID token then redirect to re-authorize" $ do + (provider, changeConfig) <- fakeProvider + changeConfig (\c -> c { returnIdToken = False }) + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid signature then redirect to re-authorize" $ do + (provider, changeConfig) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + flip runSession app $ do + newJWK <- liftIO $ JOSE.genJWK (JOSE.RSAGenParam 256) + liftIO $ changeConfig (\c -> c { jwtJWK = newJWK }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + ] + +get :: ByteString -> Session SResponse +get = request . setPath defaultRequest + +createSession :: Session () +createSession = void $ get "/prefix/oidc/complete?code=1234" + +const200 :: Wai.Application +const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" + +authSettings :: U.URI -> IO Auth.AuthSettings +authSettings host = do + oidc' <- discover host + let oidc = + oidc' + { oidcClientId = "client-id" + , oidcClientSecret = "client-secret" + , oidcScopes = ["openid", "scope1"] + } + pure $ Auth.defaultAuthSettings + & Auth.setAuthProviders (fromList [("oidc", Provider oidc)]) + & Auth.setAuthPrefix "prefix" + & Auth.setAuthCookieName "auth-cookie" + +data FakeOIDCProviderConfig + = FakeOIDCProviderConfig + { jwtExpiresIn :: Clock.NominalDiffTime, + jwtAudience :: JWT.StringOrURI, + jwtIssuer :: T.Text, + jwtJWK :: JOSE.JWK, + jwtSub :: String, + returnIdToken :: Bool, + returnRefreshToken :: Bool + } + +defaultConfig :: IO FakeOIDCProviderConfig +defaultConfig = do + jwk <- JOSE.genJWK (JOSE.RSAGenParam 256) + pure + FakeOIDCProviderConfig + { jwtExpiresIn = 600, + jwtAudience = "client-id", + jwtIssuer = "test-oidc-provider", + jwtJWK = jwk, + jwtSub = "1234", + returnIdToken = True, + returnRefreshToken = False + } + +fakeProvider :: IO (Wai.Application, (FakeOIDCProviderConfig -> FakeOIDCProviderConfig) -> IO ()) +fakeProvider = do + config <- defaultConfig + configRef <- IORef.newIORef config + let changeConfig = IORef.modifyIORef configRef + pure (fakeProvider' configRef, changeConfig) + +fakeProvider' :: IORef.IORef FakeOIDCProviderConfig -> Wai.Application +fakeProvider' configRef req respond = do + config <- IORef.readIORef configRef + case Wai.pathInfo req of + [".well-known", "openid-configuration"] -> + case Data.Text.Encoding.decodeUtf8 <$> Wai.requestHeaderHost req of + Nothing -> + Wai.responseLBS Status.badRequest400 [] "" + & respond + Just host -> + Metadata + { issuer = jwtIssuer config, + authorizationEndpoint = parseURI ("http://" <> host <> "/authorize"), + tokenEndpoint = parseURI ("http://" <> host <> "/token"), + userinfoEndpoint = Nothing, + revocationEndpoint = Nothing, + jwksUri = "http://" <> host <> "/jwks", + responseTypesSupported = ["code"], + subjectTypesSupported = ["public"], + idTokenSigningAlgValuesSupported = ["RS256"], + scopesSupported = Just ["openid"], + tokenEndpointAuthMethodsSupported = Just ["client_secret_basic"], + claimsSupported = Just ["iss", "sub", "aud", "exp", "iat"] + } + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + ["jwks"] -> + JOSE.JWKSet [jwtJWK config] + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + ["token"] -> do + now <- Clock.getCurrentTime + let claims = + JWT.emptyClaimsSet + & Lens.set JWT.claimIss (Just (fromString (T.unpack (jwtIssuer config)))) + & Lens.set JWT.claimAud (Just (JWT.Audience [jwtAudience config])) + & Lens.set JWT.claimIat (Just (JWT.NumericDate now)) + & Lens.set JWT.claimExp (Just (JWT.NumericDate (Clock.addUTCTime (jwtExpiresIn config) now))) + & Lens.set JWT.claimSub (Just (fromString (jwtSub config))) + idToken <- doJwtSign (jwtJWK config) claims + OA2.OAuth2Token + { OA2.accessToken = OA2.AccessToken "access-granted", + OA2.refreshToken = + if returnRefreshToken config + then Just (OA2.RefreshToken "refresh-token") + else Nothing, + OA2.expiresIn = Just 3600, + OA2.tokenType = Nothing, + OA2.idToken = + if returnIdToken config + then Just (OA2.IdToken idToken) + else Nothing + } + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + _ -> + Wai.responseLBS Status.notFound404 [] "" + & respond + +doJwtSign :: JOSE.JWK -> JWT.ClaimsSet -> IO T.Text +doJwtSign jwk claims = do + result <- Control.Monad.Except.runExceptT $ do + alg <- JOSE.bestJWSAlg jwk + JWT.signClaims jwk (JOSE.newJWSHeader ((), alg)) claims + case result of + Left (err :: JOSE.Error) -> fail (show err) + Right bytestring -> + JOSE.encodeCompact bytestring + & Data.Text.Lazy.Encoding.decodeUtf8 + & Data.Text.Lazy.toStrict + & pure + +parseURI :: T.Text -> U.URIRef U.Absolute +parseURI uri = + Data.Text.Encoding.encodeUtf8 uri + & U.parseURI U.laxURIParserOptions + & either (error . show) id diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 46df93f..6cdc662 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -84,6 +84,7 @@ test-suite spec hs-source-dirs: test other-modules: Spec.Network.Wai.Auth.Internal , Spec.Network.Wai.Middleware.Auth.OAuth2 + , Spec.Network.Wai.Middleware.Auth.OpenIDConnect build-depends: base , aeson , binary @@ -92,10 +93,15 @@ test-suite spec , hedgehog , hoauth2 , http-types + , jose + , microlens + , mtl , tasty , tasty-hedgehog , tasty-hunit , text + , time + , uri-bytestring , wai , wai-extra , wai-middleware-auth From caf30250491fdbf4d76e0df36bb7be8f3cde0f5e Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 10:37:45 +0000 Subject: [PATCH 08/23] Clean up test logic of OAuth2 and OpenID providers --- test/Network/Wai/Auth/Test.hs | 158 ++++++++ .../Network/Wai/Middleware/Auth/OAuth2.hs | 139 +++---- .../Wai/Middleware/Auth/OpenIDConnect.hs | 354 +++++------------- wai-middleware-auth.cabal | 3 +- 4 files changed, 310 insertions(+), 344 deletions(-) create mode 100644 test/Network/Wai/Auth/Test.hs diff --git a/test/Network/Wai/Auth/Test.hs b/test/Network/Wai/Auth/Test.hs new file mode 100644 index 0000000..baa870e --- /dev/null +++ b/test/Network/Wai/Auth/Test.hs @@ -0,0 +1,158 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +module Network.Wai.Auth.Test + (ChangeProvider + , FakeProviderConf(..) + , fakeProvider + , const200 + , get + , parseURI + ) where + +import Control.Monad.IO.Class (liftIO) +import Data.ByteString (ByteString) +import qualified Data.IORef as IORef +import qualified Crypto.JOSE as JOSE +import qualified Crypto.JWT as JWT +import qualified Control.Monad.Except +import qualified Data.Aeson as Aeson +import Data.Function ((&)) +import qualified Data.Text as T +import qualified Data.Text.Encoding as TE +import qualified Data.Text.Lazy as TL +import qualified Data.Text.Lazy.Encoding as TLE +import qualified Data.Time.Clock as Clock +import GHC.Exts (fromString) +import qualified Network.HTTP.Types.Status as Status +import qualified Network.OAuth.OAuth2 as OA2 +import qualified Network.Wai as Wai +import Network.Wai.Middleware.Auth.OpenIDConnect +import Network.Wai.Test (Session, SResponse, + defaultRequest, + request, setPath) +import qualified Lens.Micro as Lens +import qualified URI.ByteString as U + +get :: ByteString -> Session SResponse +get = request . setPath defaultRequest + +const200 :: Wai.Application +const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" + +data FakeProviderConf + = FakeProviderConf + { jwtExpiresIn :: Clock.NominalDiffTime, + jwtAudience :: JWT.StringOrURI, + jwtIssuer :: T.Text, + jwtJWK :: JOSE.JWK, + jwtSub :: String, + accessTokenExpiresIn :: Int, + returnIdToken :: Bool, + returnRefreshToken :: Bool + } + +defaultConfig :: IO FakeProviderConf +defaultConfig = do + jwk <- JOSE.genJWK (JOSE.RSAGenParam 256) + pure + FakeProviderConf + { jwtExpiresIn = 600, + jwtAudience = "client-id", + jwtIssuer = "test-oidc-provider", + jwtJWK = jwk, + jwtSub = "1234", + accessTokenExpiresIn = 600, + returnIdToken = True, + returnRefreshToken = True + } + +type ChangeProvider = (FakeProviderConf -> FakeProviderConf) -> Session () + +fakeProvider :: IO (Wai.Application, ChangeProvider) +fakeProvider = do + config <- defaultConfig + configRef <- IORef.newIORef config + let changeProvider = IORef.modifyIORef configRef + pure (fakeProvider' configRef, liftIO . changeProvider) + +fakeProvider' :: IORef.IORef FakeProviderConf -> Wai.Application +fakeProvider' configRef req respond = do + config <- IORef.readIORef configRef + case Wai.pathInfo req of + [".well-known", "openid-configuration"] -> + case TE.decodeUtf8 <$> Wai.requestHeaderHost req of + Nothing -> + Wai.responseLBS Status.badRequest400 [] "" + & respond + Just host -> + Metadata + { issuer = jwtIssuer config, + authorizationEndpoint = parseURI ("http://" <> host <> "/authorize"), + tokenEndpoint = parseURI ("http://" <> host <> "/token"), + userinfoEndpoint = Nothing, + revocationEndpoint = Nothing, + jwksUri = "http://" <> host <> "/jwks", + responseTypesSupported = ["code"], + subjectTypesSupported = ["public"], + idTokenSigningAlgValuesSupported = ["RS256"], + scopesSupported = Just ["openid"], + tokenEndpointAuthMethodsSupported = Just ["client_secret_basic"], + claimsSupported = Just ["iss", "sub", "aud", "exp", "iat"] + } + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + ["jwks"] -> + JOSE.JWKSet [jwtJWK config] + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + ["token"] -> do + now <- Clock.getCurrentTime + let claims = + JWT.emptyClaimsSet + & Lens.set JWT.claimIss (Just (fromString (T.unpack (jwtIssuer config)))) + & Lens.set JWT.claimAud (Just (JWT.Audience [jwtAudience config])) + & Lens.set JWT.claimIat (Just (JWT.NumericDate now)) + & Lens.set JWT.claimExp (Just (JWT.NumericDate (Clock.addUTCTime (jwtExpiresIn config) now))) + & Lens.set JWT.claimSub (Just (fromString (jwtSub config))) + idToken <- doJwtSign (jwtJWK config) claims + OA2.OAuth2Token + { OA2.accessToken = OA2.AccessToken "access-granted", + OA2.refreshToken = + if returnRefreshToken config + then Just (OA2.RefreshToken "refresh-token") + else Nothing, + OA2.expiresIn = Just (accessTokenExpiresIn config), + OA2.tokenType = Nothing, + OA2.idToken = + if returnIdToken config + then Just (OA2.IdToken idToken) + else Nothing + } + & Aeson.encode + & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] + & respond + _ -> + Wai.responseLBS Status.notFound404 [] "" + & respond + +doJwtSign :: JOSE.JWK -> JWT.ClaimsSet -> IO T.Text +doJwtSign jwk claims = do + result <- Control.Monad.Except.runExceptT $ do + alg <- JOSE.bestJWSAlg jwk + JWT.signClaims jwk (JOSE.newJWSHeader ((), alg)) claims + case result of + Left (err :: JOSE.Error) -> fail (show err) + Right bytestring -> + JOSE.encodeCompact bytestring + & TLE.decodeUtf8 + & TL.toStrict + & pure + +parseURI :: T.Text -> U.URIRef U.Absolute +parseURI uri = + TE.encodeUtf8 uri + & U.parseURI U.laxURIParserOptions + & either (error . show) id diff --git a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs index f3b0959..bb1c28a 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs @@ -3,37 +3,34 @@ module Spec.Network.Wai.Middleware.Auth.OAuth2 (tests) where import Control.Monad (void) -import qualified Data.Aeson as Aeson -import Data.ByteString (ByteString) import Data.Function ((&)) import qualified Data.Text as T +import qualified Data.Text.Encoding as TE import GHC.Exts (fromList) import qualified Network.HTTP.Types.Status as Status -import qualified Network.OAuth.OAuth2 as OA2 import qualified Network.Wai as Wai +import Network.Wai.Auth.Test (ChangeProvider, + FakeProviderConf(..), + fakeProvider, + const200, get) import qualified Network.Wai.Handler.Warp as Warp import qualified Network.Wai.Middleware.Auth as Auth import Network.Wai.Middleware.Auth.OAuth2 (OAuth2(..), getAccessToken) import Network.Wai.Middleware.Auth.Provider (Provider(..), ProviderInfo(..)) -import Network.Wai.Test (Session, SResponse, - assertHeader, +import Network.Wai.Test (Session, assertHeader, assertStatus, - defaultRequest, - request, runSession, - setClientCookie, - setPath) + runSession, + setClientCookie) import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase) import qualified Web.Cookie as Cookie tests :: TestTree tests = testGroup "Network.Wai.Auth.OAuth2" - [ testCase "when a request without a session is made then the response redirects to the oauth2 authorize endpoint" $ do - middleware <- Auth.mkAuthMiddleware $ authSettings "http://oauth2provider.com" - let app = middleware const200 - flip runSession app $ do + [ testCase "when a request without a session is made then the response redirects to the oauth2 authorize endpoint" $ + runSessionWithProvider const200 $ \host _ -> do redirect1 <- get "/hi" assertStatus 303 redirect1 assertHeader "Location" "/prefix" redirect1 @@ -42,67 +39,54 @@ tests = testGroup "Network.Wai.Auth.OAuth2" assertHeader "location" "/prefix/oauth2" redirect2 redirect3 <- get "/prefix/oauth2" assertStatus 303 redirect3 - assertHeader "location" "http://oauth2provider.com/authorize?scope=scope1%2Cscope2&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foauth2%2Fcomplete" redirect3 + assertHeader + "location" + (TE.encodeUtf8 host <> "/authorize?scope=scope1%2Cscope2&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foauth2%2Fcomplete") + redirect3 - , testCase "when a request with an expired session is made then the response redirects to the oauth2 authorize endpoint" $ do - Warp.testWithApplication (pure (fakeProvider (-3600))) $ \port -> do - middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 303 response + , testCase "when a request with an expired session is made then the response redirects to the oauth2 authorize endpoint" $ + runSessionWithProvider const200 $ \_ changeProvider -> do + changeProvider (\c -> c { accessTokenExpiresIn = -600, returnRefreshToken = False }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response - , testCase "when a request with a valid session is made then the middleware passes the request through" $ do - Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do - middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 200 response + , testCase "when a request with a valid session is made then the middleware passes the request through" $ + runSessionWithProvider const200 $ \_ _ -> do + createSession + response <- get "/some/endpoint" + assertStatus 200 response - , testCase "when a request with an invalid session is made then the response redirects to the oauth2 authorize endpoint" $ do - Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do - middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) - let app = middleware const200 - flip runSession app $ do - -- First create a known valid session, so we can see that it's the act - -- of corrupting it that makes the test fail. - createSession - setClientCookie - Cookie.defaultSetCookie - { Cookie.setCookieName = "auth-cookie" - , Cookie.setCookieValue = "garbage" - } - response <- get "/some/endpoint" - assertStatus 303 response + , testCase "when a request with an invalid session is made then the response redirects to the oauth2 authorize endpoint" $ + runSessionWithProvider const200 $ \_ _ -> do + -- First create a known valid session, so we can see that it's the act + -- of corrupting it that makes the test fail. + createSession + setClientCookie + Cookie.defaultSetCookie + { Cookie.setCookieName = "auth-cookie" + , Cookie.setCookieValue = "garbage" + } + response <- get "/some/endpoint" + assertStatus 303 response - , testCase "when a request is made to the oauth2 complete endpoint then the middleware fatches an access token and sets a user sesion" $ - Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do - middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) - let app = middleware const200 - flip runSession app $ do - response <- get "/prefix/oauth2/complete?code=1234" - assertStatus 303 response - assertHeader "location" "/" response + , testCase "when a request is made to the oauth2 complete endpoint then the middleware fatches an access token and sets a user sesion" $ + runSessionWithProvider const200 $ \_ _ -> do + response <- get "/prefix/oauth2/complete?code=1234" + assertStatus 303 response + assertHeader "location" "/" response - , testCase "when a request with a valid session is made then the application can access the session payload" $ do - Warp.testWithApplication (pure (fakeProvider 3600)) $ \port -> do - middleware <- Auth.mkAuthMiddleware $ authSettings ("http://localhost:" <> T.pack (show port)) - let app = middleware $ \req respond -> - case getAccessToken req of - Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" - Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" - flip runSession app $ do + , testCase "when a request with a valid session is made then the application can access the session payload" $ + let app req respond = + case getAccessToken req of + Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" + Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" + in runSessionWithProvider app $ \_ _ -> do createSession response <- get "/prefix/oauth2/complete?code=1234" assertStatus 200 response ] -get :: ByteString -> Session SResponse -get = request . setPath defaultRequest - createSession :: Session () createSession = void $ get "/prefix/oauth2/complete?code=1234" @@ -130,22 +114,11 @@ provider host = } } -const200 :: Wai.Application -const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" - -fakeProvider :: Int -> Wai.Application -fakeProvider expiresIn req respond = - case Wai.pathInfo req of - ["token"] -> - respond $ Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] body - where - body = - Aeson.encode OA2.OAuth2Token - { OA2.accessToken = OA2.AccessToken "access-granted", - OA2.refreshToken = Nothing, - OA2.expiresIn = Just expiresIn, - OA2.tokenType = Nothing, - OA2.idToken = Nothing - } - _ -> - respond $ Wai.responseLBS Status.notFound404 [] "" +runSessionWithProvider :: Wai.Application -> (T.Text -> ChangeProvider -> Session a) -> IO a +runSessionWithProvider app session = do + (p, changeProvider) <- fakeProvider + Warp.testWithApplication (pure p) $ \port -> do + let host = "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware $ authSettings host + let app' = middleware app + runSession (session host changeProvider) app' diff --git a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs index 7611ccb..edd9113 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -5,173 +5,120 @@ module Spec.Network.Wai.Middleware.Auth.OpenIDConnect (tests) where import Control.Monad (void) import Control.Monad.IO.Class (liftIO) -import Data.ByteString (ByteString) -import qualified Data.IORef as IORef import qualified Crypto.JOSE as JOSE -import qualified Crypto.JWT as JWT -import qualified Control.Monad.Except -import qualified Data.Aeson as Aeson import Data.Function ((&)) import qualified Data.Text as T -import qualified Data.Text.Encoding -import qualified Data.Text.Lazy -import qualified Data.Text.Lazy.Encoding -import qualified Data.Time.Clock as Clock import GHC.Exts (fromList, fromString) -import qualified Network.HTTP.Types.Status as Status -import qualified Network.OAuth.OAuth2 as OA2 -import qualified Network.Wai as Wai +import Network.Wai.Auth.Test (ChangeProvider, + FakeProviderConf(..), + fakeProvider, + const200, get, + parseURI) import qualified Network.Wai.Handler.Warp as Warp import qualified Network.Wai.Middleware.Auth as Auth import Network.Wai.Middleware.Auth.OpenIDConnect import Network.Wai.Middleware.Auth.Provider (Provider(..)) -import Network.Wai.Test (Session, SResponse, - assertHeader, +import Network.Wai.Test (Session, assertHeader, assertStatus, - defaultRequest, - request, runSession, - setClientCookie, - setPath) -import qualified Lens.Micro as Lens + runSession, + setClientCookie) import Test.Tasty (TestTree, testGroup) -import Test.Tasty.HUnit (assertBool, testCase) +import Test.Tasty.HUnit (testCase) import qualified URI.ByteString as U import qualified Web.Cookie as Cookie tests :: TestTree tests = testGroup "Network.Wai.Auth.OpenIDConnect" - [ testCase "when a request without a session is made then redirect to re-authorize" $ do - (provider, _) <- fakeProvider - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - redirect1 <- get "/hi" - assertStatus 303 redirect1 - assertHeader "Location" "/prefix" redirect1 - redirect2 <- get "/prefix" - assertStatus 303 redirect2 - assertHeader "location" "/prefix/oidc" redirect2 - redirect3 <- get "/prefix/oidc" - assertStatus 303 redirect3 - assertHeader - "location" - (U.serializeURIRef' host <> "/authorize?scope=openid%2Cscope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") - redirect3 - - , testCase "when a request is made with a valid session then pass the request through" $ do - (provider, _) <- fakeProvider - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 200 response - - , testCase "when an ID token expired and no refresh token is available then redirect to re-authorize" $ do - (provider, changeConfig) <- fakeProvider - changeConfig (\c -> c { jwtExpiresIn = -600 }) - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 303 response - - , testCase "when an ID token expired then use a refresh" $ do - (provider, changeConfig) <- fakeProvider - changeConfig (\c -> c { jwtExpiresIn = -600, returnRefreshToken = True }) - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - createSession - liftIO $ changeConfig (\c -> c { jwtExpiresIn = 600 }) - response <- get "/some/endpoint" - assertStatus 200 response - - , testCase "when a request is made with an invalid session redirect to re-authorize" $ do - (provider, _) <- fakeProvider - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - -- First create a known valid session, so we can see that it's the act - -- of corrupting it that makes the test fail. - createSession - setClientCookie - Cookie.defaultSetCookie - { Cookie.setCookieName = "auth-cookie" - , Cookie.setCookieValue = "garbage" - } - response <- get "/some/endpoint" - assertStatus 303 response - - , testCase "when an ID token has an invalid audience then redirect to re-authorize" $ do - (provider, changeConfig) <- fakeProvider - changeConfig (\c -> c { jwtAudience = fromString "wrong-audience" }) - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 303 response - - , testCase "when an ID token has an invalid issuer then redirect to re-authorize" $ do - (provider, changeConfig) <- fakeProvider - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - liftIO $ changeConfig (\c -> c { jwtIssuer = "wrong-issuer" }) - createSession - response <- get "/some/endpoint" - assertStatus 303 response - - , testCase "when a session does not contain an ID token then redirect to re-authorize" $ do - (provider, changeConfig) <- fakeProvider - changeConfig (\c -> c { returnIdToken = False }) - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - createSession - response <- get "/some/endpoint" - assertStatus 303 response - - , testCase "when an ID token has an invalid signature then redirect to re-authorize" $ do - (provider, changeConfig) <- fakeProvider - Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) - middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - flip runSession app $ do - newJWK <- liftIO $ JOSE.genJWK (JOSE.RSAGenParam 256) - liftIO $ changeConfig (\c -> c { jwtJWK = newJWK }) - createSession - response <- get "/some/endpoint" - assertStatus 303 response + [ testCase "when a request without a session is made then redirect to re-authorize" $ + runSessionWithProvider $ \host _ -> do + redirect1 <- get "/hi" + assertStatus 303 redirect1 + assertHeader "Location" "/prefix" redirect1 + redirect2 <- get "/prefix" + assertStatus 303 redirect2 + assertHeader "location" "/prefix/oidc" redirect2 + redirect3 <- get "/prefix/oidc" + assertStatus 303 redirect3 + assertHeader + "location" + (U.serializeURIRef' host <> "/authorize?scope=openid%2Cscope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") + redirect3 + + , testCase "when a request is made with a valid session then pass the request through" $ + runSessionWithProvider $ \_ _ -> do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when an ID token expired and no refresh token is available then redirect to re-authorize" $ + runSessionWithProvider $ \_ changeProvider -> do + changeProvider (\c -> c { jwtExpiresIn = -600, returnRefreshToken = False }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token expired then use a refresh token" $ + runSessionWithProvider $ \_ changeProvider -> do + changeProvider (\c -> c { jwtExpiresIn = -600 }) + createSession + changeProvider (\c -> c { jwtExpiresIn = 600 }) + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when a request is made with an invalid session redirect to re-authorize" $ + runSessionWithProvider $ \_ _ -> do + -- First create a known valid session, so we can see that it's the act + -- of corrupting it that makes the test fail. + createSession + setClientCookie + Cookie.defaultSetCookie + { Cookie.setCookieName = "auth-cookie" + , Cookie.setCookieValue = "garbage" + } + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid audience then redirect to re-authorize" $ + runSessionWithProvider $ \_ changeProvider -> do + changeProvider (\c -> c { jwtAudience = fromString "wrong-audience" }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid issuer then redirect to re-authorize" $ + runSessionWithProvider $ \_ changeProvider -> do + changeProvider (\c -> c { jwtIssuer = "wrong-issuer" }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when a session does not contain an ID token then redirect to re-authorize" $ + runSessionWithProvider $ \_ changeProvider -> do + changeProvider (\c -> c { returnIdToken = False }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response + + , testCase "when an ID token has an invalid signature then redirect to re-authorize" $ + runSessionWithProvider $ \_ changeProvider -> do + newJWK <- liftIO $ JOSE.genJWK (JOSE.RSAGenParam 256) + changeProvider (\c -> c { jwtJWK = newJWK }) + createSession + response <- get "/some/endpoint" + assertStatus 303 response ] -get :: ByteString -> Session SResponse -get = request . setPath defaultRequest - createSession :: Session () createSession = void $ get "/prefix/oidc/complete?code=1234" -const200 :: Wai.Application -const200 _ respond = respond $ Wai.responseLBS Status.ok200 [] "" +runSessionWithProvider :: (U.URI -> ChangeProvider -> Session a) -> IO a +runSessionWithProvider session = do + (provider, changeProvider) <- fakeProvider + Warp.testWithApplication (pure provider) $ \port -> do + let host = parseURI $ "http://localhost:" <> T.pack (show port) + middleware <- Auth.mkAuthMiddleware =<< authSettings host + let app = middleware const200 + runSession (session host changeProvider) app authSettings :: U.URI -> IO Auth.AuthSettings authSettings host = do @@ -186,116 +133,3 @@ authSettings host = do & Auth.setAuthProviders (fromList [("oidc", Provider oidc)]) & Auth.setAuthPrefix "prefix" & Auth.setAuthCookieName "auth-cookie" - -data FakeOIDCProviderConfig - = FakeOIDCProviderConfig - { jwtExpiresIn :: Clock.NominalDiffTime, - jwtAudience :: JWT.StringOrURI, - jwtIssuer :: T.Text, - jwtJWK :: JOSE.JWK, - jwtSub :: String, - returnIdToken :: Bool, - returnRefreshToken :: Bool - } - -defaultConfig :: IO FakeOIDCProviderConfig -defaultConfig = do - jwk <- JOSE.genJWK (JOSE.RSAGenParam 256) - pure - FakeOIDCProviderConfig - { jwtExpiresIn = 600, - jwtAudience = "client-id", - jwtIssuer = "test-oidc-provider", - jwtJWK = jwk, - jwtSub = "1234", - returnIdToken = True, - returnRefreshToken = False - } - -fakeProvider :: IO (Wai.Application, (FakeOIDCProviderConfig -> FakeOIDCProviderConfig) -> IO ()) -fakeProvider = do - config <- defaultConfig - configRef <- IORef.newIORef config - let changeConfig = IORef.modifyIORef configRef - pure (fakeProvider' configRef, changeConfig) - -fakeProvider' :: IORef.IORef FakeOIDCProviderConfig -> Wai.Application -fakeProvider' configRef req respond = do - config <- IORef.readIORef configRef - case Wai.pathInfo req of - [".well-known", "openid-configuration"] -> - case Data.Text.Encoding.decodeUtf8 <$> Wai.requestHeaderHost req of - Nothing -> - Wai.responseLBS Status.badRequest400 [] "" - & respond - Just host -> - Metadata - { issuer = jwtIssuer config, - authorizationEndpoint = parseURI ("http://" <> host <> "/authorize"), - tokenEndpoint = parseURI ("http://" <> host <> "/token"), - userinfoEndpoint = Nothing, - revocationEndpoint = Nothing, - jwksUri = "http://" <> host <> "/jwks", - responseTypesSupported = ["code"], - subjectTypesSupported = ["public"], - idTokenSigningAlgValuesSupported = ["RS256"], - scopesSupported = Just ["openid"], - tokenEndpointAuthMethodsSupported = Just ["client_secret_basic"], - claimsSupported = Just ["iss", "sub", "aud", "exp", "iat"] - } - & Aeson.encode - & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] - & respond - ["jwks"] -> - JOSE.JWKSet [jwtJWK config] - & Aeson.encode - & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] - & respond - ["token"] -> do - now <- Clock.getCurrentTime - let claims = - JWT.emptyClaimsSet - & Lens.set JWT.claimIss (Just (fromString (T.unpack (jwtIssuer config)))) - & Lens.set JWT.claimAud (Just (JWT.Audience [jwtAudience config])) - & Lens.set JWT.claimIat (Just (JWT.NumericDate now)) - & Lens.set JWT.claimExp (Just (JWT.NumericDate (Clock.addUTCTime (jwtExpiresIn config) now))) - & Lens.set JWT.claimSub (Just (fromString (jwtSub config))) - idToken <- doJwtSign (jwtJWK config) claims - OA2.OAuth2Token - { OA2.accessToken = OA2.AccessToken "access-granted", - OA2.refreshToken = - if returnRefreshToken config - then Just (OA2.RefreshToken "refresh-token") - else Nothing, - OA2.expiresIn = Just 3600, - OA2.tokenType = Nothing, - OA2.idToken = - if returnIdToken config - then Just (OA2.IdToken idToken) - else Nothing - } - & Aeson.encode - & Wai.responseLBS Status.ok200 [("Content-Type", "application/json")] - & respond - _ -> - Wai.responseLBS Status.notFound404 [] "" - & respond - -doJwtSign :: JOSE.JWK -> JWT.ClaimsSet -> IO T.Text -doJwtSign jwk claims = do - result <- Control.Monad.Except.runExceptT $ do - alg <- JOSE.bestJWSAlg jwk - JWT.signClaims jwk (JOSE.newJWSHeader ((), alg)) claims - case result of - Left (err :: JOSE.Error) -> fail (show err) - Right bytestring -> - JOSE.encodeCompact bytestring - & Data.Text.Lazy.Encoding.decodeUtf8 - & Data.Text.Lazy.toStrict - & pure - -parseURI :: T.Text -> U.URIRef U.Absolute -parseURI uri = - Data.Text.Encoding.encodeUtf8 uri - & U.parseURI U.laxURIParserOptions - & either (error . show) id diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 6cdc662..e4b4204 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -82,7 +82,8 @@ test-suite spec type: exitcode-stdio-1.0 main-is: Main.hs hs-source-dirs: test - other-modules: Spec.Network.Wai.Auth.Internal + other-modules: Network.Wai.Auth.Test + , Spec.Network.Wai.Auth.Internal , Spec.Network.Wai.Middleware.Auth.OAuth2 , Spec.Network.Wai.Middleware.Auth.OpenIDConnect build-depends: base From 97b651a859446d2b9b38421418426239c7064e52 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 11:46:48 +0000 Subject: [PATCH 09/23] Make oauth2 and oidc test suites more similar --- .../Network/Wai/Middleware/Auth/OAuth2.hs | 23 +++++--- .../Wai/Middleware/Auth/OpenIDConnect.hs | 54 ++++++++++++++----- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs index bb1c28a..5298550 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs @@ -29,7 +29,7 @@ import qualified Web.Cookie as Cookie tests :: TestTree tests = testGroup "Network.Wai.Auth.OAuth2" - [ testCase "when a request without a session is made then the response redirects to the oauth2 authorize endpoint" $ + [ testCase "when a request without a session is made then redirect to re-authorize" $ runSessionWithProvider const200 $ \host _ -> do redirect1 <- get "/hi" assertStatus 303 redirect1 @@ -44,20 +44,27 @@ tests = testGroup "Network.Wai.Auth.OAuth2" (TE.encodeUtf8 host <> "/authorize?scope=scope1%2Cscope2&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foauth2%2Fcomplete") redirect3 - , testCase "when a request with an expired session is made then the response redirects to the oauth2 authorize endpoint" $ + , testCase "when a request is made with a valid session then pass the request through" $ + runSessionWithProvider const200 $ \_ _ -> do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when an access token expired and no refresh token is available then redirect to re-authorize" $ runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { accessTokenExpiresIn = -600, returnRefreshToken = False }) createSession response <- get "/some/endpoint" assertStatus 303 response - , testCase "when a request with a valid session is made then the middleware passes the request through" $ - runSessionWithProvider const200 $ \_ _ -> do + , testCase "when an access token expired then use a refresh token" $ + runSessionWithProvider const200 $ \_ changeProvider -> do + changeProvider (\c -> c { accessTokenExpiresIn = -600 }) createSession response <- get "/some/endpoint" assertStatus 200 response - , testCase "when a request with an invalid session is made then the response redirects to the oauth2 authorize endpoint" $ + , testCase "when a request is made with an invalid session redirect to re-authorize" $ runSessionWithProvider const200 $ \_ _ -> do -- First create a known valid session, so we can see that it's the act -- of corrupting it that makes the test fail. @@ -70,20 +77,20 @@ tests = testGroup "Network.Wai.Auth.OAuth2" response <- get "/some/endpoint" assertStatus 303 response - , testCase "when a request is made to the oauth2 complete endpoint then the middleware fatches an access token and sets a user sesion" $ + , testCase "when a request is made to the complete endpoint then create a session" $ runSessionWithProvider const200 $ \_ _ -> do response <- get "/prefix/oauth2/complete?code=1234" assertStatus 303 response assertHeader "location" "/" response - , testCase "when a request with a valid session is made then the application can access the session payload" $ + , testCase "when a request with a valid session is made then the app can access the session" $ let app req respond = case getAccessToken req of Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" in runSessionWithProvider app $ \_ _ -> do createSession - response <- get "/prefix/oauth2/complete?code=1234" + response <- get "/some/endpoint" assertStatus 200 response ] diff --git a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs index edd9113..c615983 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -9,6 +9,8 @@ import qualified Crypto.JOSE as JOSE import Data.Function ((&)) import qualified Data.Text as T import GHC.Exts (fromList, fromString) +import qualified Network.HTTP.Types.Status as Status +import qualified Network.Wai as Wai import Network.Wai.Auth.Test (ChangeProvider, FakeProviderConf(..), fakeProvider, @@ -30,7 +32,7 @@ import qualified Web.Cookie as Cookie tests :: TestTree tests = testGroup "Network.Wai.Auth.OpenIDConnect" [ testCase "when a request without a session is made then redirect to re-authorize" $ - runSessionWithProvider $ \host _ -> do + runSessionWithProvider const200 $ \host _ -> do redirect1 <- get "/hi" assertStatus 303 redirect1 assertHeader "Location" "/prefix" redirect1 @@ -45,20 +47,20 @@ tests = testGroup "Network.Wai.Auth.OpenIDConnect" redirect3 , testCase "when a request is made with a valid session then pass the request through" $ - runSessionWithProvider $ \_ _ -> do + runSessionWithProvider const200 $ \_ _ -> do createSession response <- get "/some/endpoint" assertStatus 200 response , testCase "when an ID token expired and no refresh token is available then redirect to re-authorize" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { jwtExpiresIn = -600, returnRefreshToken = False }) createSession response <- get "/some/endpoint" assertStatus 303 response , testCase "when an ID token expired then use a refresh token" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { jwtExpiresIn = -600 }) createSession changeProvider (\c -> c { jwtExpiresIn = 600 }) @@ -66,7 +68,7 @@ tests = testGroup "Network.Wai.Auth.OpenIDConnect" assertStatus 200 response , testCase "when a request is made with an invalid session redirect to re-authorize" $ - runSessionWithProvider $ \_ _ -> do + runSessionWithProvider const200 $ \_ _ -> do -- First create a known valid session, so we can see that it's the act -- of corrupting it that makes the test fail. createSession @@ -78,29 +80,55 @@ tests = testGroup "Network.Wai.Auth.OpenIDConnect" response <- get "/some/endpoint" assertStatus 303 response + , testCase "when a request is made to the complete endpoint then create a session" $ + runSessionWithProvider const200 $ \_ _ -> do + response <- get "/prefix/oidc/complete?code=1234" + assertStatus 303 response + assertHeader "location" "/" response + + , testCase "when a request with a valid session is made then the app can access the access token" $ + let app req respond = + case getAccessToken req of + Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" + Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" + in runSessionWithProvider app $ \_ _ -> do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + + , testCase "when a request with a valid session is made then the app can access the id token" $ + let app req respond = + case getIdToken req of + Nothing -> respond $ Wai.responseLBS Status.badRequest400 [] "" + Just _ -> respond $ Wai.responseLBS Status.ok200 [] "" + in runSessionWithProvider app $ \_ _ -> do + createSession + response <- get "/some/endpoint" + assertStatus 200 response + , testCase "when an ID token has an invalid audience then redirect to re-authorize" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { jwtAudience = fromString "wrong-audience" }) createSession response <- get "/some/endpoint" assertStatus 303 response , testCase "when an ID token has an invalid issuer then redirect to re-authorize" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { jwtIssuer = "wrong-issuer" }) createSession response <- get "/some/endpoint" assertStatus 303 response , testCase "when a session does not contain an ID token then redirect to re-authorize" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do changeProvider (\c -> c { returnIdToken = False }) createSession response <- get "/some/endpoint" assertStatus 303 response , testCase "when an ID token has an invalid signature then redirect to re-authorize" $ - runSessionWithProvider $ \_ changeProvider -> do + runSessionWithProvider const200 $ \_ changeProvider -> do newJWK <- liftIO $ JOSE.genJWK (JOSE.RSAGenParam 256) changeProvider (\c -> c { jwtJWK = newJWK }) createSession @@ -111,14 +139,14 @@ tests = testGroup "Network.Wai.Auth.OpenIDConnect" createSession :: Session () createSession = void $ get "/prefix/oidc/complete?code=1234" -runSessionWithProvider :: (U.URI -> ChangeProvider -> Session a) -> IO a -runSessionWithProvider session = do +runSessionWithProvider :: Wai.Application -> (U.URI -> ChangeProvider -> Session a) -> IO a +runSessionWithProvider app session = do (provider, changeProvider) <- fakeProvider Warp.testWithApplication (pure provider) $ \port -> do let host = parseURI $ "http://localhost:" <> T.pack (show port) middleware <- Auth.mkAuthMiddleware =<< authSettings host - let app = middleware const200 - runSession (session host changeProvider) app + let app' = middleware app + runSession (session host changeProvider) app' authSettings :: U.URI -> IO Auth.AuthSettings authSettings host = do From d62b4a9542c58c2c43135346d3b15a8377fcd6d5 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 12:02:32 +0000 Subject: [PATCH 10/23] Update CHANGELOG I'm not sure if this is the definitive version. We can change it if not. --- CHANGELOG.md | 21 ++++++++++++------- .../Wai/Middleware/Auth/OpenIDConnect.hs | 8 ++++--- src/Network/Wai/Middleware/Auth/Provider.hs | 2 +- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a28cb0d..23293bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,27 @@ -# 0.2.3.0 -======== +0.2.3.0 +======= * Support `hoauth2-1.11.0` * Expose `decodeKey` +* OAuth2 provider remove a session when an access token expires. It will use a + refresh token if one is available to create a new session. If no refresh token + is available it will redirect the user to re-authenticate. +* Providers can define logic for refreshing a session without user intervention. +* Add an OpenID Connect provider. -# 0.2.2.0 -======== +0.2.2.0 +======= * Add request logging to executable * Newer multistage Docker build system -# 0.2.1.0 -======== +0.2.1.0 +======= * Fix a bug in deserialization of `UserIdentity` -# 0.2.0.0 -======== +0.2.0.0 +======= * Drop compatiblity with hoauth2 versions <= 1.0.0. * Add a function for getting the oauth2 token from an authenticated request. diff --git a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs index b758001..055eafb 100644 --- a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -53,7 +53,7 @@ import qualified URI.ByteString as U -- | An Open ID Connect provider. -- --- @since X.Y.Z +-- @since 0.2.3.0 data OpenIDConnect = OpenIDConnect { oidcMetadata :: Metadata @@ -141,6 +141,8 @@ instance AuthProvider OpenIDConnect where pure (Just (storeClaims claims req, user)) -- | Obtain configuration of an OpenID Connect from its discovery endpoint. +-- +-- @since 0.2.3.0 discover :: U.URI -> IO OpenIDConnect discover base = do let uri = base { U.uriPath = "/.well-known/openid-configuration" } @@ -223,10 +225,10 @@ storeClaims claims req = -- | Get the @IdToken@ for the current user. -- --- If called on a @Request@ behind the middleware, should almost return a +-- If called on a @Request@ behind the middleware, should always return a -- @Just@ value. -- --- @since X.Y.Z +-- @since 0.2.3.0 getIdToken :: Request -> Maybe JWT.ClaimsSet getIdToken req = Vault.lookup idTokenKey (vault req) diff --git a/src/Network/Wai/Middleware/Auth/Provider.hs b/src/Network/Wai/Middleware/Auth/Provider.hs index ef8bb43..c4bbb70 100644 --- a/src/Network/Wai/Middleware/Auth/Provider.hs +++ b/src/Network/Wai/Middleware/Auth/Provider.hs @@ -111,7 +111,7 @@ class AuthProvider ap where -- -- The default implementation never invalidates a session once set. -- - -- @since X.Y.Z + -- @since 0.2.3.0 refreshLoginState :: ap -> Request From dc68949cd8015634e51b349ef3e1d2c368addd61 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 12:07:58 +0000 Subject: [PATCH 11/23] Move metadata into the `Internal` module It's not part of the public API. We could expose the type, but there's (at the moment) nothing you would be able to do with it. Better for it not to show up in documentation then. --- src/Network/Wai/Auth/Internal.hs | 25 +++++++++++++++++ .../Wai/Middleware/Auth/OpenIDConnect.hs | 28 ++----------------- test/Network/Wai/Auth/Test.hs | 1 + 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/Network/Wai/Auth/Internal.hs b/src/Network/Wai/Auth/Internal.hs index 0462505..cb105e7 100644 --- a/src/Network/Wai/Auth/Internal.hs +++ b/src/Network/Wai/Auth/Internal.hs @@ -1,15 +1,18 @@ {-# OPTIONS_HADDOCK hide, not-home #-} +{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} module Network.Wai.Auth.Internal ( OAuth2TokenBinary(..) + , Metadata(..) , encodeToken , decodeToken , oauth2Login , refreshTokens ) where +import qualified Data.Aeson as Aeson import Data.Binary (Binary(get, put), encode, decodeOrFail) import qualified Data.ByteString as S @@ -19,6 +22,7 @@ import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8, decodeUtf8With) import Data.Text.Encoding.Error (lenientDecode) +import GHC.Generics (Generic) import Network.HTTP.Client (Manager) import Network.HTTP.Types (Status, status303, status403, status404, @@ -128,3 +132,24 @@ appendQueryParams uri params = getRedirectURI :: U.URIRef a -> S.ByteString getRedirectURI = U.serializeURIRef' + +data Metadata + = Metadata + { issuer :: T.Text + , authorizationEndpoint :: U.URI + , tokenEndpoint :: U.URI + , userinfoEndpoint :: Maybe T.Text + , revocationEndpoint :: Maybe T.Text + , jwksUri :: T.Text + , responseTypesSupported :: [T.Text] + , subjectTypesSupported :: [T.Text] + , idTokenSigningAlgValuesSupported :: [T.Text] + , scopesSupported :: Maybe [T.Text] + , tokenEndpointAuthMethodsSupported :: Maybe [T.Text] + , claimsSupported :: Maybe [T.Text] + } + deriving (Generic) + +instance Aeson.FromJSON Metadata + +instance Aeson.ToJSON Metadata diff --git a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs index 055eafb..b7c7714 100644 --- a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs @@ -1,9 +1,7 @@ -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} module Network.Wai.Middleware.Auth.OpenIDConnect ( OpenIDConnect - , Metadata(..) , discover , getAccessToken , getIdToken @@ -20,7 +18,6 @@ import Control.Applicative ((<|>)) import qualified Crypto.JOSE as JOSE import qualified Crypto.JWT as JWT import Control.Monad.Except (runExceptT) -import qualified Data.Aeson as Aeson import qualified Data.ByteString.Char8 as S8 import Data.Function ((&)) import qualified Data.Time.Clock as Clock @@ -30,7 +27,6 @@ import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy.Encoding as TLE import qualified Data.Vault.Lazy as Vault import Foreign.C.Types (CTime (..)) -import GHC.Generics (Generic) import qualified Lens.Micro as Lens import qualified Lens.Micro.Extras as Lens.Extras import Network.HTTP.Simple (httpJSON, @@ -42,7 +38,8 @@ import qualified Network.OAuth.OAuth2 as OA2 import Network.HTTP.Client (Manager) import Network.HTTP.Client.TLS (getGlobalManager) import Network.Wai (Request, vault) -import Network.Wai.Auth.Internal (decodeToken, encodeToken, +import Network.Wai.Auth.Internal (Metadata(..), + decodeToken, encodeToken, oauth2Login, refreshTokens) import Network.Wai.Middleware.Auth.Provider @@ -78,27 +75,6 @@ data OpenIDConnect , oidcAllowedSkew :: Clock.NominalDiffTime } -data Metadata - = Metadata - { issuer :: T.Text - , authorizationEndpoint :: U.URI - , tokenEndpoint :: U.URI - , userinfoEndpoint :: Maybe T.Text - , revocationEndpoint :: Maybe T.Text - , jwksUri :: T.Text - , responseTypesSupported :: [T.Text] - , subjectTypesSupported :: [T.Text] - , idTokenSigningAlgValuesSupported :: [T.Text] - , scopesSupported :: Maybe [T.Text] - , tokenEndpointAuthMethodsSupported :: Maybe [T.Text] - , claimsSupported :: Maybe [T.Text] - } - deriving (Generic) - -instance Aeson.FromJSON Metadata - -instance Aeson.ToJSON Metadata - instance AuthProvider OpenIDConnect where getProviderName _ = "oidc" getProviderInfo = oidcProviderInfo diff --git a/test/Network/Wai/Auth/Test.hs b/test/Network/Wai/Auth/Test.hs index baa870e..9df9e40 100644 --- a/test/Network/Wai/Auth/Test.hs +++ b/test/Network/Wai/Auth/Test.hs @@ -27,6 +27,7 @@ import GHC.Exts (fromString) import qualified Network.HTTP.Types.Status as Status import qualified Network.OAuth.OAuth2 as OA2 import qualified Network.Wai as Wai +import Network.Wai.Auth.Internal (Metadata(..)) import Network.Wai.Middleware.Auth.OpenIDConnect import Network.Wai.Test (Session, SResponse, defaultRequest, From 750896735790641e04c65501a28c81604e84d35f Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 12:11:41 +0000 Subject: [PATCH 12/23] Rename OpenIDConnect to OIDC It's a common shorthand for OpenIDConnect and makes for a shorter module name. --- .../Wai/Middleware/Auth/{OpenIDConnect.hs => OIDC.hs} | 2 +- test/Main.hs | 4 ++-- test/Network/Wai/Auth/Test.hs | 1 - .../Wai/Middleware/Auth/{OpenIDConnect.hs => OIDC.hs} | 6 +++--- wai-middleware-auth.cabal | 4 ++-- 5 files changed, 8 insertions(+), 9 deletions(-) rename src/Network/Wai/Middleware/Auth/{OpenIDConnect.hs => OIDC.hs} (99%) rename test/Spec/Network/Wai/Middleware/Auth/{OpenIDConnect.hs => OIDC.hs} (97%) diff --git a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs similarity index 99% rename from src/Network/Wai/Middleware/Auth/OpenIDConnect.hs rename to src/Network/Wai/Middleware/Auth/OIDC.hs index b7c7714..bcb773f 100644 --- a/src/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -1,6 +1,6 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} -module Network.Wai.Middleware.Auth.OpenIDConnect +module Network.Wai.Middleware.Auth.OIDC ( OpenIDConnect , discover , getAccessToken diff --git a/test/Main.hs b/test/Main.hs index a606390..6bd60a9 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -5,7 +5,7 @@ module Main (main) where import Test.Tasty import qualified Spec.Network.Wai.Auth.Internal import qualified Spec.Network.Wai.Middleware.Auth.OAuth2 -import qualified Spec.Network.Wai.Middleware.Auth.OpenIDConnect +import qualified Spec.Network.Wai.Middleware.Auth.OIDC main :: IO () main = defaultMain tests @@ -14,5 +14,5 @@ tests :: TestTree tests = testGroup "wai-middleware-auth" [ Spec.Network.Wai.Auth.Internal.tests , Spec.Network.Wai.Middleware.Auth.OAuth2.tests - , Spec.Network.Wai.Middleware.Auth.OpenIDConnect.tests + , Spec.Network.Wai.Middleware.Auth.OIDC.tests ] diff --git a/test/Network/Wai/Auth/Test.hs b/test/Network/Wai/Auth/Test.hs index 9df9e40..c63b5ca 100644 --- a/test/Network/Wai/Auth/Test.hs +++ b/test/Network/Wai/Auth/Test.hs @@ -28,7 +28,6 @@ import qualified Network.HTTP.Types.Status as Status import qualified Network.OAuth.OAuth2 as OA2 import qualified Network.Wai as Wai import Network.Wai.Auth.Internal (Metadata(..)) -import Network.Wai.Middleware.Auth.OpenIDConnect import Network.Wai.Test (Session, SResponse, defaultRequest, request, setPath) diff --git a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs similarity index 97% rename from test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs rename to test/Spec/Network/Wai/Middleware/Auth/OIDC.hs index c615983..b0ac96f 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OpenIDConnect.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs @@ -1,7 +1,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -module Spec.Network.Wai.Middleware.Auth.OpenIDConnect (tests) where +module Spec.Network.Wai.Middleware.Auth.OIDC (tests) where import Control.Monad (void) import Control.Monad.IO.Class (liftIO) @@ -18,7 +18,7 @@ import Network.Wai.Auth.Test (ChangeProvider, parseURI) import qualified Network.Wai.Handler.Warp as Warp import qualified Network.Wai.Middleware.Auth as Auth -import Network.Wai.Middleware.Auth.OpenIDConnect +import Network.Wai.Middleware.Auth.OIDC import Network.Wai.Middleware.Auth.Provider (Provider(..)) import Network.Wai.Test (Session, assertHeader, assertStatus, @@ -30,7 +30,7 @@ import qualified URI.ByteString as U import qualified Web.Cookie as Cookie tests :: TestTree -tests = testGroup "Network.Wai.Auth.OpenIDConnect" +tests = testGroup "Network.Wai.Auth.OIDC" [ testCase "when a request without a session is made then redirect to re-authorize" $ runSessionWithProvider const200 $ \host _ -> do redirect1 <- get "/hi" diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index e4b4204..241c9cc 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -16,7 +16,7 @@ library Network.Wai.Middleware.Auth.OAuth2 Network.Wai.Middleware.Auth.OAuth2.Github Network.Wai.Middleware.Auth.OAuth2.Google - Network.Wai.Middleware.Auth.OpenIDConnect + Network.Wai.Middleware.Auth.OIDC Network.Wai.Middleware.Auth.Provider Network.Wai.Auth.Executable Network.Wai.Auth.Internal @@ -85,7 +85,7 @@ test-suite spec other-modules: Network.Wai.Auth.Test , Spec.Network.Wai.Auth.Internal , Spec.Network.Wai.Middleware.Auth.OAuth2 - , Spec.Network.Wai.Middleware.Auth.OpenIDConnect + , Spec.Network.Wai.Middleware.Auth.OIDC build-depends: base , aeson , binary From 452bd5c3ef11d127137074052abd359aa7b49db4 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 13:09:31 +0000 Subject: [PATCH 13/23] Tweak documentation --- src/Network/Wai/Middleware/Auth/OIDC.hs | 63 ++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index bcb773f..9d2416c 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -1,17 +1,25 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} +-- | An OpenID connect provider. +-- +-- OpenID Connect is a simple identity layer on top of the OAuth2 protocol. +-- Learn more about it here: +-- +-- @since 0.2.3.0 module Network.Wai.Middleware.Auth.OIDC - ( OpenIDConnect + ( -- * Creating a provider + OpenIDConnect , discover - , getAccessToken - , getIdToken - -- * Customizing an OpenIDConnect provider + -- * Customizing a provider , oidcClientId , oidcClientSecret , oidcProviderInfo , oidcManager , oidcScopes , oidcAllowedSkew + -- * Accessing session data + , getAccessToken + , getIdToken ) where import Control.Applicative ((<|>)) @@ -50,6 +58,9 @@ import qualified URI.ByteString as U -- | An Open ID Connect provider. -- +-- To create a value use `discover` to download configuration for an existing +-- provider, then use various setter functions to customize it. +-- -- @since 0.2.3.0 data OpenIDConnect = OpenIDConnect @@ -58,20 +69,32 @@ data OpenIDConnect -- | The client id this application is registered with at the Open ID -- Connect provider. The default is an empty string, you will need to -- overwrite this. + -- + -- @since 0.2.3.0 , oidcClientId :: T.Text -- | The client secret of this application. The default is an empty -- string, you will need to overwrite this. + -- + -- @since 0.2.3.0 , oidcClientSecret :: T.Text -- | The information for this provider. The default contains some -- placeholder texts. If you're using the provider screen you'll want to -- overwrite this. + -- + -- @since 0.2.3.0 , oidcProviderInfo :: ProviderInfo -- | The HTTP manager to use. Defaults to the global manager. + -- + -- @since 0.2.3.0 , oidcManager :: Manager -- | The scopes to set. Defaults to only the "openid" scope. + -- + -- @since 0.2.3.0 , oidcScopes :: [T.Text] -- | The amount of clock skew to allow when validating id tokens. Defaults -- to 0. + -- + -- @since 0.2.3.0 , oidcAllowedSkew :: Clock.NominalDiffTime } @@ -116,7 +139,7 @@ instance AuthProvider OpenIDConnect where Just claims -> pure (Just (storeClaims claims req, user)) --- | Obtain configuration of an OpenID Connect from its discovery endpoint. +-- | Fetch configuration for a provider from its discovery endpoint. -- -- @since 0.2.3.0 discover :: U.URI -> IO OpenIDConnect @@ -169,10 +192,35 @@ validateIdToken' oidc tokens = Just idToken -> either (const Nothing) Just <$> validateIdToken oidc idToken +-- The validation of the ID token below is stricter then specified in the OIDC +-- spec, to make the job of validating tokens easier. If this is too limiting +-- for your user case please open an issue. +-- +-- Full spec for ID token validation: +-- https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation +-- +-- Ways in which the validation below is stricter then the spec requires: +-- - We don't allow the `aud` claim to contain any audiences beyond ourselves. validationSettings :: OpenIDConnect -> JWT.JWTValidationSettings validationSettings oidc = - JWT.defaultJWTValidationSettings (validateAudience oidc) + -- The Client MUST validate that the aud (audience) Claim contains its + -- client_id value registered at the Issuer identified by the iss (issuer) + -- Claim as an audience. The aud (audience) Claim MAY contain an array with + -- more than one element. The ID Token MUST be rejected if the ID Token does + -- not list the Client as a valid audience, or if it contains additional + -- audiences not trusted by the Client. + validateAudience oidc + -- If the ID Token is encrypted, decrypt it using the keys and algorithms + -- that the Client specified during Registration that the OP was to use to + -- encrypt the ID Token. If encryption was negotiated with the OP at + -- Registration time and the ID Token is not encrypted, the RP SHOULD + -- reject it. + & JWT.defaultJWTValidationSettings + -- The current time MUST be before the time represented by the exp Claim. & Lens.set JWT.jwtValidationSettingsCheckIssuedAt True + -- The Issuer Identifier for the OpenID Provider (which is typically + -- obtained during Discovery) MUST exactly match the value of the iss + -- (issuer) Claim. & Lens.set JWT.jwtValidationSettingsIssuerPredicate (validateIssuer oidc) & Lens.set JWT.jwtValidationSettingsAllowedSkew (oidcAllowedSkew oidc) @@ -204,6 +252,9 @@ storeClaims claims req = -- If called on a @Request@ behind the middleware, should always return a -- @Just@ value. -- +-- The token returned was validated when the request was processed by the +-- middleware. +-- -- @since 0.2.3.0 getIdToken :: Request -> Maybe JWT.ClaimsSet getIdToken req = Vault.lookup idTokenKey (vault req) From e8663ff9a43535a0c0d4a80da68a4a9c923e549a Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Sun, 2 Feb 2020 13:16:22 +0000 Subject: [PATCH 14/23] `discover` function takes a text That's one less library (uri-bytestring) a user needs needs to install and import to be able to use this middleware. --- src/Network/Wai/Middleware/Auth/OIDC.hs | 5 +++-- test/Network/Wai/Auth/Test.hs | 1 - test/Spec/Network/Wai/Middleware/Auth/OIDC.hs | 13 ++++++------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index 9d2416c..0f52f86 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -142,8 +142,9 @@ instance AuthProvider OpenIDConnect where -- | Fetch configuration for a provider from its discovery endpoint. -- -- @since 0.2.3.0 -discover :: U.URI -> IO OpenIDConnect -discover base = do +discover :: T.Text -> IO OpenIDConnect +discover urlText = do + base <- parseAbsoluteURI urlText let uri = base { U.uriPath = "/.well-known/openid-configuration" } metadata <- fetchMetadata uri jwkset <- fetchJWKSet (jwksUri metadata) diff --git a/test/Network/Wai/Auth/Test.hs b/test/Network/Wai/Auth/Test.hs index c63b5ca..f72b741 100644 --- a/test/Network/Wai/Auth/Test.hs +++ b/test/Network/Wai/Auth/Test.hs @@ -7,7 +7,6 @@ module Network.Wai.Auth.Test , fakeProvider , const200 , get - , parseURI ) where import Control.Monad.IO.Class (liftIO) diff --git a/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs index b0ac96f..6dda11b 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs @@ -8,14 +8,14 @@ import Control.Monad.IO.Class (liftIO) import qualified Crypto.JOSE as JOSE import Data.Function ((&)) import qualified Data.Text as T +import qualified Data.Text.Encoding as TE import GHC.Exts (fromList, fromString) import qualified Network.HTTP.Types.Status as Status import qualified Network.Wai as Wai import Network.Wai.Auth.Test (ChangeProvider, FakeProviderConf(..), fakeProvider, - const200, get, - parseURI) + const200, get) import qualified Network.Wai.Handler.Warp as Warp import qualified Network.Wai.Middleware.Auth as Auth import Network.Wai.Middleware.Auth.OIDC @@ -26,7 +26,6 @@ import Network.Wai.Test (Session, assertHeader, setClientCookie) import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase) -import qualified URI.ByteString as U import qualified Web.Cookie as Cookie tests :: TestTree @@ -43,7 +42,7 @@ tests = testGroup "Network.Wai.Auth.OIDC" assertStatus 303 redirect3 assertHeader "location" - (U.serializeURIRef' host <> "/authorize?scope=openid%2Cscope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") + (TE.encodeUtf8 host <> "/authorize?scope=openid%2Cscope1&client_id=client-id&response_type=code&redirect_uri=http%3A%2F%2Flocalhost%2Fprefix%2Foidc%2Fcomplete") redirect3 , testCase "when a request is made with a valid session then pass the request through" $ @@ -139,16 +138,16 @@ tests = testGroup "Network.Wai.Auth.OIDC" createSession :: Session () createSession = void $ get "/prefix/oidc/complete?code=1234" -runSessionWithProvider :: Wai.Application -> (U.URI -> ChangeProvider -> Session a) -> IO a +runSessionWithProvider :: Wai.Application -> (T.Text -> ChangeProvider -> Session a) -> IO a runSessionWithProvider app session = do (provider, changeProvider) <- fakeProvider Warp.testWithApplication (pure provider) $ \port -> do - let host = parseURI $ "http://localhost:" <> T.pack (show port) + let host = "http://localhost:" <> T.pack (show port) middleware <- Auth.mkAuthMiddleware =<< authSettings host let app' = middleware app runSession (session host changeProvider) app' -authSettings :: U.URI -> IO Auth.AuthSettings +authSettings :: T.Text -> IO Auth.AuthSettings authSettings host = do oidc' <- discover host let oidc = From 48639e5d96b7577abc45765b7880fef7f0d1fedc Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Thu, 27 Feb 2020 11:27:53 +0100 Subject: [PATCH 15/23] Support older versions of `jose` packages --- src/Network/Wai/Middleware/Auth/OIDC.hs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index 0f52f86..6b7586d 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} -- | An OpenID connect provider. @@ -241,9 +242,24 @@ validateIssuer oidc issClaim = fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text fromStringOrURI stringOrURI = - Lens.Extras.preview JWT.string stringOrURI + fmap toText (Lens.Extras.preview JWT.string stringOrURI) <|> fmap (T.pack . show) (Lens.Extras.preview JWT.uri stringOrURI) +-- A small helper class for compatibility with different versions of the `jose` +-- library. Pre-0.8.x `JWT.string` produces a `String`. Post-0.8.x it produces a +-- `Text`. This type class allows us to support both. +-- +-- We can drop this once we no longer wish to support `jose` versions 0.7.x and +-- before. +class ToText a where + toText :: a -> T.Text + +instance ToText T.Text where + toText = id + +instance ToText [Char] where + toText = T.pack + storeClaims :: JWT.ClaimsSet -> Request -> Request storeClaims claims req = req { vault = Vault.insert idTokenKey claims (vault req) } From e40e3c7e90d025ab03ad29bac9b7847643719ac6 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Thu, 27 Feb 2020 18:21:02 +0100 Subject: [PATCH 16/23] Fix decoding of oidc discovery data --- src/Network/Wai/Auth/Internal.hs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Network/Wai/Auth/Internal.hs b/src/Network/Wai/Auth/Internal.hs index cb105e7..90abf90 100644 --- a/src/Network/Wai/Auth/Internal.hs +++ b/src/Network/Wai/Auth/Internal.hs @@ -150,6 +150,15 @@ data Metadata } deriving (Generic) -instance Aeson.FromJSON Metadata +instance Aeson.FromJSON Metadata where + parseJSON = Aeson.genericParseJSON metadataAesonOptions -instance Aeson.ToJSON Metadata +instance Aeson.ToJSON Metadata where + + toJSON = Aeson.genericToJSON metadataAesonOptions + + toEncoding = Aeson.genericToEncoding metadataAesonOptions + +metadataAesonOptions :: Aeson.Options +metadataAesonOptions = + Aeson.defaultOptions {Aeson.fieldLabelModifier = Aeson.camelTo2 '_'} From 70806681d7eca7681b1acd6a56dcc7aff3d40846 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 09:27:55 +0100 Subject: [PATCH 17/23] Drop support for jose 0.7.x --- src/Network/Wai/Middleware/Auth/OIDC.hs | 17 +---------------- wai-middleware-auth.cabal | 2 +- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index 6b7586d..b2bdf8e 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -242,24 +242,9 @@ validateIssuer oidc issClaim = fromStringOrURI :: JWT.StringOrURI -> Maybe T.Text fromStringOrURI stringOrURI = - fmap toText (Lens.Extras.preview JWT.string stringOrURI) + Lens.Extras.preview JWT.string stringOrURI <|> fmap (T.pack . show) (Lens.Extras.preview JWT.uri stringOrURI) --- A small helper class for compatibility with different versions of the `jose` --- library. Pre-0.8.x `JWT.string` produces a `String`. Post-0.8.x it produces a --- `Text`. This type class allows us to support both. --- --- We can drop this once we no longer wish to support `jose` versions 0.7.x and --- before. -class ToText a where - toText :: a -> T.Text - -instance ToText T.Text where - toText = id - -instance ToText [Char] where - toText = T.pack - storeClaims :: JWT.ClaimsSet -> Request -> Request storeClaims claims req = req { vault = Vault.insert idTokenKey claims (vault req) } diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 241c9cc..571ae2c 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -43,7 +43,7 @@ library , http-conduit , http-reverse-proxy , http-types - , jose + , jose >= 0.8.0 , microlens , mtl , regex-posix From 9bb03503834dcf5448ce4f9374af07baea200a8a Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 10:02:49 +0100 Subject: [PATCH 18/23] Add FromJSON instance to `OpenIDConnect` type This allows it to be loaded from a configuration file --- src/Network/Wai/Middleware/Auth/OIDC.hs | 41 ++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index b2bdf8e..16c0cde 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -27,6 +27,8 @@ import Control.Applicative ((<|>)) import qualified Crypto.JOSE as JOSE import qualified Crypto.JWT as JWT import Control.Monad.Except (runExceptT) +import Data.Aeson (FromJSON(parseJSON), + withObject, (.:), (.!=)) import qualified Data.ByteString.Char8 as S8 import Data.Function ((&)) import qualified Data.Time.Clock as Clock @@ -84,10 +86,10 @@ data OpenIDConnect -- -- @since 0.2.3.0 , oidcProviderInfo :: ProviderInfo - -- | The HTTP manager to use. Defaults to the global manager. + -- | The HTTP manager to use. Defaults to the global manager when not set. -- -- @since 0.2.3.0 - , oidcManager :: Manager + , oidcManager :: Maybe Manager -- | The scopes to set. Defaults to only the "openid" scope. -- -- @since 0.2.3.0 @@ -99,14 +101,36 @@ data OpenIDConnect , oidcAllowedSkew :: Clock.NominalDiffTime } +instance FromJSON OpenIDConnect where + parseJSON = + withObject "OpenIDConnect Object" $ \obj -> do + metadata <- obj .: "metadata" + jwkSet <- obj .: "jwk_set" + clientId <- obj .: "client_id" + clientSecret <- obj .: "client_secret" + providerInfo <- obj .: "provider_info" .!= defProviderInfo + scopes <- obj .: "scopes" .!= ["openid"] + allowedSkew <- obj .: "allowed_skew" .!= 0 + pure OpenIDConnect { + oidcMetadata = metadata, + oidcJwkSet = jwkSet, + oidcClientId = clientId, + oidcClientSecret = clientSecret, + oidcProviderInfo = providerInfo, + oidcManager = Nothing, + oidcScopes = scopes, + oidcAllowedSkew = allowedSkew + } + instance AuthProvider OpenIDConnect where getProviderName _ = "oidc" getProviderInfo = oidcProviderInfo handleLogin oidc@OpenIDConnect {.. } req suffix renderUrl onSuccess onFailure = do oauth2 <- mkOauth2 oidc (Just renderUrl) + manager <- maybe getGlobalManager pure oidcManager oauth2Login oauth2 - oidcManager + manager (Just oidcScopes) (getProviderName oidc) req @@ -122,7 +146,8 @@ instance AuthProvider OpenIDConnect where case vRes of Nothing -> do oauth2 <- mkOauth2 oidc Nothing - rRes <- refreshTokens tokens (oidcManager oidc) oauth2 + manager <- maybe getGlobalManager pure (oidcManager oidc) + rRes <- refreshTokens tokens manager oauth2 case rRes of Nothing -> pure Nothing Just newTokens -> do @@ -149,18 +174,20 @@ discover urlText = do let uri = base { U.uriPath = "/.well-known/openid-configuration" } metadata <- fetchMetadata uri jwkset <- fetchJWKSet (jwksUri metadata) - manager <- getGlobalManager pure OpenIDConnect { oidcClientId = "" , oidcClientSecret = "" , oidcMetadata = metadata , oidcJwkSet = jwkset - , oidcProviderInfo = ProviderInfo "OpenID Connect Provider" "" "" - , oidcManager = manager + , oidcProviderInfo = defProviderInfo + , oidcManager = Nothing , oidcScopes = ["openid"] , oidcAllowedSkew = 0 } +defProviderInfo :: ProviderInfo +defProviderInfo = ProviderInfo "OpenID Connect Provider" "" "" + fetchMetadata :: U.URI -> IO Metadata fetchMetadata metadataEndpoint = do req <- parseRequestThrow (S8.unpack $ U.serializeURIRef' metadataEndpoint) From d2f645dfa2ae48941b25f51d065ec5e02dfb7d7c Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 10:14:15 +0100 Subject: [PATCH 19/23] Move values closer to where they're used --- src/Network/Wai/Middleware/Auth/OAuth2.hs | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index ef62b7f..5c85ec8 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -100,27 +100,27 @@ instance AuthProvider OAuth2 where refreshLoginState OAuth2 {..} req user = do authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint - let oauth2 = - OA2.OAuth2 - { oauthClientId = getClientId oa2ClientId - , oauthClientSecret = getClientSecret oa2ClientSecret - , oauthOAuthorizeEndpoint = authEndpointURI - , oauthAccessTokenEndpoint = accessTokenEndpointURI - -- Setting callback endpoint to `Nothing` below is a lie. - -- We do have a callback endpoint but in this context - -- don't have access to the function that can render it. - -- We get away with this because the callback endpoint is - -- not needed for obtaining a refresh token, the only - -- way we use the config here constructed. - , oauthCallback = Nothing - } - man <- getGlobalManager let loginState = authLoginState user case decodeToken loginState of Left _ -> pure Nothing Right tokens -> do CTime now <- epochTime if tokenExpired user now tokens then do + let oauth2 = + OA2.OAuth2 + { oauthClientId = getClientId oa2ClientId + , oauthClientSecret = getClientSecret oa2ClientSecret + , oauthOAuthorizeEndpoint = authEndpointURI + , oauthAccessTokenEndpoint = accessTokenEndpointURI + -- Setting callback endpoint to `Nothing` below is a lie. + -- We do have a callback endpoint but in this context + -- don't have access to the function that can render it. + -- We get away with this because the callback endpoint is + -- not needed for obtaining a refresh token, the only + -- way we use the config here constructed. + , oauthCallback = Nothing + } + man <- getGlobalManager rRes <- refreshTokens tokens man oauth2 case rRes of Nothing -> pure Nothing From b1f59072be4f411b0ff0381fb54063f3641c1425 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 10:19:26 +0100 Subject: [PATCH 20/23] Replace case statement with `fmap` --- src/Network/Wai/Middleware/Auth/OAuth2.hs | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index 5c85ec8..7b16927 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -14,6 +14,7 @@ import Control.Monad.Catch import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) +import Data.Functor ((<&>)) import Data.Int (Int64) import Data.Proxy (Proxy (..)) import qualified Data.Text as T @@ -122,15 +123,10 @@ instance AuthProvider OAuth2 where } man <- getGlobalManager rRes <- refreshTokens tokens man oauth2 - case rRes of - Nothing -> pure Nothing - Just newTokens -> - let user' = - user { - authLoginState = encodeToken newTokens, - authLoginTime = fromIntegral now - } - in pure (Just (req, user')) + pure (rRes <&> \newTokens -> (req, user { + authLoginState = encodeToken newTokens, + authLoginTime = fromIntegral now + })) else pure (Just (req, user)) From cfef2093b2729b98a7bfd3d0a56e9e3072b3ad24 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 10:55:13 +0100 Subject: [PATCH 21/23] Changelog change in supported `jose` versions --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23293bf..d4b03ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ======= * Support `hoauth2-1.11.0` +* Drop support for `jose` versions < 0.8 * Expose `decodeKey` * OAuth2 provider remove a session when an access token expires. It will use a refresh token if one is available to create a new session. If no refresh token From 1d7ac3a41d61f3f8395ccdda565764cc7d0b7c09 Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 11:10:18 +0100 Subject: [PATCH 22/23] Fix compilation issue from bad merge --- src/Network/Wai/Middleware/Auth/OAuth2.hs | 2 +- src/Network/Wai/Middleware/Auth/OIDC.hs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index 7b16927..9ea2b94 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -110,7 +110,7 @@ instance AuthProvider OAuth2 where let oauth2 = OA2.OAuth2 { oauthClientId = getClientId oa2ClientId - , oauthClientSecret = getClientSecret oa2ClientSecret + , oauthClientSecret = Just (getClientSecret oa2ClientSecret) , oauthOAuthorizeEndpoint = authEndpointURI , oauthAccessTokenEndpoint = accessTokenEndpointURI -- Setting callback endpoint to `Nothing` below is a lie. diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index 16c0cde..ec47343 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -203,7 +203,7 @@ mkOauth2 OpenIDConnect {..} renderUrl = do callbackURI <- for renderUrl $ \render -> parseAbsoluteURI $ render (ProviderUrl ["complete"]) [] pure OA2.OAuth2 { oauthClientId = oidcClientId - , oauthClientSecret = oidcClientSecret + , oauthClientSecret = Just oidcClientSecret , oauthOAuthorizeEndpoint = authorizationEndpoint oidcMetadata , oauthAccessTokenEndpoint = tokenEndpoint oidcMetadata , oauthCallback = callbackURI From afd1c007b40a4d6688021b16a8111cf91d17ad2f Mon Sep 17 00:00:00 2001 From: Jasper Woudenberg Date: Fri, 15 May 2020 11:57:24 +0100 Subject: [PATCH 23/23] Fix CI By default the `clientsession` library obtains an encryption key to use for client sessions from a file. When tests run in parallel they all try to use this file at once, leading to test flakiness. In this commit we override the default to use a generated in memory key instead. This should not have the same concurrency downsides. --- test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs | 2 ++ test/Spec/Network/Wai/Middleware/Auth/OIDC.hs | 2 ++ wai-middleware-auth.cabal | 1 + 3 files changed, 5 insertions(+) diff --git a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs index 5298550..dcb800f 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OAuth2.hs @@ -26,6 +26,7 @@ import Network.Wai.Test (Session, assertHeader, import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase) import qualified Web.Cookie as Cookie +import qualified Web.ClientSession tests :: TestTree tests = testGroup "Network.Wai.Auth.OAuth2" @@ -103,6 +104,7 @@ authSettings host = & Auth.setAuthProviders (fromList [("oauth2", provider host)]) & Auth.setAuthPrefix "prefix" & Auth.setAuthCookieName "auth-cookie" + & Auth.setAuthKey (snd <$> Web.ClientSession.randomKey) provider :: T.Text -> Provider provider host = diff --git a/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs index 6dda11b..4280db6 100644 --- a/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs +++ b/test/Spec/Network/Wai/Middleware/Auth/OIDC.hs @@ -27,6 +27,7 @@ import Network.Wai.Test (Session, assertHeader, import Test.Tasty (TestTree, testGroup) import Test.Tasty.HUnit (testCase) import qualified Web.Cookie as Cookie +import qualified Web.ClientSession tests :: TestTree tests = testGroup "Network.Wai.Auth.OIDC" @@ -160,3 +161,4 @@ authSettings host = do & Auth.setAuthProviders (fromList [("oidc", Provider oidc)]) & Auth.setAuthPrefix "prefix" & Auth.setAuthCookieName "auth-cookie" + & Auth.setAuthKey (snd <$> Web.ClientSession.randomKey) diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index 571ae2c..bb1e473 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -90,6 +90,7 @@ test-suite spec , aeson , binary , bytestring + , clientsession , cookie , hedgehog , hoauth2