diff --git a/octopoes/nibbles/runner.py b/octopoes/nibbles/runner.py index 86761cfbdf7..5f2b5a0e781 100644 --- a/octopoes/nibbles/runner.py +++ b/octopoes/nibbles/runner.py @@ -3,7 +3,7 @@ from typing import TypeVar from nibbles.definitions import NibbleDefinition, get_nibble_definitions -from octopoes.models import OOI +from octopoes.models import OOI, Reference from octopoes.models.origin import NibbleOrigin, OriginType from octopoes.models.types import type_by_name from octopoes.repositories.ooi_repository import OOIRepository @@ -47,11 +47,11 @@ def __init__( def update_nibbles(self): self.nibbles: list[NibbleDefinition] = get_nibble_definitions() - def _run(self, ooi: OOI, valid_time: datetime) -> dict[str, list[tuple[set[OOI], set[OOI]]]]: - return_value: dict[str, list[tuple[set[OOI], set[OOI]]]] = {} + def _run(self, ooi: OOI, valid_time: datetime) -> dict[str, dict[frozenset[Reference], set[OOI]]]: + return_value: dict[str, dict[frozenset[Reference], set[OOI]]] = {} for nibble in filter(lambda x: type(ooi) in x.signature, self.nibbles): args = self.ooi_repository.nibble_query(ooi, nibble, valid_time) - results = [(set(arg), set(flatten([nibble(arg)]))) for arg in args] + results = {frozenset({a.reference for a in arg}): set(flatten([nibble(arg)])) for arg in args} if results: return_value |= {nibble.id: results} # TODO: we could cache the writes for single OOI nibbles @@ -63,31 +63,33 @@ def _cleared(self, ooi: OOI, valid_time: datetime) -> bool: target_nibbles = filter(lambda x: type(ooi) in x.signature, self.nibbles) return any(nibble.min_scan_level < ooi_level for nibble in target_nibbles) - def _write(self, inferences: dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]], valid_time: datetime): + def _write(self, inferences: dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]], valid_time: datetime): for source_ooi, results in inferences.items(): self.ooi_repository.save(source_ooi, valid_time) for nibble_id, run_result in results.items(): - for arg, result in run_result: + for arg, result in run_result.items(): nibble_origin = NibbleOrigin( method=nibble_id, origin_type=OriginType.NIBBLE, result=[ooi.reference for ooi in result], source=source_ooi.reference, - parameters=[ooi.reference for ooi in arg], + parameters=list(arg), ) for ooi in result: self.ooi_repository.save(ooi, valid_time=valid_time) self.origin_repository.save(nibble_origin, valid_time=valid_time) - def infer(self, stack: list[OOI], valid_time: datetime) -> dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]]: - inferences: dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]] = {} + def infer( + self, stack: list[OOI], valid_time: datetime + ) -> dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]]: + inferences: dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]] = {} blockset = set(stack) if stack and self._cleared(stack[-1], valid_time): while stack: ooi = stack.pop() results = self._run(ooi, valid_time) if results: - blocks = set.union(*[ooiset for result in results.values() for _, ooiset in result]) + blocks = set.union(*[ooiset for result in results.values() for _, ooiset in result.items()]) stack += [o for o in blocks if o not in blockset] blockset |= blocks inferences |= {ooi: results} diff --git a/octopoes/tests/integration/test_nibbles.py b/octopoes/tests/integration/test_nibbles.py index be44ffd6179..c3bf6de3965 100644 --- a/octopoes/tests/integration/test_nibbles.py +++ b/octopoes/tests/integration/test_nibbles.py @@ -117,9 +117,7 @@ def test_url_classification_nibble(xtdb_octopoes_service: OctopoesService, event assert url in result assert "url_classification" in result[url] assert len(result[url]["url_classification"]) == 1 - assert len(result[url]["url_classification"][0]) == 2 - assert result[url]["url_classification"][0][0] == {url} - assert len(result[url]["url_classification"][0][1]) == 3 + assert len(result[url]["url_classification"][frozenset({url.reference})]) == 3 def find_network_url(network: Network, url: URL) -> Iterator[OOI]: @@ -175,3 +173,10 @@ def test_find_network_url_nibble(xtdb_octopoes_service: OctopoesService, event_m result = nibbler.infer([network1], valid_time) assert network1 in result + assert len(result[network1]["find_network_url"]) == 4 + assert result[network1]["find_network_url"][frozenset([network1.reference, url1.reference])] == set( + find_network_url(network1, url1) + ) + assert result[network1]["find_network_url"][frozenset([network2.reference, url1.reference])] == set() + assert result[network1]["find_network_url"][frozenset([network1.reference, url2.reference])] == set() + assert result[network1]["find_network_url"][frozenset([network2.reference, url2.reference])] == set()