From 5fafe3d27f9d63e194a0e7ff80ae759750acca4b Mon Sep 17 00:00:00 2001 From: Joel Klinger Date: Mon, 20 May 2024 13:58:51 +0100 Subject: [PATCH] [feature/PI-346-changelog_modify] modify transform and load --- src/layers/sds/cpm_translation/__init__.py | 21 ++- src/layers/sds/cpm_translation/constants.py | 30 +++ .../cpm_translation/modify/modify_device.py | 75 ++++++++ .../sds/cpm_translation/modify/modify_key.py | 177 ++++++++++++++++++ .../modify/tests/test_modify_key.py | 87 +++++++++ .../sds/cpm_translation/modify/utils.py | 23 +++ .../sds/cpm_translation/translations.py | 126 +++++++------ src/layers/sds/domain/base.py | 48 ++++- src/layers/sds/domain/changelog.py | 5 +- src/layers/sds/domain/parse.py | 2 + .../sds/domain/sds_modification_request.py | 33 ++++ 11 files changed, 561 insertions(+), 66 deletions(-) create mode 100644 src/layers/sds/cpm_translation/constants.py create mode 100644 src/layers/sds/cpm_translation/modify/modify_device.py create mode 100644 src/layers/sds/cpm_translation/modify/modify_key.py create mode 100644 src/layers/sds/cpm_translation/modify/tests/test_modify_key.py create mode 100644 src/layers/sds/cpm_translation/modify/utils.py create mode 100644 src/layers/sds/domain/sds_modification_request.py diff --git a/src/layers/sds/cpm_translation/__init__.py b/src/layers/sds/cpm_translation/__init__.py index eb85869d4..8b338d82a 100644 --- a/src/layers/sds/cpm_translation/__init__.py +++ b/src/layers/sds/cpm_translation/__init__.py @@ -8,20 +8,16 @@ from sds.domain.nhs_mhs import NhsMhs from sds.domain.parse import UnknownSdsModel from sds.domain.sds_deletion_request import SdsDeletionRequest +from sds.domain.sds_modification_request import SdsModificationRequest +from .constants import BAD_UNIQUE_IDENTIFIERS from .translations import ( delete_devices, + modify_devices, nhs_accredited_system_to_cpm_devices, nhs_mhs_to_cpm_device, ) -BAD_UNIQUE_IDENTIFIERS = { - "31af51067f47f1244d38", # pragma: allowlist secret - "a83e1431f26461894465", # pragma: allowlist secret - "S2202584A2577603", - "S100049A300185", -} - def update_in_list_of_dict(obj: list[dict[str, str]], key, value): for item in obj: @@ -72,9 +68,18 @@ def translate( ], repository=repository, ) + elif object_class == SdsModificationRequest.OBJECT_CLASS: + modification_request = SdsModificationRequest.construct(**obj) + devices = modify_devices( + modification_request=modification_request, + questionnaire_ids=[ + spine_endpoint_questionnaire.id, + spine_device_questionnaire.id, + ], + repository=repository, + ) else: raise UnknownSdsModel( f"No translation available for models with object class '{object_class}'" ) - return list(chain.from_iterable(map(Device.export_events, devices))) diff --git a/src/layers/sds/cpm_translation/constants.py b/src/layers/sds/cpm_translation/constants.py new file mode 100644 index 000000000..76e4c6c0b --- /dev/null +++ b/src/layers/sds/cpm_translation/constants.py @@ -0,0 +1,30 @@ +from uuid import UUID + +DEFAULT_ORGANISATION = "CDEF" + +DEFAULT_PRODUCT_TEAM = { + "id": UUID(int=0x12345678123456781234567812345678), + "name": "ROOT", +} + +EXCEPTIONAL_ODS_CODES = { + "696B001", + "TESTEBS1", + "TESTLSP0", + "TESTLSP1", + "TESTLSP3", + "TMSAsync1", + "TMSAsync2", + "TMSAsync3", + "TMSAsync4", + "TMSAsync5", + "TMSAsync6", + "TMSEbs2", +} + +BAD_UNIQUE_IDENTIFIERS = { + "31af51067f47f1244d38", # pragma: allowlist secret + "a83e1431f26461894465", # pragma: allowlist secret + "S2202584A2577603", + "S100049A300185", +} diff --git a/src/layers/sds/cpm_translation/modify/modify_device.py b/src/layers/sds/cpm_translation/modify/modify_device.py new file mode 100644 index 000000000..2416ed354 --- /dev/null +++ b/src/layers/sds/cpm_translation/modify/modify_device.py @@ -0,0 +1,75 @@ +from domain.core.device import Device +from pydantic import ValidationError +from sds.domain.constants import ModificationType +from sds.domain.nhs_accredited_system import NhsAccreditedSystem +from sds.domain.nhs_mhs import NhsMhs + +from .utils import InvalidModificationRequest, new_questionnaire_response_from_template + + +class MandatoryFieldError(Exception): + def __init__(self, field): + super().__init__(f"Field '{field}' cannot be null") + + +class NoValuesToRemove(Exception): + pass + + +def _unique_list(*items): + return list(set(items)) + + +def _parse_and_validate_field( + model: type[NhsAccreditedSystem] | type[NhsMhs], field: str, value +) -> list: + try: + parsed_value = model.parse_and_validate_field(field=field, value=value) + except ValidationError: + raise InvalidModificationRequest(field) + + if isinstance(parsed_value, (set, list)): + return list(parsed_value) + else: + return [parsed_value] + + +def update_device_metadata( + device: Device, + model: type[NhsAccreditedSystem] | type[NhsMhs], + modification_type: ModificationType, + field_alias: str, + new_values: list, +) -> Device: + field_name = model.get_field_name_for_alias(alias=field_alias) + ((questionnaire_response,),) = device.questionnaire_responses.values() + _current_values = questionnaire_response.get_response(question_name=field_name) + + if modification_type == ModificationType.ADD: + _updated_values = _unique_list(*_current_values, *new_values) + parsed_values = _parse_and_validate_field( + model=model, field=field_name, value=_updated_values + ) + elif modification_type == ModificationType.REPLACE: + parsed_values = _parse_and_validate_field( + model=model, field=field_name, value=new_values + ) + elif modification_type == ModificationType.DELETE: + if model.is_mandatory_field(field_name): + raise MandatoryFieldError(field_name) + if not _current_values: + raise InvalidModificationRequest( + field_name, "This device has no such data to delete" + ) + parsed_values = [] + + new_questionnaire_response = new_questionnaire_response_from_template( + questionnaire_response=questionnaire_response, + field_to_update=field_name, + new_values=parsed_values, + ) + device.update_questionnaire_response( + questionnaire_response_index=0, + questionnaire_response=new_questionnaire_response, + ) + return device diff --git a/src/layers/sds/cpm_translation/modify/modify_key.py b/src/layers/sds/cpm_translation/modify/modify_key.py new file mode 100644 index 000000000..549418d13 --- /dev/null +++ b/src/layers/sds/cpm_translation/modify/modify_key.py @@ -0,0 +1,177 @@ +from typing import Callable, Generator + +from domain.core.device import Device +from domain.core.device_key import DeviceKeyType +from domain.core.product_team import ProductTeam +from domain.core.validation import DEVICE_KEY_SEPARATOR +from sds.cpm_translation.utils import get_in_list_of_dict +from sds.domain.constants import ModificationType +from sds.domain.nhs_accredited_system import NhsAccreditedSystem +from sds.domain.nhs_mhs import NhsMhs + +from ..constants import DEFAULT_PRODUCT_TEAM +from .utils import InvalidModificationRequest, new_questionnaire_response_from_template + + +class NotAnSdsKey(Exception): + pass + + +class AccreditedSystemAlreadyExists(Exception): + def __init__(self, ods_code): + super().__init__(f"Accredited System with ODS code '{ods_code}' already exists") + + +MHS_KEY_FIELDS = ["nhs_id_code", "nhs_mhs_party_key", "nhs_mhs_svc_ia"] + + +def get_modify_key_function( + model: type[NhsMhs] | type[NhsAccreditedSystem], + modification_type: ModificationType, + field_name: str, +) -> Callable[[list[Device], str, any], Generator[Device, None, None]]: + match (model, modification_type, field_name): + case ( + NhsAccreditedSystem, + ModificationType.ADD, + "nhs_as_client", + ): + return new_accredited_system + case ( + NhsAccreditedSystem, + ModificationType.REPLACE, + "nhs_as_client", + ): + return replace_accredited_systems + case ( + NhsAccreditedSystem, + ModificationType.DELETE, + "nhs_as_client", + ): + raise InvalidModificationRequest(field_name) + case ( + NhsMhs, + ModificationType.ADD, + "nhs_mhs_party_key" | "nhs_mhs_svc_ia" | "nhs_id_code", + ): + raise InvalidModificationRequest(field_name) + case ( + NhsMhs, + ModificationType.REPLACE, + "nhs_mhs_party_key" | "nhs_mhs_svc_ia" | "nhs_id_code", + ): + return replace_msg_handling_system + case ( + NhsMhs, + ModificationType.DELETE, + "nhs_mhs_party_key" | "nhs_mhs_svc_ia" | "nhs_id_code", + ): + raise InvalidModificationRequest(field_name) + case _: + raise NotAnSdsKey + + +def new_accredited_system( + devices: list[Device], field_name: str, value: str +) -> Generator[Device, None, None]: + (ods_code,) = NhsAccreditedSystem.parse_and_validate_field( + field=field_name, value=value + ) + + current_ods_codes = {device.ods_code for device in devices} + if ods_code in current_ods_codes: + raise AccreditedSystemAlreadyExists(ods_code) + + device = devices[0] + ( + (questionnaire_id, (questionnaire_response,)), + ) = device.questionnaire_responses.items() + new_questionnaire_response = new_questionnaire_response_from_template( + questionnaire_response=questionnaire_response, + field_to_update=field_name, + new_values=[ods_code], + ) + unique_identifier = device.indexes[(questionnaire_id, "unique_identifier")] + new_accredited_system_id = DEVICE_KEY_SEPARATOR.join((ods_code, unique_identifier)) + + product_team = ProductTeam( + id=device.product_team_id, ods_code=ods_code, name=DEFAULT_PRODUCT_TEAM["name"] + ) + new_device = product_team.create_device(name=device.name, type=device.type) + new_device.add_questionnaire_response( + questionnaire_response=new_questionnaire_response + ) + new_device.add_key( + type=DeviceKeyType.ACCREDITED_SYSTEM_ID, key=new_accredited_system_id + ) + new_device.add_index( + questionnaire_id=questionnaire_id, question_name="unique_identifier" + ) + yield new_device + + +def replace_accredited_systems( + devices: list[Device], field_name: str, value: str +) -> Generator[Device, None, None]: + current_ods_codes = {device.ods_code for device in devices} + final_ods_codes = NhsAccreditedSystem.parse_and_validate_field( + field=field_name, value=value + ) + removed_ods_codes = current_ods_codes - final_ods_codes + for device in devices: + if device.ods_code in removed_ods_codes: + device.delete() + yield device + + for new_ods_code in final_ods_codes - current_ods_codes: + yield from new_accredited_system( + devices=devices, field_name=field_name, value=[new_ods_code] + ) + + +def _get_msg_handling_system_scoped_key_parts( + responses: list[dict], +) -> Generator[str, None, None]: + for key_field in MHS_KEY_FIELDS: + (_value,) = get_in_list_of_dict(obj=responses, key=key_field) + yield _value.strip() + + +def replace_msg_handling_system( + devices: list[Device], field_name: str, value: str +) -> Generator[Device, None, None]: + (device,) = devices + device.delete() + yield device + + ( + (questionnaire_id, (_questionnaire_response,)), + ) = device.questionnaire_responses.items() + new_value = NhsMhs.parse_and_validate_field(field=field_name, value=value) + new_questionnaire_response = new_questionnaire_response_from_template( + questionnaire_response=_questionnaire_response, + field_to_update=field_name, + new_values=[new_value], + ) + key_parts = _get_msg_handling_system_scoped_key_parts( + responses=_questionnaire_response.responses + ) + new_scoped_party_key = DEVICE_KEY_SEPARATOR.join(key_parts) + (ods_code,) = get_in_list_of_dict( + obj=_questionnaire_response.responses, key="nhs_id_code" + ) + + product_team = ProductTeam( + id=device.product_team_id, ods_code=ods_code, name=DEFAULT_PRODUCT_TEAM["name"] + ) + new_device = product_team.create_device(name=device.name, type=device.type) + new_device.add_questionnaire_response( + questionnaire_response=new_questionnaire_response + ) + new_device.add_key( + type=DeviceKeyType.MESSAGE_HANDLING_SYSTEM_ID, key=new_scoped_party_key + ) + new_device.add_index( + questionnaire_id=questionnaire_id, question_name="unique_identifier" + ) + yield new_device diff --git a/src/layers/sds/cpm_translation/modify/tests/test_modify_key.py b/src/layers/sds/cpm_translation/modify/tests/test_modify_key.py new file mode 100644 index 000000000..fd8de5849 --- /dev/null +++ b/src/layers/sds/cpm_translation/modify/tests/test_modify_key.py @@ -0,0 +1,87 @@ +import pytest +from sds.cpm_translation.modify.modify_key import ( + InvalidModificationRequest, + NotAnSdsKey, + get_modify_key_function, + new_accredited_system, + replace_accredited_systems, + replace_msg_handling_system, +) +from sds.domain.constants import ModificationType +from sds.domain.nhs_accredited_system import NhsAccreditedSystem +from sds.domain.nhs_mhs import NhsMhs + + +@pytest.mark.parametrize( + ("model", "modification_type", "field", "result"), + [ + ( + NhsAccreditedSystem, + ModificationType.ADD, + "nhs_as_client", + new_accredited_system, + ), + ( + NhsAccreditedSystem, + ModificationType.REPLACE, + "nhs_as_client", + replace_accredited_systems, + ), + ( + NhsMhs, + ModificationType.REPLACE, + "nhs_mhs_party_key", + replace_msg_handling_system, + ), + ( + NhsMhs, + ModificationType.REPLACE, + "nhs_mhs_svc_ia", + replace_msg_handling_system, + ), + ( + NhsMhs, + ModificationType.REPLACE, + "nhs_id_code", + replace_msg_handling_system, + ), + ], +) +def test_get_modify_key_function(model, modification_type, field, result): + assert ( + get_modify_key_function( + model=model, modification_type=modification_type, field=field + ) + is result + ) + + +@pytest.mark.parametrize( + ("model", "modification_type", "field"), + [ + (NhsAccreditedSystem, ModificationType.DELETE, "nhs_as_client"), + (NhsMhs, ModificationType.ADD, "nhs_mhs_party_key"), + (NhsMhs, ModificationType.ADD, "nhs_mhs_svc_ia"), + (NhsMhs, ModificationType.ADD, "nhs_id_code"), + (NhsMhs, ModificationType.DELETE, "nhs_mhs_party_key"), + (NhsMhs, ModificationType.DELETE, "nhs_mhs_svc_ia"), + (NhsMhs, ModificationType.DELETE, "nhs_id_code"), + ], +) +def test_get_modify_key_function_invalid(model, modification_type, field): + with pytest.raises(InvalidModificationRequest): + get_modify_key_function( + model=model, modification_type=modification_type, field=field + ) + + +@pytest.mark.parametrize("model", (NhsAccreditedSystem, NhsMhs)) +@pytest.mark.parametrize( + "modification_type", + (ModificationType.ADD, ModificationType.REPLACE, ModificationType.DELETE), +) +def test_get_modify_key_function_other(model, modification_type): + with pytest.raises(NotAnSdsKey): + get_modify_key_function( + model=model, modification_type=modification_type, field="foo" + ) diff --git a/src/layers/sds/cpm_translation/modify/utils.py b/src/layers/sds/cpm_translation/modify/utils.py new file mode 100644 index 000000000..1c5b7f635 --- /dev/null +++ b/src/layers/sds/cpm_translation/modify/utils.py @@ -0,0 +1,23 @@ +from domain.core.questionnaire import QuestionnaireResponse + +from ..utils import update_in_list_of_dict + + +class InvalidModificationRequest(Exception): + def __init__(self, field, extra_info=""): + msg = ". ".join( + filter( + bool, (f"Cannot modify field '{field}' with this operation", extra_info) + ) + ) + super().__init__(msg) + + +def new_questionnaire_response_from_template( + questionnaire_response: QuestionnaireResponse, field_to_update: str, new_values +) -> QuestionnaireResponse: + update_in_list_of_dict( + obj=questionnaire_response.responses, key=field_to_update, value=new_values + ) + non_empty_responses = list(filter(bool, questionnaire_response.responses)) + return questionnaire_response.questionnaire.respond(non_empty_responses) diff --git a/src/layers/sds/cpm_translation/translations.py b/src/layers/sds/cpm_translation/translations.py index b3f8e1db1..e24f26e2c 100644 --- a/src/layers/sds/cpm_translation/translations.py +++ b/src/layers/sds/cpm_translation/translations.py @@ -1,5 +1,6 @@ +from functools import partial +from itertools import filterfalse from typing import Generator -from uuid import UUID from domain.core.device import Device, DeviceType from domain.core.device_key import DeviceKeyType @@ -11,27 +12,12 @@ from sds.domain.nhs_accredited_system import NhsAccreditedSystem from sds.domain.nhs_mhs import NhsMhs from sds.domain.sds_deletion_request import SdsDeletionRequest +from sds.domain.sds_modification_request import SdsModificationRequest -DEFAULT_PRODUCT_TEAM = { - "id": UUID(int=0x12345678123456781234567812345678), - "name": "ROOT", -} -EXCEPTIONAL_ODS_CODES = { - "696B001", - "TESTEBS1", - "TESTLSP0", - "TESTLSP1", - "TESTLSP3", - "TMSAsync1", - "TMSAsync2", - "TMSAsync3", - "TMSAsync4", - "TMSAsync5", - "TMSAsync6", - "TMSEbs2", -} - -DEFAULT_ORGANISATION = "CDEF" +from .constants import DEFAULT_ORGANISATION, DEFAULT_PRODUCT_TEAM, EXCEPTIONAL_ODS_CODES +from .modify.modify_device import update_device_metadata +from .modify.modify_key import NotAnSdsKey, get_modify_key_function +from .utils import update_in_list_of_dict def accredited_system_ids( @@ -50,14 +36,6 @@ def scoped_party_key(nhs_mhs: NhsMhs) -> str: return DEVICE_KEY_SEPARATOR.join((ods_code, party_key, interaction_id)) -def update_in_list_of_dict(obj: list[dict[str, str]], key, value): - for item in obj: - if key in item: - item[key] = value - return - obj.append({key: value}) - - def create_product_team(ods_code: str) -> ProductTeam: if ods_code in EXCEPTIONAL_ODS_CODES: product_team = ProductTeam(**DEFAULT_PRODUCT_TEAM, ods_code=ods_code) @@ -110,18 +88,6 @@ def nhs_accredited_system_to_cpm_devices( yield _device -def modify_accredited_system_devices( - nhs_accredited_system: NhsAccreditedSystem, repository: DeviceRepository -) -> Generator[Device, None, None]: - for ( - _, - accredited_system_id, - ) in accredited_system_ids(nhs_accredited_system): - device = repository.read_by_key(key=accredited_system_id) - device.update(something="foo") - yield device - - def nhs_mhs_to_cpm_device( nhs_mhs: NhsMhs, questionnaire: Questionnaire, @@ -156,10 +122,65 @@ def nhs_mhs_to_cpm_device( return device -def modify_mhs_device(nhs_mhs: NhsMhs, repository: DeviceRepository): - device = repository.read_by_key(key=scoped_party_key(nhs_mhs)) - device.update(something="foo") - return device +class NoDeviceFound(Exception): + pass + + +def read_devices_by_unique_identifier( + questionnaire_ids: list[str], repository: DeviceRepository, value: str +) -> Generator[Device, None, None]: + for questionnaire_id in questionnaire_ids: + for device in repository.read_by_index( + questionnaire_id=questionnaire_id, + question_name="unique_identifier", + value=value, + ): + if device.is_active(): + yield device + + +def modify_devices( + modification_request: SdsModificationRequest, + questionnaire_ids: list[str], + repository: DeviceRepository, +) -> Generator[Device, None, None]: + devices = list( + read_devices_by_unique_identifier( + questionnaire_ids=questionnaire_ids, + repository=repository, + value=modification_request.unique_identifier, + ) + ) + # Only apply modifications if there are devices to modify + modifications = modification_request.modifications if devices else [] + + _devices = devices + for modification_type, field, new_values in modifications: + device_type = _devices[0].type + model = NhsAccreditedSystem if device_type is DeviceType.PRODUCT else NhsMhs + field_name = model.get_field_name_for_alias(alias=field) + + try: + modify_key = get_modify_key_function( + model=model, + field_name=field_name, + modification_type=modification_type, + ) + _devices += list( + modify_key(devices=_devices, field_name=field_name, value=new_values) + ) + except NotAnSdsKey: + update_metadata = partial( + update_device_metadata, + model=model, + modification_type=modification_type, + field_alias=field, + new_values=new_values, + ) + _active_devices = list(filter(Device.is_active, _devices)) + _inactive_devices = list(filterfalse(Device.is_active, _devices)) + _devices = [*map(update_metadata, _active_devices), *_inactive_devices] + yield from _devices def delete_devices( @@ -168,12 +189,11 @@ def delete_devices( repository: DeviceRepository, ) -> list[Device]: devices = [] - for questionnaire_id in questionnaire_ids: - for _device in repository.read_by_index( - questionnaire_id=questionnaire_id, - question_name="unique_identifier", - value=deletion_request.unique_identifier, - ): - _device.delete() - devices.append(_device) + for _device in read_devices_by_unique_identifier( + questionnaire_ids=questionnaire_ids, + repository=repository, + value=deletion_request.unique_identifier, + ): + _device.delete() + devices.append(_device) return devices diff --git a/src/layers/sds/domain/base.py b/src/layers/sds/domain/base.py index 92c2cafe7..345e9e725 100644 --- a/src/layers/sds/domain/base.py +++ b/src/layers/sds/domain/base.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Literal +from typing import ClassVar, Literal, Self import orjson from etl_utils.ldif.model import DistinguishedName @@ -30,7 +30,8 @@ def _strip_excluded_values_from_object_class( values: dict, excluded_object_class_values=EXCLUDED_OBJECT_CLASS_VALUES, ) -> dict: - object_class = values.get(OBJECT_CLASS_FIELD_NAME) + _object_class = values.get(OBJECT_CLASS_FIELD_NAME) + object_class = set(_object_class) if _object_class else _object_class if object_class: values[OBJECT_CLASS_FIELD_NAME] = object_class - excluded_object_class_values return values @@ -52,7 +53,7 @@ def _unpack_sets(cls: "SdsBaseModel", values: dict) -> dict: def _generate_distinguished_name(cls: "SdsBaseModel", values: dict) -> dict: - _distinguished_name: DistinguishedName = values.pop("_distinguished_name") + _distinguished_name: DistinguishedName = values.pop("_distinguished_name", None) if _distinguished_name is not None: values["distinguished_name"] = dict(_distinguished_name.parts) return values @@ -98,6 +99,27 @@ def alias_fields(cls) -> dict[str, ModelField]: for model_field in cls.__fields__.values() } + @classmethod + def get_field_name_for_alias(cls, alias) -> str: + try: + (field_name,) = ( + field_name + for field_name, model_field in cls.__fields__.items() + if model_field.alias == alias + ) + except ValueError: + raise ValueError(f"No field with alias '{alias}' found") + return field_name + + @classmethod + def get_alias_for_field_name(cls, field) -> str: + (field_name,) = ( + model_field.alias + for field_name, model_field in cls.__fields__.items() + if field_name == field + ) + return field_name + @root_validator(pre=True) def preprocess_inputs(cls, values): for transform in ( @@ -112,3 +134,23 @@ def preprocess_inputs(cls, values): def as_questionnaire_response_responses(self) -> list[dict[str, list]]: data = orjson.loads(self.json(exclude_none=True, exclude={"change_type"})) return [{k: (v if _is_iterable(v) else [v])} for k, v in data.items()] + + @classmethod + def force_optional(cls) -> type[Self]: + _model = type(f"{cls.__name__}-subclass", (cls,), {}) + for field in _model.__fields__.values(): + field.required = False + return _model + + @classmethod + def parse_and_validate_field(cls, field: str, value: list | set): + _model = cls.force_optional() + field_alias = cls.get_alias_for_field_name(field=field) + _value = set(value) if isinstance(value, list) else value + instance = _model(**{field_alias: _value}) + parsed_value = getattr(instance, field) + return parsed_value + + @classmethod + def is_mandatory_field(cls, field: str): + return cls.__fields__[field].required diff --git a/src/layers/sds/domain/changelog.py b/src/layers/sds/domain/changelog.py index a6da9b67b..5964b44dd 100644 --- a/src/layers/sds/domain/changelog.py +++ b/src/layers/sds/domain/changelog.py @@ -61,10 +61,11 @@ def changes_as_ldif(self) -> str: f"dn: {self.target_distinguished_name.raw}", f"changetype: {self.change_type}", ] + change_lines = header_lines + list(filter(bool, self.changes.split("\n"))) + if self.change_type is not ChangeType.ADD: - header_lines.append(f"{OBJECT_CLASS_FIELD_NAME}: {self.change_type}") + change_lines.append(f"{OBJECT_CLASS_FIELD_NAME}: {self.change_type}") - change_lines = header_lines + list(filter(bool, self.changes.split("\n"))) if unique_identifier_line not in map(str.lower, change_lines): change_lines.append(unique_identifier_line) diff --git a/src/layers/sds/domain/parse.py b/src/layers/sds/domain/parse.py index cb0b4c40b..1eaf9b23e 100644 --- a/src/layers/sds/domain/parse.py +++ b/src/layers/sds/domain/parse.py @@ -1,6 +1,7 @@ from etl_utils.ldif.model import DistinguishedName from sds.domain.nhs_mhs_cp import NhsMhsCp from sds.domain.sds_deletion_request import SdsDeletionRequest +from sds.domain.sds_modification_request import SdsModificationRequest from .base import OBJECT_CLASS_FIELD_NAME, SdsBaseModel from .nhs_accredited_system import NhsAccreditedSystem @@ -22,6 +23,7 @@ class UnknownSdsModel(Exception): NhsMhs, NhsMhsCp, SdsDeletionRequest, + SdsModificationRequest, ) EMPTY_SET = set() diff --git a/src/layers/sds/domain/sds_modification_request.py b/src/layers/sds/domain/sds_modification_request.py new file mode 100644 index 000000000..2f65f790b --- /dev/null +++ b/src/layers/sds/domain/sds_modification_request.py @@ -0,0 +1,33 @@ +from typing import ClassVar, Literal + +from pydantic import Field, validator +from sds.domain.base import OBJECT_CLASS_FIELD_NAME, SdsBaseModel +from sds.domain.constants import ModificationType +from sds.domain.organizational_unit import OrganizationalUnitDistinguishedName + + +class ImmutableFieldError(Exception): + pass + + +IMMUTABLE_SDS_FIELDS = { + "unique_identifier", +} + + +class SdsModificationRequest(SdsBaseModel): + distinguished_name: OrganizationalUnitDistinguishedName = Field(exclude=True) + OBJECT_CLASS: ClassVar[Literal["modify"]] = "modify" + object_class: str = Field(alias=OBJECT_CLASS_FIELD_NAME) + unique_identifier: str = Field(alias="uniqueidentifier") + + modifications: list[tuple[ModificationType, str, set[str]]] + + @validator("modifications") + def validate_immutable_fields( + modifications: list[tuple[ModificationType, str, any]] + ): + for _, field, _ in modifications: + if field in IMMUTABLE_SDS_FIELDS: + raise ImmutableFieldError(field) + return modifications