Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EndpointAPI.wait_until_connected_to #86

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 142 additions & 52 deletions lahja/asyncio/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def __init__(
self._running = asyncio.Event()
self._stopped = asyncio.Event()

def __str__(self) -> str:
return f"RemoteEndpoint[{self.name if self.name is not None else id(self)}]"

def __repr__(self) -> str:
return f"<{self}>"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can/should extract to standalone.


async def wait_started(self) -> None:
await self._running.wait()

Expand Down Expand Up @@ -200,9 +206,13 @@ async def _run(self) -> None:
async with self._received_response:
self._received_response.notify_all()
elif isinstance(message, SubscriptionsUpdated):
self.subscribed_messages = message.subscriptions
async with self._received_subscription:
self.subscribed_messages = message.subscriptions
self._received_subscription.notify_all()
# The ack is sent after releasing the lock since we've already
# exited the code which actually updates the subscriptions and
# we are merely responding to the sender to acknowledge
# receipt.
if message.response_expected:
await self.send_message(SubscriptionsAck())
else:
Expand All @@ -212,14 +222,19 @@ async def notify_subscriptions_updated(
self, subscriptions: Set[Type[BaseEvent]], block: bool = True
) -> None:
"""
Alert the ``Endpoint`` which has connected to us that our subscription set has
changed. If ``block`` is ``True`` then this function will block until the remote
endpoint has acknowledged the new subscription set. If ``block`` is ``False`` then this
function will return immediately after the send finishes.
Alert the endpoint on the other side of this connection that the local
subscriptions have changed. If ``block`` is ``True`` then this function
will block until the remote endpoint has acknowledged the new
subscription set. If ``block`` is ``False`` then this function will
return immediately after the send finishes.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment just above this line isn't quite right anymore: "Alert the Endpoint which has connected to us that our subscription set has changed." Since RemoteEndpoint now handles both inbound and outbound this should be something like "Alert the remote that this endpoint's subscription set has changed"

# The extra lock ensures only one coroutine can notify this endpoint at any one time
# and that no replies are accidentally received by the wrong
# coroutines. Without this, in the case where `block=True`, this inner
# block would release the lock on the call to `wait()` which would
# allow the ack from a different update to incorrectly result in this
# returning before the ack had been received.
async with self._notify_lock:
# The lock ensures only one coroutine can notify this endpoint at any one time
# and that no replies are accidentally received by the wrong coroutines.
async with self._received_response:
try:
await self.conn.send_message(
Expand Down Expand Up @@ -248,15 +263,6 @@ async def wait_until_subscription_received(self) -> None:
await self._received_subscription.wait()


@asynccontextmanager # type: ignore
async def run_remote_endpoint(remote: RemoteEndpoint) -> AsyncIterable[RemoteEndpoint]:
await remote.start()
try:
yield remote
finally:
await remote.stop()


TFunc = TypeVar("TFunc", bound=Callable[..., Any])


Expand All @@ -276,6 +282,10 @@ class AsyncioEndpoint(BaseEndpoint):
_receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]"
_receiving_loop_running: asyncio.Event

_subscription_updates_running: asyncio.Event
_subscription_updates_condition: asyncio.Condition
_subscription_updates_queue: "asyncio.Queue[None]"

_futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"]

_full_connections: Dict[str, RemoteEndpoint]
Expand Down Expand Up @@ -315,9 +325,18 @@ def __init__(self, name: str) -> None:
# over an IPC socket.
self._server_tasks: Set["asyncio.Future[Any]"] = set()

# way to signal that the connections to other endpoints have changed
self._connections_changed = asyncio.Condition()

self._running = False
self._serving = False

def __str__(self) -> str:
return f"Endpoint[{self.name}]"

def __repr__(self) -> str:
return f"<{self.name}>"

@property
def is_running(self) -> bool:
return self._running
Expand Down Expand Up @@ -365,16 +384,55 @@ def run(self, *args, **kwargs): # type: ignore
async def start(self) -> None:
if self.is_running:
raise RuntimeError(f"Endpoint {self.name} is already running")

self._receiving_loop_running = asyncio.Event()
self._receiving_queue = asyncio.Queue()

self._subscription_updates_running = asyncio.Event()
self._subscription_updates_condition = asyncio.Condition()
self._subscription_updates_queue = asyncio.Queue()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could be initialized in __init__ which would save you from needing to declare them at the beginning of this class and would also make it easier to reason about the potential for None dereferences.

The await *.running() pattern sounds like it could be pulled out into a helper which removes the potentially None instance attribute entirely. Something like await self.start_coro(self._process_subscription_updates) which doesn't return until the coro has started.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave this for now and open an issue to track as I've seen this issue arising as well but I don't want to conflate this PR with an extra fix that I'm not sure I know the right solution for.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self._running = True

self._endpoint_tasks.add(asyncio.ensure_future(self._connect_receiving_queue()))
self._endpoint_tasks.add(
asyncio.ensure_future(self._process_subscription_updates())
)
self._endpoint_tasks.add(
asyncio.ensure_future(self._process_subscription_updates_queue())
)

await self._receiving_loop_running.wait()
await self._subscription_updates_running.wait()

self.logger.debug("Endpoint[%s]: running", self.name)

async def _process_subscription_updates_queue(self) -> None:
while self.is_running:
await self._subscription_updates_queue.get()
async with self._subscription_updates_condition:
self._subscription_updates_condition.notify_all()

async def _process_subscription_updates(self) -> None:
self._subscription_updates_running.set()
async with self._subscription_updates_condition:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lithp Needed this to ensure we don't end up with concurrent subscription updates being sent out since it's possible that they arrive out of order in that case and then the other side will have potentially in-accurate subscription records.

TOOD: add a code comment explaining this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would they arrive out of order? Domain sockets always deliver messages in order and _notify_lock ensures there's only one update per socket happening at any time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started undoing this but then stopped as I'm inclined to leave it this way. Here's my reasoning.

  1. we end up needing to do subscription updates from synchronous methods with the addition of subscribe_nowait which seems to be a necessary API.
  2. I'm much more comfortable with that method doing a Queue.put_nowait than I am adding an additional call to asyncio.ensure_future(self._notify_subscription_updates()).
  3. I think that the asyncio.Condition based approach is more appropriate and under that approach we still need a background process to listen for changes. The various branches that I have in progress have made good use of that Condition API.

I'm fine unrolling this if you're still not convinced at the point where the trio based endpoint is being added to the library.

while self.is_running:
await self._subscription_updates_condition.wait()
# make a copy so that the set doesn't change while we iterate over it
subscribed_events = self.subscribed_events
await asyncio.gather(
*(
remote.notify_subscriptions_updated(subscribed_events)
for remote in itertools.chain(
self._half_connections.copy(),
self._full_connections.values(),
)
)
)

def _notify_subscriptions_changed_nowait(self) -> None:
self._subscription_updates_queue.put_nowait(None)

@check_event_loop
async def start_server(self, ipc_path: Path) -> None:
"""
Expand All @@ -399,21 +457,25 @@ async def start_server(self, ipc_path: Path) -> None:
async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None:
conn = Connection(reader, writer)
remote = RemoteEndpoint(None, conn, self._receiving_queue.put)
self._half_connections.add(remote)

task = asyncio.ensure_future(self._handle_client(remote))
task = asyncio.ensure_future(self._run_remote_endpoint(remote))
task.add_done_callback(self._server_tasks.remove)
task.add_done_callback(lambda _: self._half_connections.remove(remote))
self._server_tasks.add(task)

# the Endpoint on the other end blocks until it receives this message
await remote.notify_subscriptions_updated(self.subscribed_events)
await remote.wait_started()

async def _handle_client(self, remote: RemoteEndpoint) -> None:
try:
async with run_remote_endpoint(remote):
await remote.wait_stopped()
finally:
self._half_connections.remove(remote)
# we **must** ensure that the subscription updates are locked between
# the time that we manually update this individual connection and that
# we place it within the set of tracked connections, otherwise, a
# subscription update from elsewhere can occur between the time these
# two statements execute resulting in the remote missing a new
# subscription update. Note that inverting these statements should
# also mitigate this, but it has the downside of the manual update
# potentially being redundant.
async with self._subscription_updates_condition:
await remote.notify_subscriptions_updated(self.subscribed_events)
await self._add_half_connection(remote)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, okay, I see the problem now, _add_half_connection blocks when it tries to acquire connections_changed so there's a period in time where the remote isn't receiving subscription updates. Double-sending the same set of subscriptions has a very small cost though! And I think that trying not to send it introduced greater inefficiencies elsewhere, serializing all updates isn't cheap!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I thought I responded to this somewhere but I don't see it)

