Skip to content

Commit

Permalink
Merge pull request #16 from jwoudenberg/support-refresh-tokens
Browse files Browse the repository at this point in the history
Support refresh tokens
  • Loading branch information
jwoudenberg authored May 15, 2020
2 parents c9544ab + afd1c00 commit 235a5d6
Show file tree
Hide file tree
Showing 12 changed files with 1,073 additions and 127 deletions.
22 changes: 14 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
# 0.2.3.0
========
0.2.3.0
=======

* Support `hoauth2-1.11.0`
* Drop support for `jose` versions < 0.8
* Expose `decodeKey`
* OAuth2 provider remove a session when an access token expires. It will use a
refresh token if one is available to create a new session. If no refresh token
is available it will redirect the user to re-authenticate.
* Providers can define logic for refreshing a session without user intervention.
* Add an OpenID Connect provider.

# 0.2.2.0
========
0.2.2.0
=======

* Add request logging to executable
* Newer multistage Docker build system

# 0.2.1.0
========
0.2.1.0
=======

* Fix a bug in deserialization of `UserIdentity`

# 0.2.0.0
========
0.2.0.0
=======

* Drop compatiblity with hoauth2 versions <= 1.0.0.
* Add a function for getting the oauth2 token from an authenticated request.
Expand Down
123 changes: 123 additions & 0 deletions src/Network/Wai/Auth/Internal.hs
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
{-# OPTIONS_HADDOCK hide, not-home #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
module Network.Wai.Auth.Internal
( OAuth2TokenBinary(..)
, Metadata(..)
, encodeToken
, decodeToken
, oauth2Login
, refreshTokens
) where

import qualified Data.Aeson as Aeson
import Data.Binary (Binary(get, put), encode,
decodeOrFail)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8 (pack)
import qualified Data.ByteString.Lazy as SL
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8,
decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import GHC.Generics (Generic)
import Network.HTTP.Client (Manager)
import Network.HTTP.Types (Status, status303,
status403, status404,
status501)
import qualified Network.OAuth.OAuth2 as OA2
import Network.Wai (Request, Response,
queryString, responseLBS)
import Network.Wai.Middleware.Auth.Provider
import qualified URI.ByteString as U
import URI.ByteString (URI)

decodeToken :: S.ByteString -> Either String OA2.OAuth2Token
decodeToken bs =
Expand Down Expand Up @@ -39,3 +62,103 @@ instance Binary OAuth2TokenBinary where
idToken <- fmap OA2.IdToken <$> get
pure $ OAuth2TokenBinary $
OA2.OAuth2Token accessToken refreshToken expiresIn tokenType idToken

oauth2Login
:: OA2.OAuth2
-> Manager
-> Maybe [T.Text]
-> T.Text
-> Request
-> [T.Text]
-> (AuthLoginState -> IO Response)
-> (Status -> S.ByteString -> IO Response)
-> IO Response
oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure =
case suffix of
[] -> do
let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope
let redirectUrl =
getRedirectURI $
appendQueryParams
(OA2.authorizationUrl oauth2)
(maybe [] ((: []) . ("scope", )) scope)
return $
responseLBS
status303
[("Location", redirectUrl)]
"Redirect to OAuth2 Authentication server"
["complete"] ->
let params = queryString req
in case lookup "code" params of
Just (Just code) -> do
eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code
case eRes of
Left err -> onFailure status501 $ S8.pack $ show err
Right token -> onSuccess $ encodeToken token
_ ->
case lookup "error" params of
(Just (Just "access_denied")) ->
onFailure
status403
"User rejected access to the application."
(Just (Just error_code)) ->
onFailure status501 $ "Received an error: " <> error_code
(Just Nothing) ->
onFailure status501 $
"Unknown error connecting to " <>
encodeUtf8 providerName
Nothing ->
onFailure
status404
"Page not found. Please continue with login."
_ -> onFailure status404 "Page not found. Please continue with login."

refreshTokens :: OA2.OAuth2Token -> Manager -> OA2.OAuth2 -> IO (Maybe OA2.OAuth2Token)
refreshTokens tokens manager oauth2 =
case OA2.refreshToken tokens of
Nothing -> pure Nothing
Just refreshToken -> do
res <- OA2.refreshAccessToken manager oauth2 refreshToken
case res of
Left _ -> pure Nothing
Right newTokens -> pure (Just newTokens)

getExchangeToken :: S.ByteString -> OA2.ExchangeToken
getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode

appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams uri params =
OA2.appendQueryParams params uri

getRedirectURI :: U.URIRef a -> S.ByteString
getRedirectURI = U.serializeURIRef'

data Metadata
= Metadata
{ issuer :: T.Text
, authorizationEndpoint :: U.URI
, tokenEndpoint :: U.URI
, userinfoEndpoint :: Maybe T.Text
, revocationEndpoint :: Maybe T.Text
, jwksUri :: T.Text
, responseTypesSupported :: [T.Text]
, subjectTypesSupported :: [T.Text]
, idTokenSigningAlgValuesSupported :: [T.Text]
, scopesSupported :: Maybe [T.Text]
, tokenEndpointAuthMethodsSupported :: Maybe [T.Text]
, claimsSupported :: Maybe [T.Text]
}
deriving (Generic)

instance Aeson.FromJSON Metadata where
parseJSON = Aeson.genericParseJSON metadataAesonOptions

instance Aeson.ToJSON Metadata where

toJSON = Aeson.genericToJSON metadataAesonOptions

toEncoding = Aeson.genericToEncoding metadataAesonOptions

metadataAesonOptions :: Aeson.Options
metadataAesonOptions =
Aeson.defaultOptions {Aeson.fieldLabelModifier = Aeson.camelTo2 '_'}
28 changes: 25 additions & 3 deletions src/Network/Wai/Middleware/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ import GHC.Generics (Generic)
import Network.HTTP.Types (Header, status200,
status303, status404,
status501)
import Network.Wai (Middleware, Request,
import Network.Wai (mapResponseHeaders,
Middleware, Request,
pathInfo, rawPathInfo,
rawQueryString,
responseBuilder,
Expand Down Expand Up @@ -273,8 +274,29 @@ mkAuthMiddleware AuthSettings {..} = do
authState <- loadCookieValue secretKey asStateKey req
case authState of
Just (AuthLoggedIn user) ->
let req' = req {vault = Vault.insert userKey user $ vault req}
in app req' respond
let providerName = decodeUtf8With lenientDecode (authProviderName user)
in case HM.lookup providerName asProviders of
Nothing ->
-- We can no longer find the provider the user originally
-- authenticated with, and as a result have no way to check if the
-- session is still valid. For backwards compatibility with older
-- versions of this library we'll assume the session remains valid.
let req' = req {vault = Vault.insert userKey user $ vault req}
in app req' respond
Just provider -> do
refreshResult <- refreshLoginState provider req user
case refreshResult of
Nothing ->
-- The session has expired, the user needs to re-authenticate.
enforceLogin "/" req respond
Just (req', user') ->
let req'' = req' {vault = Vault.insert userKey user' $ vault req'}
respond' response
| user' == user = respond response
| otherwise = do
cookieHeader <- saveAuthState (AuthLoggedIn user')
respond $ mapResponseHeaders (cookieHeader :) response
in app req'' respond'
Just (AuthNeedRedirect url) -> enforceLogin url req respond
Nothing -> enforceLogin "/" req respond

Expand Down
126 changes: 59 additions & 67 deletions src/Network/Wai/Middleware/Auth/OAuth2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,23 @@ import Control.Monad.Catch
import Data.Aeson.TH (defaultOptions,
deriveJSON,
fieldLabelModifier)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8 (pack)
import Data.Functor ((<&>))
import Data.Int (Int64)
import Data.Proxy (Proxy (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8,
decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Data.Text.Encoding (encodeUtf8)
import Foreign.C.Types (CTime (..))
import Network.HTTP.Client.TLS (getGlobalManager)
import Network.HTTP.Types (status303, status403,
status404, status501)
import qualified Network.OAuth.OAuth2 as OA2
import Network.Wai (Request, queryString,
responseLBS)
import Network.Wai.Auth.Internal (encodeToken, decodeToken)
import Network.Wai (Request)
import Network.Wai.Auth.Internal (decodeToken, encodeToken,
oauth2Login,
refreshTokens)
import Network.Wai.Auth.Tools (toLowerUnderscore)
import qualified Network.Wai.Middleware.Auth as MA
import Network.Wai.Middleware.Auth.Provider
import System.PosixCompat.Time (epochTime)
import qualified URI.ByteString as U
import URI.ByteString (URI)

-- | General OAuth2 authentication `Provider`.
data OAuth2 = OAuth2
Expand Down Expand Up @@ -62,25 +60,12 @@ parseAbsoluteURI urlTxt = do
Left err -> throwM $ URIParseException err
Right url -> return url

parseAbsoluteURI' :: MonadThrow m => T.Text -> m U.URI
parseAbsoluteURI' = parseAbsoluteURI

getExchangeToken :: S.ByteString -> OA2.ExchangeToken
getExchangeToken = OA2.ExchangeToken . decodeUtf8With lenientDecode

appendQueryParams :: URI -> [(S.ByteString, S.ByteString)] -> URI
appendQueryParams uri params =
OA2.appendQueryParams params uri

getClientId :: T.Text -> T.Text
getClientId = id

getClientSecret :: T.Text -> T.Text
getClientSecret = id

getRedirectURI :: U.URIRef a -> S.ByteString
getRedirectURI = U.serializeURIRef'

-- | Aeson parser for `OAuth2` provider.
--
-- @since 0.1.0
Expand All @@ -92,9 +77,9 @@ instance AuthProvider OAuth2 where
getProviderName _ = "oauth2"
getProviderInfo = oa2ProviderInfo
handleLogin oa2@OAuth2 {..} req suffix renderUrl onSuccess onFailure = do
authEndpointURI <- parseAbsoluteURI' oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI' oa2AccessTokenEndpoint
callbackURI <- parseAbsoluteURI' $ renderUrl (ProviderUrl ["complete"]) []
authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) []
let oauth2 =
OA2.OAuth2
{ oauthClientId = getClientId oa2ClientId
Expand All @@ -103,46 +88,53 @@ instance AuthProvider OAuth2 where
, oauthAccessTokenEndpoint = accessTokenEndpointURI
, oauthCallback = Just callbackURI
}
case suffix of
[] -> do
let scope = (encodeUtf8 . T.intercalate ",") <$> oa2Scope
let redirectUrl =
getRedirectURI $
appendQueryParams
(OA2.authorizationUrl oauth2)
(maybe [] ((: []) . ("scope", )) scope)
return $
responseLBS
status303
[("Location", redirectUrl)]
"Redirect to OAuth2 Authentication server"
["complete"] ->
let params = queryString req
in case lookup "code" params of
Just (Just code) -> do
man <- getGlobalManager
eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code
case eRes of
Left err -> onFailure status501 $ S8.pack $ show err
Right token -> onSuccess $ encodeToken token
_ ->
case lookup "error" params of
(Just (Just "access_denied")) ->
onFailure
status403
"User rejected access to the application."
(Just (Just error_code)) ->
onFailure status501 $ "Received an error: " <> error_code
(Just Nothing) ->
onFailure status501 $
"Unknown error connecting to " <>
encodeUtf8 (getProviderName oa2)
Nothing ->
onFailure
status404
"Page not found. Please continue with login."
_ -> onFailure status404 "Page not found. Please continue with login."

man <- getGlobalManager
oauth2Login
oauth2
man
oa2Scope
(getProviderName oa2)
req
suffix
onSuccess
onFailure
refreshLoginState OAuth2 {..} req user = do
authEndpointURI <- parseAbsoluteURI oa2AuthorizeEndpoint
accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint
let loginState = authLoginState user
case decodeToken loginState of
Left _ -> pure Nothing
Right tokens -> do
CTime now <- epochTime
if tokenExpired user now tokens then do
let oauth2 =
OA2.OAuth2
{ oauthClientId = getClientId oa2ClientId
, oauthClientSecret = Just (getClientSecret oa2ClientSecret)
, oauthOAuthorizeEndpoint = authEndpointURI
, oauthAccessTokenEndpoint = accessTokenEndpointURI
-- Setting callback endpoint to `Nothing` below is a lie.
-- We do have a callback endpoint but in this context
-- don't have access to the function that can render it.
-- We get away with this because the callback endpoint is
-- not needed for obtaining a refresh token, the only
-- way we use the config here constructed.
, oauthCallback = Nothing
}
man <- getGlobalManager
rRes <- refreshTokens tokens man oauth2
pure (rRes <&> \newTokens -> (req, user {
authLoginState = encodeToken newTokens,
authLoginTime = fromIntegral now
}))
else
pure (Just (req, user))

tokenExpired :: AuthUser -> Int64 -> OA2.OAuth2Token -> Bool
tokenExpired user now tokens =
case OA2.expiresIn tokens of
Nothing -> False
Just expiresIn -> authLoginTime user + (fromIntegral expiresIn) < now

$(deriveJSON defaultOptions { fieldLabelModifier = toLowerUnderscore . drop 3} ''OAuth2)

Expand Down
Loading

0 comments on commit 235a5d6

Please sign in to comment.