-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EndpointAPI.wait_until_connected_to
#86
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,12 @@ def __init__( | |
self._running = asyncio.Event() | ||
self._stopped = asyncio.Event() | ||
|
||
def __str__(self) -> str: | ||
return f"RemoteEndpoint[{self.name if self.name is not None else id(self)}]" | ||
|
||
def __repr__(self) -> str: | ||
return f"<{self}>" | ||
|
||
async def wait_started(self) -> None: | ||
await self._running.wait() | ||
|
||
|
@@ -200,9 +206,13 @@ async def _run(self) -> None: | |
async with self._received_response: | ||
self._received_response.notify_all() | ||
elif isinstance(message, SubscriptionsUpdated): | ||
self.subscribed_messages = message.subscriptions | ||
async with self._received_subscription: | ||
self.subscribed_messages = message.subscriptions | ||
self._received_subscription.notify_all() | ||
# The ack is sent after releasing the lock since we've already | ||
# exited the code which actually updates the subscriptions and | ||
# we are merely responding to the sender to acknowledge | ||
# receipt. | ||
if message.response_expected: | ||
await self.send_message(SubscriptionsAck()) | ||
else: | ||
|
@@ -212,14 +222,19 @@ async def notify_subscriptions_updated( | |
self, subscriptions: Set[Type[BaseEvent]], block: bool = True | ||
) -> None: | ||
""" | ||
Alert the ``Endpoint`` which has connected to us that our subscription set has | ||
changed. If ``block`` is ``True`` then this function will block until the remote | ||
endpoint has acknowledged the new subscription set. If ``block`` is ``False`` then this | ||
function will return immediately after the send finishes. | ||
Alert the endpoint on the other side of this connection that the local | ||
subscriptions have changed. If ``block`` is ``True`` then this function | ||
will block until the remote endpoint has acknowledged the new | ||
subscription set. If ``block`` is ``False`` then this function will | ||
return immediately after the send finishes. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment just above this line isn't quite right anymore: "Alert the |
||
# The extra lock ensures only one coroutine can notify this endpoint at any one time | ||
# and that no replies are accidentally received by the wrong | ||
# coroutines. Without this, in the case where `block=True`, this inner | ||
# block would release the lock on the call to `wait()` which would | ||
# allow the ack from a different update to incorrectly result in this | ||
# returning before the ack had been received. | ||
async with self._notify_lock: | ||
# The lock ensures only one coroutine can notify this endpoint at any one time | ||
# and that no replies are accidentally received by the wrong coroutines. | ||
async with self._received_response: | ||
try: | ||
await self.conn.send_message( | ||
|
@@ -248,15 +263,6 @@ async def wait_until_subscription_received(self) -> None: | |
await self._received_subscription.wait() | ||
|
||
|
||
@asynccontextmanager # type: ignore | ||
async def run_remote_endpoint(remote: RemoteEndpoint) -> AsyncIterable[RemoteEndpoint]: | ||
await remote.start() | ||
try: | ||
yield remote | ||
finally: | ||
await remote.stop() | ||
|
||
|
||
TFunc = TypeVar("TFunc", bound=Callable[..., Any]) | ||
|
||
|
||
|
@@ -276,6 +282,10 @@ class AsyncioEndpoint(BaseEndpoint): | |
_receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]" | ||
_receiving_loop_running: asyncio.Event | ||
|
||
_subscription_updates_running: asyncio.Event | ||
_subscription_updates_condition: asyncio.Condition | ||
_subscription_updates_queue: "asyncio.Queue[None]" | ||
|
||
_futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] | ||
|
||
_full_connections: Dict[str, RemoteEndpoint] | ||
|
@@ -315,9 +325,18 @@ def __init__(self, name: str) -> None: | |
# over an IPC socket. | ||
self._server_tasks: Set["asyncio.Future[Any]"] = set() | ||
|
||
# way to signal that the connections to other endpoints have changed | ||
self._connections_changed = asyncio.Condition() | ||
|
||
self._running = False | ||
self._serving = False | ||
|
||
def __str__(self) -> str: | ||
return f"Endpoint[{self.name}]" | ||
|
||
def __repr__(self) -> str: | ||
return f"<{self.name}>" | ||
|
||
@property | ||
def is_running(self) -> bool: | ||
return self._running | ||
|
@@ -365,16 +384,55 @@ def run(self, *args, **kwargs): # type: ignore | |
async def start(self) -> None: | ||
if self.is_running: | ||
raise RuntimeError(f"Endpoint {self.name} is already running") | ||
|
||
self._receiving_loop_running = asyncio.Event() | ||
self._receiving_queue = asyncio.Queue() | ||
|
||
self._subscription_updates_running = asyncio.Event() | ||
self._subscription_updates_condition = asyncio.Condition() | ||
self._subscription_updates_queue = asyncio.Queue() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These could be initialized in The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm going to leave this for now and open an issue to track as I've seen this issue arising as well but I don't want to conflate this PR with an extra fix that I'm not sure I know the right solution for. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
self._running = True | ||
|
||
self._endpoint_tasks.add(asyncio.ensure_future(self._connect_receiving_queue())) | ||
self._endpoint_tasks.add( | ||
asyncio.ensure_future(self._process_subscription_updates()) | ||
) | ||
self._endpoint_tasks.add( | ||
asyncio.ensure_future(self._process_subscription_updates_queue()) | ||
) | ||
|
||
await self._receiving_loop_running.wait() | ||
await self._subscription_updates_running.wait() | ||
|
||
self.logger.debug("Endpoint[%s]: running", self.name) | ||
|
||
async def _process_subscription_updates_queue(self) -> None: | ||
while self.is_running: | ||
await self._subscription_updates_queue.get() | ||
async with self._subscription_updates_condition: | ||
self._subscription_updates_condition.notify_all() | ||
|
||
async def _process_subscription_updates(self) -> None: | ||
self._subscription_updates_running.set() | ||
async with self._subscription_updates_condition: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lithp Needed this to ensure we don't end up with concurrent subscription updates being sent out since it's possible that they arrive out of order in that case and then the other side will have potentially in-accurate subscription records. TOOD: add a code comment explaining this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would they arrive out of order? Domain sockets always deliver messages in order and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I started undoing this but then stopped as I'm inclined to leave it this way. Here's my reasoning.
I'm fine unrolling this if you're still not convinced at the point where the |
||
while self.is_running: | ||
await self._subscription_updates_condition.wait() | ||
# make a copy so that the set doesn't change while we iterate over it | ||
subscribed_events = self.subscribed_events | ||
await asyncio.gather( | ||
*( | ||
remote.notify_subscriptions_updated(subscribed_events) | ||
for remote in itertools.chain( | ||
self._half_connections.copy(), | ||
self._full_connections.values(), | ||
) | ||
) | ||
) | ||
|
||
def _notify_subscriptions_changed_nowait(self) -> None: | ||
self._subscription_updates_queue.put_nowait(None) | ||
|
||
@check_event_loop | ||
async def start_server(self, ipc_path: Path) -> None: | ||
""" | ||
|
@@ -399,21 +457,25 @@ async def start_server(self, ipc_path: Path) -> None: | |
async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None: | ||
conn = Connection(reader, writer) | ||
remote = RemoteEndpoint(None, conn, self._receiving_queue.put) | ||
self._half_connections.add(remote) | ||
|
||
task = asyncio.ensure_future(self._handle_client(remote)) | ||
task = asyncio.ensure_future(self._run_remote_endpoint(remote)) | ||
task.add_done_callback(self._server_tasks.remove) | ||
task.add_done_callback(lambda _: self._half_connections.remove(remote)) | ||
self._server_tasks.add(task) | ||
|
||
# the Endpoint on the other end blocks until it receives this message | ||
await remote.notify_subscriptions_updated(self.subscribed_events) | ||
await remote.wait_started() | ||
|
||
async def _handle_client(self, remote: RemoteEndpoint) -> None: | ||
try: | ||
async with run_remote_endpoint(remote): | ||
await remote.wait_stopped() | ||
finally: | ||
self._half_connections.remove(remote) | ||
# we **must** ensure that the subscription updates are locked between | ||
# the time that we manually update this individual connection and that | ||
# we place it within the set of tracked connections, otherwise, a | ||
# subscription update from elsewhere can occur between the time these | ||
# two statements execute resulting in the remote missing a new | ||
# subscription update. Note that inverting these statements should | ||
# also mitigate this, but it has the downside of the manual update | ||
# potentially being redundant. | ||
async with self._subscription_updates_condition: | ||
await remote.notify_subscriptions_updated(self.subscribed_events) | ||
await self._add_half_connection(remote) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh, okay, I see the problem now, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (I thought I responded to this somewhere but I don't see it) I'm inclined to leave it this way because it correctly manages the appropriate locks on the resource, where-as the other mechanism solves the problem through indirection. First, I think the overall cost here is low because our usage patterns for the library don't involve adding new connections at a high frequency. Second, doing it the alternate way of inverting these statements seems like a solution that could easily be undone incidentally in a refactor since it's an implicit solution. |
||
|
||
@property | ||
def subscribed_events(self) -> Set[Type[BaseEvent]]: | ||
|
@@ -426,17 +488,6 @@ def subscribed_events(self) -> Set[Type[BaseEvent]]: | |
.union(self._queues.keys()) | ||
) | ||
|
||
async def _notify_subscriptions_changed(self) -> None: | ||
""" | ||
Tell all inbound connections of our new subscriptions | ||
""" | ||
# make a copy so that the set doesn't change while we iterate over it | ||
subscribed_events = self.subscribed_events | ||
for remote in self._half_connections.copy(): | ||
await remote.notify_subscriptions_updated(subscribed_events) | ||
for remote in tuple(self._full_connections.values()): | ||
await remote.notify_subscriptions_updated(subscribed_events) | ||
|
||
def get_connected_endpoints_and_subscriptions( | ||
self | ||
) -> Tuple[Tuple[Optional[str], Set[Type[BaseEvent]]], ...]: | ||
|
@@ -532,9 +583,8 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None: | |
|
||
conn = await Connection.connect_to(config.path) | ||
remote = RemoteEndpoint(config.name, conn, self._receiving_queue.put) | ||
self._full_connections[config.name] = remote | ||
|
||
task = asyncio.ensure_future(self._handle_server(remote)) | ||
task = asyncio.ensure_future(self._run_remote_endpoint(remote)) | ||
subscriptions_task = asyncio.ensure_future( | ||
self.watch_outbound_subscriptions(remote) | ||
) | ||
|
@@ -543,19 +593,30 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None: | |
|
||
task.add_done_callback(self._endpoint_tasks.remove) | ||
task.add_done_callback(lambda _: subscriptions_task.cancel()) | ||
task.add_done_callback(lambda _: self._full_connections.pop(config.name, None)) | ||
|
||
self._endpoint_tasks.add(task) | ||
|
||
# don't return control until the caller can safely call broadcast() | ||
await remote.wait_until_subscription_received() | ||
|
||
async def _handle_server(self, remote: RemoteEndpoint) -> None: | ||
await remote.wait_started() | ||
|
||
# we **must** ensure that the subscription updates are locked between | ||
# the time that we manually update this individual connection and that | ||
# we place it within the set of tracked connections, otherwise, a | ||
# subscription update from elsewhere can occur between the time these | ||
# two statements execute resulting in the remote missing a new | ||
# subscription update. Note that inverting these statements should | ||
# also mitigate this, but it has the downside of the manual update | ||
# potentially being redundant. | ||
async with self._subscription_updates_condition: | ||
await remote.notify_subscriptions_updated(self.subscribed_events) | ||
await self._add_full_connection(remote) | ||
|
||
async def _run_remote_endpoint(self, remote: RemoteEndpoint) -> None: | ||
await remote.start() | ||
try: | ||
async with run_remote_endpoint(remote): | ||
await remote.wait_stopped() | ||
await remote.wait_stopped() | ||
finally: | ||
if remote.name is not None: | ||
self._full_connections.pop(remote.name) | ||
await remote.stop() | ||
|
||
async def watch_outbound_subscriptions(self, outbound: RemoteEndpoint) -> None: | ||
while outbound in self._full_connections.values(): | ||
|
@@ -567,6 +628,30 @@ async def watch_outbound_subscriptions(self, outbound: RemoteEndpoint) -> None: | |
def is_connected_to(self, endpoint_name: str) -> bool: | ||
return endpoint_name in self._full_connections | ||
|
||
async def wait_until_connected_to(self, endpoint_name: str) -> None: | ||
if self.is_connected_to(endpoint_name): | ||
return | ||
|
||
async with self._connections_changed: | ||
while True: | ||
await self._connections_changed.wait() | ||
if self.is_connected_to(endpoint_name): | ||
return | ||
|
||
async def _add_full_connection(self, remote: RemoteEndpoint) -> None: | ||
if remote.name is None: | ||
raise Exception("TODO: remote is not named") | ||
async with self._connections_changed: | ||
self._full_connections[remote.name] = remote | ||
self._connections_changed.notify_all() | ||
|
||
async def _add_half_connection(self, remote: RemoteEndpoint) -> None: | ||
if remote.name is not None: | ||
raise Exception("TODO: remote is named and should be a full connection") | ||
async with self._connections_changed: | ||
self._half_connections.add(remote) | ||
self._connections_changed.notify_all() | ||
|
||
async def _process_item( | ||
self, item: BaseEvent, config: Optional[BroadcastConfig] | ||
) -> None: | ||
|
@@ -749,7 +834,7 @@ def _remove_async_subscription( | |
# the user `await subscription.remove()`. This means this Endpoint will keep | ||
# getting events for a little while after it stops listening for them but | ||
# that's a performance problem, not a correctness problem. | ||
asyncio.ensure_future(self._notify_subscriptions_changed()) | ||
self._notify_subscriptions_changed_nowait() | ||
|
||
def _remove_sync_subscription( | ||
self, event_type: Type[BaseEvent], handler_fn: SubscriptionSyncHandler | ||
|
@@ -761,7 +846,7 @@ def _remove_sync_subscription( | |
# the user `await subscription.remove()`. This means this Endpoint will keep | ||
# getting events for a little while after it stops listening for them but | ||
# that's a performance problem, not a correctness problem. | ||
asyncio.ensure_future(self._notify_subscriptions_changed()) | ||
self._notify_subscriptions_changed_nowait() | ||
|
||
async def subscribe( | ||
self, | ||
|
@@ -786,7 +871,8 @@ async def subscribe( | |
self._remove_sync_subscription, event_type, casted_handler | ||
) | ||
|
||
await self._notify_subscriptions_changed() | ||
async with self._subscription_updates_condition: | ||
self._subscription_updates_condition.notify_all() | ||
|
||
return Subscription(unsubscribe_fn) | ||
|
||
|
@@ -805,7 +891,9 @@ async def stream( | |
self._queues[event_type] = [] | ||
|
||
self._queues[event_type].append(casted_queue) | ||
await self._notify_subscriptions_changed() | ||
|
||
async with self._subscription_updates_condition: | ||
self._subscription_updates_condition.notify_all() | ||
|
||
if num_events is None: | ||
# loop forever | ||
|
@@ -825,4 +913,6 @@ async def stream( | |
self._queues[event_type].remove(casted_queue) | ||
if not self._queues[event_type]: | ||
del self._queues[event_type] | ||
await self._notify_subscriptions_changed() | ||
# use nowait here since removing a subscription is not time | ||
# sensitive and not blocking here is better performance. | ||
self._notify_subscriptions_changed_nowait() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import itertools | ||
import logging | ||
import multiprocessing | ||
import os | ||
import signal | ||
import time | ||
from typing import Any, AsyncGenerator, List, NamedTuple, Optional, Tuple # noqa: F401 | ||
|
||
|
@@ -42,15 +44,26 @@ def start(self) -> None: | |
|
||
def stop(self) -> None: | ||
assert self._process is not None | ||
self._process.terminate() | ||
self._process.join(1) | ||
if self._process.pid is not None: | ||
os.kill(self._process.pid, signal.SIGINT) | ||
else: | ||
self._process.terminate() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: extract to standalone. |
||
|
||
try: | ||
self._process.join(1) | ||
except TimeoutError: | ||
self._process.terminate() | ||
self._process.join(1) | ||
|
||
@staticmethod | ||
def launch(config: DriverProcessConfig) -> None: | ||
# UNCOMMENT FOR DEBUGGING | ||
# logger = multiprocessing.log_to_stderr() | ||
# logger.setLevel(logging.INFO) | ||
config.backend.run(DriverProcess.worker, config) | ||
try: | ||
config.backend.run(DriverProcess.worker, config) | ||
except KeyboardInterrupt: | ||
return | ||
|
||
@staticmethod | ||
async def worker(config: DriverProcessConfig) -> None: | ||
|
@@ -95,7 +108,6 @@ async def worker(backend: BaseBackend, name: str, num_events: int) -> None: | |
await event_bus.connect_to_endpoints( | ||
ConnectionConfig.from_name(REPORTER_ENDPOINT) | ||
) | ||
await event_bus.wait_until_all_remotes_subscribed_to(TotalRecordedEvent) | ||
|
||
stats = LocalStatistic() | ||
events = event_bus.stream(PerfMeasureEvent, num_events=num_events) | ||
|
@@ -104,6 +116,7 @@ async def worker(backend: BaseBackend, name: str, num_events: int) -> None: | |
RawMeasureEntry(sent_at=event.sent_at, received_at=time.time()) | ||
) | ||
|
||
await event_bus.wait_until_all_remotes_subscribed_to(TotalRecordedEvent) | ||
await event_bus.broadcast( | ||
TotalRecordedEvent(stats.crunch(event_bus.name)), | ||
BroadcastConfig(filter_endpoint=REPORTER_ENDPOINT), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can/should extract to standalone.