I'm inclined to leave it this way because it correctly manages the appropriate locks on the resource, where-as the other mechanism solves the problem through indirection. First, I think the overall cost here is low because our usage patterns for the library don't involve adding new connections at a high frequency. Second, doing it the alternate way of inverting these statements seems like a solution that could easily be undone incidentally in a refactor since it's an implicit solution.


@property
def subscribed_events(self) -> Set[Type[BaseEvent]]:
Expand All @@ -426,17 +488,6 @@ def subscribed_events(self) -> Set[Type[BaseEvent]]:
.union(self._queues.keys())
)

async def _notify_subscriptions_changed(self) -> None:
"""
Tell all inbound connections of our new subscriptions
"""
# make a copy so that the set doesn't change while we iterate over it
subscribed_events = self.subscribed_events
for remote in self._half_connections.copy():
await remote.notify_subscriptions_updated(subscribed_events)
for remote in tuple(self._full_connections.values()):
await remote.notify_subscriptions_updated(subscribed_events)

def get_connected_endpoints_and_subscriptions(
self
) -> Tuple[Tuple[Optional[str], Set[Type[BaseEvent]]], ...]:
Expand Down Expand Up @@ -532,9 +583,8 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None:

conn = await Connection.connect_to(config.path)
remote = RemoteEndpoint(config.name, conn, self._receiving_queue.put)
self._full_connections[config.name] = remote

