Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): consistently generated session ids #531

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions python/src/uagents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,27 @@
from uagents.utils import get_logger


async def _run_interval(func: IntervalCallback, ctx: Context, period: float):
async def _run_interval(func: IntervalCallback, agent: "Agent", period: float):
"""
Run the provided interval callback function at a specified period.

Args:
func (IntervalCallback): The interval callback function to run.
ctx (Context): The context for the agent.
agent (Agent): The agent that is running the interval callback.
period (float): The time period at which to run the callback function.
"""
logger = agent._logger

while True:
try:
ctx = agent._build_context()
await func(ctx)
except OSError as ex:
ctx.logger.exception(f"OS Error in interval handler: {ex}")
logger.exception(f"OS Error in interval handler: {ex}")
except RuntimeError as ex:
ctx.logger.exception(f"Runtime Error in interval handler: {ex}")
logger.exception(f"Runtime Error in interval handler: {ex}")
except Exception as ex:
ctx.logger.exception(f"Exception in interval handler: {ex}")
logger.exception(f"Exception in interval handler: {ex}")

await asyncio.sleep(period)

Expand Down Expand Up @@ -377,21 +380,6 @@ def __init__(
# keep track of supported protocols
self.protocols: Dict[str, Protocol] = {}

self._ctx = InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

# register with the dispatcher
self._dispatcher.register(self.address, self)

Expand Down Expand Up @@ -426,6 +414,27 @@ async def _handle_get_messages(_ctx: Context):

self._init_done = True

def _build_context(self) -> InternalContext:
"""
An internal method to build the context for the agent

@return:
"""
return InternalContext(
agent=AgentRepresentation(
address=self.address,
name=self._name,
signing_callback=self._identity.sign_digest,
),
storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
dispenser=self._dispenser,
interval_messages=self._interval_messages,
wallet_messaging_client=self._wallet_messaging_client,
logger=self._logger,
)

def _initialize_wallet_and_identity(self, seed, name, wallet_key_derivation_index):
"""
Initialize the wallet and identity for the agent.
Expand Down Expand Up @@ -997,7 +1006,10 @@ async def handle_rest(
if not handler:
return None

args = (self._ctx, message) if message else (self._ctx,)
context = self._build_context()
args = [context]
if message:
args.append(message)

return await handler(*args) # type: ignore

Expand All @@ -1015,7 +1027,8 @@ async def _startup(self):
)
for handler in self._on_startup:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in startup handler: {ex}")
except RuntimeError as ex:
Expand All @@ -1030,7 +1043,8 @@ async def _shutdown(self):
"""
for handler in self._on_shutdown:
try:
await handler(self._ctx)
ctx = self._build_context()
await handler(ctx)
except OSError as ex:
self._logger.exception(f"OS Error in shutdown handler: {ex}")
except RuntimeError as ex:
Expand Down Expand Up @@ -1061,7 +1075,7 @@ def start_interval_tasks(self):

"""
for func, period in self._interval_handlers:
self._loop.create_task(_run_interval(func, self._ctx, period))
self._loop.create_task(_run_interval(func, self, period))

def start_message_receivers(self):
"""
Expand All @@ -1075,7 +1089,7 @@ def start_message_receivers(self):
if self._wallet_messaging_client is not None:
for task in [
self._wallet_messaging_client.poll_server(),
self._wallet_messaging_client.process_message_queue(self._ctx),
self._wallet_messaging_client.process_message_queue(self),
]:
self._loop.create_task(task)

Expand Down Expand Up @@ -1163,7 +1177,7 @@ async def _process_message_queue(self):
)

context = ExternalContext(
agent=self._ctx.agent,
agent=self,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be the agent representation not the agent itself

storage=self._storage,
ledger=self._ledger,
resolver=self._resolver,
Expand All @@ -1179,6 +1193,11 @@ async def _process_message_queue(self):
protocol=protocol_info,
)

# sanity check
assert (
context.session == session
), "Context object should always have message session"

# parse the received message
try:
recovered = model_class.parse_raw(message)
Expand Down
12 changes: 4 additions & 8 deletions python/src/uagents/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def logger(self) -> logging.Logger:

