diff --git a/src/simple_openid_connect/flows/authorization_code_flow/__init__.py b/src/simple_openid_connect/flows/authorization_code_flow/__init__.py index e225738..c618484 100644 --- a/src/simple_openid_connect/flows/authorization_code_flow/__init__.py +++ b/src/simple_openid_connect/flows/authorization_code_flow/__init__.py @@ -21,7 +21,7 @@ TokenRequest, TokenSuccessResponse, ) -from simple_openid_connect.exceptions import AuthenticationFailedError +from simple_openid_connect.exceptions import AuthenticationFailedError, ValidationError logger = logging.getLogger(__name__) @@ -31,6 +31,9 @@ def start_authentication( scope: str, client_id: str, redirect_uri: str, + state: Optional[str] = None, + nonce: Optional[str] = None, + prompt: Optional[list[str]] = None, code_challenge: Optional[str] = None, code_challenge_method: Optional[str] = None, ) -> str: @@ -38,6 +41,11 @@ def start_authentication( Start the authentication process by constructing an appropriate :class:`AuthenticationRequest`, serializing it and returning a which the end user now needs to visit. + :param state: The state intended to prevent Cross-Site Request Forgery. + :param nonce: String value used to associate a Client session with an ID Token, and to mitigate replay attacks. + :param prompt: Specifies whether the Authorization Server prompts the End-User for reauthentication and consent. + The defined values are: "none", "login", "consent" and "select_account", multiple may be given as a list. + :returns: A URL to which the user agent should be redirected """ request = AuthenticationRequest( @@ -45,6 +53,9 @@ def start_authentication( client_id=client_id, redirect_uri=redirect_uri, response_type="code", + state=state, + nonce=nonce, + prompt=prompt, code_challenge=code_challenge, code_challenge_method=code_challenge_method, ) @@ -56,6 +67,7 @@ def handle_authentication_result( token_endpoint: str, client_authentication: ClientAuthenticationMethod, redirect_uri: Union[Literal["auto"], str] = "auto", + state: Optional[str] = None, code_verifier: Optional[str] = None, code_challenge: Optional[str] = None, code_challenge_method: Optional[str] = None, @@ -70,8 +82,10 @@ def handle_authentication_result( :param client_authentication: A way for the client to authenticate itself :param redirect_uri: The `redirect_uri` that was specified during the authentication initiation. If the special value `auto` is used, it is assumed that `current_url` is the that callback and it is stripped of query parameters and fragments to reproduce the originally supplied one. + :param state: The `state` that was specified during the authentication initiation. :raises AuthenticationFailedError: If the current url indicates an authentication failure that prevents an access token from being retrieved. + :raises ValidationError: If the returned state does not match the given state. :returns: The result of the token exchange """ @@ -92,6 +106,10 @@ def handle_authentication_result( ) auth_response_msg = AuthenticationSuccessResponse.parse_url(str(current_furl)) + + if state != auth_response_msg.state: + raise ValidationError("Returned state does not match given state.") + return exchange_code_for_tokens( token_endpoint=token_endpoint, authentication_response=auth_response_msg, diff --git a/src/simple_openid_connect/flows/authorization_code_flow/client.py b/src/simple_openid_connect/flows/authorization_code_flow/client.py index 8d197fb..167d265 100644 --- a/src/simple_openid_connect/flows/authorization_code_flow/client.py +++ b/src/simple_openid_connect/flows/authorization_code_flow/client.py @@ -29,6 +29,9 @@ def __init__(self, base_client: "OpenidClient"): def start_authentication( self, + state: Optional[str] = None, + nonce: Optional[str] = None, + prompt: Optional[list[str]] = None, code_challenge: Optional[str] = None, code_challenge_method: Optional[str] = None, ) -> str: @@ -36,6 +39,10 @@ def start_authentication( Start the authentication process by constructing an appropriate :class:`AuthenticationRequest`, serializing it and returning a which the end user now needs to visit. + :param state: The state intended to prevent Cross-Site Request Forgery. + :param nonce: String value used to associate a Client session with an ID Token, and to mitigate replay attacks. + :param prompt: Specifies whether the Authorization Server prompts the End-User for reauthentication and consent. + The defined values are: "none", "login", "consent" and "select_account", multiple may be given as a list. :param code_challenge: The code challenge intended for use with Proof Key for Code Exchange (PKCE) [RFC7636]. :param code_challenge_method: The code challenge method intended for use with Proof Key for Code Exchange (PKCE) [RFC7636], typically "S256" or "plain". @@ -55,6 +62,9 @@ def start_authentication( self._base_client.scope, self._base_client.client_auth.client_id, redirect_uri.tostr(), + state=state, + nonce=nonce, + prompt=prompt, code_challenge=code_challenge, code_challenge_method=code_challenge_method, ) @@ -63,6 +73,7 @@ def handle_authentication_result( self, current_url: str, additional_redirect_args: Optional[Mapping[str, str]] = None, + state: Optional[str] = None, code_verifier: Optional[str] = None, code_challenge: Optional[str] = None, code_challenge_method: Optional[str] = None, @@ -74,6 +85,7 @@ def handle_authentication_result( The authentication result should be encoded into this url by the authorization server. :param additional_redirect_args: Additional URL parameters that were added to the redirect uri. They are probably still present in `current_url` but since they could be of any shape, no attempt is made here to automatically reconstruct them. + :param state: The `state` that was specified during the authentication initiation. :param code_verifier: The code verifier intended for use with Proof Key for Code Exchange (PKCE) [RFC7636]. :param code_challenge: The code challenge intended for use with Proof Key for Code Exchange (PKCE) [RFC7636]. :param code_challenge_method: The code challenge method intended for use with Proof Key for Code Exchange (PKCE) [RFC7636], typically "S256" or "plain". @@ -103,6 +115,7 @@ def handle_authentication_result( token_endpoint=self._base_client.provider_config.token_endpoint, client_authentication=self._base_client.client_auth, redirect_uri=redirect_uri.tostr(), + state=state, code_verifier=code_verifier, code_challenge=code_challenge, code_challenge_method=code_challenge_method,