task = asyncio.ensure_future(self._handle_server(remote))
task = asyncio.ensure_future(self._run_remote_endpoint(remote))
subscriptions_task = asyncio.ensure_future(
self.watch_outbound_subscriptions(remote)
)
Expand All @@ -543,19 +593,30 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None:

task.add_done_callback(self._endpoint_tasks.remove)
task.add_done_callback(lambda _: subscriptions_task.cancel())
task.add_done_callback(lambda _: self._full_connections.pop(config.name, None))

self._endpoint_tasks.add(task)

# don't return control until the caller can safely call broadcast()
await remote.wait_until_subscription_received()

async def _handle_server(self, remote: RemoteEndpoint) -> None:
await remote.wait_started()

# we **must** ensure that the subscription updates are locked between
# the time that we manually update this individual connection and that
# we place it within the set of tracked connections, otherwise, a
# subscription update from elsewhere can occur between the time these
# two statements execute resulting in the remote missing a new
# subscription update. Note that inverting these statements should
# also mitigate this, but it has the downside of the manual update
# potentially being redundant.
async with self._subscription_updates_condition:
await remote.notify_subscriptions_updated(self.subscribed_events)
await self._add_full_connection(remote)

async def _run_remote_endpoint(self, remote: RemoteEndpoint) -> None:
await remote.start()
try:
async with run_remote_endpoint(remote):
await remote.wait_stopped()
await remote.wait_stopped()
finally:
if remote.name is not None:
self._full_connections.pop(remote.name)
await remote.stop()

async def watch_outbound_subscriptions(self, outbound: RemoteEndpoint) -> None:
while outbound in self._full_connections.values():
Expand All @@ -567,6 +628,30 @@ async def watch_outbound_subscriptions(self, outbound: RemoteEndpoint) -> None:
def is_connected_to(self, endpoint_name: str) -> bool:
return endpoint_name in self._full_connections

async def wait_until_connected_to(self, endpoint_name: str) -> None:
if self.is_connected_to(endpoint_name):
return

async with self._connections_changed:
while True:
await self._connections_changed.wait()
if self.is_connected_to(endpoint_name):
return

async def _add_full_connection(self, remote: RemoteEndpoint) -> None:
if remote.name is None:
raise Exception("TODO: remote is not named")
async with self._connections_changed:
self._full_connections[remote.name] = remote
self._connections_changed.notify_all()

