From d49cf26a040d85a9aba5c77045d31b8554f3469e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 11 Feb 2025 03:01:45 +0000 Subject: [PATCH 1/3] Bump ruby/setup-ruby from 1.215.0 to 1.218.0 (#8584) --- .github/workflows/tests_sdk_ruby.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests_sdk_ruby.yml b/.github/workflows/tests_sdk_ruby.yml index 17c303ace1bf..6795c7062af2 100644 --- a/.github/workflows/tests_sdk_ruby.yml +++ b/.github/workflows/tests_sdk_ruby.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Ruby ${{ matrix.ruby-version }} - uses: ruby/setup-ruby@2654679fe7f7c29875c669398a8ec0791b8a64a1 + uses: ruby/setup-ruby@d781c1b4ed31764801bfae177617bb0446f5ef8d with: ruby-version: ${{ matrix.ruby-version }} - name: Set up Python 3.8 From 7804781ac6743d5ed14fe4a892e0f25ebe1ff778 Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 11 Feb 2025 16:06:06 -0500 Subject: [PATCH 2/3] Security Hub - get_findings and batch_import_findings (#8518) --- IMPLEMENTATION_COVERAGE.md | 86 +++++++++++- docs/docs/services/securityhub.rst | 112 +++++++++++++++ moto/backend_index.py | 1 + moto/securityhub/__init__.py | 1 + moto/securityhub/exceptions.py | 21 +++ moto/securityhub/models.py | 122 ++++++++++++++++ moto/securityhub/responses.py | 61 ++++++++ moto/securityhub/urls.py | 12 ++ tests/test_securityhub/__init__.py | 0 tests/test_securityhub/test_securityhub.py | 156 +++++++++++++++++++++ 10 files changed, 571 insertions(+), 1 deletion(-) create mode 100644 docs/docs/services/securityhub.rst create mode 100644 moto/securityhub/__init__.py create mode 100644 moto/securityhub/exceptions.py create mode 100644 moto/securityhub/models.py create mode 100644 moto/securityhub/responses.py create mode 100644 moto/securityhub/urls.py create mode 100644 tests/test_securityhub/__init__.py create mode 100644 tests/test_securityhub/test_securityhub.py diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index be482c7dcf99..68d5289fc598 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -8232,6 +8232,91 @@ - [ ] validate_resource_policy +## securityhub +
+2% implemented + +- [ ] accept_administrator_invitation +- [ ] accept_invitation +- [ ] batch_delete_automation_rules +- [ ] batch_disable_standards +- [ ] batch_enable_standards +- [ ] batch_get_automation_rules +- [ ] batch_get_configuration_policy_associations +- [ ] batch_get_security_controls +- [ ] batch_get_standards_control_associations +- [X] batch_import_findings +- [ ] batch_update_automation_rules +- [ ] batch_update_findings +- [ ] batch_update_standards_control_associations +- [ ] create_action_target +- [ ] create_automation_rule +- [ ] create_configuration_policy +- [ ] create_finding_aggregator +- [ ] create_insight +- [ ] create_members +- [ ] decline_invitations +- [ ] delete_action_target +- [ ] delete_configuration_policy +- [ ] delete_finding_aggregator +- [ ] delete_insight +- [ ] delete_invitations +- [ ] delete_members +- [ ] describe_action_targets +- [ ] describe_hub +- [ ] describe_organization_configuration +- [ ] describe_products +- [ ] describe_standards +- [ ] describe_standards_controls +- [ ] disable_import_findings_for_product +- [ ] disable_organization_admin_account +- [ ] disable_security_hub +- [ ] disassociate_from_administrator_account +- [ ] disassociate_from_master_account +- [ ] disassociate_members +- [ ] enable_import_findings_for_product +- [ ] enable_organization_admin_account +- [ ] enable_security_hub +- [ ] get_administrator_account +- [ ] get_configuration_policy +- [ ] get_configuration_policy_association +- [ ] get_enabled_standards +- [ ] get_finding_aggregator +- [ ] get_finding_history +- [X] get_findings +- [ ] get_insight_results +- [ ] get_insights +- [ ] get_invitations_count +- [ ] get_master_account +- [ ] get_members +- [ ] get_security_control_definition +- [ ] invite_members +- [ ] list_automation_rules +- [ ] list_configuration_policies +- [ ] list_configuration_policy_associations +- [ ] list_enabled_products_for_import +- [ ] list_finding_aggregators +- [ ] list_invitations +- [ ] list_members +- [ ] list_organization_admin_accounts +- [ ] list_security_control_definitions +- [ ] list_standards_control_associations +- [ ] list_tags_for_resource +- [ ] start_configuration_policy_association +- [ ] start_configuration_policy_disassociation +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_action_target +- [ ] update_configuration_policy +- [ ] update_finding_aggregator +- [ ] update_findings +- [ ] update_insight +- [ ] update_organization_configuration +- [ ] update_security_control +- [ ] update_security_hub_configuration +- [ ] update_standards_control +
+ ## service-quotas
10% implemented @@ -9624,7 +9709,6 @@ - savingsplans - schemas - security-ir -- securityhub - securitylake - serverlessrepo - servicecatalog diff --git a/docs/docs/services/securityhub.rst b/docs/docs/services/securityhub.rst new file mode 100644 index 000000000000..a395350a148b --- /dev/null +++ b/docs/docs/services/securityhub.rst @@ -0,0 +1,112 @@ +.. _implementedservice_securityhub: + +.. |start-h3| raw:: html + +

+ +.. |end-h3| raw:: html + +

+ +=========== +securityhub +=========== + +.. autoclass:: moto.securityhub.models.SecurityHubBackend + +|start-h3| Implemented features for this service |end-h3| + +- [ ] accept_administrator_invitation +- [ ] accept_invitation +- [ ] batch_delete_automation_rules +- [ ] batch_disable_standards +- [ ] batch_enable_standards +- [ ] batch_get_automation_rules +- [ ] batch_get_configuration_policy_associations +- [ ] batch_get_security_controls +- [ ] batch_get_standards_control_associations +- [X] batch_import_findings + + Import findings in batch to SecurityHub. + + Args: + findings: List of finding dictionaries to import + + Returns: + Tuple of (failed_count, success_count, failed_findings) + + +- [ ] batch_update_automation_rules +- [ ] batch_update_findings +- [ ] batch_update_standards_control_associations +- [ ] create_action_target +- [ ] create_automation_rule +- [ ] create_configuration_policy +- [ ] create_finding_aggregator +- [ ] create_insight +- [ ] create_members +- [ ] decline_invitations +- [ ] delete_action_target +- [ ] delete_configuration_policy +- [ ] delete_finding_aggregator +- [ ] delete_insight +- [ ] delete_invitations +- [ ] delete_members +- [ ] describe_action_targets +- [ ] describe_hub +- [ ] describe_organization_configuration +- [ ] describe_products +- [ ] describe_standards +- [ ] describe_standards_controls +- [ ] disable_import_findings_for_product +- [ ] disable_organization_admin_account +- [ ] disable_security_hub +- [ ] disassociate_from_administrator_account +- [ ] disassociate_from_master_account +- [ ] disassociate_members +- [ ] enable_import_findings_for_product +- [ ] enable_organization_admin_account +- [ ] enable_security_hub +- [ ] get_administrator_account +- [ ] get_configuration_policy +- [ ] get_configuration_policy_association +- [ ] get_enabled_standards +- [ ] get_finding_aggregator +- [ ] get_finding_history +- [X] get_findings + + Returns findings based on optional filters and sort criteria. + + +- [ ] get_insight_results +- [ ] get_insights +- [ ] get_invitations_count +- [ ] get_master_account +- [ ] get_members +- [ ] get_security_control_definition +- [ ] invite_members +- [ ] list_automation_rules +- [ ] list_configuration_policies +- [ ] list_configuration_policy_associations +- [ ] list_enabled_products_for_import +- [ ] list_finding_aggregators +- [ ] list_invitations +- [ ] list_members +- [ ] list_organization_admin_accounts +- [ ] list_security_control_definitions +- [ ] list_standards_control_associations +- [ ] list_tags_for_resource +- [ ] start_configuration_policy_association +- [ ] start_configuration_policy_disassociation +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_action_target +- [ ] update_configuration_policy +- [ ] update_finding_aggregator +- [ ] update_findings +- [ ] update_insight +- [ ] update_organization_configuration +- [ ] update_security_control +- [ ] update_security_hub_configuration +- [ ] update_standards_control + diff --git a/moto/backend_index.py b/moto/backend_index.py index 9d8abc8295d6..2135505927e2 100644 --- a/moto/backend_index.py +++ b/moto/backend_index.py @@ -184,6 +184,7 @@ ("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")), ("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")), ("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")), + ("securityhub", re.compile("https?://securityhub\\.(.+)\\.amazonaws\\.com")), ( "servicediscovery", re.compile("https?://(data-)?servicediscovery\\.(.+)\\.amazonaws\\.com"), diff --git a/moto/securityhub/__init__.py b/moto/securityhub/__init__.py new file mode 100644 index 000000000000..68a7cc517c9e --- /dev/null +++ b/moto/securityhub/__init__.py @@ -0,0 +1 @@ +from .models import securityhub_backends # noqa: F401 diff --git a/moto/securityhub/exceptions.py b/moto/securityhub/exceptions.py new file mode 100644 index 000000000000..1ff50c3d8573 --- /dev/null +++ b/moto/securityhub/exceptions.py @@ -0,0 +1,21 @@ +"""Exceptions raised by the securityhub service.""" + +from moto.core.exceptions import JsonRESTError + + +class SecurityHubClientError(JsonRESTError): + code = 400 + + +class _InvalidOperationException(SecurityHubClientError): + def __init__(self, error_type: str, op: str, msg: str): + super().__init__( + error_type, + "An error occurred (%s) when calling the %s operation: %s" + % (error_type, op, msg), + ) + + +class InvalidInputException(_InvalidOperationException): + def __init__(self, op: str, msg: str): + super().__init__("InvalidInputException", op, msg) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py new file mode 100644 index 000000000000..afcd6fcaed4d --- /dev/null +++ b/moto/securityhub/models.py @@ -0,0 +1,122 @@ +"""SecurityHubBackend class with methods for supported APIs.""" + +from typing import Any, Dict, List, Optional, Tuple + +from moto.core.base_backend import BackendDict, BaseBackend +from moto.core.common_models import BaseModel +from moto.securityhub.exceptions import InvalidInputException +from moto.utilities.paginator import paginate + + +class Finding(BaseModel): + def __init__(self, finding_id: str, finding_data: Dict[str, Any]): + self.id = finding_id + self.data = finding_data + + def as_dict(self) -> Dict[str, Any]: + return self.data + + +class SecurityHubBackend(BaseBackend): + """Implementation of SecurityHub APIs.""" + + PAGINATION_MODEL = { + "get_findings": { + "input_token": "next_token", + "limit_key": "max_results", + "limit_default": 100, + "unique_attribute": "Id", + "fail_on_invalid_token": True, + } + } + + def __init__(self, region_name: str, account_id: str): + super().__init__(region_name, account_id) + self.findings: List[Finding] = [] + + @paginate(pagination_model=PAGINATION_MODEL) + def get_findings( + self, + filters: Optional[Dict[str, Any]] = None, + sort_criteria: Optional[List[Dict[str, str]]] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + ) -> List[Dict[str, str]]: + """ + Returns findings based on optional filters and sort criteria. + """ + if max_results is not None: + try: + max_results = int(max_results) + if max_results < 1 or max_results > 100: + raise InvalidInputException( + op="GetFindings", + msg="MaxResults must be a number between 1 and 100", + ) + except ValueError: + raise InvalidInputException( + op="GetFindings", msg="MaxResults must be a number greater than 0" + ) + + findings = self.findings + + # TODO: Apply filters if provided + # TODO: Apply sort criteria if provided + + return [f.as_dict() for f in findings] + + def batch_import_findings( + self, findings: List[Dict[str, Any]] + ) -> Tuple[int, int, List[Dict[str, Any]]]: + """ + Import findings in batch to SecurityHub. + + Args: + findings: List of finding dictionaries to import + + Returns: + Tuple of (failed_count, success_count, failed_findings) + """ + failed_count = 0 + success_count = 0 + failed_findings = [] + + for finding_data in findings: + try: + if ( + not isinstance(finding_data["Resources"], list) + or len(finding_data["Resources"]) == 0 + ): + raise InvalidInputException( + op="BatchImportFindings", + msg="Finding must contain at least one resource in the Resources array", + ) + + finding_id = finding_data["Id"] + + existing_finding = next( + (f for f in self.findings if f.id == finding_id), None + ) + + if existing_finding: + existing_finding.data.update(finding_data) + else: + new_finding = Finding(finding_id, finding_data) + self.findings.append(new_finding) + + success_count += 1 + + except Exception as e: + failed_count += 1 + failed_findings.append( + { + "Id": finding_data.get("Id", ""), + "ErrorCode": "InvalidInput", + "ErrorMessage": str(e), + } + ) + + return failed_count, success_count, failed_findings + + +securityhub_backends = BackendDict(SecurityHubBackend, "securityhub") diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py new file mode 100644 index 000000000000..8ce0ce1e2855 --- /dev/null +++ b/moto/securityhub/responses.py @@ -0,0 +1,61 @@ +"""Handles incoming securityhub requests, invokes methods, returns responses.""" + +import json + +from moto.core.responses import BaseResponse + +from .models import SecurityHubBackend, securityhub_backends + + +class SecurityHubResponse(BaseResponse): + def __init__(self) -> None: + super().__init__(service_name="securityhub") + + @property + def securityhub_backend(self) -> SecurityHubBackend: + return securityhub_backends[self.current_account][self.region] + + def get_findings(self) -> str: + filters = self._get_param("Filters") + sort_criteria = self._get_param("SortCriteria") + max_results = self._get_param("MaxResults") + next_token = self._get_param("NextToken") + + findings, next_token = self.securityhub_backend.get_findings( + filters=filters, + sort_criteria=sort_criteria, + max_results=max_results, + next_token=next_token, + ) + + response = {"Findings": findings, "NextToken": next_token} + return json.dumps(response) + + def batch_import_findings(self) -> str: + raw_body = self.body + if isinstance(raw_body, bytes): + raw_body = raw_body.decode("utf-8") + body = json.loads(raw_body) + + findings = body.get("Findings", []) + + failed_count, success_count, failed_findings = ( + self.securityhub_backend.batch_import_findings( + findings=findings, + ) + ) + + return json.dumps( + { + "FailedCount": failed_count, + "FailedFindings": [ + { + "ErrorCode": finding.get("ErrorCode"), + "ErrorMessage": finding.get("ErrorMessage"), + "Id": finding.get("Id"), + } + for finding in failed_findings + ], + "SuccessCount": success_count, + } + ) diff --git a/moto/securityhub/urls.py b/moto/securityhub/urls.py new file mode 100644 index 000000000000..162e66d8ad4e --- /dev/null +++ b/moto/securityhub/urls.py @@ -0,0 +1,12 @@ +"""securityhub base URL and path.""" + +from .responses import SecurityHubResponse + +url_bases = [ + r"https?://securityhub\.(.+)\.amazonaws\.com", +] + +url_paths = { + "{0}/findings$": SecurityHubResponse.dispatch, + "{0}/findings/import$": SecurityHubResponse.dispatch, +} diff --git a/tests/test_securityhub/__init__.py b/tests/test_securityhub/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py new file mode 100644 index 000000000000..3a95c07a4923 --- /dev/null +++ b/tests/test_securityhub/test_securityhub.py @@ -0,0 +1,156 @@ +"""Unit tests for securityhub-supported APIs.""" + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from moto import mock_aws +from moto.core import DEFAULT_ACCOUNT_ID + + +@mock_aws +def test_get_findings(): + client = boto3.client("securityhub", region_name="us-east-1") + + test_finding = { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.001Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": "Test finding description", + "GeneratorId": "test-generator", + "Id": "test-finding-001", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": "test-resource", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": "Test Finding", + "Types": ["Software and Configuration Checks"], + } + + import_response = client.batch_import_findings(Findings=[test_finding]) + assert import_response["SuccessCount"] == 1 + + response = client.get_findings() + + assert "Findings" in response + assert isinstance(response["Findings"], list) + assert len(response["Findings"]) == 1 + finding = response["Findings"][0] + assert finding["Id"] == "test-finding-001" + assert finding["SchemaVersion"] == "2018-10-08" + + +@mock_aws +def test_batch_import_findings(): + client = boto3.client("securityhub", region_name="us-east-2") + + valid_finding = { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": "Test finding description", + "GeneratorId": "test-generator", + "Id": "test-finding-001", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": "test-resource", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": "Test Finding", + "Types": ["Software and Configuration Checks"], + } + + response = client.batch_import_findings(Findings=[valid_finding]) + assert response["SuccessCount"] == 1 + assert response["FailedCount"] == 0 + assert response["FailedFindings"] == [] + + invalid_finding = valid_finding.copy() + invalid_finding["Id"] = "test-finding-002" + invalid_finding["Severity"]["Label"] = "INVALID_LABEL" + + response = client.batch_import_findings(Findings=[invalid_finding]) + + assert response["SuccessCount"] == 1 + assert response["FailedCount"] == 0 + assert len(response["FailedFindings"]) == 0 + + +@mock_aws +def test_get_findings_invalid_parameters(): + client = boto3.client("securityhub", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + client.get_findings(MaxResults=101) + + err = exc.value.response["Error"] + assert err["Code"] == "InvalidInputException" + assert "MaxResults must be a number between 1 and 100" in err["Message"] + + +@mock_aws +def test_batch_import_multiple_findings(): + client = boto3.client("securityhub", region_name="us-east-1") + + findings = [ + { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": f"Test finding description {i}", + "GeneratorId": "test-generator", + "Id": f"test-finding-{i:03d}", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": f"test-resource-{i}", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": f"Test Finding {i}", + "Types": ["Software and Configuration Checks"], + } + for i in range(1, 4) + ] + + import_response = client.batch_import_findings(Findings=findings) + assert import_response["SuccessCount"] == 3 + assert import_response["FailedCount"] == 0 + assert import_response["FailedFindings"] == [] + + get_response = client.get_findings() + assert "Findings" in get_response + assert isinstance(get_response["Findings"], list) + assert len(get_response["Findings"]) == 3 + + imported_ids = {finding["Id"] for finding in get_response["Findings"]} + expected_ids = {f"test-finding-{i:03d}" for i in range(1, 4)} + assert imported_ids == expected_ids + + +@mock_aws +def test_get_findings_max_results(): + client = boto3.client("securityhub", region_name="us-east-1") + + findings = [ + { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": f"Test finding description {i}", + "GeneratorId": "test-generator", + "Id": f"test-finding-{i:03d}", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": f"test-resource-{i}", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": f"Test Finding {i}", + "Types": ["Software and Configuration Checks"], + } + for i in range(1, 4) + ] + + import_response = client.batch_import_findings(Findings=findings) + assert import_response["SuccessCount"] == 3 + + get_response = client.get_findings(MaxResults=1) + assert "Findings" in get_response + assert isinstance(get_response["Findings"], list) + assert len(get_response["Findings"]) == 1 + assert "NextToken" in get_response From 777ca53026b94c7dc9237de34c84ee41e6e1f2f6 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Tue, 11 Feb 2025 20:38:10 -0100 Subject: [PATCH 3/3] CognitoIDP: AccessTokens and IDTokens now contain the jti and origin_jti values (#8581) --- moto/cognitoidp/models.py | 59 +++++++---- tests/test_cognitoidp/test_cognitoidp.py | 121 +++++++++++++++++++++++ 2 files changed, 162 insertions(+), 18 deletions(-) diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index 4439041a7503..0783bee9cd22 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -433,7 +433,7 @@ def __init__( self.groups: Dict[str, CognitoIdpGroup] = OrderedDict() self.users: Dict[str, CognitoIdpUser] = OrderedDict() self.resource_servers: Dict[str, CognitoResourceServer] = OrderedDict() - self.refresh_tokens: Dict[str, Optional[Tuple[str, str]]] = {} + self.refresh_tokens: Dict[str, Optional[Tuple[str, str, str]]] = {} self.access_tokens: Dict[str, Tuple[str, str]] = {} self.id_tokens: Dict[str, Tuple[str, str]] = {} @@ -534,6 +534,7 @@ def create_jwt( "token_use": token_use, "auth_time": now, "exp": now + expires_in, + "jti": str(random.uuid4()), } username_is_email = "email" in self.extended_config.get( "UsernameAttributes", [] @@ -574,8 +575,14 @@ def add_custom_attributes(self, custom_attributes: List[Dict[str, str]]) -> None for attribute in attributes: self.schema_attributes[attribute.name] = attribute - def create_id_token(self, client_id: str, username: str) -> Tuple[str, int]: + def create_id_token( + self, client_id: str, username: str, origin_jti: str + ) -> Tuple[str, int]: + """ + :returns: (id_token, expires_in) + """ extra_data = self.get_user_extra_data_by_client_id(client_id, username) + extra_data["origin_jti"] = origin_jti user = self._get_user(username) for attr in user.attributes: if attr["Name"].startswith("custom:"): @@ -588,13 +595,24 @@ def create_id_token(self, client_id: str, username: str) -> Tuple[str, int]: self.id_tokens[id_token] = (client_id, username) return id_token, expires_in - def create_refresh_token(self, client_id: str, username: str) -> str: + def create_refresh_token(self, client_id: str, username: str) -> Tuple[str, str]: + """ + :returns: (refresh_token, origin_jti) + """ refresh_token = str(random.uuid4()) - self.refresh_tokens[refresh_token] = (client_id, username) - return refresh_token + origin_jti = str(random.uuid4()) + self.refresh_tokens[refresh_token] = (client_id, username, origin_jti) + return refresh_token, origin_jti - def create_access_token(self, client_id: str, username: str) -> Tuple[str, int]: - extra_data = {} + def create_access_token( + self, client_id: str, username: str, origin_jti: str + ) -> Tuple[str, int]: + """ + :returns: (access_token, expires_in) + """ + extra_data: Dict[str, Any] = { + "origin_jti": origin_jti, + } user = self._get_user(username) if len(user.groups) > 0: extra_data["cognito:groups"] = [group.group_name for group in user.groups] @@ -611,12 +629,14 @@ def create_tokens_from_refresh_token( res = self.refresh_tokens[refresh_token] if res is None: raise NotAuthorizedError(refresh_token) - client_id, username = res + client_id, username, origin_jti = res if not username: raise NotAuthorizedError(refresh_token) - access_token, expires_in = self.create_access_token(client_id, username) - id_token, _ = self.create_id_token(client_id, username) + access_token, expires_in = self.create_access_token( + client_id, username, origin_jti=origin_jti + ) + id_token, _ = self.create_id_token(client_id, username, origin_jti=origin_jti) return access_token, id_token, expires_in def get_user_extra_data_by_client_id( @@ -640,11 +660,10 @@ def sign_out(self, username: str) -> None: for token, token_tuple in list(self.refresh_tokens.items()): if token_tuple is None: continue - _, logged_in_user = token_tuple + _, logged_in_user, _ = token_tuple if username == logged_in_user: self.refresh_tokens[token] = None - for access_token, token_tuple in list(self.access_tokens.items()): - _, logged_in_user = token_tuple + for access_token, (_, logged_in_user) in list(self.access_tokens.items()): if username == logged_in_user: self.access_tokens.pop(access_token) @@ -1427,7 +1446,7 @@ def _log_user_in( client: CognitoIdpUserPoolClient, username: str, ) -> Dict[str, Dict[str, Any]]: - refresh_token = user_pool.create_refresh_token(client.id, username) + refresh_token, _ = user_pool.create_refresh_token(client.id, username) access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token( refresh_token ) @@ -2083,11 +2102,15 @@ def initiate_auth( "Session": session, } - access_token, expires_in = user_pool.create_access_token( + new_refresh_token, origin_jti = user_pool.create_refresh_token( client_id, username ) - id_token, _ = user_pool.create_id_token(client_id, username) - new_refresh_token = user_pool.create_refresh_token(client_id, username) + access_token, expires_in = user_pool.create_access_token( + client_id, username, origin_jti=origin_jti + ) + id_token, _ = user_pool.create_id_token( + client_id, username, origin_jti=origin_jti + ) return { "AuthenticationResult": { @@ -2107,7 +2130,7 @@ def initiate_auth( if res is None: raise NotAuthorizedError("Refresh Token has been revoked") - client_id, username = res + client_id, username, _ = res if not username: raise ResourceNotFoundError(username) diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index 06028c1412ed..1a6cfb57deb1 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -9,6 +9,7 @@ import time import uuid from unittest import SkipTest, mock +from uuid import UUID import boto3 import pycognito @@ -1547,6 +1548,126 @@ def test_group_in_access_token(user_pool=None, user_pool_client=None): assert payload.claims["cognito:groups"] == [group_name] +@pytest.mark.aws_verified +@cognitoidp_aws_verified( + generate_secret=True, explicit_auth_flows=["ADMIN_NO_SRP_AUTH"] +) +def test_jti_in_tokens(user_pool=None, user_pool_client=None): + conn = boto3.client("cognito-idp", "us-west-2") + + username = str(uuid.uuid4()) + temporary_password = "P2$Sword" + new_password = "P2$Sword" + user_pool_id = user_pool["UserPool"]["Id"] + user_attribute_value = str(uuid.uuid4()) + client_id = user_pool_client["UserPoolClient"]["ClientId"] + client_secret = user_pool_client["UserPoolClient"]["ClientSecret"] + secret_hash = pycognito.aws_srp.AWSSRP.get_secret_hash( + username=username, client_id=client_id, client_secret=client_secret + ) + + conn.admin_create_user( + UserPoolId=user_pool_id, + Username=username, + TemporaryPassword=temporary_password, + UserAttributes=[{"Name": "given_name", "Value": user_attribute_value}], + ) + + result = conn.admin_initiate_auth( + UserPoolId=user_pool_id, + ClientId=client_id, + AuthFlow="ADMIN_NO_SRP_AUTH", + AuthParameters={ + "USERNAME": username, + "PASSWORD": temporary_password, + "SECRET_HASH": secret_hash, + }, + ) + + # This sets a new password and logs the user in (creates tokens) + initial_login = conn.admin_respond_to_auth_challenge( + UserPoolId=user_pool_id, + Session=result["Session"], + ClientId=client_id, + ChallengeName="NEW_PASSWORD_REQUIRED", + ChallengeResponses={ + "USERNAME": username, + "NEW_PASSWORD": new_password, + "SECRET_HASH": secret_hash, + }, + )["AuthenticationResult"] + + initial_access_claims = get_jwt_payload(initial_login["AccessToken"]).claims + + initial_id_claims = get_jwt_payload(initial_login["IdToken"]).claims + + # origin_jti + # A token-revocation identifier associated with your user's refresh token. + # Should be the same for all tokens, for a single session + assert UUID(initial_access_claims["origin_jti"]) + assert initial_access_claims["origin_jti"] == initial_id_claims["origin_jti"] + + # jti + # The unique identifier of the JWT. + # Should be unique for every token + assert UUID(initial_access_claims["jti"]) + assert UUID(initial_id_claims["jti"]) + assert initial_access_claims["jti"] != initial_id_claims["jti"] + + # refresh current session + refreshed_tokens = conn.initiate_auth( + ClientId=client_id, + AuthFlow="REFRESH_TOKEN", + AuthParameters={ + "SECRET_HASH": secret_hash, + "REFRESH_TOKEN": initial_login["RefreshToken"], + }, + )["AuthenticationResult"] + refresh_access_claims = get_jwt_payload(refreshed_tokens["AccessToken"]).claims + + refresh_id_claims = get_jwt_payload(refreshed_tokens["IdToken"]).claims + + assert initial_access_claims["origin_jti"] == refresh_access_claims["origin_jti"] + assert refresh_access_claims["origin_jti"] == refresh_id_claims["origin_jti"] + + assert initial_access_claims["jti"] != refresh_access_claims["jti"] + assert refresh_access_claims["jti"] != refresh_id_claims["jti"] + + # new session + aws_srp = pycognito.aws_srp.AWSSRP( + username=username, + password=new_password, + pool_id=user_pool_id, + client_id=client_id, + client_secret=client_secret, + client=conn, + ) + auth_params = aws_srp.get_auth_params() + + result = conn.initiate_auth( + ClientId=client_id, + AuthFlow="USER_SRP_AUTH", + AuthParameters=auth_params, + ) + + challenge_response = aws_srp.process_challenge( + result["ChallengeParameters"], auth_params + ) + new_session = conn.respond_to_auth_challenge( + ClientId=client_id, + ChallengeName=result["ChallengeName"], + ChallengeResponses=challenge_response, + )["AuthenticationResult"] + new_access_claims = get_jwt_payload(new_session["AccessToken"]).claims + + new_id_claims = get_jwt_payload(new_session["IdToken"]).claims + + assert initial_access_claims["origin_jti"] != new_access_claims["origin_jti"] + assert new_access_claims["origin_jti"] == new_id_claims["origin_jti"] + + assert new_access_claims["jti"] != new_id_claims["jti"] + + @mock_aws def test_other_attributes_in_id_token(): conn = boto3.client("cognito-idp", "us-west-2")