From 21f0ad3a0d6d70cd6966656c126d5bbc3553c3ce Mon Sep 17 00:00:00 2001 From: Joel Klinger Date: Fri, 8 Nov 2024 13:36:23 +0000 Subject: [PATCH] [feature/PI-592-gsi_upgrade] gsi upgrade --- .../terraform/per_workspace/main.tf | 18 +- .../tests/test_index.py | 2 +- src/api/createDevice/tests/test_index.py | 10 +- .../tests/test_index.py | 2 +- .../tests/test_index.py | 4 +- src/api/deleteCpmProduct/tests/test_index.py | 36 +- src/api/readDevice/src/v1/steps.py | 8 +- src/api/readDevice/tests/test_index.py | 2 +- .../readDeviceReferenceData/src/v1/steps.py | 8 +- src/api/searchCpmProduct/src/v1/steps.py | 4 +- .../features/readDevice.failure.feature | 8 +- .../domain/api/common_steps/read_product.py | 2 +- src/layers/domain/core/device/v3.py | 61 +- .../v3/test_cpm_product_repository_keys_v3.py | 13 +- .../v3/test_cpm_product_repository_v3.py | 20 +- .../test_cpm_product_repository_v3_delete.py | 8 +- .../repository/cpm_product_repository/v3.py | 236 ++------ ...est_device_reference_data_repository_v1.py | 8 +- .../device_reference_data_repository/v1.py | 81 +-- .../repository/device_repository/__init__.py | 2 +- .../device_repository/tests/v1/conftest.py | 13 - .../v1/test_device_repository_indexes_v1.py | 142 ----- .../v1/test_device_repository_keys_v1.py | 83 --- ...e_repository_questionnaire_responses_v1.py | 112 ---- .../tests/v1/test_device_repository_v1.py | 73 --- .../v3/test_device_repository_keys_v3.py | 36 +- ...e_repository_questionnaire_responses_v3.py | 30 +- .../v3/test_device_repository_tags_v3.py | 55 +- .../tests/v3/test_device_repository_v3.py | 127 ++-- .../domain/repository/device_repository/v1.py | 277 --------- .../domain/repository/device_repository/v3.py | 572 ++++++++---------- .../repository/keys/tests/test_keys_v1.py | 30 +- src/layers/domain/repository/keys/v1.py | 12 +- src/layers/domain/repository/keys/v3.py | 2 - src/layers/domain/repository/marshall.py | 4 +- .../repository/product_team_repository/v2.py | 100 +-- .../repository/repository/tests/model_v3.py | 105 ++++ .../repository/tests/test_repository_v1.py | 4 +- .../repository/tests/test_repository_v3.py | 259 ++++++++ src/layers/domain/repository/repository/v3.py | 305 ++++++++++ .../repository/tests/test_transaction.py | 6 +- src/layers/domain/repository/transaction.py | 25 +- src/test_helpers/dynamodb.py | 21 +- 43 files changed, 1331 insertions(+), 1595 deletions(-) delete mode 100644 src/layers/domain/repository/device_repository/tests/v1/conftest.py delete mode 100644 src/layers/domain/repository/device_repository/tests/v1/test_device_repository_indexes_v1.py delete mode 100644 src/layers/domain/repository/device_repository/tests/v1/test_device_repository_keys_v1.py delete mode 100644 src/layers/domain/repository/device_repository/tests/v1/test_device_repository_questionnaire_responses_v1.py delete mode 100644 src/layers/domain/repository/device_repository/tests/v1/test_device_repository_v1.py delete mode 100644 src/layers/domain/repository/device_repository/v1.py create mode 100644 src/layers/domain/repository/repository/tests/model_v3.py create mode 100644 src/layers/domain/repository/repository/tests/test_repository_v3.py create mode 100644 src/layers/domain/repository/repository/v3.py diff --git a/infrastructure/terraform/per_workspace/main.tf b/infrastructure/terraform/per_workspace/main.tf index 72e844c57..559b87547 100644 --- a/infrastructure/terraform/per_workspace/main.tf +++ b/infrastructure/terraform/per_workspace/main.tf @@ -43,23 +43,15 @@ module "table" { attributes = [ { name = "pk", type = "S" }, { name = "sk", type = "S" }, - { name = "pk_1", type = "S" }, - { name = "sk_1", type = "S" }, - { name = "pk_2", type = "S" }, - { name = "sk_2", type = "S" } + { name = "pk_read", type = "S" }, + { name = "sk_read", type = "S" }, ] global_secondary_indexes = [ { - name = "idx_gsi_1" - hash_key = "pk_1" - range_key = "sk_1" - projection_type = "ALL" - }, - { - name = "idx_gsi_2" - hash_key = "pk_2" - range_key = "sk_2" + name = "idx_gsi_read" + hash_key = "pk_read" + range_key = "sk_read" projection_type = "ALL" } ] diff --git a/src/api/createCpmProductForEpr/tests/test_index.py b/src/api/createCpmProductForEpr/tests/test_index.py index 456161090..cb5a1968c 100644 --- a/src/api/createCpmProductForEpr/tests/test_index.py +++ b/src/api/createCpmProductForEpr/tests/test_index.py @@ -66,7 +66,7 @@ def test_index(): table_name=TABLE_NAME, dynamodb_client=index.cache["DYNAMODB_CLIENT"] ) read_product = repo.read( - product_team_id=product_team.id, product_id=created_product["id"] + product_team_id=product_team.id, id=created_product["id"] ).state() assert created_product == read_product diff --git a/src/api/createDevice/tests/test_index.py b/src/api/createDevice/tests/test_index.py index 0859e1291..d8ad98140 100644 --- a/src/api/createDevice/tests/test_index.py +++ b/src/api/createDevice/tests/test_index.py @@ -83,15 +83,19 @@ def test_index() -> None: assert device.name == DEVICE_NAME assert device.ods_code == ODS_CODE assert device.created_on.date() == datetime.today().date() - assert device.updated_on is None - assert device.deleted_on is None + assert not device.updated_on + assert not device.deleted_on # Retrieve the created resource repo = DeviceRepository( table_name=TABLE_NAME, dynamodb_client=index.cache["DYNAMODB_CLIENT"] ) - created_device = repo.read(device.id) + created_device = repo.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert created_device == device diff --git a/src/api/createDeviceReferenceData/tests/test_index.py b/src/api/createDeviceReferenceData/tests/test_index.py index aa2459133..8dee73f05 100644 --- a/src/api/createDeviceReferenceData/tests/test_index.py +++ b/src/api/createDeviceReferenceData/tests/test_index.py @@ -96,7 +96,7 @@ def test_index() -> None: created_device_reference_data = repo.read( product_team_id=device_reference_data.product_team_id, product_id=device_reference_data.product_id, - device_reference_data_id=device_reference_data.id, + id=device_reference_data.id, ) assert created_device_reference_data == device_reference_data diff --git a/src/api/createDeviceReferenceDataMessageSet/tests/test_index.py b/src/api/createDeviceReferenceDataMessageSet/tests/test_index.py index 92aa0498b..6a72f845d 100644 --- a/src/api/createDeviceReferenceDataMessageSet/tests/test_index.py +++ b/src/api/createDeviceReferenceDataMessageSet/tests/test_index.py @@ -98,7 +98,7 @@ def test_index_without_questionnaire() -> None: created_device_reference_data = repo.read( product_team_id=device_reference_data.product_team_id, product_id=device_reference_data.product_id, - device_reference_data_id=device_reference_data.id, + id=device_reference_data.id, ) assert created_device_reference_data == device_reference_data @@ -157,6 +157,6 @@ def test_index_with_questionnaire() -> None: created_device_reference_data = repo.read( product_team_id=device_reference_data.product_team_id, product_id=device_reference_data.product_id, - device_reference_data_id=device_reference_data.id, + id=device_reference_data.id, ) assert created_device_reference_data == device_reference_data diff --git a/src/api/deleteCpmProduct/tests/test_index.py b/src/api/deleteCpmProduct/tests/test_index.py index a5319b299..521c738a3 100644 --- a/src/api/deleteCpmProduct/tests/test_index.py +++ b/src/api/deleteCpmProduct/tests/test_index.py @@ -4,14 +4,14 @@ from unittest import mock import pytest -from domain.core.cpm_product.v1 import CpmProduct from domain.core.cpm_system_id.v1 import ProductId from domain.core.enum import Status from domain.core.root.v3 import Root -from domain.repository.cpm_product_repository.v3 import CpmProductRepository +from domain.repository.cpm_product_repository.v3 import ( + CpmProductRepository, + InactiveCpmProductRepository, +) from domain.repository.errors import ItemNotFound -from domain.repository.keys.v3 import TableKey -from domain.repository.marshall import marshall, unmarshall from domain.repository.product_team_repository.v2 import ProductTeamRepository from test_helpers.dynamodb import mock_table @@ -27,26 +27,6 @@ VERSION = 1 -class MockCpmProductRepository(CpmProductRepository): - def read_inactive_product( - self, product_team_id: str, product_id: str - ) -> CpmProduct: - pk = TableKey.CPM_PRODUCT_STATUS.key(Status.INACTIVE, product_team_id) - sk = TableKey.CPM_PRODUCT.key(product_id) - args = { - "TableName": self.table_name, - "KeyConditionExpression": "pk = :pk AND sk = :sk", - "ExpressionAttributeValues": marshall(**{":pk": pk, ":sk": sk}), - } - result = self.client.query(**args) - items = [unmarshall(i) for i in result["Items"]] - if len(items) == 0: - raise ItemNotFound(product_team_id, product_id, item_type=CpmProduct) - (item,) = items - - return CpmProduct(**item) - - @contextmanager def mock_lambda(): org = Root.create_ods_organisation(ods_code=CPM_PRODUCT_TEAM_NO_ID["ods_code"]) @@ -100,13 +80,13 @@ def test_index(): table_name=TABLE_NAME, dynamodb_client=index.cache["DYNAMODB_CLIENT"] ) with pytest.raises(ItemNotFound): - repo.read(product_team_id=product_team.id, product_id=PRODUCT_ID) + repo.read(product_team_id=product_team.id, id=PRODUCT_ID) - repo = MockCpmProductRepository( + repo = InactiveCpmProductRepository( table_name=TABLE_NAME, dynamodb_client=index.cache["DYNAMODB_CLIENT"] ) - deleted_product = repo.read_inactive_product( - product_team_id=product_team.id, product_id=PRODUCT_ID + deleted_product = repo.read( + product_team_id=product_team.id, id=PRODUCT_ID ).dict() # Sense checks on the deleted resource diff --git a/src/api/readDevice/src/v1/steps.py b/src/api/readDevice/src/v1/steps.py index d0eaf4685..f3e0c0350 100644 --- a/src/api/readDevice/src/v1/steps.py +++ b/src/api/readDevice/src/v1/steps.py @@ -32,7 +32,7 @@ def read_product(data, cache) -> CpmProduct: table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) cpm_product = product_repo.read( - product_id=path_params.product_id, product_team_id=path_params.product_team_id + id=path_params.product_id, product_team_id=path_params.product_team_id ) return cpm_product @@ -42,7 +42,11 @@ def read_device(data, cache) -> Device: device_repo = DeviceRepository( table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) - return device_repo.read(path_params.device_id) + return device_repo.read( + product_team_id=path_params.product_team_id, + product_id=path_params.product_id, + id=path_params.device_id, + ) def device_to_dict(data, cache) -> tuple[str, dict]: diff --git a/src/api/readDevice/tests/test_index.py b/src/api/readDevice/tests/test_index.py index 4ed6bfeba..6a159df4c 100644 --- a/src/api/readDevice/tests/test_index.py +++ b/src/api/readDevice/tests/test_index.py @@ -154,7 +154,7 @@ def test_index_no_such_device(version): "errors": [ { "code": "RESOURCE_NOT_FOUND", - "message": "Could not find Device for key ('does not exist')", # device saved by pk, sk = device.id still + "message": f"Could not find Device for key ('{product_team.id}', 'P.XXX-YYY', 'does not exist')", } ], } diff --git a/src/api/readDeviceReferenceData/src/v1/steps.py b/src/api/readDeviceReferenceData/src/v1/steps.py index f074a4158..9e74ca460 100644 --- a/src/api/readDeviceReferenceData/src/v1/steps.py +++ b/src/api/readDeviceReferenceData/src/v1/steps.py @@ -34,20 +34,20 @@ def read_product(data, cache) -> CpmProduct: table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) cpm_product = product_repo.read( - product_id=path_params.product_id, product_team_id=path_params.product_team_id + id=path_params.product_id, product_team_id=path_params.product_team_id ) return cpm_product def read_device_reference_data(data, cache) -> DeviceReferenceData: path_params: DeviceReferenceDataPathParams = data[parse_path_params] - product_repo = DeviceReferenceDataRepository( + device_reference_data_repo = DeviceReferenceDataRepository( table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) - return product_repo.read( + return device_reference_data_repo.read( product_id=path_params.product_id, product_team_id=path_params.product_team_id, - device_reference_data_id=path_params.device_reference_data_id, + id=path_params.device_reference_data_id, ) diff --git a/src/api/searchCpmProduct/src/v1/steps.py b/src/api/searchCpmProduct/src/v1/steps.py index 6531da1c8..66ad41a13 100644 --- a/src/api/searchCpmProduct/src/v1/steps.py +++ b/src/api/searchCpmProduct/src/v1/steps.py @@ -25,9 +25,7 @@ def query_products(data, cache) -> list: cpm_product_repo = CpmProductRepository( table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) - results = cpm_product_repo.query_products_by_product_team( - product_team_id=product_team_id - ) + results = cpm_product_repo.search(product_team_id=product_team_id) return results diff --git a/src/api/tests/feature_tests/features/readDevice.failure.feature b/src/api/tests/feature_tests/features/readDevice.failure.feature index 85f7119f3..70b36d48a 100644 --- a/src/api/tests/feature_tests/features/readDevice.failure.feature +++ b/src/api/tests/feature_tests/features/readDevice.failure.feature @@ -72,10 +72,10 @@ Feature: Read Device - failure scenarios And I note the response field "$.id" as "device_id" When I make a "GET" request with "default" headers to "ProductTeam/${ note(product_team_id) }/Product/${ note(product_id) }/Device/not-a-device" Then I receive a status code "404" with body - | path | value | - | errors.0.code | RESOURCE_NOT_FOUND | - | errors.0.message | Could not find Device for key ('not-a-device') | + | path | value | + | errors.0.code | RESOURCE_NOT_FOUND | + | errors.0.message | Could not find Device for key ('${ note(product_team_id) }', '${ note(product_id) }', 'not-a-device') | And the response headers contain: | name | value | | Content-Type | application/json | - | Content-Length | 105 | + | Content-Length | 164 | diff --git a/src/layers/domain/api/common_steps/read_product.py b/src/layers/domain/api/common_steps/read_product.py index e4d20146c..4dde7a39f 100644 --- a/src/layers/domain/api/common_steps/read_product.py +++ b/src/layers/domain/api/common_steps/read_product.py @@ -30,7 +30,7 @@ def read_product(data, cache) -> CpmProduct: table_name=cache["DYNAMODB_TABLE"], dynamodb_client=cache["DYNAMODB_CLIENT"] ) cpm_product = product_repo.read( - product_id=path_params.product_id, product_team_id=path_params.product_team_id + id=path_params.product_id, product_team_id=path_params.product_team_id ) return cpm_product diff --git a/src/layers/domain/core/device/v3.py b/src/layers/domain/core/device/v3.py index 41e425945..c214a1a6c 100644 --- a/src/layers/domain/core/device/v3.py +++ b/src/layers/domain/core/device/v3.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import StrEnum, auto from functools import cached_property -from urllib.parse import urlencode +from urllib.parse import parse_qs, urlencode from uuid import UUID, uuid4 from attr import dataclass @@ -53,7 +53,7 @@ class DeviceCreatedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] @@ -69,7 +69,7 @@ class DeviceUpdatedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] @@ -85,9 +85,9 @@ class DeviceDeletedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] - deleted_tags: list["DeviceTag"] = None + deleted_tags: list[str] = None @dataclass(kw_only=True, slots=True) @@ -103,7 +103,7 @@ class DeviceKeyAddedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] @@ -112,13 +112,13 @@ class DeviceKeyDeletedEvent(Event): deleted_key: DeviceKey id: str keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] updated_on: str = None @dataclass(kw_only=True, slots=True) class DeviceTagAddedEvent(Event): - new_tag: "DeviceTag" + new_tag: str id: str name: str product_team_id: UUID @@ -129,13 +129,13 @@ class DeviceTagAddedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] @dataclass(kw_only=True, slots=True) class DeviceTagsAddedEvent(Event): - new_tags: list["DeviceTag"] + new_tags: list[str] id: str name: str product_team_id: UUID @@ -146,7 +146,7 @@ class DeviceTagsAddedEvent(Event): updated_on: str = None deleted_on: str = None keys: list[DeviceKey] - tags: list["DeviceTag"] + tags: list[str] questionnaire_responses: dict[str, dict[str, "QuestionnaireResponse"]] @@ -154,7 +154,7 @@ class DeviceTagsAddedEvent(Event): class DeviceTagsClearedEvent(Event): id: str keys: list[dict] - deleted_tags: list["DeviceTag"] + deleted_tags: list[str] updated_on: str = None @@ -192,12 +192,20 @@ class Config: def encode_tag(cls, values: dict): initialised_with_root = "__root__" in values and len(values) == 1 item_to_process = values["__root__"] if initialised_with_root else values - if initialised_with_root: + + # Case 1: query string is provided (__root__="foo=bar") + if initialised_with_root and isinstance(item_to_process, str): + _components = ((k, v) for k, (v,) in parse_qs(item_to_process).items()) + # Case 2: query components are provided (__root__=("foo", "bar")) + elif initialised_with_root: _components = ((k, v) for k, v in item_to_process) - else: # otherwise initialise directly with key value pairs - _components = sorted((k, str(v)) for k, v in item_to_process.items()) + # Case 3: query components are directly provided (("foo", "bar")) + else: + _components = ((k, str(v)) for k, v in item_to_process.items()) - case_insensitive_components = tuple((k, v.lower()) for k, v in _components) + case_insensitive_components = tuple( + sorted((k, v.lower()) for k, v in _components) + ) return {"__root__": case_insensitive_components} def dict(self, *args, **kwargs): @@ -259,7 +267,7 @@ def update(self, **kwargs) -> DeviceUpdatedEvent: @event def delete(self) -> DeviceDeletedEvent: deleted_on = now() - deleted_tags = {t.dict() for t in self.tags} + deleted_tags = {t.value for t in self.tags} device_data = self._update( data=dict( status=Status.INACTIVE, @@ -279,6 +287,7 @@ def add_key(self, key_type: str, key_value: str) -> DeviceKeyAddedEvent: ) self.keys.append(device_key) device_data = self.state() + device_data["tags"] = {t.value for t in self.tags} device_data.pop(UPDATED_ON) # The @event decorator will handle updated_on return DeviceKeyAddedEvent(new_key=device_key, **device_data) @@ -294,7 +303,7 @@ def delete_key(self, key_type: str, key_value: str) -> DeviceKeyDeletedEvent: deleted_key=device_key, id=self.id, keys=[k.dict() for k in self.keys], - tags=[t.dict() for t in self.tags], + tags=[t.value for t in self.tags], ) @event @@ -306,8 +315,9 @@ def add_tag(self, **kwargs) -> DeviceTagAddedEvent: ) self.tags.add(device_tag) device_data = self.state() + device_data["tags"] = {t.value for t in self.tags} device_data.pop(UPDATED_ON) # The @event decorator will handle updated_on - return DeviceTagAddedEvent(new_tag=device_tag, **device_data) + return DeviceTagAddedEvent(new_tag=device_tag.value, **device_data) @event def add_tags(self, tags: list[dict]) -> DeviceTagsAddedEvent: @@ -320,8 +330,11 @@ def add_tags(self, tags: list[dict]) -> DeviceTagsAddedEvent: ) self.tags = self.tags.union(new_tags) device_data = self.state() + device_data["tags"] = {t.value for t in self.tags} device_data.pop(UPDATED_ON) # The @event decorator will handle updated_on - return DeviceTagsAddedEvent(new_tags=new_tags, **device_data) + return DeviceTagsAddedEvent( + new_tags={tag.value for tag in new_tags}, **device_data + ) @event def clear_tags(self): @@ -329,7 +342,9 @@ def clear_tags(self): self.tags = set() device_data = self.state() return DeviceTagsClearedEvent( - id=device_data["id"], keys=device_data["keys"], deleted_tags=deleted_tags + id=device_data["id"], + keys=device_data["keys"], + deleted_tags={tag.value for tag in deleted_tags}, ) @event @@ -351,7 +366,7 @@ def add_questionnaire_response( return QuestionnaireResponseUpdatedEvent( entity_id=self.id, entity_keys=[k.dict() for k in self.keys], - entity_tags=[t.dict() for t in self.tags], + entity_tags=[t.value for t in self.tags], questionnaire_responses={ qid: {_created_on: qr.dict() for _created_on, qr in _qr.items()} for qid, _qr in self.questionnaire_responses.items() @@ -383,7 +398,7 @@ def update_questionnaire_response( return QuestionnaireResponseUpdatedEvent( entity_id=self.id, entity_keys=[k.dict() for k in self.keys], - entity_tags=[t.dict() for t in self.tags], + entity_tags=[t.value for t in self.tags], questionnaire_responses={ qid: {_created_on: qr.dict() for _created_on, qr in _qr.items()} for qid, _qr in self.questionnaire_responses.items() diff --git a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_keys_v3.py b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_keys_v3.py index d5fe6759e..c6b340219 100644 --- a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_keys_v3.py +++ b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_keys_v3.py @@ -19,7 +19,7 @@ def test__product_repository__add_key( repository.write(product) product_by_id = repository.read( - product_team_id=product.product_team_id, product_id=product.id + product_team_id=product.product_team_id, id=product.id ) assert product_by_id.keys == [party_key] @@ -32,21 +32,16 @@ def test__product_repository__add_key_then_delete( product.add_key(**party_key.dict()) repository.write(product) - product_by_id = repository.read( - product_team_id=product.product_team_id, product_id=product.id - ) - assert product_by_id.keys == [party_key] - - # Read and delete product product_from_db = repository.read( - product_team_id=product.product_team_id, product_id=product.id + product_team_id=product.product_team_id, id=product.id ) + assert product_from_db.keys == [party_key] product_from_db.delete() repository.write(product_from_db) # No longer retrievable with pytest.raises(ItemNotFound): - repository.read(product_team_id=product.product_team_id, product_id=product.id) + repository.read(product_team_id=product.product_team_id, id=product.id) @pytest.mark.integration diff --git a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3.py b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3.py index 8049f9efc..c85cd299d 100644 --- a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3.py +++ b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3.py @@ -24,9 +24,7 @@ def _create_product_team(name: str = "FOOBAR Product Team", ods_code: str = "F5H @pytest.mark.integration def test__cpm_product_repository(product: CpmProduct, repository: CpmProductRepository): repository.write(product) - result = repository.read( - product_team_id=product.product_team_id, product_id=product.id - ) + result = repository.read(product_team_id=product.product_team_id, id=product.id) assert result == product @@ -46,16 +44,14 @@ def test__cpm_product_repository__product_does_not_exist( product_team_id = consistent_uuid(1) product_id = "P.XXX-YYY" with pytest.raises(ItemNotFound): - repository.read(product_team_id=product_team_id, product_id=product_id) + repository.read(product_team_id=product_team_id, id=product_id) def test__cpm_product_repository_local( product: CpmProduct, repository: CpmProductRepository ): repository.write(product) - result = repository.read( - product_team_id=product.product_team_id, product_id=product.id - ) + result = repository.read(product_team_id=product.product_team_id, id=product.id) assert result == product @@ -65,7 +61,7 @@ def test__cpm_product_repository__product_does_not_exist_local( product_team_id = consistent_uuid(1) product_id = "P.XXX-YYY" with pytest.raises(ItemNotFound): - repository.read(product_team_id=product_team_id, product_id=product_id) + repository.read(product_team_id=product_team_id, id=product_id) @pytest.mark.integration @@ -86,7 +82,7 @@ def test__query_products_by_product_team(): name="cpm-product-name-2", product_id=product_id.id ) repo.write(cpm_product_2) - result = repo.query_products_by_product_team(product_team_id=product_team.id) + result = repo.search(product_team_id=product_team.id) assert len(result) == 2 assert isinstance(result[0], CpmProduct) assert isinstance(result[1], CpmProduct) @@ -118,7 +114,7 @@ def test__query_products_by_product_team_a(): name="cpm-product-name-3", product_id=product_id.id ) repo.write(cpm_product_3) - result = repo.query_products_by_product_team(product_team_id=product_team_a.id) + result = repo.search(product_team_id=product_team_a.id) assert len(result) == 2 assert isinstance(result[0], CpmProduct) assert isinstance(result[1], CpmProduct) @@ -158,7 +154,7 @@ def test__query_products_by_product_team_with_sk_prefix(): client = dynamodb_client() client.put_item(**args) - result = repo.query_products_by_product_team(product_team_id=product_team.id) + result = repo.search(product_team_id=product_team.id) assert len(result) == 2 assert isinstance(result[0], CpmProduct) assert isinstance(result[1], CpmProduct) @@ -181,6 +177,6 @@ def test__cpm_product_repository_search(): ) repo.write(cpm_product) - result = repo.query_products_by_product_team(product_team_id=product_team.id) + result = repo.search(product_team_id=product_team.id) assert result == [cpm_product] diff --git a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3_delete.py b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3_delete.py index d5a7037ad..cb27b5d21 100644 --- a/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3_delete.py +++ b/src/layers/domain/repository/cpm_product_repository/tests/v3/test_cpm_product_repository_v3_delete.py @@ -11,14 +11,14 @@ def test__cpm_product_repository_delete( ): repository.write(product) # Create product in DB product_from_db = repository.read( - product_team_id=product.product_team_id, product_id=product.id + product_team_id=product.product_team_id, id=product.id ) product_from_db.delete() repository.write(product_from_db) # No longer retrievable with pytest.raises(ItemNotFound): - repository.read(product_team_id=product.product_team_id, product_id=product.id) + repository.read(product_team_id=product.product_team_id, id=product.id) @pytest.mark.integration @@ -36,14 +36,14 @@ def test__cpm_product_repository_delete_local( ): repository.write(product) # Create product in DB product_from_db = repository.read( - product_team_id=product.product_team_id, product_id=product.id + product_team_id=product.product_team_id, id=product.id ) product_from_db.delete() repository.write(product_from_db) # No longer retrievable with pytest.raises(ItemNotFound): - repository.read(product_team_id=product.product_team_id, product_id=product.id) + repository.read(product_team_id=product.product_team_id, id=product.id) def test__cpm_product_repository_cannot_delete_if_does_not_exist_local( diff --git a/src/layers/domain/repository/cpm_product_repository/v3.py b/src/layers/domain/repository/cpm_product_repository/v3.py index 410bb2d82..8bb7882d3 100644 --- a/src/layers/domain/repository/cpm_product_repository/v3.py +++ b/src/layers/domain/repository/cpm_product_repository/v3.py @@ -6,233 +6,95 @@ CpmProductKeyAddedEvent, ) from domain.core.product_key.v1 import ProductKey -from domain.repository.device_repository.v2 import TooManyResults -from domain.repository.errors import ItemNotFound from domain.repository.keys.v3 import TableKey -from domain.repository.marshall import marshall, marshall_value, unmarshall -from domain.repository.repository.v2 import Repository -from domain.repository.transaction import ( - ConditionExpression, - TransactionStatement, - TransactItem, - update_transactions, -) - - -def create_product_index( - table_name: str, - product_data: dict, - pk_key_parts: tuple[str], - sk_key_parts: tuple[str], - pk_table_key: TableKey = TableKey.PRODUCT_TEAM, - sk_table_key: TableKey = TableKey.CPM_PRODUCT, - root=False, -) -> TransactItem: - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) - return TransactItem( - Put=TransactionStatement( - TableName=table_name, - Item=marshall(pk=pk, sk=sk, root=root, **product_data), - ConditionExpression=ConditionExpression.MUST_NOT_EXIST, - ) - ) - - -def _product_root_primary_key(product_team_id: str, product_id: str) -> dict: - """ - Generates one fully marshalled (i.e. {"pk": {"S": "123"}} DynamoDB - primary key (i.e. pk + sk) for the provided Device, indexed by the Device ID - """ - pk = TableKey.PRODUCT_TEAM.key(product_team_id) - sk = TableKey.CPM_PRODUCT.key(product_id) - return marshall(pk=pk, sk=sk) - - -def _product_non_root_primary_keys( - product_team_id: str, device_keys: list[ProductKey] -) -> list[dict]: - """ - Generates all the fully marshalled (i.e. {"pk": {"S": "123"}} DynamoDB - primary keys (i.e. pk + sk) for the provided Device. This is one primary key - for every value of Device.keys and Device.tags - """ - pk = TableKey.PRODUCT_TEAM.key(product_team_id) - device_key_primary_keys = [ - marshall(pk=pk, sk=sk) - for sk in ( - TableKey.CPM_PRODUCT.key(k.key_type, k.key_value) for k in device_keys - ) - ] - return device_key_primary_keys - - -def update_product_indexes( - table_name: str, - data: dict, - product_team_id: str, - product_id: str, - keys: list[ProductKey], -): - # Update the root product - root_primary_key = _product_root_primary_key( - product_team_id=product_team_id, product_id=product_id - ) - update_root_product_transactions = update_transactions( - table_name=table_name, primary_keys=[root_primary_key], data=data - ) - # Update non-root products - non_root_primary_keys = _product_non_root_primary_keys( - product_team_id=product_team_id, device_keys=keys - ) - update_non_root_devices_transactions = update_transactions( - table_name=table_name, primary_keys=non_root_primary_keys, data=data - ) - return update_root_product_transactions + update_non_root_devices_transactions - - -def delete_product_index( - table_name: str, - pk_key_parts: tuple[str], - sk_key_parts: tuple[str], - pk_table_key: TableKey = TableKey.PRODUCT_TEAM, - sk_table_key: TableKey = TableKey.CPM_PRODUCT, -): - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) - return TransactItem( - Delete=TransactionStatement( - TableName=table_name, - Key=marshall(pk=pk, sk=sk), - ConditionExpression=ConditionExpression.MUST_EXIST, - ) - ) +from domain.repository.repository.v3 import Repository class CpmProductRepository(Repository[CpmProduct]): def __init__(self, table_name: str, dynamodb_client): super().__init__( - table_name=table_name, model=CpmProduct, dynamodb_client=dynamodb_client + table_name=table_name, + model=CpmProduct, + dynamodb_client=dynamodb_client, + parent_table_keys=(TableKey.PRODUCT_TEAM,), + table_key=TableKey.CPM_PRODUCT, ) + def read(self, product_team_id: str, id: str): + return super()._read(parent_ids=(product_team_id,), id=id) + + def search(self, product_team_id: str): + return super()._query(parent_ids=(product_team_id,)) + def handle_CpmProductCreatedEvent(self, event: CpmProductCreatedEvent): - return create_product_index( - table_name=self.table_name, - product_data=asdict(event), - pk_key_parts=(event.product_team_id,), - sk_key_parts=(event.id,), + return self.create_index( + id=event.id, + parent_key_parts=(event.product_team_id,), + data=asdict(event), root=True, ) def handle_CpmProductKeyAddedEvent(self, event: CpmProductKeyAddedEvent): # Create a copy of the Product indexed against the new key - create_transaction = create_product_index( - table_name=self.table_name, - product_data=asdict(event), - pk_key_parts=event.new_key.parts, - sk_key_parts=event.new_key.parts, - pk_table_key=TableKey.CPM_PRODUCT_KEY, - sk_table_key=TableKey.CPM_PRODUCT_KEY, + create_transaction = self.create_index( + id=event.new_key.key_value, + parent_key_parts=(event.product_team_id,), + data=asdict(event), + root=False, ) # Update the value of "keys" on all other copies of this Device product_keys = {ProductKey(**key) for key in event.keys} product_keys_before_update = product_keys - {event.new_key} - update_transactions = update_product_indexes( - table_name=self.table_name, - product_id=event.id, - product_team_id=event.product_team_id, + update_transactions = self.update_indexes( + id=event.id, keys=product_keys_before_update, data={"keys": event.keys, "updated_on": event.updated_on}, ) return [create_transaction] + update_transactions def handle_CpmProductDeletedEvent(self, event: CpmProductDeletedEvent): - inactive_root_copy_transaction = create_product_index( - table_name=self.table_name, - product_data=asdict(event), - pk_table_key=TableKey.CPM_PRODUCT_STATUS, - pk_key_parts=( - event.status, - event.product_team_id, - ), - sk_key_parts=(event.id,), + inactive_root_copy_transaction = self.create_index( + id=event.id, + parent_key_parts=(event.product_team_id,), + data=asdict(event), root=True, + table_key=TableKey.CPM_PRODUCT_STATUS, ) - keys = {ProductKey(**k) for k in event.keys} + keys = {ProductKey(**k).key_value for k in event.keys} inactive_key_indexes_copy_transactions = [ - create_product_index( - table_name=self.table_name, - pk_table_key=TableKey.CPM_PRODUCT_STATUS, - pk_key_parts=( - event.status, - event.product_team_id, - ), - sk_key_parts=key.parts, - product_data=asdict(event), + self.create_index( + id=key, + parent_key_parts=(event.product_team_id,), + data=asdict(event), root=False, + table_key=TableKey.CPM_PRODUCT_STATUS, ) for key in keys ] - original_root_delete_transaction = delete_product_index( - table_name=self.table_name, - pk_key_parts=(event.product_team_id,), - sk_key_parts=(event.id,), - ) - - original_key_indexes_delete_transactions = [ - delete_product_index( - table_name=self.table_name, - pk_key_parts=key.parts, - sk_key_parts=key.parts, - pk_table_key=TableKey.CPM_PRODUCT_KEY, - sk_table_key=TableKey.CPM_PRODUCT_KEY, - ) - for key in keys + original_indexes_delete_transactions = [ + self.delete_index(key) for key in (*keys, event.id) ] - return ( [inactive_root_copy_transaction] + inactive_key_indexes_copy_transactions - + [original_root_delete_transaction] - + original_key_indexes_delete_transactions + + original_indexes_delete_transactions ) - def read(self, product_team_id: str, product_id: str) -> CpmProduct: - pk = TableKey.PRODUCT_TEAM.key(product_team_id) - sk = TableKey.CPM_PRODUCT.key(product_id) - args = { - "TableName": self.table_name, - "KeyConditionExpression": "pk = :pk AND sk = :sk", - "ExpressionAttributeValues": marshall(**{":pk": pk, ":sk": sk}), - } - result = self.client.query(**args) - items = [unmarshall(i) for i in result["Items"]] - if len(items) == 0: - raise ItemNotFound(product_team_id, product_id, item_type=CpmProduct) - (item,) = items - return CpmProduct(**item) +class InactiveCpmProductRepository(Repository[CpmProduct]): + """Read-only repository""" - def query_products_by_product_team(self, product_team_id) -> list[CpmProduct]: - product_team_id = TableKey.PRODUCT_TEAM.key(product_team_id) - args = { - "TableName": self.table_name, - "KeyConditionExpression": "pk = :pk AND begins_with(sk, :sk_prefix)", - "ExpressionAttributeValues": { - ":pk": marshall_value(product_team_id), - ":sk_prefix": marshall_value(f"{TableKey.CPM_PRODUCT}#"), - }, - } - response = self.client.query(**args) - if "LastEvaluatedKey" in response: - raise TooManyResults(f"Too many results for query '{args}'") - - # Convert to Products - if len(response["Items"]) > 0: - products = map(unmarshall, response["Items"]) - return [CpmProduct(**p) for p in products] + def __init__(self, table_name, dynamodb_client): + super().__init__( + table_name=table_name, + model=CpmProduct, + dynamodb_client=dynamodb_client, + parent_table_keys=(TableKey.PRODUCT_TEAM,), + table_key=TableKey.CPM_PRODUCT_STATUS, + ) - return [] + def read(self, product_team_id: str, id: str): + return self._read(parent_ids=(product_team_id,), id=id) diff --git a/src/layers/domain/repository/device_reference_data_repository/tests/test_device_reference_data_repository_v1.py b/src/layers/domain/repository/device_reference_data_repository/tests/test_device_reference_data_repository_v1.py index 0290d1721..495923bf6 100644 --- a/src/layers/domain/repository/device_reference_data_repository/tests/test_device_reference_data_repository_v1.py +++ b/src/layers/domain/repository/device_reference_data_repository/tests/test_device_reference_data_repository_v1.py @@ -17,7 +17,7 @@ def test__cpm_device_reference_data_repository( result = repository.read( product_team_id=device_reference_data.product_team_id, product_id=device_reference_data.product_id, - device_reference_data_id=device_reference_data.id, + id=device_reference_data.id, ) assert result == device_reference_data @@ -43,7 +43,7 @@ def test__cpm_device_reference_data_repository__device_reference_data_does_not_e repository.read( product_team_id=product_team_id, product_id=product_id, - device_reference_data_id=device_reference_data_id, + id=device_reference_data_id, ) @@ -55,7 +55,7 @@ def test__cpm_product_repository_local( result = repository.read( product_team_id=device_reference_data.product_team_id, product_id=device_reference_data.product_id, - device_reference_data_id=device_reference_data.id, + id=device_reference_data.id, ) assert result == device_reference_data @@ -70,7 +70,7 @@ def test__cpm_device_reference_data_repository__device_reference_data_does_not_e repository.read( product_team_id=product_team_id, product_id=product_id, - device_reference_data_id=device_reference_data_id, + id=device_reference_data_id, ) diff --git a/src/layers/domain/repository/device_reference_data_repository/v1.py b/src/layers/domain/repository/device_reference_data_repository/v1.py index 2002b2ed6..9d9d1d0dd 100644 --- a/src/layers/domain/repository/device_reference_data_repository/v1.py +++ b/src/layers/domain/repository/device_reference_data_repository/v1.py @@ -7,11 +7,9 @@ QuestionnaireResponseUpdatedEvent, ) from domain.repository.device_repository.v2 import create_device_index -from domain.repository.errors import ItemNotFound from domain.repository.keys.v3 import TableKey -from domain.repository.marshall import marshall, unmarshall -from domain.repository.repository.v2 import Repository -from domain.repository.transaction import TransactItem, update_transactions +from domain.repository.repository.v3 import Repository +from domain.repository.transaction import TransactItem class QueryType(StrEnum): @@ -36,76 +34,29 @@ def __init__(self, table_name, dynamodb_client): table_name=table_name, model=DeviceReferenceData, dynamodb_client=dynamodb_client, + table_key=TableKey.DEVICE_REFERENCE_DATA, + parent_table_keys=(TableKey.PRODUCT_TEAM, TableKey.CPM_PRODUCT), ) + def read(self, product_team_id: str, product_id: str, id: str): + return super()._read(parent_ids=(product_team_id, product_id), id=id) + + def search(self, product_team_id: str, product_id: str): + return super()._query(parent_ids=(product_team_id, product_id)) + def handle_DeviceReferenceDataCreatedEvent( self, event: DeviceReferenceDataCreatedEvent ) -> TransactItem: - data = asdict(event) - data["pk_1"] = TableKey.PRODUCT_TEAM_KEY.key( - event.product_team_id, TableKey.PRODUCT_TEAM, event.product_id - ) - data["sk_1"] = TableKey.DEVICE_REFERENCE_DATA.key(event.id) - return create_device_reference_data( - table_name=self.table_name, id=event.id, data=data, root=True + return self.create_index( + id=event.id, + parent_key_parts=(event.product_team_id, event.product_id), + data=asdict(event), + root=True, ) def handle_QuestionnaireResponseUpdatedEvent( self, event: QuestionnaireResponseUpdatedEvent ) -> TransactItem: - pk = TableKey.DEVICE_REFERENCE_DATA.key(event.id) data = asdict(event) data.pop("id") - return update_transactions( - table_name=self.table_name, primary_keys=[marshall(pk=pk, sk=pk)], data=data - ) - - def _query( - self, - product_team_id: str, - product_id: str, - device_reference_data_id: str, - sk_query_type: QueryType, - ) -> list[dict]: - pk_1 = TableKey.PRODUCT_TEAM_KEY.key( - product_team_id, TableKey.PRODUCT_TEAM, product_id - ) - sk_1 = TableKey.DEVICE_REFERENCE_DATA.key(device_reference_data_id) - sk_condition = sk_query_type.format("sk_1", ":sk_1") - args = { - "TableName": self.table_name, - "IndexName": "idx_gsi_1", - "KeyConditionExpression": f"pk_1 = :pk_1 AND {sk_condition}", - "ExpressionAttributeValues": marshall(**{":pk_1": pk_1, ":sk_1": sk_1}), - } - result = self.client.query(**args) - return result["Items"] - - def read( - self, product_team_id: str, product_id: str, device_reference_data_id: str - ) -> DeviceReferenceData: - items = self._query( - product_team_id=product_team_id, - product_id=product_id, - device_reference_data_id=device_reference_data_id, - sk_query_type=QueryType.EQUALS, - ) - try: - (item,) = items - except ValueError: - raise ItemNotFound( - product_team_id, - product_id, - device_reference_data_id, - item_type=DeviceReferenceData, - ) - return DeviceReferenceData(**unmarshall(item)) - - def search(self, product_team_id: str, product_id: str): - items = self._query( - product_team_id=product_team_id, - product_id=product_id, - device_reference_data_id="", - sk_query_type=QueryType.BEGINS_WITH, - ) - return [DeviceReferenceData(**unmarshall(item)) for item in items] + return self.update_indexes(id=event.id, keys=[], data=data) diff --git a/src/layers/domain/repository/device_repository/__init__.py b/src/layers/domain/repository/device_repository/__init__.py index e1ddb07c7..8f55f201f 100644 --- a/src/layers/domain/repository/device_repository/__init__.py +++ b/src/layers/domain/repository/device_repository/__init__.py @@ -1 +1 @@ -from .v1 import * # noqa: F403, F401 +from .v2 import * # noqa: F403, F401 diff --git a/src/layers/domain/repository/device_repository/tests/v1/conftest.py b/src/layers/domain/repository/device_repository/tests/v1/conftest.py deleted file mode 100644 index fe0168c3a..000000000 --- a/src/layers/domain/repository/device_repository/tests/v1/conftest.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Generator - -import pytest -from domain.repository.device_repository import DeviceRepository -from domain.repository.device_repository.tests.utils import repository_fixture - - -@pytest.fixture -def repository(request) -> Generator[DeviceRepository, None, None]: - yield from repository_fixture( - is_integration_test=request.node.get_closest_marker("integration"), - repository_class=DeviceRepository, - ) diff --git a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_indexes_v1.py b/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_indexes_v1.py deleted file mode 100644 index 4854a190a..000000000 --- a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_indexes_v1.py +++ /dev/null @@ -1,142 +0,0 @@ -import pytest -from domain.core.device import Device, DeviceType -from domain.core.product_team import ProductTeam -from domain.core.questionnaire import Questionnaire -from domain.core.root import Root -from domain.repository.device_repository import DeviceRepository - - -@pytest.fixture -def product_team(): - org = Root.create_ods_organisation(ods_code="AB123") - return org.create_product_team( - id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" - ) - - -@pytest.fixture -def shoe_questionnaire() -> Questionnaire: - _questionnaire = Questionnaire(name="shoe", version=1) - _questionnaire.add_question( - name="foot", answer_types=(str,), mandatory=True, choices={"L", "R"} - ) - _questionnaire.add_question(name="shoe-size", answer_types=(int,), mandatory=True) - return _questionnaire - - -@pytest.fixture -def device_right_shoe_size_123( - product_team: ProductTeam, shoe_questionnaire: Questionnaire -) -> Device: - shoe_response = shoe_questionnaire.respond( - responses=[{"foot": ["R"]}, {"shoe-size": [123]}], - ) - device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) - device.add_questionnaire_response(questionnaire_response=shoe_response) - device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="foot") - device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="shoe-size") - return device - - -@pytest.fixture -def device_left_shoe_size_123( - product_team: ProductTeam, shoe_questionnaire: Questionnaire -) -> Device: - shoe_response = shoe_questionnaire.respond( - responses=[{"foot": ["L"]}, {"shoe-size": [123]}], - ) - device = product_team.create_device(name="Device-2", type=DeviceType.PRODUCT) - device.add_questionnaire_response(questionnaire_response=shoe_response) - device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="foot") - device.add_index(questionnaire_id=shoe_questionnaire.id, question_name="shoe-size") - return device - - -@pytest.mark.integration -def test__device_repository__query_by_index__find_right_shoes( - device_right_shoe_size_123: Device, - device_left_shoe_size_123: Device, - repository: DeviceRepository, -): - repository.write(device_right_shoe_size_123) - repository.write(device_left_shoe_size_123) - - (device,) = repository.read_by_index( - questionnaire_id="shoe/1", question_name="foot", value="R" - ) - assert device.id == device_right_shoe_size_123.id - assert device.indexes == device_right_shoe_size_123.indexes - - -@pytest.mark.integration -def test__device_repository__query_by_index__find_left_shoes( - device_right_shoe_size_123: Device, - device_left_shoe_size_123: Device, - repository: DeviceRepository, -): - repository.write(device_right_shoe_size_123) - repository.write(device_left_shoe_size_123) - - (device,) = repository.read_by_index( - questionnaire_id="shoe/1", question_name="foot", value="L" - ) - assert device.id == device_left_shoe_size_123.id - assert device.indexes == device_left_shoe_size_123.indexes - - -@pytest.mark.integration -def test__device_repository__query_by_index__find_shoes_by_size_int( - device_right_shoe_size_123: Device, - device_left_shoe_size_123: Device, - repository: DeviceRepository, -): - repository.write(device_right_shoe_size_123) - repository.write(device_left_shoe_size_123) - - (device_1, device_2) = repository.read_by_index( - questionnaire_id="shoe/1", question_name="shoe-size", value=123 - ) - assert {device_1.id, device_2.id} == { - device_left_shoe_size_123.id, - device_right_shoe_size_123.id, - } - assert {*device_2.indexes, *device_1.indexes} == { - *device_left_shoe_size_123.indexes, - *device_right_shoe_size_123.indexes, - } - - -@pytest.mark.integration -def test__device_repository__query_by_index__find_shoes_by_size_str( - device_right_shoe_size_123: Device, - device_left_shoe_size_123: Device, - repository: DeviceRepository, -): - repository.write(device_right_shoe_size_123) - repository.write(device_left_shoe_size_123) - - (device_1, device_2) = repository.read_by_index( - questionnaire_id="shoe/1", question_name="shoe-size", value="123" - ) - assert {device_1.id, device_2.id} == { - device_left_shoe_size_123.id, - device_right_shoe_size_123.id, - } - assert {*device_2.indexes, *device_1.indexes} == { - *device_left_shoe_size_123.indexes, - *device_right_shoe_size_123.indexes, - } - - -@pytest.mark.integration -def test__device_repository__query_by_index__empty_result( - device_right_shoe_size_123: Device, - device_left_shoe_size_123: Device, - repository: DeviceRepository, -): - repository.write(device_right_shoe_size_123) - repository.write(device_left_shoe_size_123) - result = repository.read_by_index( - questionnaire_id="shoe/1", question_name="shoe-size", value="345" - ) - assert result == [] diff --git a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_keys_v1.py b/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_keys_v1.py deleted file mode 100644 index 304da130a..000000000 --- a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_keys_v1.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest -from domain.core.device import Device, DeviceKeyType, DeviceType -from domain.core.device_key import DeviceKey -from domain.core.root import Root -from domain.repository.device_repository import DeviceRepository -from domain.repository.marshall import unmarshall - - -@pytest.fixture -def device_with_asid() -> Device: - org = Root.create_ods_organisation(ods_code="AB123") - product_team = org.create_product_team( - id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" - ) - device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) - device.add_key(key="P.WWW-XXX", type=DeviceKeyType.PRODUCT_ID) - device.add_key(key="ABC:1234567890", type=DeviceKeyType.ACCREDITED_SYSTEM_ID) - return device - - -@pytest.fixture -def device_with_mhs_id() -> Device: - org = Root.create_ods_organisation(ods_code="AB123") - team = org.create_product_team( - id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" - ) - device = team.create_device(name="Device-2", type=DeviceType.ENDPOINT) - device.add_key(key="P.WWW-YYY", type=DeviceKeyType.PRODUCT_ID) - device.add_key( - key="ABC:DEF-444:4444444444", type=DeviceKeyType.MESSAGE_HANDLING_SYSTEM_ID - ) - return device - - -@pytest.mark.integration -def test__device_repository__query_by_key_type( - device_with_asid: Device, device_with_mhs_id: Device, repository: DeviceRepository -): - repository.write(device_with_asid) - repository.write(device_with_mhs_id) - - result = repository.query_by_key_type( - key_type=DeviceKeyType.MESSAGE_HANDLING_SYSTEM_ID - ) - (_device,) = map(unmarshall, result["Items"]) - assert _device["id"] == str(device_with_mhs_id.id) - - result = repository.query_by_key_type(key_type=DeviceKeyType.ACCREDITED_SYSTEM_ID) - (_device,) = map(unmarshall, result["Items"]) - assert _device["id"] == str(device_with_asid.id) - - -@pytest.mark.integration -def test__device_repository__query_by_type( - device_with_asid: Device, device_with_mhs_id: Device, repository: DeviceRepository -): - repository.write(device_with_asid) - repository.write(device_with_mhs_id) - - result = repository.query_by_device_type(type=DeviceType.ENDPOINT) - (_device,) = map(unmarshall, result["Items"]) - assert _device["id"] == str(device_with_mhs_id.id) - - result = repository.query_by_device_type(type=DeviceType.PRODUCT) - (_device,) = map(unmarshall, result["Items"]) - assert _device["id"] == str(device_with_asid.id) - - -@pytest.mark.integration -def test__device_repository__delete_key( - device_with_asid: Device, repository: DeviceRepository -): - # Persist model before deleting from model - repository.write(device_with_asid) - - # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(id=device_with_asid.id) - intermediate_device.delete_key(key="ABC:1234567890") - repository.write(intermediate_device) - - assert repository.read(device_with_asid.id).keys == { - "P.WWW-XXX": DeviceKey(type=DeviceKeyType.PRODUCT_ID, key="P.WWW-XXX") - } diff --git a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_questionnaire_responses_v1.py b/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_questionnaire_responses_v1.py deleted file mode 100644 index b1643a38f..000000000 --- a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_questionnaire_responses_v1.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -from domain.core.device import Device, DeviceType -from domain.core.questionnaire import Questionnaire -from domain.core.root import Root -from domain.repository.device_repository import DeviceRepository -from domain.repository.device_repository.tests.utils import devices_exactly_equal - - -@pytest.fixture -def device() -> Device: - shoe_questionnaire = Questionnaire(name="shoe", version=1) - shoe_questionnaire.add_question( - name="foot", answer_types=(str,), mandatory=True, choices={"L", "R"} - ) - shoe_questionnaire.add_question( - name="shoe-size", answer_types=(int,), mandatory=True - ) - shoe_response_1 = shoe_questionnaire.respond( - responses=[{"foot": ["L"]}, {"shoe-size": [123]}], - ) - shoe_response_2 = shoe_questionnaire.respond( - responses=[{"foot": ["L"]}, {"shoe-size": [345]}], - ) - - health_questionnaire = Questionnaire(name="health", version=1) - health_questionnaire.add_question( - name="weight", answer_types=(int,), mandatory=True - ) - health_questionnaire.add_question( - name="height", answer_types=(int,), mandatory=True - ) - health_response = health_questionnaire.respond( - responses=[{"weight": [123]}, {"height": [345]}] - ) - - org = Root.create_ods_organisation(ods_code="AB123") - product_team = org.create_product_team( - id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" - ) - device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) - device.add_questionnaire_response(questionnaire_response=shoe_response_1) - device.add_questionnaire_response(questionnaire_response=shoe_response_2) - device.add_questionnaire_response(questionnaire_response=health_response) - return device - - -@pytest.mark.integration -def test__device_repository__with_questionnaires( - device: Device, repository: DeviceRepository -): - repository.write(device) - assert repository.read(id=device.id) == device - - -@pytest.mark.integration -def test__device_repository__delete_questionnaire_response_that_has_been_persisted( - device: Device, repository: DeviceRepository -): - def _sum_array_lengths(obj: dict[str, list]): - return sum(map(len, obj.values())) - - # Persist model before deleting from model - repository.write(device) - - # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(id=device.id) - _responses_before = _sum_array_lengths(device.questionnaire_responses) - responses_before = _sum_array_lengths(intermediate_device.questionnaire_responses) - assert intermediate_device == device - assert _responses_before == responses_before - - # Delete from the model pulled down from the db, and then persist - intermediate_device.delete_questionnaire_response( - questionnaire_id="health/1", questionnaire_response_index=0 - ) - responses_after = _sum_array_lengths(intermediate_device.questionnaire_responses) - repository.write(intermediate_device) - assert responses_after == responses_before - 1 - - # Verify that we get the correct result back - device_from_db = repository.read(id=device.id) - assert devices_exactly_equal(device_from_db, intermediate_device) - assert not devices_exactly_equal(device_from_db, device) - - -@pytest.mark.integration -def test__device_repository__modify_questionnaire_response_that_has_been_persisted( - device: Device, repository: DeviceRepository -): - # Persist model before updating model - repository.write(device) - intermediate_device = repository.read(id=device.id) - - # Update the model - questionnaire_responses = intermediate_device.questionnaire_responses - shoe_questionnaire = questionnaire_responses["shoe/1"][0].questionnaire - questionnaire_response = shoe_questionnaire.respond( - responses=[{"foot": ["R"]}, {"shoe-size": [789]}] - ) - intermediate_device.update_questionnaire_response( - questionnaire_response_index=0, questionnaire_response=questionnaire_response - ) - - # Persist and verify consistency - repository.write(intermediate_device) - device_from_db = repository.read(id=intermediate_device.id) - assert devices_exactly_equal(device_from_db, intermediate_device) - assert not devices_exactly_equal(device_from_db, device) - assert device_from_db.questionnaire_responses["shoe/1"][0].responses == [ - {"foot": ["R"]}, - {"shoe-size": [789]}, - ] diff --git a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_v1.py b/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_v1.py deleted file mode 100644 index 26e229e17..000000000 --- a/src/layers/domain/repository/device_repository/tests/v1/test_device_repository_v1.py +++ /dev/null @@ -1,73 +0,0 @@ -import pytest -from domain.core.device import Device, DeviceKeyType, DeviceStatus, DeviceType -from domain.core.root import Root -from domain.repository.device_repository import DeviceRepository -from domain.repository.errors import AlreadyExistsError, ItemNotFound - - -@pytest.fixture -def device() -> Device: - org = Root.create_ods_organisation(ods_code="AB123") - product_team = org.create_product_team( - id="6f8c285e-04a2-4194-a84e-dabeba474ff7", name="Team" - ) - device = product_team.create_device(name="Device-1", type=DeviceType.PRODUCT) - device.add_key(key="P.WWW-XXX", type=DeviceKeyType.PRODUCT_ID) - return device - - -@pytest.mark.integration -def test__device_repository(device: Device, repository: DeviceRepository): - repository.write(device) - assert repository.read(device.id) == device - - -@pytest.mark.integration -def test__device_repository_already_exists(device, repository: DeviceRepository): - repository.write(device) - with pytest.raises(AlreadyExistsError): - repository.write(device) - - -@pytest.mark.integration -def test__device_repository__device_does_not_exist(repository: DeviceRepository): - with pytest.raises(ItemNotFound): - repository.read("123") - - -def test__device_repository_local(device: Device, repository: DeviceRepository): - repository.write(device) - assert repository.read(device.id) == device - - -def test__device_repository__device_does_not_exist_local(repository: DeviceRepository): - with pytest.raises(ItemNotFound): - repository.read("123") - - -@pytest.mark.integration -def test__device_repository__update(device: Device, repository: DeviceRepository): - # Persist model before deleting from model - repository.write(device) - - # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(id=device.id) - intermediate_device.update(name="foo-bar") - repository.write(intermediate_device) - - final_device = repository.read(device.id) - assert final_device.name == "foo-bar" - - -@pytest.mark.integration -def test__device_repository__delete(device: Device, repository: DeviceRepository): - # Persist model before deleting from model - repository.write(device) - - # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(id=device.id) - intermediate_device.delete() - repository.write(intermediate_device) - - final_device = repository.read(device.id) - assert final_device.status is DeviceStatus.INACTIVE diff --git a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_keys_v3.py b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_keys_v3.py index 0025ec70b..93e15289c 100644 --- a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_keys_v3.py +++ b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_keys_v3.py @@ -7,26 +7,40 @@ @pytest.mark.integration def test__device_repository__add_two_keys(device: Device, repository: DeviceRepository): repository.write(device) - second_device = repository.read(device.id) + second_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) second_device.add_key( key_value="ABC:1234567890", key_type=DeviceKeyType.ACCREDITED_SYSTEM_ID ) repository.write(second_device) - assert repository.read(device.id).keys == [ + assert repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ).keys == [ DeviceKey(key_type=DeviceKeyType.PRODUCT_ID, key_value="P.WWW-XXX"), DeviceKey( key_type=DeviceKeyType.ACCREDITED_SYSTEM_ID, key_value="ABC:1234567890" ), ] - assert repository.read(DeviceKeyType.PRODUCT_ID, "P.WWW-XXX").keys == [ + assert repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id="P.WWW-XXX", + ).keys == [ DeviceKey(key_type=DeviceKeyType.PRODUCT_ID, key_value="P.WWW-XXX"), DeviceKey( key_type=DeviceKeyType.ACCREDITED_SYSTEM_ID, key_value="ABC:1234567890" ), ] assert repository.read( - DeviceKeyType.ACCREDITED_SYSTEM_ID, "ABC:1234567890" + product_team_id=device.product_team_id, + product_id=device.product_id, + id="ABC:1234567890", ).keys == [ DeviceKey(key_type=DeviceKeyType.PRODUCT_ID, key_value="P.WWW-XXX"), DeviceKey( @@ -43,12 +57,18 @@ def test__device_repository__delete_key( repository.write(device_with_asid) # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(device_with_asid.id) + intermediate_device = repository.read( + product_team_id=device_with_asid.product_team_id, + product_id=device_with_asid.product_id, + id=device_with_asid.id, + ) intermediate_device.delete_key( key_type=DeviceKeyType.ACCREDITED_SYSTEM_ID, key_value="ABC:1234567890" ) repository.write(intermediate_device) - assert repository.read(device_with_asid.id).keys == [ - DeviceKey(key_type=DeviceKeyType.PRODUCT_ID, key_value="P.WWW-CCC") - ] + assert repository.read( + product_team_id=device_with_asid.product_team_id, + product_id=device_with_asid.product_id, + id=device_with_asid.id, + ).keys == [DeviceKey(key_type=DeviceKeyType.PRODUCT_ID, key_value="P.WWW-CCC")] diff --git a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_questionnaire_responses_v3.py b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_questionnaire_responses_v3.py index 5dad90fca..9aaf58236 100644 --- a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_questionnaire_responses_v3.py +++ b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_questionnaire_responses_v3.py @@ -54,7 +54,14 @@ def test__device_repository__with_questionnaires( device: Device, repository: DeviceRepository ): repository.write(device) - assert repository.read(device.id) == device + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) + == device + ) @pytest.mark.integration @@ -67,7 +74,14 @@ def test__device_repository__with_questionnaires_and_tags( """ device.add_tag(foo="bar") repository.write(device) - assert repository.read(device.id) == device + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) + == device + ) @pytest.mark.integration @@ -76,7 +90,11 @@ def test__device_repository__modify_questionnaire_response_that_has_been_persist ): # Persist model before updating model repository.write(device) - intermediate_device = repository.read(device.id) + intermediate_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) # Update the model questionnaire_responses = intermediate_device.questionnaire_responses @@ -94,7 +112,11 @@ def test__device_repository__modify_questionnaire_response_that_has_been_persist # Persist and verify consistency repository.write(intermediate_device) - device_from_db = repository.read(intermediate_device.id) + device_from_db = repository.read( + product_team_id=intermediate_device.product_team_id, + product_id=intermediate_device.product_id, + id=intermediate_device.id, + ) assert devices_exactly_equal(device_from_db, intermediate_device) assert not devices_exactly_equal(device_from_db, device) assert device_from_db.questionnaire_responses["shoe/1"][ diff --git a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_tags_v3.py b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_tags_v3.py index df8c126ff..37e0994d9 100644 --- a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_tags_v3.py +++ b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_tags_v3.py @@ -2,7 +2,6 @@ import pytest from domain.core.device.v3 import Device, DeviceTag -from domain.core.device_key.v2 import DeviceKeyType from domain.core.enum import Status from domain.repository.device_repository.v3 import ( CannotDropMandatoryFields, @@ -76,8 +75,22 @@ def _test_add_two_tags( DeviceTag(shoe_size="456"), } - assert repository.read(device.id).tags == expected_tags - assert repository.read(DeviceKeyType.PRODUCT_ID, "P.WWW-XXX").tags == expected_tags + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ).tags + == expected_tags + ) + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id="P.WWW-XXX", + ).tags + == expected_tags + ) (_device_123,) = repository.query_by_tag(shoe_size=123) assert _device_123.dict(exclude=DONT_COMPARE_FIELDS) == second_device.dict( @@ -94,7 +107,11 @@ def _test_add_two_tags( @pytest.mark.integration def test__device_repository__add_two_tags(device: Device, repository: DeviceRepository): repository.write(device) - second_device = repository.read(device.id) + second_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) second_device.add_tag(shoe_size=123) second_device.add_tag(shoe_size=456) repository.write(second_device) @@ -109,7 +126,11 @@ def test__device_repository__add_two_tags_at_once( device: Device, repository: DeviceRepository ): repository.write(device) - second_device = repository.read(device.id) + second_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) second_device.add_tags([dict(shoe_size=123), dict(shoe_size=456)]) repository.write(second_device) @@ -123,7 +144,11 @@ def test__device_repository__add_two_tags_and_then_clear( device: Device, repository: DeviceRepository ): repository.write(device) - second_device = repository.read(device.id) + second_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) second_device.add_tags([dict(shoe_size=123), dict(shoe_size=456)]) repository.write(second_device) @@ -131,8 +156,22 @@ def test__device_repository__add_two_tags_and_then_clear( second_device.clear_tags() repository.write(second_device) - assert repository.read(device.id).tags == set() - assert repository.read(DeviceKeyType.PRODUCT_ID, "P.WWW-XXX").tags == set() + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ).tags + == set() + ) + assert ( + repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id="P.WWW-XXX", + ).tags + == set() + ) assert repository.query_by_tag(shoe_size=123) == [] assert repository.query_by_tag(shoe_size=456) == [] diff --git a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_v3.py b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_v3.py index ac06587b7..d7c641bb2 100644 --- a/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_v3.py +++ b/src/layers/domain/repository/device_repository/tests/v3/test_device_repository_v3.py @@ -3,8 +3,7 @@ import pytest from attr import asdict from domain.core.device.v3 import Device as DeviceV3 -from domain.core.device.v3 import DeviceCreatedEvent, DeviceTag -from domain.core.device_key.v2 import DeviceKey as DeviceKeyV2 +from domain.core.device.v3 import DeviceCreatedEvent from domain.core.device_key.v2 import DeviceKeyType from domain.core.enum import Status from domain.core.root.v3 import Root @@ -13,8 +12,7 @@ DeviceRepository as DeviceRepositoryV3, ) from domain.repository.device_repository.v3 import ( - _device_non_root_primary_keys, - _device_root_primary_key, + InactiveDeviceRepository, compress_device_fields, ) from domain.repository.errors import AlreadyExistsError, ItemNotFound @@ -55,37 +53,16 @@ def another_device_with_same_key() -> DeviceV3: return device -def test__device_root_primary_key(): - primary_key = _device_root_primary_key(device_id="123") - assert primary_key == {"pk": {"S": "D#123"}, "sk": {"S": "D#123"}} - - -def test__device_non_root_primary_keys(): - primary_keys = _device_non_root_primary_keys( - device_id="123", - device_keys=[ - DeviceKeyV2(key_type=DeviceKeyType.PRODUCT_ID, key_value=DEVICE_KEY) - ], - device_tags=[DeviceTag(foo="bar")], - ) - assert primary_keys == [ - { - "pk": {"S": "D#product_id#P.WWW-XXX"}, - "sk": {"S": "D#product_id#P.WWW-XXX"}, - }, - { - "pk": {"S": "DT#foo=bar"}, - "sk": {"S": "D#123"}, - }, - ] - - @pytest.mark.integration def test__device_repository_read_by_id( device: DeviceV3, repository: DeviceRepositoryV3 ): repository.write(device) - device_from_db = repository.read(device.id) + device_from_db = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert device_from_db.dict() == device.dict() @@ -94,7 +71,11 @@ def test__device_repository_read_by_key( device: DeviceV3, repository: DeviceRepositoryV3 ): repository.write(device) - device_from_db = repository.read(*device.keys[0].parts) + device_from_db = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.keys[0].key_value, + ) assert device_from_db.dict() == device.dict() @@ -125,12 +106,16 @@ def test__device_repository_key_already_exists_on_another_device( @pytest.mark.integration def test__device_repository__device_does_not_exist(repository: DeviceRepositoryV3): with pytest.raises(ItemNotFound): - repository.read("123") + repository.read(product_team_id="foo", product_id="bar", id="123") def test__device_repository_local(device: DeviceV3, repository: DeviceRepositoryV3): repository.write(device) - device_from_db = repository.read(device.id) + device_from_db = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert device_from_db.dict() == device.dict() @@ -138,7 +123,7 @@ def test__device_repository__device_does_not_exist_local( repository: DeviceRepositoryV3, ): with pytest.raises(ItemNotFound): - repository.read("123") + repository.read(product_team_id="foo", product_id="bar", id="123") @pytest.mark.integration @@ -146,12 +131,20 @@ def test__device_repository__update(device: DeviceV3, repository: DeviceReposito repository.write(device) # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(device.id) + intermediate_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) intermediate_device.update(name="foo-bar") repository.write(intermediate_device) - final_device = repository.read(device.id) + final_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert final_device.name == "foo-bar" @@ -165,17 +158,26 @@ def test__device_repository__delete( ): repository.write(device_with_tag) + read_query = dict( + product_team_id=device_with_tag.product_team_id, + product_id=device_with_tag.product_id, + id=device_with_tag.id, + ) + # Retrieve the model and treat this as the initial state - device = repository.read(device_with_tag.id) + device = repository.read(**read_query) device.delete() repository.write(device) # Attempt to read the original device, expecting an ItemNotFound error with pytest.raises(ItemNotFound): - repository.read(device_with_tag.id) + repository.read(**read_query) # Read the deleted device - deleted_device = repository.read_inactive(device_with_tag.id) + inactive_repository = InactiveDeviceRepository( + table_name=repository.table_name, dynamodb_client=repository.client + ) + deleted_device = inactive_repository.read(**read_query) # Assert device is inactive after being deleted assert deleted_device is not None @@ -195,15 +197,28 @@ def test__device_repository__can_delete_second_device_with_same_key( device = product.create_device(name="OriginalDevice") device.add_key(key_value=DEVICE_KEY, key_type=DeviceKeyType.PRODUCT_ID) repository.write(device) - repository.read(DeviceKeyType.PRODUCT_ID, DEVICE_KEY) # passes + + read_query = dict( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=DEVICE_KEY, + ) + repository.read(**read_query) # passes device.clear_events() device.delete() repository.write(device) with pytest.raises(ItemNotFound): - repository.read(DeviceKeyType.PRODUCT_ID, DEVICE_KEY) + repository.read(**read_query) - deleted_device = repository.read_inactive(device.id) + inactive_repository = InactiveDeviceRepository( + table_name=repository.table_name, dynamodb_client=repository.client + ) + deleted_device = inactive_repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert deleted_device.status is Status.INACTIVE # Can re-add the same product id Key after a previous device is inactive @@ -211,16 +226,20 @@ def test__device_repository__can_delete_second_device_with_same_key( _device = product.create_device(name=f"Device-{i}") _device.add_key(key_value=DEVICE_KEY, key_type=DeviceKeyType.PRODUCT_ID) repository.write(_device) - repository.read(DeviceKeyType.PRODUCT_ID, DEVICE_KEY) # passes + repository.read(**read_query) # passes _device.clear_events() _device.delete() repository.write(_device) with pytest.raises(ItemNotFound): - repository.read(DeviceKeyType.PRODUCT_ID, DEVICE_KEY) + repository.read(**read_query) # Assert device is inactive after being deleted - _deleted_device = repository.read_inactive(_device.id) + _deleted_device = inactive_repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=_device.id, + ) assert _deleted_device.status is Status.INACTIVE @@ -229,7 +248,11 @@ def test__device_repository__add_key(device: DeviceV3, repository: DeviceReposit repository.write(device) # Retrieve the model and treat this as the initial state - intermediate_device = repository.read(device.id) + intermediate_device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=device.id, + ) assert len(intermediate_device.keys) == 1 intermediate_device.add_key( @@ -239,12 +262,16 @@ def test__device_repository__add_key(device: DeviceV3, repository: DeviceReposit # Read the same device multiple times, indexed by key and id # to verify that they're all the same - root_index = [(intermediate_device.id,)] - non_root_indexes = [k.parts for k in intermediate_device.keys] + root_index = intermediate_device.id + non_root_indexes = [k.key_value for k in intermediate_device.keys] retrieved_devices = [] - for key_parts in root_index + non_root_indexes: - _device = repository.read(*key_parts).dict() + for key_value in [root_index] + non_root_indexes: + _device = repository.read( + product_team_id=device.product_team_id, + product_id=device.product_id, + id=key_value, + ).dict() retrieved_devices.append(_device) # Assert that there are 2 keys, the device can be retrieved 3 ways from the db, diff --git a/src/layers/domain/repository/device_repository/v1.py b/src/layers/domain/repository/device_repository/v1.py deleted file mode 100644 index 0988c6f3e..000000000 --- a/src/layers/domain/repository/device_repository/v1.py +++ /dev/null @@ -1,277 +0,0 @@ -from collections import defaultdict -from typing import TYPE_CHECKING - -import orjson -from attr import asdict as _asdict -from domain.core.device import ( - Device, - DeviceCreatedEvent, - DeviceIndexAddedEvent, - DeviceKey, - DeviceKeyAddedEvent, - DeviceKeyDeletedEvent, - DeviceType, - DeviceUpdatedEvent, -) -from domain.core.questionnaire import ( - Questionnaire, - QuestionnaireInstanceEvent, - QuestionnaireResponse, - QuestionnaireResponseAddedEvent, - QuestionnaireResponseDeletedEvent, - QuestionnaireResponseUpdatedEvent, -) -from domain.repository.errors import ItemNotFound -from domain.repository.keys import TableKeys, strip_key_prefix -from domain.repository.marshall import marshall, marshall_value, unmarshall -from domain.repository.questionnaire_repository import deserialise_question -from domain.repository.repository import Repository -from domain.repository.transaction import ( - ConditionExpression, - TransactionStatement, - TransactItem, -) -from event.json import json_loads - -if TYPE_CHECKING: - from mypy_boto3_dynamodb.type_defs import QueryOutputTypeDef - - -def asdict(obj) -> dict: - return _asdict(obj, recurse=False) - - -class DeviceRepository(Repository[Device]): - def __init__(self, table_name, dynamodb_client): - super().__init__( - table_name=table_name, model=Device, dynamodb_client=dynamodb_client - ) - - def handle_DeviceCreatedEvent( - self, - event: DeviceCreatedEvent, - condition_expression=ConditionExpression.MUST_NOT_EXIST, - ) -> TransactItem: - pk = TableKeys.DEVICE.key(event.id) - pk_1 = TableKeys.DEVICE_TYPE.key(event.type) - event_data = asdict(event) - _condition_expression = ( - {"ConditionExpression": condition_expression} - if event_data.get("_trust", False) is False - else {} - ) - return TransactItem( - Put=TransactionStatement( - TableName=self.table_name, - Item=marshall(pk=pk, sk=pk, pk_1=pk_1, sk_1=pk, **event_data), - **_condition_expression, - ) - ) - - def handle_DeviceUpdatedEvent(self, event: DeviceUpdatedEvent) -> TransactItem: - return self.handle_DeviceCreatedEvent( - event=event, condition_expression=ConditionExpression.MUST_EXIST - ) - - def handle_DeviceKeyAddedEvent(self, event: DeviceKeyAddedEvent): - pk = TableKeys.DEVICE.key(event.id) - sk = TableKeys.DEVICE_KEY.key(event.key) - pk_2 = TableKeys.DEVICE_KEY_TYPE.key(event.type) - event_data = asdict(event) - condition_expression = ( - {"ConditionExpression": ConditionExpression.MUST_NOT_EXIST} - if event_data.get("_trust", False) is False - else {} - ) - return TransactItem( - Put=TransactionStatement( - TableName=self.table_name, - Item=marshall( - pk=pk, sk=sk, pk_1=sk, sk_1=sk, pk_2=pk_2, sk_2=sk, **event_data - ), - **condition_expression, - ) - ) - - def handle_DeviceKeyDeletedEvent( - self, event: DeviceKeyDeletedEvent - ) -> TransactItem: - pk = TableKeys.DEVICE.key(event.id) - sk = TableKeys.DEVICE_KEY.key(event.key) - return TransactItem( - Delete=TransactionStatement( - TableName=self.table_name, - Key=marshall(pk=pk, sk=sk), - ConditionExpression=ConditionExpression.MUST_EXIST, - ) - ) - - def handle_QuestionnaireInstanceEvent(self, event: QuestionnaireInstanceEvent): - pk = TableKeys.DEVICE.key(event.entity_id) - sk = TableKeys.QUESTIONNAIRE.key(event.questionnaire_id) - event_data = asdict(event) - event_data["questions"] = orjson.dumps(event_data["questions"]).decode() - return TransactItem( - Put=TransactionStatement( - TableName=self.table_name, - Item=marshall(pk=pk, sk=sk, pk_1=sk, sk_1=pk, **event_data), - ) - ) - - def handle_QuestionnaireResponseAddedEvent( - self, - event: QuestionnaireResponseAddedEvent, - condition_expression=ConditionExpression.MUST_NOT_EXIST, - ) -> TransactItem: - pk = TableKeys.DEVICE.key(event.entity_id) - sk = TableKeys.QUESTIONNAIRE_RESPONSE.key( - event.questionnaire_id, event.questionnaire_response_index - ) - event_data = asdict(event) - event_data["responses"] = orjson.dumps(event_data["responses"]).decode() - condition_expression = ( - {"ConditionExpression": condition_expression} - if event_data.get("_trust", False) is False - else {} - ) - return TransactItem( - Put=TransactionStatement( - TableName=self.table_name, - Item=marshall(pk=pk, sk=sk, pk_1=sk, sk_1=pk, **event_data), - **condition_expression, - ) - ) - - def handle_QuestionnaireResponseUpdatedEvent( - self, event: QuestionnaireResponseUpdatedEvent - ) -> TransactItem: - return self.handle_QuestionnaireResponseAddedEvent( - event=event, condition_expression=ConditionExpression.MUST_EXIST - ) - - def handle_QuestionnaireResponseDeletedEvent( - self, event: QuestionnaireResponseDeletedEvent - ) -> TransactItem: - pk = TableKeys.DEVICE.key(event.entity_id) - sk = TableKeys.QUESTIONNAIRE_RESPONSE.key( - event.questionnaire_id, event.questionnaire_response_index - ) - return TransactItem( - Delete=TransactionStatement( - TableName=self.table_name, - Key=marshall(pk=pk, sk=sk), - ConditionExpression=ConditionExpression.MUST_EXIST, - ) - ) - - def handle_DeviceIndexAddedEvent(self, event: DeviceIndexAddedEvent): - pk = TableKeys.DEVICE.key(event.id) - sk = TableKeys.DEVICE_INDEX.key( - event.questionnaire_id, event.question_name, event.value - ) - event_data = asdict(event) - condition_expression = ( - {"ConditionExpression": ConditionExpression.MUST_NOT_EXIST} - if event_data.get("_trust", False) is False - else {} - ) - return TransactItem( - Put=TransactionStatement( - TableName=self.table_name, - Item=marshall(pk=pk, sk=sk, pk_1=sk, sk_1=pk, **asdict(event)), - **condition_expression, - ) - ) - - def query_by_key_type(self, key_type, **kwargs) -> "QueryOutputTypeDef": - pk_2 = TableKeys.DEVICE_KEY_TYPE.key(key_type) - args = { - "TableName": self.table_name, - "IndexName": "idx_gsi_2", - "KeyConditionExpression": "pk_2 = :pk_2", - "ExpressionAttributeValues": {":pk_2": marshall_value(pk_2)}, - } - return self.client.query(**args, **kwargs) - - def query_by_device_type(self, type: DeviceType, **kwargs) -> "QueryOutputTypeDef": - pk_1 = TableKeys.DEVICE_TYPE.key(type) - args = { - "TableName": self.table_name, - "IndexName": "idx_gsi_1", - "KeyConditionExpression": "pk_1 = :pk_1", - "ExpressionAttributeValues": { - ":pk_1": marshall_value(pk_1), - }, - } - return self.client.query(**args, **kwargs) - - def read_by_index(self, questionnaire_id: str, question_name: str, value: str): - pk_1 = TableKeys.DEVICE_INDEX.key(questionnaire_id, question_name, value) - result = self.client.query( - TableName=self.table_name, - IndexName="idx_gsi_1", - KeyConditionExpression="pk_1 = :pk_1", - ExpressionAttributeValues={ - ":pk_1": marshall_value(pk_1), - }, - ) - items = (unmarshall(i) for i in result["Items"]) - return [self.read(strip_key_prefix(item["pk"])) for item in items] - - def read_by_key(self, key) -> Device: - pk_1 = TableKeys.DEVICE_KEY.key(key) - args = { - "TableName": self.table_name, - "IndexName": "idx_gsi_1", - "KeyConditionExpression": "pk_1 = :pk_1 AND sk_1 = :pk_1", - "ExpressionAttributeValues": {":pk_1": marshall_value(pk_1)}, - } - result = self.client.query(**args) - items = [unmarshall(i) for i in result["Items"]] - if len(items) == 0: - raise ItemNotFound(key, item_type=Device) - (item,) = items - return self.read(strip_key_prefix(item["pk"])) - - def read(self, id) -> Device: - pk = TableKeys.DEVICE.key(id) - args = { - "TableName": self.table_name, - "KeyConditionExpression": "pk = :pk", - "ExpressionAttributeValues": {":pk": marshall_value(pk)}, - } - result = self.client.query(**args) - items = [unmarshall(i) for i in result["Items"]] - if len(items) == 0: - raise ItemNotFound(id, item_type=Device) - - (device,) = TableKeys.DEVICE.filter(items, key="sk") - keys = TableKeys.DEVICE_KEY.filter_and_group(items, key="sk") - _indexes = TableKeys.DEVICE_INDEX.filter(items, key="sk") - indexes = set( - (idx["questionnaire_id"], idx["question_name"], idx["value"]) - for idx in _indexes - ) - - questionnaires = {} - for id_, data in TableKeys.QUESTIONNAIRE.filter_and_group(items, key="sk"): - data["questions"] = { - question_name: deserialise_question(question) - for question_name, question in json_loads(data["questions"]).items() - } - questionnaires[id_] = Questionnaire(**data) - - questionnaire_responses = defaultdict(list) - for qr in TableKeys.QUESTIONNAIRE_RESPONSE.filter(items, key="sk"): - qid = qr["questionnaire_id"] - _questionnaire_response = QuestionnaireResponse( - questionnaire=questionnaires[qid], responses=json_loads(qr["responses"]) - ) - questionnaire_responses[qid].append(_questionnaire_response) - - return Device( - keys={id_: DeviceKey(**data) for id_, data in keys}, - questionnaire_responses=questionnaire_responses, - indexes=indexes, - **device, - ) diff --git a/src/layers/domain/repository/device_repository/v3.py b/src/layers/domain/repository/device_repository/v3.py index cc9184ca8..c1df27f2c 100644 --- a/src/layers/domain/repository/device_repository/v3.py +++ b/src/layers/domain/repository/device_repository/v3.py @@ -1,8 +1,8 @@ from copy import copy from attr import asdict +from domain.core.device.v3 import Device as _Device from domain.core.device.v3 import ( - Device, DeviceCreatedEvent, DeviceDeletedEvent, DeviceKeyAddedEvent, @@ -18,16 +18,17 @@ from domain.core.event import Event from domain.core.questionnaire.v2 import QuestionnaireResponseUpdatedEvent from domain.repository.compression import pkl_dumps_gzip, pkl_loads_gzip -from domain.repository.errors import ItemNotFound from domain.repository.keys.v3 import TableKey from domain.repository.marshall import marshall, marshall_value, unmarshall -from domain.repository.repository.v2 import Repository +from domain.repository.repository.v3 import Repository, TooManyResults from domain.repository.transaction import ( ConditionExpression, TransactionStatement, TransactItem, + dynamodb_projection_expression, update_transactions, ) +from pydantic import validator TAGS = "tags" ROOT_FIELDS_TO_COMPRESS = [TAGS] @@ -35,10 +36,6 @@ BATCH_GET_SIZE = 100 -class TooManyResults(Exception): - pass - - class CannotDropMandatoryFields(Exception): def __init__(self, bad_fields: set[str]) -> None: super().__init__(f"Cannot drop mandatory fields: {', '.join(bad_fields)}") @@ -82,120 +79,44 @@ def decompress_device_fields(device: dict): return device -def _device_root_primary_key(device_id: str) -> dict: - """ - Generates one fully marshalled (i.e. {"pk": {"S": "123"}} DynamoDB - primary key (i.e. pk + sk) for the provided Device, indexed by the Device ID - """ - root_pk = TableKey.DEVICE.key(device_id) - return marshall(pk=root_pk, sk=root_pk) - - -def _device_non_root_primary_keys( - device_id: str, device_keys: list[DeviceKey], device_tags: list[DeviceTag] -) -> list[dict]: - """ - Generates all the fully marshalled (i.e. {"pk": {"S": "123"}} DynamoDB - primary keys (i.e. pk + sk) for the provided Device. This is one primary key - for every value of Device.keys and Device.tags - """ - root_pk = TableKey.DEVICE.key(device_id) - device_key_primary_keys = [ - marshall(pk=pk, sk=pk) - for pk in (TableKey.DEVICE.key(k.key_type, k.key_value) for k in device_keys) - ] - device_tag_primary_keys = [ - marshall(pk=pk, sk=root_pk) - for pk in (TableKey.DEVICE_TAG.key(t.value) for t in device_tags) - ] - return device_key_primary_keys + device_tag_primary_keys - - -def update_device_indexes( - table_name: str, - data: dict | DeviceUpdatedEvent, - id: str, - keys: list[DeviceKey], - tags: list[DeviceTag], -): - # Update the root device without compressing the 'questionnaire_responses' field - root_primary_key = _device_root_primary_key(device_id=id) - update_root_device_transactions = update_transactions( - table_name=table_name, - primary_keys=[root_primary_key], - data=compress_device_fields(data), - ) - # Update non-root devices with compressed 'questionnaire_responses' field - non_root_primary_keys = _device_non_root_primary_keys( - device_id=id, device_keys=keys, device_tags=tags - ) - update_non_root_devices_transactions = update_transactions( - table_name=table_name, - primary_keys=non_root_primary_keys, - data=compress_device_fields( - data, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS - ), - ) - return update_root_device_transactions + update_non_root_devices_transactions - - -def create_device_index( - table_name: str, - pk_key_parts: tuple[str], - device_data: dict, - sk_key_parts=None, - pk_table_key: TableKey = TableKey.DEVICE, - sk_table_key: TableKey = TableKey.DEVICE, - root=False, +def create_tag_index( + table_name: str, device_id: str, tag_value: str, data: dict ) -> TransactItem: - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) if sk_key_parts else pk + pk = TableKey.DEVICE_TAG.key(tag_value) + sk = TableKey.DEVICE.key(device_id) return TransactItem( Put=TransactionStatement( TableName=table_name, - Item=marshall(pk=pk, sk=sk, root=root, **device_data), + Item=marshall(pk=pk, sk=sk, pk_read=pk, sk_read=sk, root=False, **data), ConditionExpression=ConditionExpression.MUST_NOT_EXIST, ) ) -def create_device_index_batch( - pk_key_parts: tuple[str], - device_data: dict, - sk_key_parts=None, - pk_table_key: TableKey = TableKey.DEVICE, - sk_table_key: TableKey = TableKey.DEVICE, - root=False, -) -> dict: +def create_tag_index_batch(device_id: str, tag_value: str, data: dict): """ - Difference between `create_device_index` and `create_device_index_batch`: + Difference between `create_tag_index` and `create_tag_index_batch`: - `create_device_index` is intended for the event-based - handlers (e.g. `handle_DeviceCreatedEvent`) which are called by the base + `create_index` is intended for the event-based + `handle_TagAddedEvent` which is called by the base `write` method, which expects `TransactItem`s for use with `client.transact_write_items` - `create_device_index_batch` is intended the device-based handler + `create_tag_index_batch` is intended for the entity-based handler `handle_bulk` which is called by the base method `write_bulk`, which expects `BatchWriteItem`s which we render as a `dict` for use with `client.batch_write_items` """ - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) if sk_key_parts else pk + pk = TableKey.DEVICE_TAG.key(tag_value) + sk = TableKey.DEVICE.key(device_id) return { "PutRequest": { - "Item": marshall(pk=pk, sk=sk, root=root, **device_data), - }, + "Item": marshall(pk=pk, sk=sk, pk_read=pk, sk_read=sk, root=False, **data) + } } -def delete_device_index( - table_name: str, - pk_key_parts: tuple[str], - sk_key_parts=None, - pk_table_key: TableKey = TableKey.DEVICE, - sk_table_key: TableKey = TableKey.DEVICE, -) -> TransactItem: - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) if sk_key_parts else pk +def delete_tag_index(table_name: str, device_id: str, tag_value: str) -> TransactItem: + pk = TableKey.DEVICE_TAG.key(tag_value) + sk = TableKey.DEVICE.key(device_id) return TransactItem( Delete=TransactionStatement( TableName=table_name, @@ -205,40 +126,91 @@ def delete_device_index( ) +def update_tag_indexes( + table_name: str, device_id: str, tag_values: list[str], data: dict +) -> TransactItem: + root_pk = TableKey.DEVICE.key(device_id) + tag_keys = [ + marshall(pk=TableKey.DEVICE_TAG.key(tag_value), sk=root_pk) + for tag_value in tag_values + ] + return update_transactions(table_name=table_name, primary_keys=tag_keys, data=data) + + +class Device(_Device): + """Wrapper around domain Device that also deserialises tags""" + + @validator("tags", pre=True) + def deserialise_tags(cls, tags): + if isinstance(tags, str): + tags = [pkl_loads_gzip(tag) for tag in pkl_loads_gzip(tags)] + return tags + + class DeviceRepository(Repository[Device]): def __init__(self, table_name, dynamodb_client): super().__init__( - table_name=table_name, model=Device, dynamodb_client=dynamodb_client + table_name=table_name, + model=Device, + dynamodb_client=dynamodb_client, + parent_table_keys=(TableKey.PRODUCT_TEAM, TableKey.CPM_PRODUCT), + table_key=TableKey.DEVICE, ) + def read(self, product_team_id: str, product_id: str, id: str): + return super()._read(parent_ids=(product_team_id, product_id), id=id) + + def search(self, product_team_id: str, product_id: str, id: str): + return super()._query(parent_ids=(product_team_id, product_id)) + def handle_DeviceCreatedEvent(self, event: DeviceCreatedEvent) -> TransactItem: - return create_device_index( - table_name=self.table_name, - pk_key_parts=(event.id,), - device_data=compress_device_fields(event), + return self.create_index( + id=event.id, + parent_key_parts=(event.product_team_id, event.product_id), + data=compress_device_fields(event), root=True, ) def handle_DeviceUpdatedEvent( self, event: DeviceUpdatedEvent ) -> list[TransactItem]: - keys = {DeviceKey(**key) for key in event.keys} - tags = {DeviceTag(__root__=tag) for tag in event.tags} - return update_device_indexes( - table_name=self.table_name, data=event, id=event.id, keys=keys, tags=tags + root_pk = self.table_key.key(event.id) + (root_transaction,) = update_transactions( + table_name=self.table_name, + primary_keys=[marshall(pk=root_pk, sk=root_pk)], + data=compress_device_fields(event), + ) + + non_root_data = compress_device_fields( + event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS + ) + + key_pks = (self.table_key.key(DeviceKey(**k).key_value) for k in event.keys) + key_transactions = update_transactions( + table_name=self.table_name, + primary_keys=[marshall(pk=pk, sk=pk) for pk in key_pks], + data=non_root_data, ) + tag_transactions = update_tag_indexes( + table_name=self.table_name, + device_id=event.id, + tag_values=event.tags, + data=non_root_data, + ) + + return [root_transaction] + key_transactions + tag_transactions + def handle_DeviceDeletedEvent( self, event: DeviceDeletedEvent ) -> list[TransactItem]: # Inactive Devices have tags removed so that they are # no longer searchable - delete_transactions = [ - delete_device_index( + tag_delete_transactions = [ + delete_tag_index( table_name=self.table_name, - pk_key_parts=(DeviceTag(__root__=tag).value,), - sk_key_parts=(event.id,), - pk_table_key=TableKey.DEVICE_TAG, + device_id=event.id, + tag_value=tag, ) for tag in event.deleted_tags ] @@ -248,227 +220,228 @@ def handle_DeviceDeletedEvent( inactive_data["status"] = str(Status.INACTIVE) # Collect keys for the original devices - original_keys = {DeviceKey(**key) for key in event.keys} + original_keys = {DeviceKey(**key).key_value for key in event.keys} # Create copy of original device and indexes with new pk and sk - inactive_root_copy_transactions = [] - inactive_root_copy_transactions.append( - create_device_index( - table_name=self.table_name, - pk_table_key=TableKey.DEVICE_STATUS, - pk_key_parts=(event.status, event.id), - sk_key_parts=(event.id,), - device_data=inactive_data, - root=True, - ) + root_copy_transaction = self.create_index( + id=event.id, + parent_key_parts=(event.product_team_id, event.product_id), + data=inactive_data, + table_key=TableKey.DEVICE_STATUS, + root=True, ) - inactive_key_indexes_copy_transactions = [] - for key in original_keys: - inactive_key_indexes_copy_transactions.append( - create_device_index( - table_name=self.table_name, - pk_table_key=TableKey.DEVICE_STATUS, - pk_key_parts=(event.status, event.id), - sk_key_parts=key.parts, - device_data=inactive_data, - root=False, - ) - ) - # Create delete transactions for original device and key indexes - original_root_delete_transactions = [] - original_root_delete_transactions.append( - delete_device_index( - table_name=self.table_name, - pk_key_parts=(event.id,), - pk_table_key=TableKey.DEVICE, - ) - ) - - original_key_indexes_delete_transactions = [] - for key in original_keys: - original_key_indexes_delete_transactions.append( - delete_device_index( - table_name=self.table_name, - pk_key_parts=key.parts, - pk_table_key=TableKey.DEVICE, - ) - ) + root_delete_transaction = self.delete_index(event.id) + key_delete_transactions = [self.delete_index(key) for key in original_keys] return ( - delete_transactions - + inactive_root_copy_transactions - + inactive_key_indexes_copy_transactions - + original_root_delete_transactions - + original_key_indexes_delete_transactions + tag_delete_transactions + + [root_copy_transaction, root_delete_transaction] + + key_delete_transactions ) def handle_DeviceKeyAddedEvent( self, event: DeviceKeyAddedEvent ) -> list[TransactItem]: # Create a copy of the Device indexed against the new key - create_transaction = create_device_index( - table_name=self.table_name, - pk_key_parts=event.new_key.parts, - device_data=compress_device_fields( - event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS - ), + _non_root_data = compress_device_fields( + event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS + ) + create_key_transaction = self.create_index( + id=event.new_key.key_value, + parent_key_parts=(event.product_team_id, event.product_id), + data=_non_root_data, + root=False, + ) + + data = {"keys": event.keys, "updated_on": event.updated_on} + + # Update "keys" on the root and key-indexed Devices + device_keys = {DeviceKey(**key).key_value for key in event.keys} + device_keys_before_update = device_keys - {event.new_key.key_value} + update_root_and_key_transactions = self.update_indexes( + id=event.id, keys=device_keys_before_update, data=data ) - # Update the value of "keys" on all other copies of this Device - device_keys = {DeviceKey(**key) for key in event.keys} - device_keys_before_update = device_keys - {event.new_key} - device_tags = {DeviceTag(__root__=tag) for tag in event.tags} - update_transactions = update_device_indexes( + + # Update "keys" on the tag-indexed Devices + update_tag_transactions = update_tag_indexes( table_name=self.table_name, - id=event.id, - keys=device_keys_before_update, - tags=device_tags, - data={ - "keys": event.keys, - "updated_on": event.updated_on, - }, + device_id=event.id, + tag_values=event.tags, + data=data, + ) + + return ( + [create_key_transaction] + + update_root_and_key_transactions + + update_tag_transactions ) - return [create_transaction] + update_transactions def handle_DeviceKeyDeletedEvent( self, event: DeviceKeyDeletedEvent ) -> list[TransactItem]: - # Delete the copy of the Device indexed against the deleted key - delete_transaction = delete_device_index( - table_name=self.table_name, pk_key_parts=event.deleted_key.parts - ) - # Update the value of "keys" on all other copies of this Device - device_keys = {DeviceKey(**key) for key in event.keys} - device_keys_before_update = device_keys - {event.deleted_key} - device_tags = {DeviceTag(__root__=tag) for tag in event.tags} - update_transactions = update_device_indexes( + delete_key_transaction = self.delete_index(event.deleted_key.key_value) + + data = {"keys": event.keys, "updated_on": event.updated_on} + + # Update "keys" on the root and key-indexed Devices + device_keys = {DeviceKey(**key).key_value for key in event.keys} + device_keys_before_update = device_keys - {event.deleted_key.key_value} + update_root_and_key_transactions = self.update_indexes( + id=event.id, keys=device_keys_before_update, data=data + ) + + # Update "keys" on the tag-indexed Devices + update_tag_transactions = update_tag_indexes( table_name=self.table_name, - id=event.id, - keys=device_keys_before_update, - tags=device_tags, - data={ - "keys": event.keys, - "updated_on": event.updated_on, - }, + device_id=event.id, + tag_values=event.tags, + data=data, + ) + + return ( + [delete_key_transaction] + + update_root_and_key_transactions + + update_tag_transactions ) - return [delete_transaction] + update_transactions def handle_DeviceTagAddedEvent( self, event: DeviceTagAddedEvent ) -> list[TransactItem]: + data = {"tags": event.tags, "updated_on": event.updated_on} + # Create a copy of the Device indexed against the new tag - create_transaction = create_device_index( + create_tag_transaction = create_tag_index( table_name=self.table_name, - pk_key_parts=(event.new_tag.value,), - sk_key_parts=(event.id,), - pk_table_key=TableKey.DEVICE_TAG, - device_data=compress_device_fields( + device_id=event.id, + tag_value=event.new_tag, + data=compress_device_fields( event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS ), ) - # Update the value of "tags" on all other copies of this Device - device_keys = {DeviceKey(**key) for key in event.keys} - device_tags = {DeviceTag(__root__=tag) for tag in event.tags} - device_tags_before_update = device_tags - {event.new_tag} - update_transactions = update_device_indexes( + + # Update "tags" on the root and key-indexed Devices + device_keys = {DeviceKey(**key).key_value for key in event.keys} + update_root_and_key_transactions = self.update_indexes( + id=event.id, keys=device_keys, data=data + ) + + # Update "tags" on the tag-indexed Devices + update_tag_transactions = update_tag_indexes( table_name=self.table_name, - id=event.id, - keys=device_keys, - tags=device_tags_before_update, - data={"tags": event.tags, "updated_on": event.updated_on}, + device_id=event.id, + tag_values=event.tags, + data=data, + ) + + return ( + [create_tag_transaction] + + update_root_and_key_transactions + + update_tag_transactions ) - return [create_transaction] + update_transactions def handle_DeviceTagsAddedEvent(self, event: DeviceTagsAddedEvent): - # Create a copy of the Device indexed against the new tag - device_data = compress_device_fields( + data = {"tags": event.tags, "updated_on": event.updated_on} + _data = compress_device_fields( event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS ) - create_transactions = [ - create_device_index( + + # Create a copy of the Device indexed against the new tag + create_tag_transactions = [ + create_tag_index( table_name=self.table_name, - pk_key_parts=(new_tag.value,), - sk_key_parts=(event.id,), - pk_table_key=TableKey.DEVICE_TAG, - device_data=device_data, + device_id=event.id, + tag_value=tag, + data=_data, ) - for new_tag in event.new_tags + for tag in event.new_tags ] - # Update the value of "tags" on all other copies of this Device - device_keys = {DeviceKey(**key) for key in event.keys} - device_tags = {DeviceTag(__root__=tag) for tag in event.tags} - device_tags_before_update = device_tags - event.new_tags - update_transactions = update_device_indexes( + # Update "tags" on the root and key-indexed Devices + device_keys = {DeviceKey(**key).key_value for key in event.keys} + update_root_and_key_transactions = self.update_indexes( + id=event.id, keys=device_keys, data=data + ) + + # Update "tags" on the tag-indexed Devices + update_tag_transactions = update_tag_indexes( table_name=self.table_name, - id=event.id, - keys=device_keys, - tags=device_tags_before_update, - data={"tags": event.tags, "updated_on": event.updated_on}, + device_id=event.id, + tag_values=set(event.tags) - set(event.new_tags), + data=data, + ) + + return ( + create_tag_transactions + + update_root_and_key_transactions + + update_tag_transactions ) - return create_transactions + update_transactions def handle_DeviceTagsClearedEvent(self, event: DeviceTagsClearedEvent): delete_tags_transactions = [ - delete_device_index( + delete_tag_index( table_name=self.table_name, - pk_key_parts=(tag.value,), - sk_key_parts=(event.id,), - pk_table_key=TableKey.DEVICE_TAG, + device_id=event.id, + tag_value=tag, ) for tag in event.deleted_tags ] - keys = {DeviceKey(**key) for key in event.keys} - update_transactions = update_device_indexes( - table_name=self.table_name, - id=event.id, - keys=keys, - tags=[], # tags already deleted in delete_tags_transactions - data={"tags": []}, + keys = {DeviceKey(**key).key_value for key in event.keys} + update_transactions = self.update_indexes( + id=event.id, keys=keys, data={"tags": []} ) return delete_tags_transactions + update_transactions def handle_QuestionnaireResponseUpdatedEvent( self, event: QuestionnaireResponseUpdatedEvent ): - keys = {DeviceKey(**key) for key in event.entity_keys} - tags = {DeviceTag(__root__=tag) for tag in event.entity_tags} - return update_device_indexes( + data = { + "questionnaire_responses": event.questionnaire_responses, + "updated_on": event.updated_on, + } + + # Update "questionnaire_responses" on the root and key-indexed Devices + keys = {DeviceKey(**key).key_value for key in event.entity_keys} + update_root_and_key_transactions = self.update_indexes( + id=event.entity_id, keys=keys, data=data + ) + + # Update "questionnaire_responses" on the tag-indexed Devices + tag_values = {DeviceTag(__root__=tag) for tag in event.entity_tags} + update_tag_transactions = update_tag_indexes( table_name=self.table_name, - id=event.entity_id, - keys=keys, - tags=tags, - data={ - "questionnaire_responses": event.questionnaire_responses, - "updated_on": event.updated_on, - }, + device_id=event.entity_id, + tag_values=tag_values, + data=data, ) + return update_root_and_key_transactions + update_tag_transactions def handle_bulk(self, item: dict) -> list[dict]: - create_device_transaction = create_device_index_batch( - pk_key_parts=(item["id"],), - device_data=compress_device_fields(item), - root=True, + parent_key = (item["product_team_id"], item["product_id"]) + + root_data = compress_device_fields(item) + create_device_transaction = self.create_index_batch( + id=item["id"], parent_key_parts=parent_key, data=root_data, root=True ) - device_data = compress_device_fields( + non_root_data = compress_device_fields( item, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS ) create_keys_transactions = [ - create_device_index_batch( - pk_key_parts=(key["key_type"], key["key_value"]), - device_data=device_data, + self.create_index_batch( + id=key["key_value"], + parent_key_parts=parent_key, + data=non_root_data, + root=False, ) for key in item["keys"] ] + create_tags_transactions = [ - create_device_index_batch( - pk_key_parts=(DeviceTag(__root__=tag).value,), - sk_key_parts=(item["id"],), - pk_table_key=TableKey.DEVICE_TAG, - device_data=device_data, + create_tag_index_batch( + device_id=item["id"], tag_value=tag, data=non_root_data ) for tag in item["tags"] ] @@ -478,47 +451,6 @@ def handle_bulk(self, item: dict) -> list[dict]: + create_tags_transactions ) - def read(self, *key_parts: str) -> Device: - """ - Read the device by either id or key. If calling by id, then do: - repository.read("123") - If calling by key then you must include the key type (e.g. 'product_id'): - repository.read("product_id", "123") - - """ - key = TableKey.DEVICE.key(*key_parts) - result = self.client.get_item( - TableName=self.table_name, Key=marshall(pk=key, sk=key) - ) - try: - item = result["Item"] - except KeyError: - raise ItemNotFound(*key_parts, item_type=Device) - - _device = unmarshall(item) - return Device(**decompress_device_fields(_device)) - - def read_inactive(self, *key_parts: str) -> Device: - """ - Read the inactive device by id:: - - repository.read("123") - - """ - pk = TableKey.DEVICE_STATUS.key(Status.INACTIVE, *key_parts) - sk = TableKey.DEVICE.key(*key_parts) - - result = self.client.get_item( - TableName=self.table_name, Key=marshall(pk=pk, sk=sk) - ) - try: - item = result["Item"] - except KeyError: - raise ItemNotFound(*key_parts, item_type=Device) - - _device = unmarshall(item) - return Device(**decompress_device_fields(_device)) - def query_by_tag( self, fields_to_drop: list[str] | set[str] = None, @@ -548,7 +480,7 @@ def query_by_tag( "ExpressionAttributeValues": {":pk": marshall_value(pk)}, "KeyConditionExpression": "pk = :pk", "TableName": self.table_name, - **_dynamodb_projection_expression(fields_to_return), + **dynamodb_projection_expression(fields_to_return), } response = self.client.query(**query_params) @@ -561,19 +493,17 @@ def query_by_tag( return [Device(**d) for d in sorted(devices_as_dict, key=lambda d: d["id"])] -def _dynamodb_projection_expression(updated_fields: list[str]): - expression_attribute_names = {} - update_clauses = [] - - for field_name in updated_fields: - field_name_placeholder = f"#{field_name}" +class InactiveDeviceRepository(Repository[Device]): + """Read-only repository""" - update_clauses.append(field_name_placeholder) - expression_attribute_names[field_name_placeholder] = field_name - - projection_expression = ", ".join(update_clauses) + def __init__(self, table_name, dynamodb_client): + super().__init__( + table_name=table_name, + model=Device, + dynamodb_client=dynamodb_client, + parent_table_keys=(TableKey.PRODUCT_TEAM, TableKey.CPM_PRODUCT), + table_key=TableKey.DEVICE_STATUS, + ) - return dict( - ProjectionExpression=projection_expression, - ExpressionAttributeNames=expression_attribute_names, - ) + def read(self, product_team_id: str, product_id: str, id: str): + return self._read(parent_ids=(product_team_id, product_id), id=id) diff --git a/src/layers/domain/repository/keys/tests/test_keys_v1.py b/src/layers/domain/repository/keys/tests/test_keys_v1.py index 5c1180f0c..325aeddb1 100644 --- a/src/layers/domain/repository/keys/tests/test_keys_v1.py +++ b/src/layers/domain/repository/keys/tests/test_keys_v1.py @@ -49,20 +49,22 @@ def test_TableKeys_filter(table_key: TableKeys, expected): ) def test_TableKeys_filter_and_group(table_key: TableKeys, expected): iterable = [ - {"pk_1": "D#foo", "other_data": "FOO"}, - {"pk_1": "PT#baz", "other_data": "BAZ"}, - {"pk_1": "D#bar", "other_data": "BAR"}, + {"pk_read": "D#foo", "other_data": "FOO"}, + {"pk_read": "PT#baz", "other_data": "BAZ"}, + {"pk_read": "D#bar", "other_data": "BAR"}, ] - assert list(table_key.filter_and_group(iterable=iterable, key="pk_1")) == expected + assert ( + list(table_key.filter_and_group(iterable=iterable, key="pk_read")) == expected + ) def test_group_by_key(): iterable = [ - {"pk_1": "D#foo", "other_data": "FOO"}, - {"pk_1": "PT#baz", "other_data": "BAZ"}, - {"pk_1": "D#bar", "other_data": "BAR"}, + {"pk_read": "D#foo", "other_data": "FOO"}, + {"pk_read": "PT#baz", "other_data": "BAZ"}, + {"pk_read": "D#bar", "other_data": "BAR"}, ] - assert list(group_by_key(iterable=iterable, key="pk_1")) == [ + assert list(group_by_key(iterable=iterable, key="pk_read")) == [ ("foo", {"other_data": "FOO"}), ("baz", {"other_data": "BAZ"}), ("bar", {"other_data": "BAR"}), @@ -80,18 +82,10 @@ def test_remove_keys(): **{ "pk": "0", "sk": "0", - "pk_1": "1", - "sk_1": "1", - "pk_2": "2", - "sk_2": "2", - "pk_3": "3", - "sk_3": "3", + "pk_read": "1", + "sk_read": "1", "foo": "FOO", - "pk_4": "4", "bar": "BAR", - "sk_4": "4", - "pk_5": "5", - "sk_5": "5", "baz": "BAZ", } ) diff --git a/src/layers/domain/repository/keys/v1.py b/src/layers/domain/repository/keys/v1.py index 19132c6eb..1b376d601 100644 --- a/src/layers/domain/repository/keys/v1.py +++ b/src/layers/domain/repository/keys/v1.py @@ -56,16 +56,8 @@ def strip_key_prefix(key: str): def remove_keys( pk=None, sk=None, - pk_1=None, - sk_1=None, - pk_2=None, - sk_2=None, - pk_3=None, - sk_3=None, - pk_4=None, - sk_4=None, - pk_5=None, - sk_5=None, + pk_read=None, + sk_read=None, **values, ): return values diff --git a/src/layers/domain/repository/keys/v3.py b/src/layers/domain/repository/keys/v3.py index d03b3df81..42cebb723 100644 --- a/src/layers/domain/repository/keys/v3.py +++ b/src/layers/domain/repository/keys/v3.py @@ -5,10 +5,8 @@ class TableKey(TableKeyAction, StrEnum): PRODUCT_TEAM = "PT" - PRODUCT_TEAM_KEY = "PTK" CPM_SYSTEM_ID = "CSI" CPM_PRODUCT = "P" - CPM_PRODUCT_KEY = "PK" CPM_PRODUCT_STATUS = "PS" DEVICE_REFERENCE_DATA = "DRD" DEVICE = "D" diff --git a/src/layers/domain/repository/marshall.py b/src/layers/domain/repository/marshall.py index e07bfdc45..62a6c0aa2 100644 --- a/src/layers/domain/repository/marshall.py +++ b/src/layers/domain/repository/marshall.py @@ -14,12 +14,12 @@ } -def marshall_value(value): +def marshall_value(value) -> dict: fn = MARSHALL_FUNCTION_BY_TYPE.get(type(value), (lambda x: {"S": str(x)})) return fn(value) -def marshall(**data): +def marshall(**data) -> dict: return marshall_value(data)["M"] diff --git a/src/layers/domain/repository/product_team_repository/v2.py b/src/layers/domain/repository/product_team_repository/v2.py index 355be1e3a..dd75b58af 100644 --- a/src/layers/domain/repository/product_team_repository/v2.py +++ b/src/layers/domain/repository/product_team_repository/v2.py @@ -1,95 +1,37 @@ from attr import asdict -from domain.core.cpm_system_id.v1 import ProductTeamId from domain.core.product_team.v3 import ProductTeam, ProductTeamCreatedEvent from domain.core.product_team_key import ProductTeamKey -from domain.repository.errors import ItemNotFound from domain.repository.keys.v3 import TableKey -from domain.repository.marshall import marshall, marshall_value, unmarshall -from domain.repository.repository.v2 import Repository -from domain.repository.transaction import ( - ConditionExpression, - TransactionStatement, - TransactItem, -) - - -def create_product_team_index( - table_name: str, - product_team_data: dict, - pk_key_parts: tuple[str], - sk_key_parts: tuple[str], - pk_table_key: TableKey = TableKey.PRODUCT_TEAM, - sk_table_key: TableKey = TableKey.PRODUCT_TEAM, - root=False, -) -> TransactItem: - pk = pk_table_key.key(*pk_key_parts) - sk = sk_table_key.key(*sk_key_parts) - return TransactItem( - Put=TransactionStatement( - TableName=table_name, - Item=marshall(pk=pk, sk=sk, root=root, **product_team_data), - ConditionExpression=ConditionExpression.MUST_NOT_EXIST, - ) - ) +from domain.repository.repository.v3 import Repository class ProductTeamRepository(Repository[ProductTeam]): def __init__(self, table_name: str, dynamodb_client): super().__init__( - table_name=table_name, model=ProductTeam, dynamodb_client=dynamodb_client + table_name=table_name, + model=ProductTeam, + dynamodb_client=dynamodb_client, + table_key=TableKey.PRODUCT_TEAM, + parent_table_keys=(TableKey.PRODUCT_TEAM,), ) + def read(self, id: str) -> ProductTeam: + return super()._read(parent_ids=(), id=id) + def handle_ProductTeamCreatedEvent(self, event: ProductTeamCreatedEvent): - create_transaction = create_product_team_index( - table_name=self.table_name, - product_team_data=asdict(event), - pk_key_parts=(event.id,), - sk_key_parts=(event.id,), - root=True, + create_root_transaction = self.create_index( + id=event.id, parent_key_parts=(event.id,), data=asdict(event), root=True ) - product_team_keys = {ProductTeamKey(**key) for key in event.keys} - transactions = [] - for product_team_key in product_team_keys: - key_type = ( - product_team_key.key_type.replace("_", " ").title().replace(" ", "") + keys = {ProductTeamKey(**key) for key in event.keys} + create_key_transactions = [ + self.create_index( + id=key.key_value, + parent_key_parts=(key.key_value,), + data=asdict(event), + root=True, ) - create_key_transaction = create_product_team_index( - table_name=self.table_name, - product_team_data=asdict(event), - pk_key_parts=( - key_type, - product_team_key.key_value, - ), - sk_key_parts=( - key_type, - product_team_key.key_value, - ), - pk_table_key=TableKey.PRODUCT_TEAM_KEY, - sk_table_key=TableKey.PRODUCT_TEAM_KEY, - ) - transactions.append(create_key_transaction) - - return [create_transaction] + transactions - - def read(self, id) -> ProductTeam: - pk = ( - TableKey.PRODUCT_TEAM.key(id) - if ProductTeamId.validate_cpm_system_id(id) - else TableKey.PRODUCT_TEAM_KEY.key(f"ProductTeamIdAlias#{id}") - ) - args = { - "TableName": self.table_name, - "KeyConditionExpression": "pk = :pk AND sk = :sk", - "ExpressionAttributeValues": { - ":pk": marshall_value(pk), - ":sk": marshall_value(pk), - }, - } - result = self.client.query(**args) - items = [unmarshall(i) for i in result["Items"]] - if len(items) == 0: - raise ItemNotFound(id, item_type=ProductTeam) - (item,) = items + for key in keys + ] - return ProductTeam(**item) + return [create_root_transaction] + create_key_transactions diff --git a/src/layers/domain/repository/repository/tests/model_v3.py b/src/layers/domain/repository/repository/tests/model_v3.py new file mode 100644 index 000000000..a9d18d24a --- /dev/null +++ b/src/layers/domain/repository/repository/tests/model_v3.py @@ -0,0 +1,105 @@ +from enum import StrEnum + +from attr import asdict, dataclass +from domain.repository.keys.v1 import TableKeyAction +from domain.repository.marshall import marshall +from domain.repository.repository.v3 import Repository +from domain.repository.transaction import ( + ConditionExpression, + TransactionStatement, + TransactItem, +) +from event.aws.client import dynamodb_client +from pydantic import BaseModel, Field + +from test_helpers.terraform import read_terraform_output + + +class MyTableKey(TableKeyAction, StrEnum): + FOO = "foo" + BAR = "bar" + + +@dataclass +class MyEventAdd: + field: str + + +@dataclass +class MyOtherEventAdd: + field: str + + +@dataclass +class MyEventDelete: + field: str + + +class MyModel(BaseModel): + field: str + events: list[MyEventAdd | MyOtherEventAdd | MyEventDelete] = Field( + default_factory=list, exclude=True + ) + + class Config: + arbitrary_types_allowed = True + + +class MyRepositoryV3(Repository[MyModel]): + def __init__(self): + table_name = read_terraform_output("dynamodb_table_name.value") + super().__init__( + table_name=table_name, + model=MyModel, + dynamodb_client=dynamodb_client(), + parent_table_keys=(MyTableKey.FOO,), + table_key=MyTableKey.FOO, + ) + + def read(self, id: str): + return self._read(parent_ids=(id,), id=id) + + def handle_bulk(self, item): + return [{"PutRequest": {"Item": marshall(**item)}}] + + def handle_MyEventAdd(self, event: MyEventAdd): + # This event will raise a transaction error on duplicates + return TransactItem( + Put=TransactionStatement( + TableName=self.table_name, + Item=marshall( + pk=MyTableKey.FOO.key(event.field), + sk=MyTableKey.FOO.key(event.field), + pk_read=MyTableKey.FOO.key(event.field), + sk_read=MyTableKey.FOO.key(event.field), + **asdict(event) + ), + ConditionExpression=ConditionExpression.MUST_NOT_EXIST, + ) + ) + + def handle_MyOtherEventAdd(self, event: MyOtherEventAdd): + return TransactItem( + Put=TransactionStatement( + TableName=self.table_name, + Item=marshall( + pk=MyTableKey.BAR.key(event.field), + sk=MyTableKey.BAR.key(event.field), + pk_read=MyTableKey.BAR.key(event.field), + sk_read=MyTableKey.BAR.key(event.field), + **asdict(event) + ), + ) + ) + + def handle_MyEventDelete(self, event: MyEventDelete): + return TransactItem( + Delete=TransactionStatement( + TableName=self.table_name, + Key=marshall( + pk=MyTableKey.FOO.key(event.field), + sk=MyTableKey.FOO.key(event.field), + ), + ConditionExpression=ConditionExpression.MUST_EXIST, + ) + ) diff --git a/src/layers/domain/repository/repository/tests/test_repository_v1.py b/src/layers/domain/repository/repository/tests/test_repository_v1.py index d7b032eb5..2a6605cce 100644 --- a/src/layers/domain/repository/repository/tests/test_repository_v1.py +++ b/src/layers/domain/repository/repository/tests/test_repository_v1.py @@ -79,9 +79,9 @@ def test_repository_raise_already_exists_from_single_transaction( ( "ValidationException: Transaction request cannot include multiple operations on one item", f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:456"}}, "sk": {{"S": "prefix:456"}}, "field": {{"S": "456"}}}}}}}}', - f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}', + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk)"}}}}', f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "prefix:345"}}, "sk": {{"S": "prefix:345"}}, "field": {{"S": "345"}}}}}}}}', - f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk) AND attribute_not_exists(sk) AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)"}}}}', + f'{{"Put": {{"TableName": "{repository.table_name}", "Item": {{"pk": {{"S": "123"}}, "sk": {{"S": "123"}}, "field": {{"S": "123"}}}}, "ConditionExpression": "attribute_not_exists(pk)"}}}}', ) ) diff --git a/src/layers/domain/repository/repository/tests/test_repository_v3.py b/src/layers/domain/repository/repository/tests/test_repository_v3.py new file mode 100644 index 000000000..9512e759d --- /dev/null +++ b/src/layers/domain/repository/repository/tests/test_repository_v3.py @@ -0,0 +1,259 @@ +import pytest +from domain.repository.errors import AlreadyExistsError, ItemNotFound +from domain.repository.repository.v3 import ( + exponential_backoff_with_jitter, + retry_with_jitter, +) + +from .model_v3 import ( + MyEventAdd, + MyEventDelete, + MyModel, + MyOtherEventAdd, + MyRepositoryV3, + MyTableKey, +) + + +@pytest.fixture +def repository() -> "MyRepositoryV3": + return MyRepositoryV3() + + +@pytest.mark.integration +def test_single_repository_write(repository: MyRepositoryV3): + value = "123" + my_item = MyModel(field=value, events=[MyEventAdd(field=value)]) + repository.write(my_item) + assert repository.read(id=value).dict() == my_item.dict() + + +@pytest.mark.integration +def test_writes_to_same_key_split_over_batches_repository_write( + repository: MyRepositoryV3, +): + first_value = "123" + second_value = "abc" + third_value = "xyz" + + my_item = MyModel( + field=first_value, + events=[ + MyEventAdd(field=first_value), + MyEventAdd(field=second_value), + MyEventAdd(field=third_value), + # batch split should occur here since MyEventDelete requires + # MyEventAdd to have occurred first + MyEventDelete(field=first_value), + MyEventDelete(field=second_value), + MyEventDelete(field=third_value), + ], + ) + db_responses = repository.write(my_item) + batch_count = len(db_responses) + + with pytest.raises(ItemNotFound): + repository.read(id=first_value) + + with pytest.raises(ItemNotFound): + repository.read(id=second_value) + + with pytest.raises(ItemNotFound): + repository.read(id=third_value) + + assert batch_count == 2 + + +@pytest.mark.integration +@pytest.mark.parametrize( + ["number_of_adds", "number_of_batches"], + [ + (12, 1), + (100, 1), + (101, 2), + (150, 2), + (200, 2), + (201, 3), + ], +) +def test_writes_to_different_keys_split_over_batches_repository_write( + repository: MyRepositoryV3, number_of_adds: int, number_of_batches: int +): + my_item = MyModel( + field="abc", + events=[MyEventAdd(field=str(i)) for i in range(number_of_adds)], + ) + db_responses = repository.write(my_item) + batch_count = len(db_responses) + assert batch_count == number_of_batches + + +@pytest.mark.integration +def test_repository_raise_already_exists(repository: MyRepositoryV3): + my_item = MyModel(field="123", events=[MyEventAdd(field="123")]) + repository.write(my_item) + with pytest.raises(AlreadyExistsError): + repository.write(my_item) + + +@pytest.mark.integration +def test_repository_raise_already_exists_multiple_events(repository: MyRepositoryV3): + my_item = MyModel( + field="123", + events=[ + MyOtherEventAdd(field="456"), + MyEventAdd(field="123"), + MyOtherEventAdd(field="345"), + ], + ) + repository.write(my_item) + with pytest.raises(AlreadyExistsError): + repository.write(my_item) # Should cause AlreadyExistsError + + +@pytest.mark.integration +def test_repository_add_and_delete_separate_transactions(repository: MyRepositoryV3): + value = "123" + my_item = MyModel(field=value, events=[MyEventAdd(field=value)]) + repository.write(my_item) + intermediate_item = repository.read(id=value) + + assert intermediate_item == my_item + + intermediate_item.events.append(MyEventDelete(field=value)) + repository.write(intermediate_item) + + with pytest.raises(ItemNotFound): + repository.read(id=value) + + +@pytest.mark.integration +def test_repository_write_bulk(repository: MyRepositoryV3): + responses = repository.write_bulk( + [ + { + "pk": str(i), + "sk": str(i), + "pk_read": MyTableKey.FOO.key(str(i)), + "sk_read": MyTableKey.FOO.key(str(i)), + "field": f"boo-{i}", + } + for i in range(51) + ], + batch_size=25, + ) + assert len(responses) >= 3 # 51/25 + + for i in range(51): + assert repository.read(id=str(i)).field == f"boo-{i}" + + +def test_exponential_backoff_with_jitter(): + base_delay = 0.1 + max_delay = 5 + min_delay = 0.05 + n_samples = 1000 + + delays = [] + for retry in range(n_samples): + delay = exponential_backoff_with_jitter( + n_retries=retry, + base_delay=base_delay, + min_delay=min_delay, + max_delay=max_delay, + ) + assert max_delay >= delay >= min_delay + delays.append(delay) + assert len(set(delays)) == n_samples # all delays should be unique + assert sum(delays[n_samples:]) < sum( + delays[:n_samples] + ) # final delays should be larger than first delays + + +@pytest.mark.parametrize( + "error_code", + [ + "ProvisionedThroughputExceededException", + "ThrottlingException", + "InternalServerError", + ], +) +def test_retry_with_jitter_all_fail(error_code: str): + class MockException(Exception): + def __init__(self, error_code): + self.response = {"Error": {"Code": error_code}} + + max_retries = 3 + + @retry_with_jitter(max_retries=max_retries, error=MockException) + def throw(error_code): + raise MockException(error_code=error_code) + + with pytest.raises(ExceptionGroup) as exception_info: + throw(error_code=error_code) + + assert ( + exception_info.value.message + == f"Failed to put item after {max_retries} retries" + ) + assert len(exception_info.value.exceptions) == max_retries + assert all( + isinstance(exc, MockException) for exc in exception_info.value.exceptions + ) + + +@pytest.mark.parametrize( + "error_code", + [ + "ProvisionedThroughputExceededException", + "ThrottlingException", + "InternalServerError", + ], +) +def test_retry_with_jitter_third_passes(error_code: str): + class MockException(Exception): + retries = 0 + + def __init__(self, error_code): + self.response = {"Error": {"Code": error_code}} + + max_retries = 3 + + @retry_with_jitter(max_retries=max_retries, error=MockException) + def throw(error_code): + if MockException.retries == max_retries - 1: + return "foo" + MockException.retries += 1 + raise MockException(error_code=error_code) + + assert throw(error_code=error_code) == "foo" + + +@pytest.mark.parametrize( + "error_code", + [ + "SomeOtherError", + ], +) +def test_retry_with_jitter_other_code(error_code: str): + class MockException(Exception): + def __init__(self, error_code): + self.response = {"Error": {"Code": error_code}} + + @retry_with_jitter(max_retries=3, error=MockException) + def throw(error_code): + raise MockException(error_code=error_code) + + with pytest.raises(MockException) as exception_info: + throw(error_code=error_code) + + assert exception_info.value.response == {"Error": {"Code": error_code}} + + +def test_retry_with_jitter_other_exception(): + @retry_with_jitter(max_retries=3, error=ValueError) + def throw(): + raise TypeError() + + with pytest.raises(TypeError): + throw() diff --git a/src/layers/domain/repository/repository/v3.py b/src/layers/domain/repository/repository/v3.py new file mode 100644 index 000000000..9434d9180 --- /dev/null +++ b/src/layers/domain/repository/repository/v3.py @@ -0,0 +1,305 @@ +import random +import time +from abc import abstractmethod +from enum import StrEnum +from functools import wraps +from itertools import batched, chain +from typing import TYPE_CHECKING, Generator, Iterable + +from botocore.exceptions import ClientError +from domain.core.aggregate_root import AggregateRoot +from domain.repository.errors import ItemNotFound +from domain.repository.keys.v1 import KEY_SEPARATOR +from domain.repository.keys.v3 import TableKey +from domain.repository.marshall import marshall, unmarshall +from domain.repository.transaction import ( # TransactItem, + ConditionExpression, + Transaction, + TransactionStatement, + TransactItem, + handle_client_errors, + update_transactions, +) + +if TYPE_CHECKING: + from mypy_boto3_dynamodb import DynamoDBClient + from mypy_boto3_dynamodb.type_defs import ( + BatchWriteItemOutputTypeDef, + TransactWriteItemsOutputTypeDef, + ) + +BATCH_SIZE = 100 +MAX_BATCH_WRITE_SIZE = 10 +RETRY_ERRORS = [ + "ProvisionedThroughputExceededException", + "ThrottlingException", + "InternalServerError", +] + + +class TooManyResults(Exception): + pass + + +class QueryType(StrEnum): + EQUALS = "{} = {}" + BEGINS_WITH = "begins_with({}, {})" + + +def exponential_backoff_with_jitter( + n_retries, base_delay=0.1, min_delay=0.05, max_delay=5 +): + """Calculate the delay with exponential backoff and jitter.""" + delay = min(base_delay * (2**n_retries), max_delay) + return random.uniform(min_delay, delay) + + +def retry_with_jitter(max_retries=5, error=ClientError): + def wrapper(func): + @wraps(func) + def wrapped(*args, **kwargs): + exceptions = [] + while len(exceptions) < max_retries: + try: + return func(*args, **kwargs) + except error as e: + error_code = e.response["Error"]["Code"] + if error_code not in RETRY_ERRORS: + raise + exceptions.append(e) + delay = exponential_backoff_with_jitter(n_retries=len(exceptions)) + time.sleep(delay) + raise ExceptionGroup( + f"Failed to put item after {max_retries} retries", exceptions + ) + + return wrapped + + return wrapper + + +def _split_transactions_by_key( + transact_items: Iterable[TransactItem], n_max: int +) -> Generator[list[TransactItem], None, None]: + buffer, keys = [], set() + for transact_item in transact_items: + transaction_statement = ( + transact_item.Put or transact_item.Delete or transact_item.Update + ) + item = transaction_statement.Key or transaction_statement.Item + key = (item["pk"]["S"], item["sk"]["S"]) + if key in keys: + yield from batched(buffer, n=n_max) + buffer, keys = [], set() + buffer.append(transact_item) + keys.add(key) + yield from batched(buffer, n=n_max) + + +def transact_write_chunk( + client: "DynamoDBClient", chunk: list[TransactItem] +) -> "TransactWriteItemsOutputTypeDef": + transaction = Transaction(TransactItems=chunk) + with handle_client_errors(commands=chunk): + _response = client.transact_write_items(**transaction.dict(exclude_none=True)) + return _response + + +@retry_with_jitter() +def batch_write_chunk( + client: "DynamoDBClient", table_name: str, chunk: list[dict] +) -> "BatchWriteItemOutputTypeDef": + while chunk: + _response = client.batch_write_item(RequestItems={table_name: chunk}) + chunk = _response["UnprocessedItems"].get(table_name) + return _response + + +class Repository[ModelType: AggregateRoot]: + + def __init__( + self, + table_name, + model: type[ModelType], + dynamodb_client, + parent_table_keys: tuple[TableKey], + table_key: TableKey, + ): + self.table_name = table_name + self.model = model + self.client: "DynamoDBClient" = dynamodb_client + self.batch_size = BATCH_SIZE + self.parent_table_keys = parent_table_keys + self.table_key = table_key + + @abstractmethod + def handle_bulk(self, item): ... + + def write(self, entity: ModelType, batch_size=None): + batch_size = batch_size or self.batch_size + + def generate_transaction_statements(event): + handler_name = f"handle_{type(event).__name__}" + handler = getattr(self, handler_name) + transact_items = handler(event=event) + + if not isinstance(transact_items, list): + transact_items = [transact_items] + return transact_items + + transact_items = chain.from_iterable( + (generate_transaction_statements(event) for event in entity.events) + ) + + responses = [ + transact_write_chunk(client=self.client, chunk=transact_item_chunk) + for transact_item_chunk in _split_transactions_by_key( + transact_items, batch_size + ) + ] + return responses + + def write_bulk(self, entities: list[ModelType], batch_size=None): + batch_size = batch_size or MAX_BATCH_WRITE_SIZE + batch_write_items = list(chain.from_iterable(map(self.handle_bulk, entities))) + responses = [ + batch_write_chunk( + client=self.client, table_name=self.table_name, chunk=chunk + ) + for chunk in batched(batch_write_items, batch_size) + ] + return responses + + def create_index( + self, + id: str, + parent_key_parts: tuple[str], + data: dict, + root: bool, + table_key: TableKey = None, + parent_table_keys: tuple[TableKey] = None, + ) -> TransactItem: + if table_key is None: + table_key = self.table_key + if parent_table_keys is None: + parent_table_keys = self.parent_table_keys + + if len(parent_table_keys) != len(parent_key_parts): + raise ValueError( + f"Expected provide {len(parent_table_keys)} parent key parts, got {len(parent_key_parts)}" + ) + + write_key = table_key.key(id) + read_key = KEY_SEPARATOR.join( + table_key.key(_id) + for table_key, _id in zip(parent_table_keys, parent_key_parts) + ) + + return TransactItem( + Put=TransactionStatement( + TableName=self.table_name, + Item=marshall( + pk=write_key, + sk=write_key, + pk_read=read_key, + sk_read=write_key, + root=root, + **data, + ), + ConditionExpression=ConditionExpression.MUST_NOT_EXIST, + ) + ) + + def create_index_batch( + self, + id: str, + parent_key_parts: tuple[str], + data: dict, + root: bool, + table_key: TableKey = None, + parent_table_keys: tuple[TableKey] = None, + ) -> TransactItem: + """ + Difference between `create_index` and `create_index_batch`: + + `create_index` is intended for the event-based + handlers (e.g. `handle_XyzCreatedEvent`) which are called by the base + `write` method, which expects `TransactItem`s for use with `client.transact_write_items` + + `create_index_batch` is intended for the entity-based handler + `handle_bulk` which is called by the base method `write_bulk`, which expects + `BatchWriteItem`s which we render as a `dict` for use with `client.batch_write_items` + """ + + if table_key is None: + table_key = self.table_key + if parent_key_parts is None: + parent_table_keys = self.parent_table_keys + + write_key = table_key.key(id) + read_key = KEY_SEPARATOR.join( + table_key.key(_id) + for table_key, _id in zip(parent_table_keys, parent_key_parts) + ) + + return { + "PutRequest": { + "Item": marshall( + pk=write_key, + sk=write_key, + pk_read=read_key, + sk_read=write_key, + root=root, + **data, + ), + }, + } + + def update_indexes(self, id: str, keys: list[str], data: dict): + primary_keys = [ + marshall(pk=pk, sk=pk) for pk in map(self.table_key.key, [id, *keys]) + ] + return update_transactions( + table_name=self.table_name, primary_keys=primary_keys, data=data + ) + + def delete_index(self, id: str): + pk = self.table_key.key(id) + return TransactItem( + Delete=TransactionStatement( + TableName=self.table_name, + Key=marshall(pk=pk, sk=pk), + ConditionExpression=ConditionExpression.MUST_EXIST, + ) + ) + + def _query(self, parent_ids: tuple[str], id: str = None) -> list[ModelType]: + pk_read = KEY_SEPARATOR.join( + table_key.key(_id) + for table_key, _id in zip(self.parent_table_keys, parent_ids) + ) + sk_read = self.table_key.key(id or "") + + sk_query_type = QueryType.BEGINS_WITH if id is None else QueryType.EQUALS + sk_condition = sk_query_type.format("sk_read", ":sk_read") + + args = { + "TableName": self.table_name, + "IndexName": "idx_gsi_read", + "KeyConditionExpression": f"pk_read = :pk_read AND {sk_condition}", + "ExpressionAttributeValues": marshall( + **{":pk_read": pk_read, ":sk_read": sk_read} + ), + } + result = self.client.query(**args) + if "LastEvaluatedKey" in result: + raise TooManyResults(f"Too many results for query ({(*parent_ids, id)})") + return [self.model(**item) for item in map(unmarshall, result["Items"])] + + def _read(self, parent_ids: tuple[str], id: str) -> ModelType: + items = self._query(parent_ids=parent_ids or (id,), id=id) + try: + (item,) = items + except ValueError: + raise ItemNotFound(*parent_ids, id, item_type=self.model) + return item diff --git a/src/layers/domain/repository/tests/test_transaction.py b/src/layers/domain/repository/tests/test_transaction.py index 827541ba8..e70301a66 100644 --- a/src/layers/domain/repository/tests/test_transaction.py +++ b/src/layers/domain/repository/tests/test_transaction.py @@ -15,11 +15,7 @@ Put=TransactionStatement( TableName="table", Item={}, - ConditionExpression=( - "attribute_not_exists(pk) AND attribute_not_exists(sk) " - "AND attribute_not_exists(pk_1) AND attribute_not_exists(sk_1) " - "AND attribute_not_exists(pk_2) AND attribute_not_exists(sk_2)" - ), + ConditionExpression="attribute_not_exists(pk)", ) ) ] diff --git a/src/layers/domain/repository/transaction.py b/src/layers/domain/repository/transaction.py index 609795d49..a0a50649a 100644 --- a/src/layers/domain/repository/transaction.py +++ b/src/layers/domain/repository/transaction.py @@ -10,14 +10,12 @@ from .errors import AlreadyExistsError, UnhandledTransaction -ATTRIBUTE_NOT_EXISTS = "attribute_not_exists({})".format - class ConditionExpression(StrEnum): - MUST_EXIST = "attribute_exists(pk) and attribute_exists(sk)" - MUST_NOT_EXIST = " AND ".join( - map(ATTRIBUTE_NOT_EXISTS, ("pk", "sk", "pk_1", "sk_1", "pk_2", "sk_2")) - ) + # NB: confusingly in DynamoDB "pk" in ConditionExpressions means "primary key" ("pk" + "sk") + # rather than the individual "pk" field + MUST_EXIST = "attribute_exists(pk)" + MUST_NOT_EXIST = "attribute_not_exists(pk)" TRANSACTION_ERROR_MAPPING = { @@ -126,3 +124,18 @@ def update_transactions( TransactItem(Update=update_statement(Key=key)) for key in primary_keys ] return transact_items + + +def dynamodb_projection_expression(updated_fields: list[str]): + expression_attribute_names = {} + update_clauses = [] + for field_name in updated_fields: + field_name_placeholder = f"#{field_name}" + update_clauses.append(field_name_placeholder) + expression_attribute_names[field_name_placeholder] = field_name + projection_expression = ", ".join(update_clauses) + + return dict( + ProjectionExpression=projection_expression, + ExpressionAttributeNames=expression_attribute_names, + ) diff --git a/src/test_helpers/dynamodb.py b/src/test_helpers/dynamodb.py index fab833a37..a96bd72c4 100644 --- a/src/test_helpers/dynamodb.py +++ b/src/test_helpers/dynamodb.py @@ -13,23 +13,20 @@ ] GLOBAL_SECONDARY_INDEXES = [ { - "IndexName": f"idx_gsi_{i}", + "IndexName": "idx_gsi_read", "KeySchema": [ - {"AttributeName": f"pk_{i}", "KeyType": "HASH"}, - {"AttributeName": f"sk_{i}", "KeyType": "RANGE"}, + {"AttributeName": "pk_read", "KeyType": "HASH"}, + {"AttributeName": "sk_read", "KeyType": "RANGE"}, ], "Projection": {"ProjectionType": "ALL"}, } - for i in range(5) ] -ATTRIBUTE_DEFINITIONS = ( - [ - {"AttributeName": "pk", "AttributeType": "S"}, - {"AttributeName": "sk", "AttributeType": "S"}, - ] - + [{"AttributeName": f"pk_{i}", "AttributeType": "S"} for i in range(5)] - + [{"AttributeName": f"sk_{i}", "AttributeType": "S"} for i in range(5)] -) +ATTRIBUTE_DEFINITIONS = [ + {"AttributeName": "pk", "AttributeType": "S"}, + {"AttributeName": "sk", "AttributeType": "S"}, + {"AttributeName": "pk_read", "AttributeType": "S"}, + {"AttributeName": "sk_read", "AttributeType": "S"}, +] def _scan(client: DynamoDBClient, table_name: str) -> Generator[dict, None, None]: