diff --git a/python/src/uagents/agent.py b/python/src/uagents/agent.py index e2637a6f..7eab36aa 100644 --- a/python/src/uagents/agent.py +++ b/python/src/uagents/agent.py @@ -36,7 +36,7 @@ parse_agentverse_config, parse_endpoint_config, ) -from uagents.context import Context, ExternalContext, InternalContext +from uagents.context import Context, ContextFactory, ExternalContext, InternalContext from uagents.crypto import Identity, derive_key_from_seed, is_user_address from uagents.dispatch import Sink, dispatcher from uagents.envelope import EnvelopeHistory, EnvelopeHistoryEntry @@ -71,24 +71,30 @@ from uagents.utils import get_logger -async def _run_interval(func: IntervalCallback, ctx: Context, period: float): +async def _run_interval( + func: IntervalCallback, + logger: logging.Logger, + context_factory: ContextFactory, + period: float, +): """ Run the provided interval callback function at a specified period. Args: func (IntervalCallback): The interval callback function to run. - ctx (Context): The context for the agent. + agent (Agent): The agent that is running the interval callback. period (float): The time period at which to run the callback function. """ while True: try: + ctx = context_factory() await func(ctx) except OSError as ex: - ctx.logger.exception(f"OS Error in interval handler: {ex}") + logger.exception(f"OS Error in interval handler: {ex}") except RuntimeError as ex: - ctx.logger.exception(f"Runtime Error in interval handler: {ex}") + logger.exception(f"Runtime Error in interval handler: {ex}") except Exception as ex: - ctx.logger.exception(f"Exception in interval handler: {ex}") + logger.exception(f"Exception in interval handler: {ex}") await asyncio.sleep(period) @@ -377,21 +383,6 @@ def __init__( # keep track of supported protocols self.protocols: Dict[str, Protocol] = {} - self._ctx = InternalContext( - agent=AgentRepresentation( - address=self.address, - name=self._name, - signing_callback=self._identity.sign_digest, - ), - storage=self._storage, - ledger=self._ledger, - resolver=self._resolver, - dispenser=self._dispenser, - interval_messages=self._interval_messages, - wallet_messaging_client=self._wallet_messaging_client, - logger=self._logger, - ) - # register with the dispatcher self._dispatcher.register(self.address, self) @@ -426,6 +417,27 @@ async def _handle_get_messages(_ctx: Context): self._init_done = True + def _build_context(self) -> InternalContext: + """ + An internal method to build the context for the agent + + @return: + """ + return InternalContext( + agent=AgentRepresentation( + address=self.address, + name=self._name, + signing_callback=self._identity.sign_digest, + ), + storage=self._storage, + ledger=self._ledger, + resolver=self._resolver, + dispenser=self._dispenser, + interval_messages=self._interval_messages, + wallet_messaging_client=self._wallet_messaging_client, + logger=self._logger, + ) + def _initialize_wallet_and_identity(self, seed, name, wallet_key_derivation_index): """ Initialize the wallet and identity for the agent. @@ -997,7 +1009,10 @@ async def handle_rest( if not handler: return None - args = (self._ctx, message) if message else (self._ctx,) + context = self._build_context() + args = [context] + if message: + args.append(message) return await handler(*args) # type: ignore @@ -1015,7 +1030,8 @@ async def _startup(self): ) for handler in self._on_startup: try: - await handler(self._ctx) + ctx = self._build_context() + await handler(ctx) except OSError as ex: self._logger.exception(f"OS Error in startup handler: {ex}") except RuntimeError as ex: @@ -1030,7 +1046,8 @@ async def _shutdown(self): """ for handler in self._on_shutdown: try: - await handler(self._ctx) + ctx = self._build_context() + await handler(ctx) except OSError as ex: self._logger.exception(f"OS Error in shutdown handler: {ex}") except RuntimeError as ex: @@ -1061,7 +1078,9 @@ def start_interval_tasks(self): """ for func, period in self._interval_handlers: - self._loop.create_task(_run_interval(func, self._ctx, period)) + self._loop.create_task( + _run_interval(func, self._logger, self._build_context, period) + ) def start_message_receivers(self): """ @@ -1075,7 +1094,9 @@ def start_message_receivers(self): if self._wallet_messaging_client is not None: for task in [ self._wallet_messaging_client.poll_server(), - self._wallet_messaging_client.process_message_queue(self._ctx), + self._wallet_messaging_client.process_message_queue( + self._build_context + ), ]: self._loop.create_task(task) @@ -1163,7 +1184,7 @@ async def _process_message_queue(self): ) context = ExternalContext( - agent=self._ctx.agent, + agent=self, storage=self._storage, ledger=self._ledger, resolver=self._resolver, @@ -1179,6 +1200,11 @@ async def _process_message_queue(self): protocol=protocol_info, ) + # sanity check + assert ( + context.session == session + ), "Context object should always have message session" + # parse the received message try: recovered = model_class.parse_raw(message) diff --git a/python/src/uagents/context.py b/python/src/uagents/context.py index a6139170..3c9e5ba8 100644 --- a/python/src/uagents/context.py +++ b/python/src/uagents/context.py @@ -10,6 +10,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Optional, @@ -116,7 +117,7 @@ def logger(self) -> logging.Logger: @property @abstractmethod - def session(self) -> Union[uuid.UUID, None]: + def session(self) -> uuid.UUID: """ Get the session UUID associated with the context. @@ -256,6 +257,7 @@ def __init__( ledger: LedgerClient, resolver: Resolver, dispenser: Dispenser, + session: Optional[uuid.UUID] = None, interval_messages: Optional[Set[str]] = None, wallet_messaging_client: Optional[Any] = None, logger: Optional[logging.Logger] = None, @@ -266,7 +268,7 @@ def __init__( self._resolver = resolver self._dispenser = dispenser self._logger = logger - self._session: Optional[uuid.UUID] = None + self._session = session or uuid.uuid4() self._interval_messages = interval_messages self._wallet_messaging_client = wallet_messaging_client self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {} @@ -288,7 +290,7 @@ def logger(self) -> Union[logging.Logger, None]: return self._logger @property - def session(self) -> Union[uuid.UUID, None]: + def session(self) -> uuid.UUID: """ Get the session UUID associated with the context. @@ -408,7 +410,6 @@ async def send( we don't have access properties that are only necessary in re-active contexts, like 'replies', 'message_received', or 'protocol'. """ - self._session = None schema_digest = Model.build_schema_digest(message) message_body = message.model_dump_json() @@ -440,8 +441,6 @@ async def send_raw( protocol_digest: Optional[str] = None, queries: Optional[Dict[str, asyncio.Future]] = None, ) -> MsgStatus: - self._session = self._session or uuid.uuid4() - # Extract address from destination agent identifier if present _, parsed_name, parsed_address = parse_identifier(destination) @@ -575,7 +574,6 @@ class ExternalContext(InternalContext): def __init__( self, message_received: MsgDigest, - session: Optional[uuid.UUID] = None, queries: Optional[Dict[str, asyncio.Future]] = None, replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None, protocol: Optional[Tuple[str, Protocol]] = None, @@ -594,7 +592,6 @@ def __init__( protocol (Optional[Tuple[str, Protocol]]): The optional Tuple of protocols. """ super().__init__(**kwargs) - self._session = session or None self._queries = queries or {} self._replies = replies self._message_received = message_received @@ -674,3 +671,6 @@ async def send( protocol_digest=self._protocol[0], queries=self._queries, ) + + +ContextFactory = Callable[[], Context] diff --git a/python/src/uagents/wallet_messaging.py b/python/src/uagents/wallet_messaging.py index 2c86355a..cb66b560 100644 --- a/python/src/uagents/wallet_messaging.py +++ b/python/src/uagents/wallet_messaging.py @@ -11,7 +11,7 @@ from requests import HTTPError, JSONDecodeError from uagents.config import WALLET_MESSAGING_POLL_INTERVAL_SECONDS -from uagents.context import Context +from uagents.context import ContextFactory from uagents.crypto import Identity from uagents.types import WalletMessageCallback from uagents.utils import get_logger @@ -79,8 +79,9 @@ async def poll_server(self): ) await asyncio.sleep(self._poll_interval) - async def process_message_queue(self, ctx: Context): + async def process_message_queue(self, context_factory: ContextFactory): # noqa: F821 while True: msg: WalletMessage = await self._message_queue.get() for handler in self._message_handlers: + ctx = context_factory() await handler(ctx, msg) diff --git a/python/tests/test_agent.py b/python/tests/test_agent.py index e8dcdde3..4d63bb2b 100644 --- a/python/tests/test_agent.py +++ b/python/tests/test_agent.py @@ -82,7 +82,7 @@ def _(ctx: Context): startup_handlers = self.agent._on_startup self.assertEqual(len(startup_handlers), 1) self.assertTrue(isinstance(startup_handlers[0], Callable)) - self.assertIsNone(self.agent._ctx.storage.get("startup")) + self.assertIsNone(self.agent._storage.get("startup")) def test_agent_on_shutdown_event(self): @self.agent.on_event("shutdown") @@ -92,7 +92,7 @@ def _(ctx: Context): shutdown_handlers = self.agent._on_shutdown self.assertEqual(len(shutdown_handlers), 1) self.assertTrue(isinstance(shutdown_handlers[0], Callable)) - self.assertIsNone(self.agent._ctx.storage.get("shutdown")) + self.assertIsNone(self.agent._storage.get("shutdown")) def test_agent_on_rest_get(self): @self.agent.on_rest_get("/get", Response) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index c15db9f9..34dae936 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -52,10 +52,8 @@ def setUp(self): self.alice = Agent(name="alice", seed="alice recovery phrase", resolve=resolver) self.bob = Agent(name="bob", seed="bob recovery phrase") - self.agent = self.alice - self.context = self.agent._ctx self.loop = asyncio.get_event_loop() - self.loop.create_task(self.context._dispenser.run()) + self.loop.create_task(self.alice._dispenser.run()) def get_external_context( self, @@ -65,13 +63,13 @@ def get_external_context( queries: Optional[Dict[str, asyncio.Future]] = None, ): return ExternalContext( - agent=self.context.agent, - storage=self.agent._storage, - ledger=self.agent._ledger, - resolver=self.agent._resolver, - dispenser=self.agent._dispenser, - wallet_messaging_client=self.agent._wallet_messaging_client, - logger=self.agent._logger, + agent=self.alice, + storage=self.alice._storage, + ledger=self.alice._ledger, + resolver=self.alice._resolver, + dispenser=self.alice._dispenser, + wallet_messaging_client=self.alice._wallet_messaging_client, + logger=self.alice._logger, queries=queries, session=None, replies=replies, @@ -79,13 +77,14 @@ def get_external_context( ) async def test_send_local_dispatch(self): - result = await self.context.send(self.bob.address, msg) + context = self.alice._build_context() + result = await context.send(self.bob.address, msg) exp_msg_status = MsgStatus( status=DeliveryStatus.DELIVERED, detail="Message dispatched locally", destination=self.bob.address, endpoint="", - session=self.context.session, + session=context.session, ) self.assertEqual(result, exp_msg_status) @@ -121,28 +120,29 @@ async def test_send_local_dispatch_invalid_reply(self): self.assertEqual(result, exp_msg_status) async def test_send_local_dispatch_valid_interval_msg(self): - self.context._interval_messages = {msg_digest} - result = await self.context.send(self.bob.address, msg) + context = self.alice._build_context() + context._interval_messages = {msg_digest} + result = await context.send(self.bob.address, msg) exp_msg_status = MsgStatus( status=DeliveryStatus.DELIVERED, detail="Message dispatched locally", destination=self.bob.address, endpoint="", - session=self.context.session, + session=context.session, ) self.assertEqual(result, exp_msg_status) - self.context._interval_messages = set() async def test_send_local_dispatch_invalid_interval_msg(self): - self.context._interval_messages = {msg_digest} - result = await self.context.send(self.bob.address, incoming) + context = self.alice._build_context() + context._interval_messages = {msg_digest} + result = await context.send(self.bob.address, incoming) exp_msg_status = MsgStatus( status=DeliveryStatus.FAILED, detail="Invalid interval message", destination=self.bob.address, endpoint="", - session=self.context.session, + session=context.session, ) self.assertEqual(result, exp_msg_status) @@ -170,13 +170,14 @@ async def test_send_resolve_sync_query(self): async def test_send_external_dispatch_resolve_failure(self): destination = Identity.generate().address - result = await self.context.send(destination, msg) + context = self.alice._build_context() + result = await context.send(destination, msg) exp_msg_status = MsgStatus( status=DeliveryStatus.FAILED, detail="Unable to resolve destination endpoint", destination=destination, endpoint="", - session=self.context.session, + session=context.session, ) self.assertEqual(result, exp_msg_status) @@ -186,8 +187,10 @@ async def test_send_external_dispatch_success(self, mocked_responses): # Mock the HTTP POST request with a status code and response content mocked_responses.post(endpoints[0], status=200) + context = self.alice._build_context() + # Perform the actual operation - result = await self.context.send(self.clyde.address, msg) + result = await context.send(self.clyde.address, msg) # Define the expected message status exp_msg_status = MsgStatus( @@ -195,7 +198,7 @@ async def test_send_external_dispatch_success(self, mocked_responses): detail="Message successfully delivered via HTTP", destination=self.clyde.address, endpoint=endpoints[0], - session=self.context.session, + session=context.session, ) # Assertions @@ -206,8 +209,10 @@ async def test_send_external_dispatch_failure(self, mocked_responses): # Mock the HTTP POST request with a status code and response content mocked_responses.post(endpoints[0], status=404) + context = self.alice._build_context() + # Perform the actual operation - result = await self.context.send(self.clyde.address, msg) + result = await context.send(self.clyde.address, msg) # Define the expected message status exp_msg_status = MsgStatus( @@ -215,7 +220,7 @@ async def test_send_external_dispatch_failure(self, mocked_responses): detail="Message delivery failed", destination=self.clyde.address, endpoint="", - session=self.context.session, + session=context.session, ) # Assertions @@ -231,8 +236,10 @@ async def test_send_external_dispatch_multiple_endpoints_first_success( mocked_responses.post(endpoints[0], status=200) mocked_responses.post(endpoints[1], status=404) + context = self.alice._build_context() + # Perform the actual operation - result = await self.context.send(self.clyde.address, msg) + result = await context.send(self.clyde.address, msg) # Define the expected message status exp_msg_status = MsgStatus( @@ -240,7 +247,7 @@ async def test_send_external_dispatch_multiple_endpoints_first_success( detail="Message successfully delivered via HTTP", destination=self.clyde.address, endpoint=endpoints[0], - session=self.context.session, + session=context.session, ) # Assertions @@ -261,8 +268,10 @@ async def test_send_external_dispatch_multiple_endpoints_second_success( mocked_responses.post(endpoints[0], status=404) mocked_responses.post(endpoints[1], status=200) + context = self.alice._build_context() + # Perform the actual operation - result = await self.context.send(self.clyde.address, msg) + result = await context.send(self.clyde.address, msg) # Define the expected message status exp_msg_status = MsgStatus( @@ -270,7 +279,7 @@ async def test_send_external_dispatch_multiple_endpoints_second_success( detail="Message successfully delivered via HTTP", destination=self.clyde.address, endpoint=endpoints[1], - session=self.context.session, + session=context.session, ) # Assertions @@ -288,8 +297,10 @@ async def test_send_external_dispatch_multiple_endpoints_failure( mocked_responses.post(endpoints[0], status=404) mocked_responses.post(endpoints[1], status=404) + context = self.alice._build_context() + # Perform the actual operation - result = await self.context.send(self.clyde.address, msg) + result = await context.send(self.clyde.address, msg) # Define the expected message status exp_msg_status = MsgStatus( @@ -297,7 +308,7 @@ async def test_send_external_dispatch_multiple_endpoints_failure( detail="Message delivery failed", destination=self.clyde.address, endpoint="", - session=self.context.session, + session=context.session, ) # Assertions