diff --git a/app/Main.hs b/app/Main.hs index 0f25bf9..426d81c 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -3,6 +3,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TemplateHaskell #-} module Main where + import qualified Data.ByteString as S import Data.Serialize (put, runPut) import Network.Wai.Auth.Executable @@ -100,7 +101,9 @@ main = do then snd <$> randomKey else do keyContent <- S.readFile keyInput - either error return (decodeKey keyContent <|> initKey keyContent) + case decodeKey keyContent of + Left _ -> either error return (initKey keyContent) + Right key -> pure key if null keyOutput then S.putStr key else S.writeFile keyOutput key diff --git a/src/Network/Wai/Auth/Config.hs b/src/Network/Wai/Auth/Config.hs index c150aba..dcfa849 100644 --- a/src/Network/Wai/Auth/Config.hs +++ b/src/Network/Wai/Auth/Config.hs @@ -11,7 +11,7 @@ module Network.Wai.Auth.Config , decodeKey ) where -import Data.Aeson +import Data.Aeson hiding (Key) import Data.Aeson.TH (deriveJSON) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) diff --git a/src/Network/Wai/Auth/Internal.hs b/src/Network/Wai/Auth/Internal.hs index 488b45a..312ff0e 100644 --- a/src/Network/Wai/Auth/Internal.hs +++ b/src/Network/Wai/Auth/Internal.hs @@ -1,6 +1,5 @@ {-# OPTIONS_HADDOCK hide, not-home #-} {-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TupleSections #-} module Network.Wai.Auth.Internal @@ -12,6 +11,8 @@ module Network.Wai.Auth.Internal , refreshTokens ) where + +import Control.Monad.Except (runExceptT) import qualified Data.Aeson as Aeson import Data.Binary (Binary(get, put), encode, decodeOrFail) @@ -92,7 +93,7 @@ oauth2Login oauth2 man oa2Scope providerName req suffix onSuccess onFailure = let params = queryString req in case lookup "code" params of Just (Just code) -> do - eRes <- OA2.fetchAccessToken man oauth2 $ getExchangeToken code + eRes <- runExceptT $ OA2.fetchAccessToken man oauth2 $ getExchangeToken code case eRes of Left err -> onFailure status501 $ S8.pack $ show err Right token -> onSuccess $ encodeToken token @@ -119,7 +120,7 @@ refreshTokens tokens manager oauth2 = case OA2.refreshToken tokens of Nothing -> pure Nothing Just refreshToken -> do - res <- OA2.refreshAccessToken manager oauth2 refreshToken + res <- runExceptT $ OA2.refreshAccessToken manager oauth2 refreshToken case res of Left _ -> pure Nothing Right newTokens -> pure (Just newTokens) diff --git a/src/Network/Wai/Middleware/Auth/OAuth2.hs b/src/Network/Wai/Middleware/Auth/OAuth2.hs index 704b634..d9e914e 100644 --- a/src/Network/Wai/Middleware/Auth/OAuth2.hs +++ b/src/Network/Wai/Middleware/Auth/OAuth2.hs @@ -14,6 +14,7 @@ import Control.Monad.Catch import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) +import Data.Default (Default(..)) import Data.Functor ((<&>)) import Data.Int (Int64) import Data.Proxy (Proxy (..)) @@ -83,12 +84,12 @@ instance AuthProvider OAuth2 where accessTokenEndpointURI <- parseAbsoluteURI oa2AccessTokenEndpoint callbackURI <- parseAbsoluteURI $ renderUrl (ProviderUrl ["complete"]) [] let oauth2 = - OA2.OAuth2 - { oauthClientId = getClientId oa2ClientId - , oauthClientSecret = Just $ getClientSecret oa2ClientSecret - , oauthOAuthorizeEndpoint = authEndpointURI - , oauthAccessTokenEndpoint = accessTokenEndpointURI - , oauthCallback = Just callbackURI + def + { OA2.oauth2ClientId = getClientId oa2ClientId + , OA2.oauth2ClientSecret = getClientSecret oa2ClientSecret + , OA2.oauth2AuthorizeEndpoint = authEndpointURI + , OA2.oauth2TokenEndpoint = accessTokenEndpointURI + , OA2.oauth2RedirectUri = callbackURI } man <- getGlobalManager oauth2Login @@ -110,18 +111,11 @@ instance AuthProvider OAuth2 where CTime now <- epochTime if tokenExpired user now tokens then do let oauth2 = - OA2.OAuth2 - { oauthClientId = getClientId oa2ClientId - , oauthClientSecret = Just (getClientSecret oa2ClientSecret) - , oauthOAuthorizeEndpoint = authEndpointURI - , oauthAccessTokenEndpoint = accessTokenEndpointURI - -- Setting callback endpoint to `Nothing` below is a lie. - -- We do have a callback endpoint but in this context - -- don't have access to the function that can render it. - -- We get away with this because the callback endpoint is - -- not needed for obtaining a refresh token, the only - -- way we use the config here constructed. - , oauthCallback = Nothing + def + { OA2.oauth2ClientId = getClientId oa2ClientId + , OA2.oauth2ClientSecret = getClientSecret oa2ClientSecret + , OA2.oauth2AuthorizeEndpoint = authEndpointURI + , OA2.oauth2TokenEndpoint = accessTokenEndpointURI } man <- getGlobalManager rRes <- refreshTokens tokens man oauth2 diff --git a/src/Network/Wai/Middleware/Auth/OIDC.hs b/src/Network/Wai/Middleware/Auth/OIDC.hs index 7bf6a2f..b045820 100644 --- a/src/Network/Wai/Middleware/Auth/OIDC.hs +++ b/src/Network/Wai/Middleware/Auth/OIDC.hs @@ -31,7 +31,9 @@ import Control.Monad.Except (runExceptT) import Data.Aeson (FromJSON(parseJSON), withObject, (.:), (.!=)) import qualified Data.ByteString.Char8 as S8 +import Data.Default (Default(..)) import Data.Function ((&)) +import Data.Maybe (fromMaybe) import qualified Data.Time.Clock as Clock import Data.Traversable (for) import qualified Data.Text as T @@ -210,12 +212,12 @@ fetchJWKSet jwkSetEndpoint = do mkOauth2 :: OpenIDConnect -> Maybe (Text.Hamlet.Render ProviderUrl) -> IO OA2.OAuth2 mkOauth2 OpenIDConnect {..} renderUrl = do callbackURI <- for renderUrl $ \render -> parseAbsoluteURI $ render (ProviderUrl ["complete"]) [] - pure OA2.OAuth2 - { oauthClientId = oidcClientId - , oauthClientSecret = Just oidcClientSecret - , oauthOAuthorizeEndpoint = authorizationEndpoint oidcMetadata - , oauthAccessTokenEndpoint = tokenEndpoint oidcMetadata - , oauthCallback = callbackURI + pure def + { OA2.oauth2ClientId = oidcClientId + , OA2.oauth2ClientSecret = oidcClientSecret + , OA2.oauth2AuthorizeEndpoint = authorizationEndpoint oidcMetadata + , OA2.oauth2TokenEndpoint = tokenEndpoint oidcMetadata + , OA2.oauth2RedirectUri = fromMaybe (OA2.oauth2RedirectUri def) callbackURI } validateIdToken :: OpenIDConnect -> OA2.IdToken -> IO (Either JWT.JWTError JWT.ClaimsSet) diff --git a/src/Network/Wai/Middleware/Auth/Provider.hs b/src/Network/Wai/Middleware/Auth/Provider.hs index c4bbb70..caba265 100644 --- a/src/Network/Wai/Middleware/Auth/Provider.hs +++ b/src/Network/Wai/Middleware/Auth/Provider.hs @@ -29,11 +29,10 @@ import Blaze.ByteString.Builder (toByteString) import Control.Arrow (second) import Data.Aeson (FromJSON (..), Object, Result (..), Value) -import Data.Aeson.Types (parseEither) - import Data.Aeson.TH (defaultOptions, deriveJSON, fieldLabelModifier) -import Data.Aeson.Types (Parser) +import Data.Aeson.Types (Parser, parseEither) +import Data.Aeson.KeyMap (toHashMapText) import Data.Binary (Binary) import qualified Data.ByteString as S import qualified Data.ByteString.Builder as B @@ -184,13 +183,14 @@ mkProviderParser _ = -- | Parse configuration for providers from an `Object`. parseProviders :: Object -> [ProviderParser] -> Result Providers -parseProviders unparsedProvidersHM providerParsers = +parseProviders unparsedProvidersO providerParsers = if HM.null unrecognized then sequence $ HM.intersectionWith parseProvider unparsedProvidersHM parsersHM else Error $ "Provider name(s) are not recognized: " ++ T.unpack (T.intercalate ", " $ HM.keys unrecognized) where + unparsedProvidersHM = toHashMapText unparsedProvidersO parsersHM = HM.fromList providerParsers unrecognized = HM.difference unparsedProvidersHM parsersHM parseProvider v p = either Error Success $ parseEither p v diff --git a/stack.yaml b/stack.yaml index b17acde..2d0dc93 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1 +1 @@ -resolver: lts-17.12 +resolver: lts-22.4 diff --git a/stack.yaml.lock b/stack.yaml.lock index a0e9c1f..f71a69f 100644 --- a/stack.yaml.lock +++ b/stack.yaml.lock @@ -6,7 +6,7 @@ packages: [] snapshots: - completed: - size: 565712 - url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/17/6.yaml - sha256: 4e5e581a709c88e3fe26a9ce8bf331435729bead762fb5c190064c6c5bb1b835 - original: lts-17.6 + sha256: 8b211c5a6aad3787e023dfddaf7de7868968e4f240ecedf14ad1c5b2199046ca + size: 714097 + url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/22/4.yaml + original: lts-22.4 diff --git a/test/Network/Wai/Auth/Test.hs b/test/Network/Wai/Auth/Test.hs index f72b741..bf645f6 100644 --- a/test/Network/Wai/Auth/Test.hs +++ b/test/Network/Wai/Auth/Test.hs @@ -14,7 +14,6 @@ import Data.ByteString (ByteString) import qualified Data.IORef as IORef import qualified Crypto.JOSE as JOSE import qualified Crypto.JWT as JWT -import qualified Control.Monad.Except import qualified Data.Aeson as Aeson import Data.Function ((&)) import qualified Data.Text as T @@ -139,7 +138,7 @@ fakeProvider' configRef req respond = do doJwtSign :: JOSE.JWK -> JWT.ClaimsSet -> IO T.Text doJwtSign jwk claims = do - result <- Control.Monad.Except.runExceptT $ do + result <- JOSE.runJOSE $ do alg <- JOSE.bestJWSAlg jwk JWT.signClaims jwk (JOSE.newJWSHeader ((), alg)) claims case result of diff --git a/test/Spec/Network/Wai/Auth/Internal.hs b/test/Spec/Network/Wai/Auth/Internal.hs index afa0e93..2af8882 100644 --- a/test/Spec/Network/Wai/Auth/Internal.hs +++ b/test/Spec/Network/Wai/Auth/Internal.hs @@ -11,7 +11,7 @@ import Hedgehog import Hedgehog.Gen as Gen import Hedgehog.Range as Range import Network.Wai.Auth.Internal -import qualified Network.OAuth.OAuth2.Internal as OA2 +import qualified Network.OAuth.OAuth2 as OA2 tests :: TestTree tests = testGroup "Network.Wai.Auth.Internal" diff --git a/wai-middleware-auth.cabal b/wai-middleware-auth.cabal index b2d96f8..50dc1e2 100644 --- a/wai-middleware-auth.cabal +++ b/wai-middleware-auth.cabal @@ -37,6 +37,7 @@ library , cereal , clientsession , cookie >= 0.4.2 + , data-default , exceptions , hoauth2 >= 1.11 , http-client