Skip to content

Commit

Permalink
Use websockets for json communication (#4490)
Browse files Browse the repository at this point in the history
* first commit

* Fix bugs

* Fix more bugsz

* slight change

* minor bug fixes

* It's working quite well now

* Update to solve that last bug. I think it should work now

* bug fix that was causing messages to be send to tall worker threads

* Fix unittests issue and fix formatting

* Minor change to increase robustness of message sending

* Some more formatting changes

* Add in worker auth and use wss rather than ws for websocket URLs to make it secure. Not yet tested b/c on plane and internet isn't good enough to run CodaLab (and download images and such)

* Added in server auth with secret. Aslo still need to test (still on plane)

* Fixed issues and got auth working. Now, I'll work on returning error codes so that we have proper tests

* Add in tests for authentication functionality (for worker and server)

* Slight cleanup to data sending code

* Adding in ssl certification

* Add in SSL stuff for worker; still testing on dev

* Revert "Add in SSL stuff for worker; still testing on dev"

This reverts commit 4eb3d7b.

* Revert "Adding in ssl certification"

This reverts commit cbd1505.

* Fixed formatting

* Very minor formatting change to ignore one line for MyPy

* Another minor formatting change

* add exponential backoff to see if that fixes dev issue

* Added code to actually detect worker disconnections now so that some websockets will be invalidated

* Make sockets get looped over in random order to help distribute load

* Clean up ws-server and delete a Dataclass I was using previously

* a few more minor changes

* Make websocket locks more robust and improve error messaging

* More permissible retries in case of other errors (e.g. like 1013). With exponential backoff, it's still not very aggressive at all

* minor change to have a different error message if worker doesn't yet have any sockets open with ws-server

* Rename send_json and send_json_message_with_sock

* Rearrange worker_model to minimize diff

* Final changes

* Fix formatting

* Minor changes to get auth working again and to robustly return an error if the client tries to connect to an invalid path.

* Merge in master and make some minor changes

---------

Co-authored-by: AndrewJGaut <[email protected]>
  • Loading branch information
AndrewJGaut and AndrewJGaut authored Sep 10, 2023
1 parent 528db0f commit 7c460c1
Show file tree
Hide file tree
Showing 14 changed files with 349 additions and 166 deletions.
156 changes: 119 additions & 37 deletions codalab/bin/ws_server.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,157 @@
# Main entry point for CodaLab cl-ws-server.
# Main entry point to the CodaLab Websocket Server.
# The Websocket Server handles communication between the REST server and workers.
import argparse
import asyncio
from collections import defaultdict
import logging
import os
import random
import re
import time
from typing import Any, Dict
import websockets
import threading

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')

worker_to_ws: Dict[str, Any] = {}
from codalab.lib.codalab_manager import CodaLabManager


async def rest_server_handler(websocket):
"""Handles routes of the form: /main. This route is called by the rest-server
whenever a worker needs to be pinged (to ask it to check in). The body of the
message is the worker id to ping. This function sends a message to the worker
with that worker id through an appropriate websocket.
class TimedLock:
"""A lock that gets automatically released after timeout_seconds.
"""
# Got a message from the rest server.
worker_id = await websocket.recv()
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")

try:
worker_ws = worker_to_ws[worker_id]
await worker_ws.send(worker_id)
except KeyError:
logger.error(f"Websocket not found for worker: {worker_id}")
def __init__(self, timeout_seconds: float = 60):
self._lock = threading.Lock()
self._time_since_locked: float
self._timeout: float = timeout_seconds

def acquire(self, blocking=True, timeout=-1):
acquired = self._lock.acquire(blocking, timeout)
if acquired:
self._time_since_locked = time.time()
return acquired

def locked(self):
return self._lock.locked()

def release(self):
self._lock.release()

def timeout(self):
return time.time() - self._time_since_locked > self._timeout

def release_if_timeout(self):
if self.locked() and self.timeout():
self.release()

async def worker_handler(websocket, worker_id):
"""Handles routes of the form: /worker/{id}. This route is called when
a worker first connects to the ws-server, creating a connection that can
be used to ask the worker to check-in later.
"""
# runs on worker connect
worker_to_ws[worker_id] = websocket
logger.warning(f"Connected to worker {worker_id}!")

worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict(
dict
) # Maps worker ID to socket ID to websocket
worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict(
dict
) # Maps worker ID to socket ID to lock
ACK = b'a'
logger = logging.getLogger(__name__)
manager = CodaLabManager()
bundle_model = manager.model()
worker_model = manager.worker_model()
server_secret = os.getenv("CODALAB_SERVER_SECRET")


async def send_to_worker_handler(server_websocket, worker_id):
"""Handles routes of the form: /send_to_worker/{worker_id}. This route is called by
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
"""
# Authenticate server.
received_secret = await server_websocket.recv()
if received_secret != server_secret:
logger.warning("Server unable to authenticate.")
await server_websocket.close(1008, "Server unable to authenticate.")
return

# Check if any websockets available
if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0:
logger.warning(f"No websockets currently available for worker {worker_id}")
await server_websocket.close(
1011, f"No websockets currently available for worker {worker_id}"
)
return

# Send message from server to worker.
for socket_id, worker_websocket in random.sample(
worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id])
):
if worker_to_lock[worker_id][socket_id].acquire(blocking=False):
data = await server_websocket.recv()
await worker_websocket.send(data)
await server_websocket.send(ACK)
worker_to_lock[worker_id][socket_id].release()
return

logger.warning(f"All websockets for worker {worker_id} are currently busy.")
await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.")


async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None:
"""Handles routes of the form: /worker_connect/{worker_id}/{socket_id}.
This route is called when a worker first connects to the ws-server, creating
a connection that can be used to ask the worker to check-in later.
"""
# Authenticate worker.
access_token = await websocket.recv()
user_id = worker_model.get_user_id_for_worker(worker_id=worker_id)
authenticated = bundle_model.access_token_exists_for_user(
'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible.
)
logger.error(f"AUTHENTICATED: {authenticated}")
if not authenticated:
logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.")
await websocket.close(
1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate."
)
return

# Establish a connection with worker and keep it alive.
worker_to_ws[worker_id][socket_id] = websocket
worker_to_lock[worker_id][socket_id] = TimedLock()
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")
while True:
try:
await asyncio.wait_for(websocket.recv(), timeout=60)
worker_to_lock[worker_id][
socket_id
].release_if_timeout() # Failsafe in case not released
except asyncio.futures.TimeoutError:
pass
except websockets.exceptions.ConnectionClosed:
logger.error(f"Socket connection closed with worker {worker_id}.")
logger.warning(f"Socket connection closed with worker {worker_id}.")
break


ROUTES = (
(r'^.*/main$', rest_server_handler),
(r'^.*/worker/(.+)$', worker_handler),
)
del worker_to_ws[worker_id][socket_id]
del worker_to_lock[worker_id][socket_id]
logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections")


async def ws_handler(websocket, *args):
"""Handler for websocket connections. Routes websockets to the appropriate
route handler defined in ROUTES."""
logger.warning(f"websocket handler, path: {websocket.path}.")
ROUTES = (
(r'^.*/send_to_worker/(.+)$', send_to_worker_handler),
(r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler),
)
logger.info(f"websocket handler, path: {websocket.path}.")
for (pattern, handler) in ROUTES:
match = re.match(pattern, websocket.path)
if match:
return await handler(websocket, *match.groups())
assert False
return await websocket.close(1011, f"Path {websocket.path} is not valid.")


async def async_main():
"""Main function that runs the websocket server."""
parser = argparse.ArgumentParser()
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
parser.add_argument(
'--port', help='Port to run the server on.', type=int, required=False, default=2901
)
args = parser.parse_args()
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
async with websockets.serve(ws_handler, "0.0.0.0", args.port):
Expand Down
9 changes: 8 additions & 1 deletion codalab/lib/codalab_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ def ws_server(self):
ws_port = self.config['ws-server']['ws_port']
return f"ws://ws-server:{ws_port}"

@property # type: ignore
@cached
def server_secret(self):
return os.getenv("CODALAB_SERVER_SECRET")

@property # type: ignore
@cached
def worker_socket_dir(self):
Expand Down Expand Up @@ -380,7 +385,9 @@ def model(self):

@cached
def worker_model(self):
return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server)
return WorkerModel(
self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret
)

@cached
def upload_manager(self):
Expand Down
14 changes: 5 additions & 9 deletions codalab/lib/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
read_args = {'type': 'get_target_info', 'depth': depth}
self._send_read_message(worker, response_socket_id, target, read_args)
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
result = self._worker_model.get_json_message(sock, 60)
result = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
if result is None: # dead workers are a fact of life now
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
raise NotFoundError(
Expand Down Expand Up @@ -365,9 +365,7 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
'path': target.subpath,
'read_args': read_args,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
if not self._worker_model.send_json_message(message, worker['worker_id']):
logging.info('Unable to reach worker')

def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
Expand All @@ -378,21 +376,19 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
'port': port,
'message': message,
}
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
if not self._worker_model.send_json_message(message, worker['worker_id']):
logging.info('Unable to reach worker')

def _get_read_response_stream(self, response_socket_id):
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
header_message = self._worker_model.get_json_message(sock, 60)
header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
precondition(header_message is not None, 'Unable to reach worker')
if 'error_code' in header_message:
raise http_error_to_exception(
header_message['error_code'], header_message['error_message']
)

fileobj = self._worker_model.get_stream(sock, 60)
fileobj = self._worker_model.recv_stream(sock, 60)
precondition(fileobj is not None, 'Unable to reach worker')
return fileobj

Expand Down
21 changes: 20 additions & 1 deletion codalab/model/bundle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None):

return OAuth2Token(self, **row)

def find_oauth2_token(self, client_id, user_id, expires_after):
def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()):
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
Expand All @@ -2900,6 +2900,25 @@ def find_oauth2_token(self, client_id, user_id, expires_after):

return OAuth2Token(self, **row)

def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool:
"""Check that the provided access_token exists in the database for the provided user_id.
"""
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
.where(
and_(
oauth2_token.c.client_id == client_id,
oauth2_token.c.user_id == user_id,
oauth2_token.c.access_token == access_token,
oauth2_token.c.expires > datetime.datetime.utcnow(),
)
)
.limit(1)
).fetchone()

return row is not None

def save_oauth2_token(self, token):
with self.engine.begin() as connection:
result = connection.execute(oauth2_token.insert().values(token.columns))
Expand Down
Loading

0 comments on commit 7c460c1

Please sign in to comment.