Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
originalsouth committed Nov 19, 2024
1 parent 87909ae commit 4b853d9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
22 changes: 12 additions & 10 deletions octopoes/nibbles/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TypeVar

from nibbles.definitions import NibbleDefinition, get_nibble_definitions
from octopoes.models import OOI
from octopoes.models import OOI, Reference
from octopoes.models.origin import NibbleOrigin, OriginType
from octopoes.models.types import type_by_name
from octopoes.repositories.ooi_repository import OOIRepository
Expand Down Expand Up @@ -47,11 +47,11 @@ def __init__(
def update_nibbles(self):
self.nibbles: list[NibbleDefinition] = get_nibble_definitions()

def _run(self, ooi: OOI, valid_time: datetime) -> dict[str, list[tuple[set[OOI], set[OOI]]]]:
return_value: dict[str, list[tuple[set[OOI], set[OOI]]]] = {}
def _run(self, ooi: OOI, valid_time: datetime) -> dict[str, dict[frozenset[Reference], set[OOI]]]:
return_value: dict[str, dict[frozenset[Reference], set[OOI]]] = {}
for nibble in filter(lambda x: type(ooi) in x.signature, self.nibbles):
args = self.ooi_repository.nibble_query(ooi, nibble, valid_time)
results = [(set(arg), set(flatten([nibble(arg)]))) for arg in args]
results = {frozenset({a.reference for a in arg}): set(flatten([nibble(arg)])) for arg in args}
if results:
return_value |= {nibble.id: results}
# TODO: we could cache the writes for single OOI nibbles
Expand All @@ -63,31 +63,33 @@ def _cleared(self, ooi: OOI, valid_time: datetime) -> bool:
target_nibbles = filter(lambda x: type(ooi) in x.signature, self.nibbles)
return any(nibble.min_scan_level < ooi_level for nibble in target_nibbles)

def _write(self, inferences: dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]], valid_time: datetime):
def _write(self, inferences: dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]], valid_time: datetime):
for source_ooi, results in inferences.items():
self.ooi_repository.save(source_ooi, valid_time)
for nibble_id, run_result in results.items():
for arg, result in run_result:
for arg, result in run_result.items():
nibble_origin = NibbleOrigin(
method=nibble_id,
origin_type=OriginType.NIBBLE,
result=[ooi.reference for ooi in result],
source=source_ooi.reference,
parameters=[ooi.reference for ooi in arg],
parameters=list(arg),
)
for ooi in result:
self.ooi_repository.save(ooi, valid_time=valid_time)
self.origin_repository.save(nibble_origin, valid_time=valid_time)

def infer(self, stack: list[OOI], valid_time: datetime) -> dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]]:
inferences: dict[OOI, dict[str, list[tuple[set[OOI], set[OOI]]]]] = {}
def infer(
self, stack: list[OOI], valid_time: datetime
) -> dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]]:
inferences: dict[OOI, dict[str, dict[frozenset[Reference], set[OOI]]]] = {}
blockset = set(stack)
if stack and self._cleared(stack[-1], valid_time):
while stack:
ooi = stack.pop()
results = self._run(ooi, valid_time)
if results:
blocks = set.union(*[ooiset for result in results.values() for _, ooiset in result])
blocks = set.union(*[ooiset for result in results.values() for _, ooiset in result.items()])
stack += [o for o in blocks if o not in blockset]
blockset |= blocks
inferences |= {ooi: results}
Expand Down
11 changes: 8 additions & 3 deletions octopoes/tests/integration/test_nibbles.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def test_url_classification_nibble(xtdb_octopoes_service: OctopoesService, event
assert url in result
assert "url_classification" in result[url]
assert len(result[url]["url_classification"]) == 1
assert len(result[url]["url_classification"][0]) == 2
assert result[url]["url_classification"][0][0] == {url}
assert len(result[url]["url_classification"][0][1]) == 3
assert len(result[url]["url_classification"][frozenset({url.reference})]) == 3


def find_network_url(network: Network, url: URL) -> Iterator[OOI]:
Expand Down Expand Up @@ -175,3 +173,10 @@ def test_find_network_url_nibble(xtdb_octopoes_service: OctopoesService, event_m

result = nibbler.infer([network1], valid_time)
assert network1 in result
assert len(result[network1]["find_network_url"]) == 4
assert result[network1]["find_network_url"][frozenset([network1.reference, url1.reference])] == set(
find_network_url(network1, url1)
)
assert result[network1]["find_network_url"][frozenset([network2.reference, url1.reference])] == set()
assert result[network1]["find_network_url"][frozenset([network1.reference, url2.reference])] == set()
assert result[network1]["find_network_url"][frozenset([network2.reference, url2.reference])] == set()

0 comments on commit 4b853d9

Please sign in to comment.