diff --git a/src/uagents/models.py b/src/uagents/models.py index 8429430b..15aa0b9f 100644 --- a/src/uagents/models.py +++ b/src/uagents/models.py @@ -1,60 +1,37 @@ import hashlib -from typing import Type, Union, Dict, ClassVar, Any - +from typing import Type, Union, Dict +import json from pydantic import BaseModel -from pydantic.schema import model_schema, default_ref_template class Model(BaseModel): - schema_no_descriptions: ClassVar[Union[Dict[str, Any], None]] = None + @staticmethod + def remove_descriptions(schema: Dict[str, Dict[str, str]]): + fields_with_descr = [] + if not "properties" in schema: + return + for field_name, field_props in schema["properties"].items(): + if "description" in field_props: + fields_with_descr.append(field_name) - @classmethod - def _remove_descriptions( - cls, orig_descriptions: Dict[str, Union[str, Dict]] - ): - for field_name, field in cls.__fields__.items(): - if field.field_info and field.field_info.description: - orig_descriptions[field_name] = field.field_info.description - field.field_info.description = None - elif issubclass(field.type_, Model): - orig_descriptions[field_name] = {} - Model._remove_descriptions(field.type_, orig_descriptions[field_name]) + for field_name in fields_with_descr: + del schema["properties"][field_name]["description"] - @classmethod - def _restore_descriptions(cls, orig_descriptions: Dict[str, Union[str, Dict]] - ): - for field_name, field in cls.__fields__.items(): - if ( - field.field_info - and field_name in orig_descriptions - and not issubclass(field.type_, Model) - ): - field.field_info.description = orig_descriptions[field_name] - elif issubclass(field.type_, Model): - Model._restore_descriptions(field.type_, orig_descriptions[field_name]) + if "definitions" in schema: + for definition in schema["definitions"].values(): + Model.remove_descriptions(definition) @classmethod - def _restore_schema_cache(cls): - schema = model_schema(cls, by_alias=True, ref_template=default_ref_template) - cls.__schema_cache__[(True, default_ref_template)] = schema + def schema_json_no_descr(cls) -> str: + orig_schema = json.loads(cls.schema_json(indent=None, sort_keys=True)) + Model.remove_descriptions(orig_schema) + return json.dumps(orig_schema) @staticmethod def build_schema_digest(model: Union["Model", Type["Model"]]) -> str: - type_obj: Type["Model"] = model if isinstance(model, type) else model.__class__ - if type_obj.schema_no_descriptions is None: - orig_descriptions: Dict[str, Union[str, Dict]] = {} - type_obj._remove_descriptions(orig_descriptions) digest = ( - hashlib.sha256( - model.schema_json(indent=None, sort_keys=True).encode("utf8") - ) - .digest() - .hex() + hashlib.sha256(model.schema_json_no_descr().encode("utf-8")).digest().hex() ) - if type_obj.schema_no_descriptions is None: - type_obj.schema_no_descriptions = type_obj.schema() - type_obj._restore_descriptions(orig_descriptions) - type_obj._restore_schema_cache() return f"model:{digest}" diff --git a/src/uagents/protocol.py b/src/uagents/protocol.py index bdda74a8..bf990e83 100644 --- a/src/uagents/protocol.py +++ b/src/uagents/protocol.py @@ -176,7 +176,7 @@ def manifest(self) -> Dict[str, Any]: for schema_digest, model in all_models.items(): manifest["models"].append( - {"digest": schema_digest, "schema": model.schema_no_descriptions} + {"digest": schema_digest, "schema": model.schema_json_no_descr()} ) for request, responses in self._replies.items(): @@ -204,7 +204,6 @@ def manifest(self) -> Dict[str, Any]: manifest["models"].append( {"digest": schema_digest, "schema": model.schema()} ) - final_manifest: Dict[str, Any] = copy.deepcopy(manifest) final_manifest["metadata"] = metadata @@ -213,9 +212,9 @@ def manifest(self) -> Dict[str, Any]: @staticmethod def compute_digest(manifest: Dict[str, Any]) -> str: cleaned_manifest = copy.deepcopy(manifest) - if "metadata" in cleaned_manifest: - del cleaned_manifest["metadata"] cleaned_manifest["metadata"] = {} + for model in cleaned_manifest["models"]: + Model.remove_descriptions(model["schema"]) encoded = json.dumps(cleaned_manifest, indent=None, sort_keys=True).encode( "utf8" diff --git a/tests/test_field_descr.py b/tests/test_field_descr.py index 04b5240f..f13a3c97 100644 --- a/tests/test_field_descr.py +++ b/tests/test_field_descr.py @@ -1,3 +1,4 @@ +# pylint: disable=function-redefined import unittest from pydantic import Field from uagents import Model, Protocol @@ -33,14 +34,37 @@ def setUp(self) -> None: self.protocol_with_descr = protocol_with_descr return super().setUp() - def test_field_description(self): - message_with_descr = create_message_with_descr() + def test_schema_json(self): + class Message(Model): + message: str + id: str - Model.build_schema_digest(message_with_descr) + self.assertEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descr(), + ) - message_field_info = message_with_descr.__fields__["message"].field_info - self.assertIsNotNone(message_field_info) - self.assertEqual(message_field_info.description, "message field description") + class Message(Model): + message: str = Field(description="message field description") + id: str = Field(description="id field description") + + self.assertNotEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descr(), + ) + + class MessageArgs(Model): + arg: str = Field(description="arg field description") + + class Message(Model): + message: str = Field(description="message field description") + id: str = Field(description="id field description") + args: MessageArgs + + self.assertNotEqual( + Message.schema_json(indent=None, sort_keys=True), + Message.schema_json_no_descr(), + ) def test_model_digest(self): model_digest_no_descr = Model.build_schema_digest(create_message_no_descr()) @@ -82,6 +106,16 @@ def _(_ctx, _sender, _msg): self.assertEqual(model_digest_no_descr, model_digest_with_descr) self.assertEqual(proto_digest_no_descr, proto_digest_with_descr) + def test_compute_digest(self): + protocol = Protocol(name="test", version="1.1.1") + + @protocol.on_message(create_message_with_descr()) + def _(_ctx, _sender, _msg): + pass + + # computed_digest = Protocol.compute_digest(protocol.manifest()) + # self.assertEqual(protocol.digest, computed_digest) + if __name__ == "__main__": unittest.main()