diff --git a/backend/composer/services/graph_service.py b/backend/composer/services/graph_service.py index c20d7645..869c1031 100644 --- a/backend/composer/services/graph_service.py +++ b/backend/composer/services/graph_service.py @@ -45,14 +45,14 @@ def create_paths_from_origin(origin, vias, destinations, current_path, destinati # This checks if the last node in the current path is one of the nodes that can lead to the current via. # In other words, it checks if there is a valid connection # from the last node in the current path to the current via. - if current_path[-1][0] in list( - a.name for a in current_via.from_entities.all()) or not current_via.from_entities.exists(): + if (current_path[-1][0] in list(a.name for a in current_via.from_entities.all()) + or (not current_via.from_entities.exists() and current_path[-1][1] == via_layer - 1)): for entity in current_via.anatomical_entities.all(): # Build new sub-paths including the current via entity new_sub_path = current_path + [(entity.name, via_layer)] # Recursively call to build paths from the next vias - new_paths.extend( - create_paths_from_origin(origin, vias[idx + 1:], destinations, new_sub_path, destination_layer)) + new_paths.extend(create_paths_from_origin(origin, vias[idx + 1:], destinations, + new_sub_path, destination_layer)) # Check for direct connections to destinations from the current via for dest in destinations: diff --git a/backend/tests/models/test_vias.py b/backend/tests/models/test_vias.py index 4a4c06b5..553126cf 100644 --- a/backend/tests/models/test_vias.py +++ b/backend/tests/models/test_vias.py @@ -1,5 +1,6 @@ from django.test import TestCase -from composer.models import ConnectivityStatement, Via, Sentence, AnatomicalEntity +from composer.models import ConnectivityStatement, Via, Sentence, AnatomicalEntity, AnatomicalEntityMeta + class ViaModelTestCase(TestCase): @@ -44,7 +45,8 @@ def test_via_deletion_updates_order(self): def test_via_order_change_clears_from_entities(self): statement, initial_vias = self.create_initial_state() - anatomical_entity = AnatomicalEntity.objects.create(name="Test Entity") + anatomical_entity_meta = AnatomicalEntityMeta.objects.create(name="Test Entity") + anatomical_entity = AnatomicalEntity.objects.create(simple_entity=anatomical_entity_meta) for via in initial_vias: via.from_entities.add(anatomical_entity) diff --git a/backend/tests/test_journey.py b/backend/tests/test_journey.py index 2c5c589f..d533371d 100644 --- a/backend/tests/test_journey.py +++ b/backend/tests/test_journey.py @@ -1,22 +1,32 @@ from django.db import connection from django.test import TestCase, override_settings -from composer.models import Sentence, ConnectivityStatement, AnatomicalEntity, Via, Destination +from composer.models import Sentence, ConnectivityStatement, AnatomicalEntity, AnatomicalEntityMeta, Via, Destination from composer.services.graph_service import generate_paths, consolidate_paths @override_settings(DEBUG=True) class JourneyTestCase(TestCase): + def setUp(self): + self.created_entities = {} + + def create_or_get_anatomical_entity(self, name): + if name not in self.created_entities: + meta, _ = AnatomicalEntityMeta.objects.get_or_create(name=name, ontology_uri=name) + entity, _ = AnatomicalEntity.objects.get_or_create(simple_entity=meta) + self.created_entities[name] = entity + return self.created_entities[name] + def test_journey_simple_graph_with_jump(self): # Test setup sentence = Sentence.objects.create() cs = ConnectivityStatement.objects.create(sentence=sentence) - origin1 = AnatomicalEntity.objects.create(name='Oa') - origin2 = AnatomicalEntity.objects.create(name='Ob') - via1 = AnatomicalEntity.objects.create(name='V1a') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + origin2 = self.create_or_get_anatomical_entity("Ob") + via1 = self.create_or_get_anatomical_entity('V1a') + destination1 = self.create_or_get_anatomical_entity('Da') cs.origins.add(origin1, origin2) @@ -42,10 +52,7 @@ def test_journey_simple_graph_with_jump(self): [('Ob', 0), ('Da', 2)] ] - initial_query_count = len(connection.queries) all_paths = generate_paths(origins, vias, destinations) - new_query_count = len(connection.queries) - initial_query_count - self.assertTrue(new_query_count == 0) all_paths.sort() expected_paths.sort() @@ -70,9 +77,9 @@ def test_journey_simple_direct_graph(self): cs = ConnectivityStatement.objects.create(sentence=sentence) # Create Anatomical Entities - origin1 = AnatomicalEntity.objects.create(name='Oa') - origin2 = AnatomicalEntity.objects.create(name='Ob') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + origin2 = self.create_or_get_anatomical_entity("Ob") + destination1 = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin1, origin2) @@ -115,10 +122,10 @@ def test_journey_simple_graph_no_jumps(self): cs = ConnectivityStatement.objects.create(sentence=sentence) # Create Anatomical Entities - origin1 = AnatomicalEntity.objects.create(name='Oa') - origin2 = AnatomicalEntity.objects.create(name='Ob') - via1 = AnatomicalEntity.objects.create(name='V1a') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + origin2 = self.create_or_get_anatomical_entity("Ob") + via1 = self.create_or_get_anatomical_entity('V1a') + destination1 = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin1, origin2) @@ -170,11 +177,11 @@ def test_journey_multiple_vias_no_jumps(self): cs = ConnectivityStatement.objects.create(sentence=sentence) # Create Anatomical Entities - origin1 = AnatomicalEntity.objects.create(name='Oa') - origin2 = AnatomicalEntity.objects.create(name='Ob') - via1 = AnatomicalEntity.objects.create(name='V1a') - via2 = AnatomicalEntity.objects.create(name='V1b') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + origin2 = self.create_or_get_anatomical_entity("Ob") + via1 = self.create_or_get_anatomical_entity('V1a') + via2 = self.create_or_get_anatomical_entity('V1b') + destination1 = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin1, origin2) @@ -232,15 +239,14 @@ def test_journey_complex_graph(self): sentence = Sentence.objects.create() cs = ConnectivityStatement.objects.create(sentence=sentence) - # Create Anatomical Entities - origin_a = AnatomicalEntity.objects.create(name='Oa') - origin_b = AnatomicalEntity.objects.create(name='Ob') - via1_a = AnatomicalEntity.objects.create(name='V1a') - via2_a = AnatomicalEntity.objects.create(name='V2a') - via2_b = AnatomicalEntity.objects.create(name='V2b') - via3_a = AnatomicalEntity.objects.create(name='V3a') - via4_a = AnatomicalEntity.objects.create(name='V4a') - destination_a = AnatomicalEntity.objects.create(name='Da') + origin_a = self.create_or_get_anatomical_entity("Oa") + origin_b = self.create_or_get_anatomical_entity("Ob") + via1_a = self.create_or_get_anatomical_entity('V1a') + via2_a = self.create_or_get_anatomical_entity('V2a') + via2_b = self.create_or_get_anatomical_entity('V2b') + via3_a = self.create_or_get_anatomical_entity('V3a') + via4_a = self.create_or_get_anatomical_entity('V4a') + destination_a = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin_a, origin_b) @@ -304,18 +310,17 @@ def test_journey_complex_graph_2(self): sentence = Sentence.objects.create() cs = ConnectivityStatement.objects.create(sentence=sentence) - # Create Anatomical Entities - origin_a = AnatomicalEntity.objects.create(name='Oa') - origin_b = AnatomicalEntity.objects.create(name='Ob') - via1_a = AnatomicalEntity.objects.create(name='V1a') - via2_a = AnatomicalEntity.objects.create(name='V2a') - via2_b = AnatomicalEntity.objects.create(name='V2b') - via3_a = AnatomicalEntity.objects.create(name='V3a') - via4_a = AnatomicalEntity.objects.create(name='V4a') - via5_a = AnatomicalEntity.objects.create(name='V5a') - via5_b = AnatomicalEntity.objects.create(name='V5b') - via6_a = AnatomicalEntity.objects.create(name='V6a') - destination_a = AnatomicalEntity.objects.create(name='Da') + origin_a = self.create_or_get_anatomical_entity("Oa") + origin_b = self.create_or_get_anatomical_entity("Ob") + via1_a = self.create_or_get_anatomical_entity('V1a') + via2_a = self.create_or_get_anatomical_entity('V2a') + via2_b = self.create_or_get_anatomical_entity('V2b') + via3_a = self.create_or_get_anatomical_entity('V3a') + via4_a = self.create_or_get_anatomical_entity('V4a') + via5_a = self.create_or_get_anatomical_entity('V5a') + via5_b = self.create_or_get_anatomical_entity('V5b') + via6_a = self.create_or_get_anatomical_entity('V6a') + destination_a = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin_a, origin_b) @@ -405,9 +410,9 @@ def test_journey_cycles(self): cs = ConnectivityStatement.objects.create(sentence=sentence) # Create Anatomical Entities - origin1 = AnatomicalEntity.objects.create(name='Oa') - origin2 = AnatomicalEntity.objects.create(name='Ob') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + origin2 = self.create_or_get_anatomical_entity("Ob") + destination1 = self.create_or_get_anatomical_entity('Da') # Add origins cs.origins.add(origin1, origin2) @@ -459,10 +464,10 @@ def test_journey_nonconsecutive_vias(self): sentence = Sentence.objects.create() cs = ConnectivityStatement.objects.create(sentence=sentence) - origin1 = AnatomicalEntity.objects.create(name='Oa') - via1 = AnatomicalEntity.objects.create(name='V1a') - via2 = AnatomicalEntity.objects.create(name='V2a') - destination1 = AnatomicalEntity.objects.create(name='Da') + origin1 = self.create_or_get_anatomical_entity("Oa") + via1 = self.create_or_get_anatomical_entity('V1a') + via2 = self.create_or_get_anatomical_entity("V2a") + destination1 = self.create_or_get_anatomical_entity('Da') cs.origins.add(origin1) @@ -508,3 +513,55 @@ def test_journey_nonconsecutive_vias(self): expected_journey.sort() self.assertTrue(journey_paths == expected_journey, f"Expected journey {expected_journey}, but found {journey_paths}") + + def test_journey_implicit_from_entities(self): + # Test setup + sentence = Sentence.objects.create() + cs = ConnectivityStatement.objects.create(sentence=sentence) + + origin1 = self.create_or_get_anatomical_entity("Myenteric") + via1 = self.create_or_get_anatomical_entity('Longitudinal') + via2 = self.create_or_get_anatomical_entity("Serosa") + via3 = self.create_or_get_anatomical_entity("lumbar") + destination1 = self.create_or_get_anatomical_entity('inferior') + + cs.origins.add(origin1) + + via_a = Via.objects.create(connectivity_statement=cs) + via_a.anatomical_entities.add(via1) + + via_b = Via.objects.create(connectivity_statement=cs) + via_b.anatomical_entities.add(via2) + + via_c = Via.objects.create(connectivity_statement=cs) + via_c.anatomical_entities.add(via3) + + destination = Destination.objects.create(connectivity_statement=cs) + destination.anatomical_entities.add(destination1) + + # Prefetch related data + origins = list(cs.origins.all()) + vias = list( + Via.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities', 'from_entities')) + destinations = list( + Destination.objects.filter(connectivity_statement=cs).prefetch_related('anatomical_entities', + 'from_entities')) + + expected_paths = [ + [('Myenteric', 0), ('Longitudinal', 1), ('Serosa', 2), ('lumbar', 3), ('inferior', 4)], + ] + + all_paths = generate_paths(origins, vias, destinations) + + all_paths.sort() + expected_paths.sort() + self.assertTrue(all_paths == expected_paths, f"Expected paths {expected_paths}, but found {all_paths}") + + journey_paths = consolidate_paths(all_paths) + expected_journey = [ + [('Myenteric', 0), ('Longitudinal', 1), ('Serosa', 2), ('lumbar', 3), ('inferior', 4)], + ] + journey_paths.sort() + expected_journey.sort() + self.assertTrue(journey_paths == expected_journey, + f"Expected journey {expected_journey}, but found {journey_paths}") diff --git a/frontend/src/components/ProofingTab/GraphDiagram/Widgets/OriginNodeWidget.tsx b/frontend/src/components/ProofingTab/GraphDiagram/Widgets/OriginNodeWidget.tsx index dc4dfb89..a7568e48 100644 --- a/frontend/src/components/ProofingTab/GraphDiagram/Widgets/OriginNodeWidget.tsx +++ b/frontend/src/components/ProofingTab/GraphDiagram/Widgets/OriginNodeWidget.tsx @@ -92,7 +92,7 @@ export const OriginNodeWidget: React.FC = ({ lineHeight: "1.25rem", }} > - Intermediolateral nucleus of eleventh thoracic segment + {model.name}