async def _add_half_connection(self, remote: RemoteEndpoint) -> None:
if remote.name is not None:
raise Exception("TODO: remote is named and should be a full connection")
async with self._connections_changed:
self._half_connections.add(remote)
self._connections_changed.notify_all()

async def _process_item(
self, item: BaseEvent, config: Optional[BroadcastConfig]
) -> None:
Expand Down Expand Up @@ -749,7 +834,7 @@ def _remove_async_subscription(
# the user `await subscription.remove()`. This means this Endpoint will keep
# getting events for a little while after it stops listening for them but
# that's a performance problem, not a correctness problem.
asyncio.ensure_future(self._notify_subscriptions_changed())
self._notify_subscriptions_changed_nowait()

def _remove_sync_subscription(
self, event_type: Type[BaseEvent], handler_fn: SubscriptionSyncHandler
Expand All @@ -761,7 +846,7 @@ def _remove_sync_subscription(
# the user `await subscription.remove()`. This means this Endpoint will keep
# getting events for a little while after it stops listening for them but
# that's a performance problem, not a correctness problem.
asyncio.ensure_future(self._notify_subscriptions_changed())
self._notify_subscriptions_changed_nowait()

async def subscribe(
self,
Expand All @@ -786,7 +871,8 @@ async def subscribe(
self._remove_sync_subscription, event_type, casted_handler
)

await self._notify_subscriptions_changed()
async with self._subscription_updates_condition:
self._subscription_updates_condition.notify_all()

return Subscription(unsubscribe_fn)

Expand All @@ -805,7 +891,9 @@ async def stream(
self._queues[event_type] = []

self._queues[event_type].append(casted_queue)
await self._notify_subscriptions_changed()

async with self._subscription_updates_condition:
self._subscription_updates_condition.notify_all()

if num_events is None:
# loop forever
Expand All @@ -825,4 +913,6 @@ async def stream(
self._queues[event_type].remove(casted_queue)
if not self._queues[event_type]:
del self._queues[event_type]
await self._notify_subscriptions_changed()
# use nowait here since removing a subscription is not time
# sensitive and not blocking here is better performance.
self._notify_subscriptions_changed_nowait()
21 changes: 17 additions & 4 deletions lahja/tools/benchmark/process.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
import logging
import multiprocessing
import os
import signal
import time
from typing import Any, AsyncGenerator, List, NamedTuple, Optional, Tuple # noqa: F401

Expand Down Expand Up @@ -42,15 +44,26 @@ def start(self) -> None:

def stop(self) -> None:
assert self._process is not None
self._process.terminate()
self._process.join(1)
if self._process.pid is not None:
os.kill(self._process.pid, signal.SIGINT)
else:
self._process.terminate()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: extract to standalone.


try:
self._process.join(1)
except TimeoutError:
self._process.terminate()
self._process.join(1)

@staticmethod
def launch(config: DriverProcessConfig) -> None:
# UNCOMMENT FOR DEBUGGING
# logger = multiprocessing.log_to_stderr()
# logger.setLevel(logging.INFO)
config.backend.run(DriverProcess.worker, config)
try:
config.backend.run(DriverProcess.worker, config)
except KeyboardInterrupt:
return

@staticmethod
async def worker(config: DriverProcessConfig) -> None:
Expand Down Expand Up @@ -95,7 +108,6 @@ async def worker(backend: BaseBackend, name: str, num_events: int) -> None:
await event_bus.connect_to_endpoints(
ConnectionConfig.from_name(REPORTER_ENDPOINT)
)
await event_bus.wait_until_all_remotes_subscribed_to(TotalRecordedEvent)

stats = LocalStatistic()
events = event_bus.stream(PerfMeasureEvent, num_events=num_events)
Expand All @@ -104,6 +116,7 @@ async def worker(backend: BaseBackend, name: str, num_events: int) -> None:
RawMeasureEntry(sent_at=event.sent_at, received_at=time.time())
)

await event_bus.wait_until_all_remotes_subscribed_to(TotalRecordedEvent)
await event_bus.broadcast(
TotalRecordedEvent(stats.crunch(event_bus.name)),
BroadcastConfig(filter_endpoint=REPORTER_ENDPOINT),
Expand Down
Loading