@property
@abstractmethod
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.

Expand Down Expand Up @@ -256,6 +256,7 @@ def __init__(
ledger: LedgerClient,
resolver: Resolver,
dispenser: Dispenser,
session: Optional[uuid.UUID] = None,
interval_messages: Optional[Set[str]] = None,
wallet_messaging_client: Optional[Any] = None,
logger: Optional[logging.Logger] = None,
Expand All @@ -266,7 +267,7 @@ def __init__(
self._resolver = resolver
self._dispenser = dispenser
self._logger = logger
self._session: Optional[uuid.UUID] = None
self._session = session or uuid.uuid4()
self._interval_messages = interval_messages
self._wallet_messaging_client = wallet_messaging_client
self._outbound_messages: Dict[str, Tuple[JsonStr, str]] = {}
Expand All @@ -288,7 +289,7 @@ def logger(self) -> Union[logging.Logger, None]:
return self._logger

@property
def session(self) -> Union[uuid.UUID, None]:
def session(self) -> uuid.UUID:
"""
Get the session UUID associated with the context.

Expand Down Expand Up @@ -408,7 +409,6 @@ async def send(
we don't have access properties that are only necessary in re-active
contexts, like 'replies', 'message_received', or 'protocol'.
"""
self._session = None
schema_digest = Model.build_schema_digest(message)
message_body = message.model_dump_json()

Expand Down Expand Up @@ -440,8 +440,6 @@ async def send_raw(
protocol_digest: Optional[str] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
) -> MsgStatus:
self._session = self._session or uuid.uuid4()

# Extract address from destination agent identifier if present
_, parsed_name, parsed_address = parse_identifier(destination)

Expand Down Expand Up @@ -575,7 +573,6 @@ class ExternalContext(InternalContext):
def __init__(
self,
message_received: MsgDigest,
session: Optional[uuid.UUID] = None,
queries: Optional[Dict[str, asyncio.Future]] = None,
replies: Optional[Dict[str, Dict[str, Type[Model]]]] = None,
protocol: Optional[Tuple[str, Protocol]] = None,
Expand All @@ -594,7 +591,6 @@ def __init__(
protocol (Optional[Tuple[str, Protocol]]): The optional Tuple of protocols.
"""
super().__init__(**kwargs)
self._session = session or None
self._queries = queries or {}
self._replies = replies
self._message_received = message_received
Expand Down
4 changes: 2 additions & 2 deletions python/src/uagents/wallet_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from requests import HTTPError, JSONDecodeError

from uagents.config import WALLET_MESSAGING_POLL_INTERVAL_SECONDS
from uagents.context import Context
from uagents.crypto import Identity
from uagents.types import WalletMessageCallback
from uagents.utils import get_logger
Expand Down Expand Up @@ -79,8 +78,9 @@ async def poll_server(self):
)
await asyncio.sleep(self._poll_interval)

async def process_message_queue(self, ctx: Context):
async def process_message_queue(self, agent: "Agent"): # noqa: F821
while True:
msg: WalletMessage = await self._message_queue.get()
for handler in self._message_handlers:
ctx = agent._build_context()
await handler(ctx, msg)
4 changes: 2 additions & 2 deletions python/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _(ctx: Context):
startup_handlers = self.agent._on_startup
self.assertEqual(len(startup_handlers), 1)
self.assertTrue(isinstance(startup_handlers[0], Callable))
self.assertIsNone(self.agent._ctx.storage.get("startup"))
self.assertIsNone(self.agent._storage.get("startup"))

def test_agent_on_shutdown_event(self):
@self.agent.on_event("shutdown")
Expand All @@ -92,7 +92,7 @@ def _(ctx: Context):
shutdown_handlers = self.agent._on_shutdown
self.assertEqual(len(shutdown_handlers), 1)
self.assertTrue(isinstance(shutdown_handlers[0], Callable))
self.assertIsNone(self.agent._ctx.storage.get("shutdown"))
self.assertIsNone(self.agent._storage.get("shutdown"))

def test_agent_on_rest_get(self):
@self.agent.on_rest_get("/get", Response)
Expand Down
Loading
Loading