diff --git a/CHANGELOG.md b/CHANGELOG.md index c938960ac..cb923591e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 2024-05-15 +- [PI-336] Changelog deletes +- Dependabot (pydantic) + ## 2024-05-02 - [PI-341] Prod permissions - [PI-268] Search for a device diff --git a/VERSION b/VERSION index 97f7bd042..ed8ce936b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2024.05.02 +2024.05.15 diff --git a/changelog/2024-05-15.md b/changelog/2024-05-15.md new file mode 100644 index 000000000..4bb53dc47 --- /dev/null +++ b/changelog/2024-05-15.md @@ -0,0 +1,2 @@ +- [PI-336] Changelog deletes +- Dependabot (pydantic) diff --git a/infrastructure/terraform/per_workspace/modules/etl/sds/main.tf b/infrastructure/terraform/per_workspace/modules/etl/sds/main.tf index b6c5ae656..29473e73e 100644 --- a/infrastructure/terraform/per_workspace/modules/etl/sds/main.tf +++ b/infrastructure/terraform/per_workspace/modules/etl/sds/main.tf @@ -116,7 +116,7 @@ module "worker_transform" { "dynamodb:Query" ], "Effect": "Allow", - "Resource": ["${var.table_arn}"] + "Resource": ["${var.table_arn}", "${var.table_arn}/*"] }, { "Action": [ diff --git a/pyproject.toml b/pyproject.toml index 9ba1bec19..497521541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "connecting-party-manager" -version = "2024.05.02" +version = "2024.05.15" description = "Repository for the Connecting Party Manager API and related services" authors = ["NHS England"] license = "LICENSE.md" diff --git a/src/etl/sds/tests/test_sds_etl_components.py b/src/etl/sds/tests/test_sds_etl_components.py index a7be4f9c3..761f962d1 100644 --- a/src/etl/sds/tests/test_sds_etl_components.py +++ b/src/etl/sds/tests/test_sds_etl_components.py @@ -4,7 +4,7 @@ import boto3 import pytest -from domain.core.device import DeviceType +from domain.core.device import DeviceStatus, DeviceType from domain.core.device_key import DeviceKeyType from etl.clear_state_inputs import EMPTY_JSON_DATA, EMPTY_LDIF_DATA from etl_utils.constants import CHANGELOG_NUMBER, WorkerKey @@ -24,6 +24,21 @@ from test_helpers.pytest_skips import long_running from test_helpers.terraform import read_terraform_output +# Note that unique identifier "000428682512" is the identifier of 'GOOD_SDS_RECORD' +DELETION_REQUEST_000428682512 = """ +dn: o=nhs,ou=Services,uniqueIdentifier=000428682512 +changetype: delete +objectclass: delete +uniqueidentifier: 000428682512 +""" + +DELETION_REQUEST_000842065542 = """ +dn: o=nhs,ou=Services,uniqueIdentifier=000842065542 +changetype: delete +objectclass: delete +uniqueidentifier: 000842065542 +""" + @pytest.fixture def state_machine_input(request: pytest.FixtureRequest): @@ -74,7 +89,7 @@ def execute_state_machine( error_message = cause["errorMessage"] stack_trace = cause["stackTrace"] except Exception: - error_message = response["cause"] + error_message = response.get("cause", "no error message") stack_trace = [] print( # noqa: T201 @@ -83,7 +98,7 @@ def execute_state_machine( "\n", *stack_trace, ) - raise RuntimeError(response["error"]) + raise RuntimeError(response.get("error", "no error message")) return response @@ -102,6 +117,12 @@ def get_object(key: WorkerKey) -> str: return response["Body"].read() +def put_object(key: WorkerKey, body: bytes) -> str: + client = boto3.client("s3") + etl_bucket = read_terraform_output("sds_etl.value.bucket") + return client.put_object(Bucket=etl_bucket, Key=key, Body=body) + + @pytest.mark.integration @pytest.mark.parametrize( "worker_data", @@ -206,3 +227,88 @@ def test_end_to_end_bulk_trigger(repository: MockDeviceRepository): assert product_count == accredited_system_count == 5670 assert endpoint_count == message_handling_system_count == 154506 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "worker_data", + [ + { + WorkerKey.EXTRACT: "\n".join([GOOD_SDS_RECORD, ANOTHER_GOOD_SDS_RECORD]), + WorkerKey.TRANSFORM: pkl_dumps_lz4(deque()), + WorkerKey.LOAD: pkl_dumps_lz4(deque()), + } + ], + indirect=True, +) +@pytest.mark.parametrize( + "state_machine_input", + [ + StateMachineInput.bulk(changelog_number=123), + ], + indirect=True, +) +def test_end_to_end_changelog_delete( + repository: MockDeviceRepository, worker_data, state_machine_input +): + """Note that the start of this test is the same as test_end_to_end, and then makes changes""" + extract_data = get_object(key=WorkerKey.EXTRACT) + transform_data = pkl_loads_lz4(get_object(key=WorkerKey.TRANSFORM)) + load_data = pkl_loads_lz4(get_object(key=WorkerKey.LOAD)) + + assert len(extract_data) == 0 + assert len(transform_data) == 0 + assert len(load_data) == 0 + assert len(list(repository.all_devices())) == len( + worker_data[WorkerKey.EXTRACT].split("\n\n") + ) + + # Now execute a changelog initial state in the ETL + put_object(key=WorkerKey.EXTRACT, body=DELETION_REQUEST_000428682512) + response = execute_state_machine( + state_machine_input=StateMachineInput.update( + changelog_number_start=124, changelog_number_end=125 + ) + ) + assert response["status"] == "SUCCEEDED" + + # Verify that the device with unique id 000428682512 is now "inactive" + (device,) = repository.read_by_index( + questionnaire_id="spine_device/1", + question_name="unique_identifier", + value="000428682512", + ) + assert device.status == DeviceStatus.INACTIVE + + # Verify that the other device is still "active" + (device,) = repository.read_by_index( + questionnaire_id="spine_device/1", + question_name="unique_identifier", + value="000842065542", + ) + assert device.status == DeviceStatus.ACTIVE + + # Execute another changelog initial state in the ETL + put_object(key=WorkerKey.EXTRACT, body=DELETION_REQUEST_000842065542) + response = execute_state_machine( + state_machine_input=StateMachineInput.update( + changelog_number_start=124, changelog_number_end=125 + ) + ) + assert response["status"] == "SUCCEEDED" + + # Verify that the device with unique id 000428682512 is still "inactive" + (device,) = repository.read_by_index( + questionnaire_id="spine_device/1", + question_name="unique_identifier", + value="000428682512", + ) + assert device.status == DeviceStatus.INACTIVE + + # Verify that the other device is now "inactive" + (device,) = repository.read_by_index( + questionnaire_id="spine_device/1", + question_name="unique_identifier", + value="000842065542", + ) + assert device.status == DeviceStatus.INACTIVE diff --git a/src/etl/sds/trigger/update/operations.py b/src/etl/sds/trigger/update/operations.py index c6d23db87..dff2af633 100644 --- a/src/etl/sds/trigger/update/operations.py +++ b/src/etl/sds/trigger/update/operations.py @@ -97,11 +97,9 @@ def get_latest_changelog_number_from_ldap( ], ) - _, (unpack_record) = record - - return int( - unpack_record[ChangelogAttributes.LAST_CHANGELOG_NUMBER][0].decode("utf-8") - ) + _, (_record) = record + (last_changelog_number_str,) = _record[ChangelogAttributes.LAST_CHANGELOG_NUMBER] + return int(last_changelog_number_str) def get_changelog_entries_from_ldap( diff --git a/src/etl/sds/trigger/update/tests/test_update_trigger.py b/src/etl/sds/trigger/update/tests/test_update_trigger.py index be8024a41..fc743ccb3 100644 --- a/src/etl/sds/trigger/update/tests/test_update_trigger.py +++ b/src/etl/sds/trigger/update/tests/test_update_trigger.py @@ -3,6 +3,7 @@ from unittest import mock import boto3 +import pytest from etl_utils.constants import CHANGELOG_NUMBER from moto import mock_aws @@ -16,39 +17,96 @@ "CPM_FQDN": "cpm-fqdn", "LDAP_HOST": "ldap-host", "ETL_BUCKET": "etl-bucket", - "ETL_EXTRACT_INPUT_KEY": "etl-input", "LDAP_CHANGELOG_USER": "user", "LDAP_CHANGELOG_PASSWORD": "eggs", # pragma: allowlist secret } ALLOWED_EXCEPTIONS = (JSONDecodeError,) -CHANGELOG_NUMBER_VALUE = "538684" +LATEST_CHANGELOG_NUMBER = b"540382" +CURRENT_CHANGELOG_NUMBER = str(int(LATEST_CHANGELOG_NUMBER) - 1).encode() +CHANGELOG_NUMBER_RESULT = [ + 101, + [ + [ + "cn=changelog,o=nhs", + { + "firstchangenumber": [b"46425"], + "lastchangenumber": [LATEST_CHANGELOG_NUMBER], + }, + ] + ], +] + +CHANGE_RESULT = ( + 101, + [ + [ + "changenumber=540246,cn=changelog,o=nhs", + { + "objectClass": [ + b"top", + b"changeLogEntry", + b"nhsExternalChangelogEntry", + ], + "changeNumber": [b"540246"], + "changes": [ + b"\\nobjectClass: nhsmhsservice\\nobjectClass: top\\nnhsIDCode: F2R5Q\\nnhsMHSPartyKey: F2R5Q-823886\\nnhsMHSServiceName: urn:nhs:names:services:pdsquery\\nuniqueIdentifier: 4d554a907e83a4067695" + ], + "changeTime": [b"20240502100040Z"], + "changeType": [b"add"], + "targetDN": [ + b"uniqueIdentifier=4d554a907e83a4067695,ou=Services,o=nhs" + ], + }, + ] + ], +) + +CHANGE_RESULT_WITHOUT_UNIQUE_IDENTIFIER = ( + 101, + [ + [ + "changenumber=540246,cn=changelog,o=nhs", + { + "objectClass": [ + b"top", + b"changeLogEntry", + b"nhsExternalChangelogEntry", + ], + "changeNumber": [b"540246"], + "changes": [ + b"\\nobjectClass: nhsmhsservice\\nobjectClass: top\\nnhsIDCode: F2R5Q\\nnhsMHSPartyKey: F2R5Q-823886\\nnhsMHSServiceName: urn:nhs:names:services:pdsquery\\n" + ], + "changeTime": [b"20240502100040Z"], + "changeType": [b"add"], + "targetDN": [ + b"uniqueIdentifier=4d554a907e83a4067695,ou=Services,o=nhs" + ], + }, + ] + ], +) -def test_update(): + +CHANGE_AS_LDIF = """dn: o=nhs,ou=services,uniqueidentifier=4d554a907e83a4067695 +changetype: add +objectClass: nhsmhsservice +objectClass: top +nhsIDCode: F2R5Q +nhsMHSPartyKey: F2R5Q-823886 +nhsMHSServiceName: urn:nhs:names:services:pdsquery +uniqueIdentifier: 4d554a907e83a4067695""".encode() + + +@pytest.mark.parametrize( + "change_result", [CHANGE_RESULT, CHANGE_RESULT_WITHOUT_UNIQUE_IDENTIFIER] +) +def test_update(change_result): mocked_ldap = mock.Mock() mocked_ldap_client = mock.Mock() mocked_ldap.initialize.return_value = mocked_ldap_client - mocked_ldap_client.result.return_value = ( - 101, - [ - ( - "changenumber=75852519,cn=changelog,o=nhs", - { - "objectclass": { - "top", - "changeLogEntry", - "nhsExternalChangelogEntry", - }, - "changenumber": "75852519", - "changes": "foo", - "changetime": "20240116173441Z", - "changetype": "add", - "targetdn": "uniqueIdentifier=200000042019,ou=Services,o=nhs", - }, - ), - ], - ) + mocked_ldap_client.result.side_effect = (CHANGELOG_NUMBER_RESULT, change_result) with mock_aws(), mock.patch.dict( os.environ, MOCKED_UPDATE_TRIGGER_ENVIRONMENT, clear=True @@ -75,7 +133,7 @@ def test_update(): s3_client.put_object( Bucket=MOCKED_UPDATE_TRIGGER_ENVIRONMENT["ETL_BUCKET"], Key=CHANGELOG_NUMBER, - Body="0", + Body=CURRENT_CHANGELOG_NUMBER, ) from etl.sds.trigger.update import update @@ -86,12 +144,36 @@ def test_update(): update.CACHE["ldap_client"] = mocked_ldap_client # Remove start execution, since it's meaningless - idx = update.steps.index(_start_execution) - update.steps.pop(idx) + if _start_execution in update.steps: + idx = update.steps.index(_start_execution) + update.steps.pop(idx) # Don't execute the notify lambda - update.notify = mock.Mock(return_value="abc") + update.notify = ( + lambda lambda_client, function_name, result, trigger_type: result + ) + # Execute the test response = update.handler() - assert response == "abc" + # Verify the changelog number is NOT updated (as it should be updated in the ETL, not the trigger) + changelog_number_response = s3_client.get_object( + Bucket=MOCKED_UPDATE_TRIGGER_ENVIRONMENT["ETL_BUCKET"], Key=CHANGELOG_NUMBER + ) + assert changelog_number_response["Body"].read() == CURRENT_CHANGELOG_NUMBER + + # Verify the history file was created + etl_history_response = s3_client.get_object( + Bucket=MOCKED_UPDATE_TRIGGER_ENVIRONMENT["ETL_BUCKET"], + Key=f"history/changelog/{int(LATEST_CHANGELOG_NUMBER)}/input--extract/unprocessed", + ) + assert etl_history_response["Body"].read().lower() == CHANGE_AS_LDIF.lower() + + # Verify the ETL input file was created + etl_input_response = s3_client.get_object( + Bucket=MOCKED_UPDATE_TRIGGER_ENVIRONMENT["ETL_BUCKET"], + Key="input--extract/unprocessed", + ) + assert etl_input_response["Body"].read().lower() == CHANGE_AS_LDIF.lower() + + assert not isinstance(response, Exception), response diff --git a/src/etl/sds/trigger/update/tests/test_update_trigger_operations.py b/src/etl/sds/trigger/update/tests/test_update_trigger_operations.py index afabb8115..92431aeba 100644 --- a/src/etl/sds/trigger/update/tests/test_update_trigger_operations.py +++ b/src/etl/sds/trigger/update/tests/test_update_trigger_operations.py @@ -211,7 +211,7 @@ def test_parse_changelog_changes_with_add(): assert ( result - == "dn: o=nhs,ou=Services,uniqueIdentifier=f1c55263f1ee924f460f\nchangetype: add\nobjectClass: nhsMhs\nobjectClass: top\nnhsApproverURP: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nnhsContractPropertyTemplateKey: 14\nnhsDateApproved: 20240417082830\nnhsDateDNSApproved: 20240417082830\nnhsDateRequested: 20240417082818\nnhsDNSApprover: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nnhsEPInteractionType: ebXML\nnhsIDCode: X26\nnhsMHSAckRequested: never\nnhsMhsCPAId: f1c55263f1ee924f460f\nnhsMHSDuplicateElimination: never\nnhsMHSEndPoint: https://simple-sync.intspineservices.nhs.uk/\nnhsMhsFQDN: simple-sync.intspineservices.nhs.uk\nnhsMHsIN: QUPA_IN050000UK32\nnhsMhsIPAddress: 0.0.0.0\nnhsMHSIsAuthenticated: none\nnhsMHSPartyKey: X26-823848\nnhsMHsSN: urn:nhs:names:services:pdsquery\nnhsMhsSvcIA: urn:nhs:names:services:pdsquery:QUPA_IN050000UK32\nnhsMHSSyncReplyMode: None\nnhsProductKey: 10894\nnhsProductName: Compliance\nnhsProductVersion: Initial\nnhsRequestorURP: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nuniqueIdentifier: f1c55263f1ee924f460f" + == "dn: o=nhs,ou=services,uniqueidentifier=f1c55263f1ee924f460f\nchangetype: add\nobjectClass: nhsMhs\nobjectClass: top\nnhsApproverURP: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nnhsContractPropertyTemplateKey: 14\nnhsDateApproved: 20240417082830\nnhsDateDNSApproved: 20240417082830\nnhsDateRequested: 20240417082818\nnhsDNSApprover: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nnhsEPInteractionType: ebXML\nnhsIDCode: X26\nnhsMHSAckRequested: never\nnhsMhsCPAId: f1c55263f1ee924f460f\nnhsMHSDuplicateElimination: never\nnhsMHSEndPoint: https://simple-sync.intspineservices.nhs.uk/\nnhsMhsFQDN: simple-sync.intspineservices.nhs.uk\nnhsMHsIN: QUPA_IN050000UK32\nnhsMhsIPAddress: 0.0.0.0\nnhsMHSIsAuthenticated: none\nnhsMHSPartyKey: X26-823848\nnhsMHsSN: urn:nhs:names:services:pdsquery\nnhsMhsSvcIA: urn:nhs:names:services:pdsquery:QUPA_IN050000UK32\nnhsMHSSyncReplyMode: None\nnhsProductKey: 10894\nnhsProductName: Compliance\nnhsProductVersion: Initial\nnhsRequestorURP: uniqueidentifier=555050304105,uniqueidentifier=555008548101,uid=555008545108,ou=people, o=nhs\nuniqueIdentifier: f1c55263f1ee924f460f" ) @@ -231,7 +231,13 @@ def test_parse_changelog_changes_with_modify(): assert ( result - == "dn: o=nhs,ou=Services,uniqueIdentifier=200000002202\nchangetype: modify\nreplace: nhsAsSvcIA\nnhsAsSvcIA: urn:nhs:names:services:cpisquery:MCCI_IN010000UK13\nnhsAsSvcIA: urn:nhs:names:services:cpisquery:QUPC_IN000006GB01" + == """dn: o=nhs,ou=services,uniqueidentifier=200000002202 +changetype: modify +objectclass: modify +replace: nhsAsSvcIA +nhsAsSvcIA: urn:nhs:names:services:cpisquery:MCCI_IN010000UK13 +nhsAsSvcIA: urn:nhs:names:services:cpisquery:QUPC_IN000006GB01 +uniqueidentifier: 200000002202""" ) @@ -248,5 +254,8 @@ def test_parse_changelog_changes_with_delete(): assert ( result - == "dn: o=nhs,ou=Services,uniqueIdentifier=7abed27a247a511b7f0a\nchangetype: delete" + == """dn: o=nhs,ou=services,uniqueidentifier=7abed27a247a511b7f0a +changetype: delete +objectclass: delete +uniqueidentifier: 7abed27a247a511b7f0a""" ) diff --git a/src/etl/sds/worker/transform/tests/test_transform_worker.py b/src/etl/sds/worker/transform/tests/test_transform_worker.py index 49d69b443..7ed8ddb19 100644 --- a/src/etl/sds/worker/transform/tests/test_transform_worker.py +++ b/src/etl/sds/worker/transform/tests/test_transform_worker.py @@ -108,8 +108,8 @@ def test_transform_worker_pass_dupe_check_mock( response = transform.handler(event={}, context=None) assert response == { "stage_name": "transform", - # 4 x initial unprocessed because a key event + 2 questionnaire events are also created - "processed_records": n_initial_processed + 4 * n_initial_unprocessed, + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + "processed_records": n_initial_processed + 5 * n_initial_unprocessed, "unprocessed_records": 0, "error_message": None, } @@ -122,7 +122,7 @@ def test_transform_worker_pass_dupe_check_mock( # Confirm that everything has now been processed, and that there is no # unprocessed data left in the bucket - assert n_final_processed == n_initial_processed + 4 * n_initial_unprocessed + assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed assert n_final_unprocessed == 0 @@ -148,7 +148,7 @@ def test_transform_worker_pass_no_dupes( assert response == { "stage_name": "transform", # 2 x initial unprocessed because a key event is also created - "processed_records": n_initial_processed + 4 * n_initial_unprocessed, + "processed_records": n_initial_processed + 5 * n_initial_unprocessed, "unprocessed_records": 0, "error_message": None, } @@ -161,7 +161,7 @@ def test_transform_worker_pass_no_dupes( # Confirm that everything has now been processed, and that there is no # unprocessed data left in the bucket - assert n_final_processed == n_initial_processed + 4 * n_initial_unprocessed + assert n_final_processed == n_initial_processed + 5 * n_initial_unprocessed assert n_final_unprocessed == 0 @@ -243,8 +243,8 @@ def test_transform_worker_bad_record( response = transform.handler(event={}, context=None) assert response == { "stage_name": "transform", - # 4 x initial unprocessed because a key event + 2 questionnaire events are also created - "processed_records": n_initial_processed + (4 * bad_record_index), + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + "processed_records": n_initial_processed + (5 * bad_record_index), "unprocessed_records": n_initial_unprocessed - bad_record_index, "error_message": ( "The following errors were encountered\n" @@ -264,8 +264,8 @@ def test_transform_worker_bad_record( # Confirm that there are still unprocessed records, and that there may have been # some records processed successfully assert n_final_unprocessed > 0 - # 4 x initial unprocessed because a key event + 2 questionnaire events are also created - assert n_final_processed == n_initial_processed + (4 * bad_record_index) + # 5 x initial unprocessed because a key event + 2 questionnaire events + 1 index event are also created + assert n_final_processed == n_initial_processed + (5 * bad_record_index) assert n_final_unprocessed == n_initial_unprocessed - bad_record_index diff --git a/src/layers/domain/core/device.py b/src/layers/domain/core/device.py index 61a4f22a7..09b915b72 100644 --- a/src/layers/domain/core/device.py +++ b/src/layers/domain/core/device.py @@ -1,5 +1,6 @@ from collections import defaultdict from enum import StrEnum, auto +from itertools import chain from uuid import UUID, uuid4 from attr import dataclass, field @@ -27,6 +28,10 @@ class QuestionnaireResponseNotFoundError(Exception): pass +class QuestionNotFoundError(Exception): + pass + + @dataclass(kw_only=True, slots=True) class DeviceCreatedEvent(Event): id: str @@ -62,6 +67,14 @@ class DeviceKeyDeletedEvent(Event): key: str +@dataclass(kw_only=True, slots=True) +class DeviceIndexAddedEvent(Event): + id: str + questionnaire_id: str + question_name: str + value: str + + class DeviceType(StrEnum): """ A Product is to be classified as being one of the following. These terms @@ -84,6 +97,39 @@ class DeviceStatus(StrEnum): INACTIVE = auto() # "soft" delete +def _get_unique_answers( + questionnaire_responses: list[QuestionnaireResponse], question_name: str +): + all_responses = chain.from_iterable( + _questionnaire_response.responses + for _questionnaire_response in questionnaire_responses + ) + matching_responses = filter( + lambda response: question_name in response, all_responses + ) + matching_response_answers = ( + answer for responses in matching_responses for answer in responses.values() + ) + unique_answers = set(chain.from_iterable(matching_response_answers)) + return unique_answers + + +def _get_questionnaire_responses( + questionnaire_responses: dict[str, list[QuestionnaireResponse]], + questionnaire_id: str, +) -> list[QuestionnaireResponse]: + _questionnaire_responses = questionnaire_responses.get(questionnaire_id) + if _questionnaire_responses is None: + raise QuestionnaireNotFoundError( + f"This device does not contain a Questionnaire with id '{questionnaire_id}'" + ) + elif not _questionnaire_responses: + raise QuestionnaireResponseNotFoundError( + f"This device does not contain a QuestionnaireResponse for Questionnaire with id '{questionnaire_id}'" + ) + return _questionnaire_responses + + class Device(AggregateRoot): """ An entity in the database. It could model all sorts of different logical or @@ -132,6 +178,36 @@ def delete_key(self, key: str) -> DeviceKeyDeletedEvent: event = DeviceKeyDeletedEvent(id=self.id, key=device_key.key) return self.add_event(event) + def add_index( + self, questionnaire_id: str, question_name: str + ) -> list[DeviceIndexAddedEvent]: + questionnaire_responses = _get_questionnaire_responses( + questionnaire_responses=self.questionnaire_responses, + questionnaire_id=questionnaire_id, + ) + if question_name not in questionnaire_responses[0].questionnaire.questions: + raise QuestionNotFoundError( + f"Questionnaire '{questionnaire_id}' does not " + f"contain question '{question_name}'" + ) + unique_answers = _get_unique_answers( + questionnaire_responses=questionnaire_responses, + question_name=question_name, + ) + + events = [] + for answer in unique_answers: + event = DeviceIndexAddedEvent( + id=self.id, + questionnaire_id=questionnaire_id, + question_name=question_name, + value=answer, + ) + events.append(event) + self.add_event(event) + + return events + def add_questionnaire_response( self, questionnaire_response: QuestionnaireResponse, @@ -236,6 +312,7 @@ class DeviceEventDeserializer(EventDeserializer): DeviceUpdatedEvent, DeviceKeyAddedEvent, DeviceKeyDeletedEvent, + DeviceIndexAddedEvent, QuestionnaireResponseAddedEvent, QuestionnaireResponseUpdatedEvent, QuestionnaireResponseDeletedEvent, diff --git a/src/layers/domain/core/questionnaire.py b/src/layers/domain/core/questionnaire.py index 6217b8f74..d12c3ffe9 100644 --- a/src/layers/domain/core/questionnaire.py +++ b/src/layers/domain/core/questionnaire.py @@ -210,7 +210,7 @@ class QuestionnaireResponse(BaseModel): """ Validates questionnaire responses against questionnaire questions Responses is of the form: - ["question_name": ["answer_1", ..., "answer_n"]] + [{"question_name": ["answer_1", ..., "answer_n"]}] where n > 1 if Question.multiple is true for the Question in Questionnaire with the matching Question.name diff --git a/src/layers/domain/core/questionnaires/tests/test_spine_device_questionnaire.py b/src/layers/domain/core/questionnaires/tests/test_spine_device_questionnaire.py index 7506b9e72..138c16e65 100644 --- a/src/layers/domain/core/questionnaires/tests/test_spine_device_questionnaire.py +++ b/src/layers/domain/core/questionnaires/tests/test_spine_device_questionnaire.py @@ -3,7 +3,7 @@ from domain.core.questionnaire import Questionnaire from domain.core.questionnaires import QuestionnaireInstance from event.json import json_load -from hypothesis import assume, given +from hypothesis import assume, given, settings from sds.cpm_translation.tests.test_cpm_translation import ( NHS_ACCREDITED_SYSTEM_STRATEGY, ) @@ -52,6 +52,7 @@ def _test_spine_device_questionnaire_v1( return True +@settings(deadline=1500) @given(nhs_accredited_system=NHS_ACCREDITED_SYSTEM_STRATEGY) def test_spine_device_questionnaire_v1_local( nhs_accredited_system: NhsAccreditedSystem, diff --git a/src/layers/domain/core/questionnaires/tests/test_spine_endpoint_questionnaire.py b/src/layers/domain/core/questionnaires/tests/test_spine_endpoint_questionnaire.py index a0988f166..209804026 100644 --- a/src/layers/domain/core/questionnaires/tests/test_spine_endpoint_questionnaire.py +++ b/src/layers/domain/core/questionnaires/tests/test_spine_endpoint_questionnaire.py @@ -3,7 +3,7 @@ from domain.core.questionnaire import Questionnaire from domain.core.questionnaires import QuestionnaireInstance from event.json import json_load -from hypothesis import given +from hypothesis import given, settings from sds.cpm_translation.tests.test_cpm_translation import NHS_MHS_STRATEGY from sds.domain.nhs_mhs import NhsMhs @@ -37,6 +37,7 @@ def _test_spine_endpoint_questionnaire_v1( return True +@settings(deadline=1500) @given(nhs_mhs=NHS_MHS_STRATEGY) def test_spine_endpoint_questionnaire_v1_local(nhs_mhs: NhsMhs): spine_endpoint_questionnaire_v1 = render_questionnaire( diff --git a/src/layers/domain/core/tests/test_device.py b/src/layers/domain/core/tests/test_device.py index 56fc69e8b..07d129e41 100644 --- a/src/layers/domain/core/tests/test_device.py +++ b/src/layers/domain/core/tests/test_device.py @@ -1,6 +1,9 @@ +from itertools import chain + import pytest from domain.core.device import ( Device, + DeviceIndexAddedEvent, DeviceKeyAddedEvent, DeviceKeyDeletedEvent, DeviceStatus, @@ -8,6 +11,9 @@ DeviceUpdatedEvent, QuestionnaireNotFoundError, QuestionnaireResponseNotFoundError, + QuestionNotFoundError, + _get_questionnaire_responses, + _get_unique_answers, ) from domain.core.device_key import DeviceKey, DeviceKeyType from domain.core.error import NotFoundError @@ -169,3 +175,99 @@ def test_device_delete_questionnaire_response_key_error(device: Device): device.delete_questionnaire_response( questionnaire_id="bar/1", questionnaire_response_index=0 ) + + +def test__get_unique_answers(): + questionnaire = Questionnaire(name="foo", version=1) + questionnaire.add_question(name="question1", multiple=True) + questionnaire_response_1 = questionnaire.respond( + [ + {"question1": ["foo"]}, + {"question1": ["bar"]}, + {"question1": ["foo"]}, + ] + ) + + questionnaire_response_2 = questionnaire.respond( + [ + {"question1": ["baz", "BAR"]}, + {"question1": ["foo"]}, + ] + ) + + questionnaire_response_3 = questionnaire.respond( + [ + {"question1": ["FOO"]}, + {"question1": ["bar"]}, + {"question1": ["foo"]}, + ] + ) + + unique_answers = _get_unique_answers( + questionnaire_responses=[ + questionnaire_response_1, + questionnaire_response_2, + questionnaire_response_3, + ], + question_name="question1", + ) + + assert unique_answers == {"foo", "bar", "FOO", "BAR", "baz"} + + +def test__get_questionnaire_responses(): + questionnaire = Questionnaire(name="foo", version=1) + questionnaire.add_question(name="question1") + questionnaire_response = questionnaire.respond([{"question1": ["foo"]}]) + questionnaire_responses = [questionnaire_response] + assert ( + _get_questionnaire_responses( + questionnaire_responses={questionnaire.id: questionnaire_responses}, + questionnaire_id=questionnaire.id, + ) + == questionnaire_responses + ) + + +def test_device_add_index(device: Device): + questionnaire = Questionnaire(name="foo", version=1) + questionnaire.add_question(name="question1", multiple=True) + + N_QUESTIONNAIRE_RESPONSES = 123 + N_UNIQUE_ANSWERS = 7 + + answers = [["a", "b", "c"], ["d"], ["e", "f", "g"], ["a"], ["b", "c"]] + assert len(set(chain.from_iterable(answers))) == N_UNIQUE_ANSWERS + + for _ in range(N_QUESTIONNAIRE_RESPONSES): + for _answers in answers: + questionnaire_response = questionnaire.respond( + responses=[{"question1": _answers}] + ) + device.add_questionnaire_response( + questionnaire_response=questionnaire_response + ) + + events = device.add_index(questionnaire_id="foo/1", question_name="question1") + assert len(events) == N_UNIQUE_ANSWERS + assert all(isinstance(event, DeviceIndexAddedEvent) for event in events) + + +def test_device_add_index_no_such_questionnaire(device: Device): + with pytest.raises(QuestionnaireNotFoundError): + device.add_index(questionnaire_id="foo/1", question_name="question1") + + +def test_device_add_index_no_such_questionnaire_response(device: Device): + device.questionnaire_responses["foo/1"] = [] + with pytest.raises(QuestionnaireResponseNotFoundError): + device.add_index(questionnaire_id="foo/1", question_name="question1") + + +def test_device_add_index_no_such_question(device: Device): + questionnaire = Questionnaire(name="foo", version=1) + questionnaire_response = questionnaire.respond(responses=[]) + device.add_questionnaire_response(questionnaire_response=questionnaire_response) + + with pytest.raises(QuestionNotFoundError): + device.add_index(questionnaire_id="foo/1", question_name="question1") diff --git a/src/layers/domain/repository/device_repository.py b/src/layers/domain/repository/device_repository.py index 9b06e3f55..0b20a888b 100644 --- a/src/layers/domain/repository/device_repository.py +++ b/src/layers/domain/repository/device_repository.py @@ -6,6 +6,7 @@ from domain.core.device import ( Device, DeviceCreatedEvent, + DeviceIndexAddedEvent, DeviceKey, DeviceKeyAddedEvent, DeviceKeyDeletedEvent, @@ -160,6 +161,25 @@ def handle_QuestionnaireResponseDeletedEvent( ) ) + def handle_DeviceIndexAddedEvent(self, event: DeviceIndexAddedEvent): + pk = TableKeys.DEVICE.key(event.id) + sk = TableKeys.DEVICE_INDEX.key( + event.questionnaire_id, event.question_name, event.value + ) + event_data = asdict(event) + condition_expression = ( + {"ConditionExpression": ConditionExpression.MUST_NOT_EXIST} + if event_data.get("_trust", False) is False + else {} + ) + return TransactItem( + Put=TransactionStatement( + TableName=self.table_name, + Item=marshall(pk=pk, sk=sk, pk_1=sk, sk_1=pk, **asdict(event)), + **condition_expression, + ) + ) + def query_by_key_type(self, key_type, **kwargs) -> "QueryOutputTypeDef": pk_2 = TableKeys.DEVICE_KEY_TYPE.key(key_type) args = { @@ -182,6 +202,19 @@ def query_by_device_type(self, type: DeviceType, **kwargs) -> "QueryOutputTypeDe } return self.client.query(**args, **kwargs) + def read_by_index(self, questionnaire_id: str, question_name: str, value: str): + pk_1 = TableKeys.DEVICE_INDEX.key(questionnaire_id, question_name, value) + result = self.client.query( + TableName=self.table_name, + IndexName="idx_gsi_1", + KeyConditionExpression="pk_1 = :pk_1", + ExpressionAttributeValues={ + ":pk_1": marshall_value(pk_1), + }, + ) + items = (unmarshall(i) for i in result["Items"]) + return [self.read(strip_key_prefix(item["pk"])) for item in items] + def read_by_key(self, key) -> Device: pk_1 = TableKeys.DEVICE_KEY.key(key) args = { diff --git a/src/layers/domain/repository/keys.py b/src/layers/domain/repository/keys.py index 715b0f69e..59f8093be 100644 --- a/src/layers/domain/repository/keys.py +++ b/src/layers/domain/repository/keys.py @@ -9,6 +9,7 @@ class TableKeys(StrEnum): DEVICE_KEY = "DK" DEVICE_TYPE = "DT" DEVICE_KEY_TYPE = "DKT" + DEVICE_INDEX = "DI" PRODUCT_TEAM = "PT" QUESTIONNAIRE = "Q" QUESTIONNAIRE_RESPONSE = "QR" diff --git a/src/layers/domain/repository/tests/device_repository_tests/test_device_repository_indexes.py b/src/layers/domain/repository/tests/device_repository_tests/test_device_repository_indexes.py new file mode 100644 index 000000000..22389dda1 --- /dev/null +++ b/src/layers/domain/repository/tests/device_repository_tests/test_device_repository_indexes.py @@ -0,0 +1,132 @@ +import pytest +from domain.core.device import Device, DeviceType +from domain.core.product_team import ProductTeam +from domain.core.questionnaire import Questionnaire +from domain.core.root import Root +from domain.repository.device_repository import DeviceRepository + + +@pytest.fixture +def product_team(): + org = Root.create_ods_organisation(ods_code="AB123") + return org.create_product_team( + id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" + ) + + +@pytest.fixture +def shoe_questionnaire() -> Questionnaire: + _questionnaire = Questionnaire(name="shoe", version=1) + _questionnaire.add_question( + name="foot", answer_types=(str,), mandatory=True, choices={"L", "R"} + ) + _questionnaire.add_question(name="shoe-size", answer_types=(int,), mandatory=True) + return _questionnaire + + +@pytest.fixture +def device_right_shoe_size_123( + product_team: ProductTeam, shoe_questionnaire: Questionnaire +) -> Device: + shoe_response = shoe_questionnaire.respond( + responses=[{"foot": ["R"]}, {"shoe-size": [123]}], + ) + device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) + device.add_questionnaire_response(questionnaire_response=shoe_response) + device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="foot") + device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="shoe-size") + return device + + +@pytest.fixture +def device_left_shoe_size_123( + product_team: ProductTeam, shoe_questionnaire: Questionnaire +) -> Device: + shoe_response = shoe_questionnaire.respond( + responses=[{"foot": ["L"]}, {"shoe-size": [123]}], + ) + device = product_team.create_device(name="Device-2", type=DeviceType.PRODUCT) + device.add_questionnaire_response(questionnaire_response=shoe_response) + device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="foot") + device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="shoe-size") + return device + + +@pytest.mark.integration +def test__device_repository__query_by_index__find_right_shoes( + device_right_shoe_size_123: Device, + device_left_shoe_size_123: Device, + repository: DeviceRepository, +): + repository.write(device_right_shoe_size_123) + repository.write(device_left_shoe_size_123) + + (device,) = repository.read_by_index( + questionnaire_id="shoe/1", question_name="foot", value="R" + ) + assert device.id == device_right_shoe_size_123.id + + +@pytest.mark.integration +def test__device_repository__query_by_index__find_left_shoes( + device_right_shoe_size_123: Device, + device_left_shoe_size_123: Device, + repository: DeviceRepository, +): + repository.write(device_right_shoe_size_123) + repository.write(device_left_shoe_size_123) + + (device,) = repository.read_by_index( + questionnaire_id="shoe/1", question_name="foot", value="L" + ) + assert device.id == device_left_shoe_size_123.id + + +@pytest.mark.integration +def test__device_repository__query_by_index__find_shoes_by_size_int( + device_right_shoe_size_123: Device, + device_left_shoe_size_123: Device, + repository: DeviceRepository, +): + repository.write(device_right_shoe_size_123) + repository.write(device_left_shoe_size_123) + + (device_1, device_2) = repository.read_by_index( + questionnaire_id="shoe/1", question_name="shoe-size", value=123 + ) + assert {device_1.id, device_2.id} == { + device_left_shoe_size_123.id, + device_right_shoe_size_123.id, + } + + +@pytest.mark.integration +def test__device_repository__query_by_index__find_shoes_by_size_str( + device_right_shoe_size_123: Device, + device_left_shoe_size_123: Device, + repository: DeviceRepository, +): + repository.write(device_right_shoe_size_123) + repository.write(device_left_shoe_size_123) + + (device_1, device_2) = repository.read_by_index( + questionnaire_id="shoe/1", question_name="shoe-size", value="123" + ) + assert {device_1.id, device_2.id} == { + device_left_shoe_size_123.id, + device_right_shoe_size_123.id, + } + + +@pytest.mark.integration +def test__device_repository__query_by_index__empty_result( + device_right_shoe_size_123: Device, + device_left_shoe_size_123: Device, + repository: DeviceRepository, +): + repository.write(device_right_shoe_size_123) + repository.write(device_left_shoe_size_123) + result = repository.read_by_index( + questionnaire_id="shoe/1", question_name="shoe-size", value="345" + ) + assert result == [] diff --git a/src/layers/etl_utils/ldif/model.py b/src/layers/etl_utils/ldif/model.py index 7ea98e222..8294057fb 100644 --- a/src/layers/etl_utils/ldif/model.py +++ b/src/layers/etl_utils/ldif/model.py @@ -16,7 +16,7 @@ class DistinguishedName(BaseModel): @classmethod def parse(cls, raw_distinguished_name: str) -> Self: - unsorted_parts = DISTINGUISHED_NAME_RE.findall(raw_distinguished_name) + unsorted_parts = DISTINGUISHED_NAME_RE.findall(raw_distinguished_name.lower()) if not unsorted_parts: raise BadDistinguishedName(raw_distinguished_name) sorted_parts = sorted(unsorted_parts, key=lambda *args: args) diff --git a/src/layers/etl_utils/ldif/tests/test_ldif.py b/src/layers/etl_utils/ldif/tests/test_ldif.py index f980e4f27..ae8894ecc 100644 --- a/src/layers/etl_utils/ldif/tests/test_ldif.py +++ b/src/layers/etl_utils/ldif/tests/test_ldif.py @@ -109,25 +109,25 @@ [ "foo=FOO", DistinguishedName( - parts=(("foo", "FOO"),), + parts=(("foo", "foo"),), ), ], [ "foo=FOO,bar.baz=BAR.BAZ", DistinguishedName( parts=( - ("bar.baz", "BAR.BAZ"), - ("foo", "FOO"), + ("bar.baz", "bar.baz"), + ("foo", "foo"), ), ), ], ), ) -def test_distinguished_name(raw_distinguished_name, parsed_distinguished_name): +def test_distinguished_name(raw_distinguished_name: str, parsed_distinguished_name): distinguished_name = DistinguishedName.parse(raw_distinguished_name) assert distinguished_name == parsed_distinguished_name assert sorted(distinguished_name.raw.split(",")) == sorted( - raw_distinguished_name.split(",") + raw_distinguished_name.lower().split(",") ) diff --git a/src/layers/sds/cpm_translation/__init__.py b/src/layers/sds/cpm_translation/__init__.py index 096576e97..eb85869d4 100644 --- a/src/layers/sds/cpm_translation/__init__.py +++ b/src/layers/sds/cpm_translation/__init__.py @@ -4,25 +4,17 @@ from domain.core.device import Device from domain.core.questionnaire import Questionnaire from domain.repository.device_repository import DeviceRepository -from sds.domain.constants import ChangeType from sds.domain.nhs_accredited_system import NhsAccreditedSystem from sds.domain.nhs_mhs import NhsMhs +from sds.domain.parse import UnknownSdsModel +from sds.domain.sds_deletion_request import SdsDeletionRequest from .translations import ( - delete_accredited_system_devices, - delete_mhs_device, - modify_accredited_system_devices, - modify_mhs_device, + delete_devices, nhs_accredited_system_to_cpm_devices, nhs_mhs_to_cpm_device, ) -ACCREDITED_SYSTEM_TRANSLATIONS = { - ChangeType.ADD: nhs_accredited_system_to_cpm_devices, - ChangeType.MODIFY: modify_accredited_system_devices, - ChangeType.DELETE: delete_accredited_system_devices, -} - BAD_UNIQUE_IDENTIFIERS = { "31af51067f47f1244d38", # pragma: allowlist secret "a83e1431f26461894465", # pragma: allowlist secret @@ -39,22 +31,6 @@ def update_in_list_of_dict(obj: list[dict[str, str]], key, value): obj.append({key: value}) -MESSAGE_HANDLING_SYSTEM_TRANSLATIONS = { - ChangeType.ADD: lambda **kwargs: [nhs_mhs_to_cpm_device(**kwargs)], - ChangeType.MODIFY: lambda **kwargs: [modify_mhs_device(**kwargs)], - ChangeType.DELETE: lambda **kwargs: [delete_mhs_device(**kwargs)], -} - - -def _parse_object_class(object_class: str) -> str: - for _object_class in (NhsMhs.OBJECT_CLASS, NhsAccreditedSystem.OBJECT_CLASS): - if object_class.lower() == _object_class: - return _object_class - raise NotImplementedError( - f"No method implemented that translates objects of type '{_object_class}'" - ) - - def translate( obj: dict[str, str], spine_device_questionnaire: Questionnaire, @@ -67,28 +43,38 @@ def translate( if obj.get("unique_identifier") in BAD_UNIQUE_IDENTIFIERS: return [] - object_class = _parse_object_class(obj["object_class"]) - + object_class = obj["object_class"].lower() if object_class == NhsAccreditedSystem.OBJECT_CLASS: - instance = NhsAccreditedSystem.construct(**obj) - translate_kwargs = dict( - nhs_accredited_system=instance, + nhs_accredited_system = NhsAccreditedSystem.construct(**obj) + devices = nhs_accredited_system_to_cpm_devices( + nhs_accredited_system=nhs_accredited_system, questionnaire=spine_device_questionnaire, _questionnaire=_spine_device_questionnaire, _trust=_trust, + ) + elif object_class == NhsMhs.OBJECT_CLASS: + nhs_mhs = NhsMhs.construct(**obj) + devices = [ + nhs_mhs_to_cpm_device( + nhs_mhs=nhs_mhs, + questionnaire=spine_endpoint_questionnaire, + _questionnaire=_spine_endpoint_questionnaire, + _trust=_trust, + ) + ] + elif object_class == SdsDeletionRequest.OBJECT_CLASS: + deletion_request = SdsDeletionRequest.construct(**obj) + devices = delete_devices( + deletion_request=deletion_request, + questionnaire_ids=[ + spine_endpoint_questionnaire.id, + spine_device_questionnaire.id, + ], repository=repository, ) - translations = ACCREDITED_SYSTEM_TRANSLATIONS else: - instance = NhsMhs.construct(**obj) - translate_kwargs = dict( - nhs_mhs=instance, - questionnaire=spine_endpoint_questionnaire, - _questionnaire=_spine_endpoint_questionnaire, - _trust=_trust, - repository=repository, + raise UnknownSdsModel( + f"No translation available for models with object class '{object_class}'" ) - translations = MESSAGE_HANDLING_SYSTEM_TRANSLATIONS - devices = translations[instance.change_type](**translate_kwargs) return list(chain.from_iterable(map(Device.export_events, devices))) diff --git a/src/layers/sds/cpm_translation/tests/test_cpm_translation.py b/src/layers/sds/cpm_translation/tests/test_cpm_translation.py index 4c222fbb4..516694332 100644 --- a/src/layers/sds/cpm_translation/tests/test_cpm_translation.py +++ b/src/layers/sds/cpm_translation/tests/test_cpm_translation.py @@ -1,11 +1,17 @@ from itertools import chain from string import ascii_letters, digits +from typing import Generator import pytest +from domain.core.device import DeviceStatus, DeviceType, DeviceUpdatedEvent from domain.core.load_questionnaire import render_questionnaire +from domain.core.questionnaire import Questionnaire from domain.core.questionnaires import QuestionnaireInstance +from domain.core.root import Root from domain.core.validation import ODS_CODE_REGEX, SdsId +from domain.repository.device_repository import DeviceRepository from etl_utils.ldif.model import DistinguishedName +from event.aws.client import dynamodb_client from event.json import json_load from hypothesis import given, settings from hypothesis.provisional import urls @@ -15,8 +21,15 @@ nhs_mhs_to_cpm_device, translate, ) +from sds.cpm_translation.translations import delete_devices from sds.domain.nhs_accredited_system import NhsAccreditedSystem from sds.domain.nhs_mhs import NhsMhs +from sds.domain.sds_deletion_request import SdsDeletionRequest + +from test_helpers.dynamodb import mock_table +from test_helpers.terraform import read_terraform_output + +TABLE_NAME = "my_table" DUMMY_DISTINGUISHED_NAME = DistinguishedName( parts=(("ou", "services"), ("uniqueidentifier", "foobar"), ("o", "nhs")) @@ -26,6 +39,7 @@ "device_key_added_event", "questionnaire_instance_event", "questionnaire_response_added_event", + "device_index_added_event", ] @@ -60,7 +74,20 @@ ) -@settings(deadline=1000) +@pytest.fixture +def repository( + request: pytest.FixtureRequest, +) -> Generator[DeviceRepository, None, None]: + if request.node.get_closest_marker("integration"): + table_name = read_terraform_output("dynamodb_table_name.value") + client = dynamodb_client() + yield DeviceRepository(table_name=table_name, dynamodb_client=client) + else: + with mock_table(TABLE_NAME) as client: + yield DeviceRepository(table_name=TABLE_NAME, dynamodb_client=client) + + +@settings(deadline=1500) @given(nhs_mhs=NHS_MHS_STRATEGY) def test_nhs_mhs_to_cpm_device(nhs_mhs: NhsMhs): questionnaire = render_questionnaire( @@ -76,7 +103,7 @@ def test_nhs_mhs_to_cpm_device(nhs_mhs: NhsMhs): assert event_names == EXPECTED_EVENTS -@settings(deadline=1000) +@settings(deadline=1500) @given(nhs_accredited_system=NHS_ACCREDITED_SYSTEM_STRATEGY) def test_nhs_accredited_system_to_cpm_devices( nhs_accredited_system: NhsAccreditedSystem, @@ -99,7 +126,7 @@ def test_nhs_accredited_system_to_cpm_devices( @pytest.mark.s3("sds/etl/bulk/1701246-fix-18032023.json") -def test_translate(test_data_paths): +def test_translate_bulk(test_data_paths): (path,) = test_data_paths with open(path) as f: data = json_load(f) @@ -131,3 +158,125 @@ def test_translate(test_data_paths): ods_codes = obj.get("nhs_as_client") or ["dummy_org"] n_ods_codes = len(ods_codes) assert event_names == EXPECTED_EVENTS * n_ods_codes, obj + + +@pytest.mark.integration +def test_delete_devices(repository: DeviceRepository): + # Set initial state + org = Root.create_ods_organisation(ods_code="AB123") + product_team = org.create_product_team( + id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" + ) + _questionnaire = Questionnaire(name="my_questionnaire", version=1) + _questionnaire.add_question( + name="unique_identifier", answer_types=(str,), mandatory=True + ) + _questionnaire_response = _questionnaire.respond( + responses=[{"unique_identifier": ["001"]}] + ) + _device_1 = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) + _device_1.add_questionnaire_response(questionnaire_response=_questionnaire_response) + _device_1.add_index( + questionnaire_id=_questionnaire.id, question_name="unique_identifier" + ) + repository.write(_device_1) + + _device_2 = product_team.create_device(name="Device-2", type=DeviceType.PRODUCT) + _device_2.add_questionnaire_response(questionnaire_response=_questionnaire_response) + _device_2.add_index( + questionnaire_id=_questionnaire.id, question_name="unique_identifier" + ) + repository.write(_device_2) + + deletion_request = SdsDeletionRequest( + _distinguished_name=DistinguishedName( + parts=(("ou", "services"), ("uniqueidentifier", "001"), ("o", "nhs")) + ), + object_class="delete", + unique_identifier="001", + ) + + devices = delete_devices( + deletion_request=deletion_request, + questionnaire_ids=[_questionnaire.id], + repository=repository, + ) + (event_1, event_2) = sorted( + chain.from_iterable(device.events for device in devices), + key=lambda event: event.name, + ) + assert event_1 == DeviceUpdatedEvent( + id=_device_1.id, + name=_device_1.name, + type=_device_1.type, + product_team_id=_device_1.product_team_id, + ods_code=_device_1.ods_code, + status=DeviceStatus.INACTIVE, + ) + assert event_2 == DeviceUpdatedEvent( + id=_device_2.id, + name=_device_2.name, + type=_device_2.type, + product_team_id=_device_2.product_team_id, + ods_code=_device_2.ods_code, + status=DeviceStatus.INACTIVE, + ) + + +@pytest.mark.integration +def test_delete_devices_no_questionnaire(repository: DeviceRepository): + deletion_request = SdsDeletionRequest( + _distinguished_name=DistinguishedName( + parts=(("ou", "services"), ("uniqueidentifier", "001"), ("o", "nhs")) + ), + object_class="delete", + unique_identifier="001", + ) + + devices = delete_devices( + deletion_request=deletion_request, + questionnaire_ids=["does not exist"], + repository=repository, + ) + assert len(devices) == 0 + + +@pytest.mark.integration +def test_delete_devices_no_matching_device(repository: DeviceRepository): + # Set initial state + org = Root.create_ods_organisation(ods_code="AB123") + product_team = org.create_product_team( + id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" + ) + _questionnaire = Questionnaire(name="my_questionnaire", version=1) + _questionnaire.add_question( + name="unique_identifier", answer_types=(str,), mandatory=True + ) + _questionnaire_response = _questionnaire.respond( + responses=[{"unique_identifier": ["001"]}] + ) + _device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) + _device.add_questionnaire_response(questionnaire_response=_questionnaire_response) + _device.add_index( + questionnaire_id=_questionnaire.id, question_name="unique_identifier" + ) + repository.write(_device) + + deletion_request = SdsDeletionRequest( + _distinguished_name=DistinguishedName( + parts=( + ("ou", "services"), + ("uniqueidentifier", "does not exist"), + ("o", "nhs"), + ) + ), + object_class="delete", + unique_identifier="does not exist", + ) + + devices = delete_devices( + deletion_request=deletion_request, + questionnaire_ids=[_questionnaire.id], + repository=repository, + ) + assert len(devices) == 0 diff --git a/src/layers/sds/cpm_translation/translations.py b/src/layers/sds/cpm_translation/translations.py index 8e3592b11..b3f8e1db1 100644 --- a/src/layers/sds/cpm_translation/translations.py +++ b/src/layers/sds/cpm_translation/translations.py @@ -10,6 +10,7 @@ from domain.repository.device_repository import DeviceRepository from sds.domain.nhs_accredited_system import NhsAccreditedSystem from sds.domain.nhs_mhs import NhsMhs +from sds.domain.sds_deletion_request import SdsDeletionRequest DEFAULT_PRODUCT_TEAM = { "id": UUID(int=0x12345678123456781234567812345678), @@ -71,7 +72,6 @@ def nhs_accredited_system_to_cpm_devices( questionnaire: Questionnaire, _questionnaire: dict, _trust: bool = False, - **extra ) -> Generator[Device, None, None]: unique_identifier = nhs_accredited_system.unique_identifier product_name = nhs_accredited_system.nhs_product_name or unique_identifier @@ -104,11 +104,14 @@ def nhs_accredited_system_to_cpm_devices( _questionnaire=_questionnaire, _trust=True, ) + _device.add_index( + questionnaire_id=questionnaire.id, question_name="unique_identifier" + ) yield _device def modify_accredited_system_devices( - nhs_accredited_system: NhsAccreditedSystem, repository: DeviceRepository, **extra + nhs_accredited_system: NhsAccreditedSystem, repository: DeviceRepository ) -> Generator[Device, None, None]: for ( _, @@ -119,24 +122,11 @@ def modify_accredited_system_devices( yield device -def delete_accredited_system_devices( - nhs_accredited_system: NhsAccreditedSystem, repository: DeviceRepository, **extra -): - for ( - _, - accredited_system_id, - ) in accredited_system_ids(nhs_accredited_system): - device = repository.read_by_key(key=accredited_system_id) - device.delete() - yield device - - def nhs_mhs_to_cpm_device( nhs_mhs: NhsMhs, questionnaire: Questionnaire, _questionnaire: dict, _trust: bool = False, - **extra ) -> Device: ods_code = nhs_mhs.nhs_id_code _scoped_party_key = scoped_party_key(nhs_mhs) @@ -160,16 +150,30 @@ def nhs_mhs_to_cpm_device( _questionnaire=_questionnaire, _trust=True, ) + device.add_index( + questionnaire_id=questionnaire.id, question_name="unique_identifier" + ) return device -def modify_mhs_device(nhs_mhs: NhsMhs, repository: DeviceRepository, **extra): +def modify_mhs_device(nhs_mhs: NhsMhs, repository: DeviceRepository): device = repository.read_by_key(key=scoped_party_key(nhs_mhs)) device.update(something="foo") return device -def delete_mhs_device(nhs_mhs: NhsMhs, repository: DeviceRepository, **extra): - device = repository.read_by_key(key=scoped_party_key(nhs_mhs)) - device.delete() - return device +def delete_devices( + deletion_request: SdsDeletionRequest, + questionnaire_ids: list[str], + repository: DeviceRepository, +) -> list[Device]: + devices = [] + for questionnaire_id in questionnaire_ids: + for _device in repository.read_by_index( + questionnaire_id=questionnaire_id, + question_name="unique_identifier", + value=deletion_request.unique_identifier, + ): + _device.delete() + devices.append(_device) + return devices diff --git a/src/layers/sds/domain/changelog.py b/src/layers/sds/domain/changelog.py index e59be383e..a6da9b67b 100644 --- a/src/layers/sds/domain/changelog.py +++ b/src/layers/sds/domain/changelog.py @@ -50,13 +50,22 @@ def validate_change_number( def validate_changes(cls, changes): if isinstance(changes, bytes): changes = changes.decode("unicode_escape") - return changes + return changes.strip("\n") def changes_as_ldif(self) -> str: - change_lines = [ + distinguished_name = dict(self.target_distinguished_name.parts) + unique_identifier_line = ( + f"uniqueidentifier: {distinguished_name['uniqueidentifier']}" + ) + header_lines = [ f"dn: {self.target_distinguished_name.raw}", f"changetype: {self.change_type}", ] - if self.change_type is not ChangeType.DELETE: - change_lines.append(self.changes.strip("\n")) + if self.change_type is not ChangeType.ADD: + header_lines.append(f"{OBJECT_CLASS_FIELD_NAME}: {self.change_type}") + + change_lines = header_lines + list(filter(bool, self.changes.split("\n"))) + if unique_identifier_line not in map(str.lower, change_lines): + change_lines.append(unique_identifier_line) + return "\n".join(change_lines) diff --git a/src/layers/sds/domain/constants.py b/src/layers/sds/domain/constants.py index e4f2f2f8e..7fa48930c 100644 --- a/src/layers/sds/domain/constants.py +++ b/src/layers/sds/domain/constants.py @@ -1,6 +1,10 @@ from enum import StrEnum, auto -FILTER_TERMS = [("objectClass", "nhsMHS"), ("objectClass", "nhsAS")] +FILTER_TERMS = [ + ("objectClass", "nhsMHS"), + ("objectClass", "nhsAS"), + ("objectClass", "delete"), +] class CaseInsensitiveEnum(StrEnum): diff --git a/src/layers/sds/domain/parse.py b/src/layers/sds/domain/parse.py index f65928757..cb0b4c40b 100644 --- a/src/layers/sds/domain/parse.py +++ b/src/layers/sds/domain/parse.py @@ -1,5 +1,6 @@ from etl_utils.ldif.model import DistinguishedName from sds.domain.nhs_mhs_cp import NhsMhsCp +from sds.domain.sds_deletion_request import SdsDeletionRequest from .base import OBJECT_CLASS_FIELD_NAME, SdsBaseModel from .nhs_accredited_system import NhsAccreditedSystem @@ -20,6 +21,7 @@ class UnknownSdsModel(Exception): NhsMhsService, NhsMhs, NhsMhsCp, + SdsDeletionRequest, ) EMPTY_SET = set() diff --git a/src/layers/sds/domain/sds_deletion_request.py b/src/layers/sds/domain/sds_deletion_request.py new file mode 100644 index 000000000..969b03991 --- /dev/null +++ b/src/layers/sds/domain/sds_deletion_request.py @@ -0,0 +1,12 @@ +from typing import ClassVar, Literal + +from pydantic import Field +from sds.domain.base import OBJECT_CLASS_FIELD_NAME, SdsBaseModel +from sds.domain.organizational_unit import OrganizationalUnitDistinguishedName + + +class SdsDeletionRequest(SdsBaseModel): + distinguished_name: OrganizationalUnitDistinguishedName = Field(exclude=True) + OBJECT_CLASS: ClassVar[Literal["delete"]] = "delete" + object_class: str = Field(alias=OBJECT_CLASS_FIELD_NAME) + unique_identifier: str = Field(alias="uniqueidentifier") diff --git a/src/layers/sds/domain/tests/test_sds_changelog_model.py b/src/layers/sds/domain/tests/test_sds_changelog_model.py index 911e5fd73..4ad502c82 100644 --- a/src/layers/sds/domain/tests/test_sds_changelog_model.py +++ b/src/layers/sds/domain/tests/test_sds_changelog_model.py @@ -36,7 +36,7 @@ def test_changelog_model_against_changelog_data(test_data_paths): assert changelog_record.change_time == "20240116173441Z" assert changelog_record.change_type == "add" assert changelog_record.target_distinguished_name == DistinguishedName( - parts=(("o", "nhs"), ("ou", "Services"), ("uniqueIdentifier", "200000042019")) + parts=(("o", "nhs"), ("ou", "services"), ("uniqueidentifier", "200000042019")) ) @@ -56,14 +56,11 @@ def test_changelog_changes_are_valid_ldif(test_data_paths): **record, ) - # HACK THE RECORD - FOR SOME REASON DOESN'T START WITH DN LINE? - changelog_record.changes = ( - "dn: uniqueidentifier=200000042019,ou=services,o=nhs" + changelog_record.changes - ) - # Check that the change itself is valid LDIF nested_ldif_lines = list( - parse_ldif(file_opener=StringIO, path_or_data=changelog_record.changes) + parse_ldif( + file_opener=StringIO, path_or_data=changelog_record.changes_as_ldif() + ) ) assert len(nested_ldif_lines) == 1