From 384a5ac6a75a8daf4c6cd2f2fdbcfa9c5737dfb5 Mon Sep 17 00:00:00 2001 From: Eric Barry Date: Mon, 29 Jul 2019 11:55:04 -0400 Subject: [PATCH 1/2] #37 Fixing incorrect endpoint for posting the form data. --- awsprocesscreds/saml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 16ed432..3582acb 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -152,7 +152,7 @@ def _retrieve_login_form_from_endpoint(self, endpoint): login_form_html_node = self._parse_form_from_html(response.text) if login_form_html_node is None: raise SAMLError(self._ERROR_NO_FORM % endpoint) - form_action = urljoin(endpoint, + form_action = urljoin(response.url, login_form_html_node.attrib.get('action', '')) if not form_action.lower().startswith('https://'): raise SAMLError('Your SAML IdP must use HTTPS connection') From 46d6b166bf8ecdde85fa21c5685a7c985af895e2 Mon Sep 17 00:00:00 2001 From: Eric Barry Date: Thu, 1 Aug 2019 21:34:50 -0400 Subject: [PATCH 2/2] #37 - fixing tests and adding in two new ones to deal with action posting and landing page redirects. --- tests/unit/test_saml.py | 79 ++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index db2e218..92328c0 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -154,15 +154,34 @@ def test_non_https_url(self, generic_auth, mock_requests_session, # The error is raised after the call to get the form, but before the # call to submit it. mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='http://example.com' ) + with pytest.raises(SAMLError, match='HTTPS'): generic_auth.retrieve_saml_assertion(config) + def test_endpoint_form_redirect_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=login_form, + url='https://test.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://test.com/login" + def test_form_action_appended_to_url(self, generic_auth, generic_config, login_form, mock_requests_session): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -174,10 +193,37 @@ def test_form_action_appended_to_url(self, generic_auth, generic_config, url_used = mock_requests_session.post.call_args[0][0] assert url_used == "https://example.com/login" + def test_form_action_replaces_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + saml_form = ( + '' + '
' + '' + '' + '' + '
' + '' + ) + + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://www.test.com" + def test_extract_assertion(self, generic_auth, mock_requests_session, generic_config, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -206,7 +252,8 @@ def test_passes_in_other_form_fields(self, generic_auth, generic_config, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=saml_form + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -238,7 +285,8 @@ def tests_uses_default_form_values(self, generic_auth, generic_config, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=saml_form + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -275,7 +323,8 @@ def test_missing_form_username(self, generic_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=missing_form_fields + spec=requests.Response, status_code=200, text=missing_form_fields, + url='https://example.com' ) with pytest.raises(SAMLError, match='could not find'): generic_auth.retrieve_saml_assertion(generic_config) @@ -288,7 +337,8 @@ def test_missing_form_password(self, generic_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=missing_form_fields + spec=requests.Response, status_code=200, text=missing_form_fields, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -313,7 +363,8 @@ def test_missing_form_password(self, generic_auth, mock_requests_session, def test_empty_assertion(self, generic_auth, mock_requests_session, login_form, generic_config, assertion_response): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=assertion_response @@ -324,7 +375,8 @@ def test_empty_assertion(self, generic_auth, mock_requests_session, def test_non_200_authenticate_response(self, generic_auth, generic_config, mock_requests_session, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, text=login_form, status_code=200 + spec=requests.Response, text=login_form, status_code=200, + url='https://example.com' ) # This 401 response represents an auth failure, such as a bad password. @@ -340,7 +392,8 @@ def test_non_200_authenticate_response(self, generic_auth, generic_config, def test_no_saml_assertion_in_response(self, generic_auth, generic_config, mock_requests_session, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, text=login_form, status_code=200 + spec=requests.Response, text=login_form, status_code=200, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, text='login failed', @@ -380,7 +433,8 @@ def test_authn_requests_made(self, okta_auth, okta_config, mock_requests_session.get.return_value = mock.Mock( text=('
'), - status_code=200 + status_code=200, + url='https://example.com' ) saml_assertion = okta_auth.retrieve_saml_assertion(okta_config) assert saml_assertion == 'fakeassertion' @@ -434,7 +488,8 @@ def test_uses_adfs_fields(self, adfs_auth, mock_requests_session, '' ) mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=adfs_login_form + spec=requests.Response, status_code=200, text=adfs_login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=(