Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip erroring ooi's in ooi_repository.py #4069

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
37 changes: 29 additions & 8 deletions octopoes/octopoes/repositories/ooi_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import structlog
from bits.definitions import BitDefinition
from httpx import HTTPStatusError, codes
from pydantic import RootModel, TypeAdapter
from pydantic import RootModel, TypeAdapter, ValidationError

from octopoes.config.settings import (
DEFAULT_LIMIT,
Expand Down Expand Up @@ -240,12 +240,13 @@ def serialize(cls, ooi: OOI) -> dict[str, Any]:
return export

@classmethod
def deserialize(cls, data: dict[str, Any], to_type: type[OOI] | None = None) -> OOI:
def deserialize(
cls, data: dict[str, Any], to_type: type[OOI] | None = None, skip_errors: bool = False
) -> OOI | bool:
if "object_type" not in data:
raise ValueError("Data is missing object_type")

object_cls = type_by_name(data["object_type"])
object_cls = to_type or object_cls
object_cls = to_type or type_by_name(data["object_type"])
user_id = data.get("user_id")

# remove type prefixes
Expand All @@ -259,7 +260,19 @@ def deserialize(cls, data: dict[str, Any], to_type: type[OOI] | None = None) ->
if scan_profiles := data.get("_reference", []):
stripped["scan_profile"] = scan_profiles[0]

return object_cls.model_validate(stripped)
try:
return object_cls.model_validate(stripped)
except ValidationError as error:
if skip_errors:
logger.error(
"""An OOI could not be validated due to a mismatch between the database and the current models.
PK: %r on (wanted) type %s. Validation error: %r""",
stripped["primary_key"],
object_cls,
error,
)
return False
raise error

def get(self, reference: Reference, valid_time: datetime) -> OOI:
try:
Expand Down Expand Up @@ -308,7 +321,11 @@ def load_bulk_as_list(self, references: set[Reference], valid_time: datetime) ->
return []

query = Query().where_in(OOI, id=references).pull(OOI, fields="[* {:_reference [*]}]")
return [self.deserialize(x[0]) for x in self.session.client.query(query, valid_time)]
return [
deserialized
for x in self.session.client.query(query, valid_time)
if (deserialized := self.deserialize(data=x[0], skip_errors=True)) is not False
]

def list_oois(
self,
Expand Down Expand Up @@ -385,7 +402,7 @@ def list_oois(
)

res = self.session.client.query(data_query, valid_time)
oois = [self.deserialize(x[0]) for x in res]
oois = [deserialized for x in res if (deserialized := self.deserialize(data=x[0], skip_errors=True))]
return Paginated(count=count, items=oois)

def list_oois_by_object_types(self, types: set[type[OOI]], valid_time: datetime) -> list[OOI]:
Expand All @@ -400,7 +417,11 @@ def list_oois_by_object_types(self, types: set[type[OOI]], valid_time: datetime)
:in-args [[{object_types}]]
}}
""".format(object_types=" ".join(map(lambda t: str_val(t.get_object_type()), types)))
return [self.deserialize(x[0]) for x in self.session.client.query(data_query, valid_time)]
return [
deserialized
for x in self.session.client.query(data_query, valid_time)
if (deserialized := self.deserialize(data=x[0], skip_errors=True))
]

def list_random(
self, valid_time: datetime, amount: int = 1, scan_levels: set[ScanLevel] = DEFAULT_SCAN_LEVEL_FILTER
Expand Down
Loading