Skip to content

Commit

Permalink
relocate cache ref to dispenser
Browse files Browse the repository at this point in the history
  • Loading branch information
Archento committed Aug 30, 2024
1 parent 2999472 commit bec7141
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 45 deletions.
21 changes: 10 additions & 11 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class Agent(Sink):
_signed_message_handlers (Dict[str, MessageCallback]): Handlers for signed messages.
_unsigned_message_handlers (Dict[str, MessageCallback]): Handlers for
unsigned messages.
_message_cache (EnvelopeHistory): History of messages received by the agent.
_models (Dict[str, Type[Model]]): Dictionary mapping supported message digests to messages.
_replies (Dict[str, Dict[str, Type[Model]]]): Dictionary of allowed replies for each type
of incoming message.
Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(
test (Optional[bool]): True if the agent will register and transact on the testnet.
loop (Optional[asyncio.AbstractEventLoop]): The asyncio event loop to use.
log_level (Union[int, str]): The logging level for the agent.
enable_agent_inspector (bool): Enable the agent inspector for debugging.
"""
self._init_done = False
self._name = name
Expand Down Expand Up @@ -344,14 +346,13 @@ def __init__(
self._interval_messages: Set[str] = set()
self._signed_message_handlers: Dict[str, MessageCallback] = {}
self._unsigned_message_handlers: Dict[str, MessageCallback] = {}
self.sent_messages: EnvelopeHistory = EnvelopeHistory(envelopes=[])
self.received_messages: EnvelopeHistory = EnvelopeHistory(envelopes=[])
self._message_cache: EnvelopeHistory = EnvelopeHistory(envelopes=[])
self._rest_handlers: RestHandlerMap = {}
self._models: Dict[str, Type[Model]] = {}
self._replies: Dict[str, Dict[str, Type[Model]]] = {}
self._queries: Dict[str, asyncio.Future] = {}
self._dispatcher = dispatcher
self._dispenser = Dispenser()
self._dispenser = Dispenser(msg_cache_ref=self._message_cache)
self._message_queue = asyncio.Queue()
self._on_startup = []
self._on_shutdown = []
Expand Down Expand Up @@ -388,7 +389,6 @@ def __init__(
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
sent_messages=self.sent_messages,
)

# register with the dispatcher
Expand All @@ -406,20 +406,20 @@ def __init__(
async def _handle_error_message(ctx: Context, sender: str, msg: ErrorMessage):
ctx.logger.exception(f"Received error message from {sender}: {msg.error}")

# define default rest message handlers if agent inspector is enabled
if enable_agent_inspector:

@self.on_rest_get("/agent_info", AgentInfo)
@self.on_rest_get("/agent_info", AgentInfo) # type: ignore
async def _handle_get_info(_ctx: Context):
return AgentInfo(
agent_address=self.address,
endpoints=self._endpoints,
protocols=list(self.protocols.keys()),
)

@self.on_rest_get("/messages", EnvelopeHistory)
@self.on_rest_get("/messages", EnvelopeHistory) # type: ignore
async def _handle_get_messages(_ctx: Context):
messages = self.sent_messages + self.received_messages
return messages
return self._message_cache

self._init_done = True

Expand Down Expand Up @@ -1135,15 +1135,15 @@ async def _process_message_queue(self):
protocol_info = self.get_message_protocol(schema_digest)
protocol_digest = protocol_info[0] if protocol_info else None

self.received_messages.add_entry(
self._message_cache.add_entry(
EnvelopeHistoryEntry(
version=1,
sender=sender,
target=self.address,
session=session,
schema_digest=schema_digest,
payload=message,
protocol_digest=protocol_digest,
payload=message,
)
)

Expand All @@ -1162,7 +1162,6 @@ async def _process_message_queue(self):
message=message, schema_digest=schema_digest
),
protocol=self.get_message_protocol(schema_digest),
sent_messages=self.sent_messages,
)

# parse the received message
Expand Down
10 changes: 8 additions & 2 deletions python/src/uagents/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from uagents.config import DEFAULT_ENVELOPE_TIMEOUT_SECONDS
from uagents.crypto import Identity, is_user_address
from uagents.dispatch import dispatcher
from uagents.envelope import Envelope
from uagents.envelope import Envelope, EnvelopeHistory, EnvelopeHistoryEntry
from uagents.models import Model
from uagents.resolver import GlobalResolver, Resolver
from uagents.types import DeliveryStatus, JsonStr, MsgStatus
Expand All @@ -26,10 +26,11 @@ class Dispenser:
Dispenses messages externally.
"""

def __init__(self):
def __init__(self, msg_cache_ref: Optional[EnvelopeHistory] = None):
self._envelopes: asyncio.Queue[
Tuple[Envelope, List[str], asyncio.Future, bool]
] = asyncio.Queue()
self._msg_cache_ref = msg_cache_ref

def add_envelope(
self,
Expand Down Expand Up @@ -62,6 +63,11 @@ async def run(self):
sync=sync,
)
response_future.set_result(result)

if self._msg_cache_ref:
self._msg_cache_ref.add_entry(
EnvelopeHistoryEntry.from_envelope(env)
)
except Exception as err:
LOGGER.error(f"Failed to send envelope: {err}")

Expand Down
24 changes: 1 addition & 23 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
DEFAULT_SEARCH_LIMIT,
)
from uagents.dispatch import dispatcher
from uagents.envelope import Envelope, EnvelopeHistory, EnvelopeHistoryEntry
from uagents.envelope import Envelope
from uagents.models import ErrorMessage, Model
from uagents.resolver import Resolver, parse_identifier
from uagents.storage import KeyValueStore
Expand Down Expand Up @@ -259,7 +259,6 @@ def __init__(
interval_messages: Optional[Set[str]] = None,
wallet_messaging_client: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
sent_messages: Optional[EnvelopeHistory] = None,
):
self._agent = agent
self._storage = storage
Expand All @@ -271,11 +270,6 @@ def __init__(
self._interval_messages = interval_messages
self._wallet_messaging_client = wallet_messaging_client
self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {}
self.sent_messages = (
sent_messages
if sent_messages is not None
else EnvelopeHistory(envelopes=[])
)

@property
def agent(self) -> AgentRepresentation:
Expand Down Expand Up @@ -462,18 +456,6 @@ async def send_raw(
self._session,
)

self.sent_messages.add_entry(
EnvelopeHistoryEntry(
version=1,
sender=self.address,
target=destination,
session=self._session,
schema_digest=message_schema_digest,
payload=message_body,
protocol=protocol_digest,
)
)

return response

# Handle sync dispatch of messages
Expand Down Expand Up @@ -529,10 +511,6 @@ async def send_raw(

self._queue_envelope(env, endpoints, fut, sync)

env_dict = env.model_dump()
env_dict["payload"] = env.decode_payload()
self.sent_messages.add_entry(EnvelopeHistoryEntry(**env_dict))

try:
result = await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
Expand Down
29 changes: 20 additions & 9 deletions python/src/uagents/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import time
from typing import Callable, List, Optional

from pydantic import UUID4, BaseModel, ConfigDict, Field, field_serializer
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
field_serializer,
)

from uagents.crypto import Identity
from uagents.types import JsonStr
Expand Down Expand Up @@ -127,6 +133,18 @@ class EnvelopeHistoryEntry(BaseModel):
def serialize_session(self, session: UUID4, _info):
return str(session)

@classmethod
def from_envelope(cls, envelope: Envelope):
return cls(
version=envelope.version,
sender=envelope.sender,
target=envelope.target,
session=envelope.session,
schema_digest=envelope.schema_digest,
protocol_digest=envelope.protocol_digest,
payload=envelope.decode_payload(),
)


class EnvelopeHistory(BaseModel):
envelopes: List[EnvelopeHistoryEntry]
Expand All @@ -140,14 +158,7 @@ def apply_retention_policy(self):
cutoff_time = time.time() - 86400
self.envelopes = [e for e in self.envelopes if e.timestamp > cutoff_time]

def __add__(self, other: "EnvelopeHistory"):
combined_envelopes = self.envelopes + other.envelopes
new_history = EnvelopeHistory(envelopes=combined_envelopes)
new_history.apply_retention_policy()

return new_history

@field_serializer("envelopes", when_used="json")
@field_serializer("envelopes", when_used="always")
def serialize_envelopes_in_order(
self, envelopes: List[EnvelopeHistoryEntry], _info
):
Expand Down

0 comments on commit bec7141

Please sign in to comment.