diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 16ed432..3582acb 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -152,7 +152,7 @@ def _retrieve_login_form_from_endpoint(self, endpoint): login_form_html_node = self._parse_form_from_html(response.text) if login_form_html_node is None: raise SAMLError(self._ERROR_NO_FORM % endpoint) - form_action = urljoin(endpoint, + form_action = urljoin(response.url, login_form_html_node.attrib.get('action', '')) if not form_action.lower().startswith('https://'): raise SAMLError('Your SAML IdP must use HTTPS connection') diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index db2e218..92328c0 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -154,15 +154,34 @@ def test_non_https_url(self, generic_auth, mock_requests_session, # The error is raised after the call to get the form, but before the # call to submit it. mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='http://example.com' ) + with pytest.raises(SAMLError, match='HTTPS'): generic_auth.retrieve_saml_assertion(config) + def test_endpoint_form_redirect_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=login_form, + url='https://test.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '
' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://test.com/login" + def test_form_action_appended_to_url(self, generic_auth, generic_config, login_form, mock_requests_session): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -174,10 +193,37 @@ def test_form_action_appended_to_url(self, generic_auth, generic_config, url_used = mock_requests_session.post.call_args[0][0] assert url_used == "https://example.com/login" + def test_form_action_replaces_url(self, generic_auth, generic_config, + login_form, mock_requests_session): + saml_form = ( + '' + '' + '' + ) + + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=saml_form, + url='https://example.com' + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '' + ) + ) + generic_auth.retrieve_saml_assertion(generic_config) + url_used = mock_requests_session.post.call_args[0][0] + assert url_used == "https://www.test.com" + def test_extract_assertion(self, generic_auth, mock_requests_session, generic_config, login_form): mock_requests_session.get.return_value = mock.Mock( - spec=requests.Response, status_code=200, text=login_form + spec=requests.Response, status_code=200, text=login_form, + url='https://example.com' ) mock_requests_session.post.return_value = mock.Mock( spec=requests.Response, status_code=200, text=( @@ -206,7 +252,8 @@ def test_passes_in_other_form_fields(self, generic_auth, generic_config, '