Skip to content

Commit

Permalink
Revert "Use websockets for json communication (#4490)" (#4548)
Browse files Browse the repository at this point in the history
This reverts commit 7c460c1.
  • Loading branch information
AndrewJGaut authored Oct 10, 2023
1 parent 5e97254 commit bf1dff7
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 349 deletions.
156 changes: 37 additions & 119 deletions codalab/bin/ws_server.py
Original file line number Diff line number Diff line change
@@ -1,157 +1,75 @@
# Main entry point to the CodaLab Websocket Server.
# The Websocket Server handles communication between the REST server and workers.
# Main entry point for CodaLab cl-ws-server.
import argparse
import asyncio
from collections import defaultdict
import logging
import os
import random
import re
import time
from typing import Any, Dict
import websockets
import threading

from codalab.lib.codalab_manager import CodaLabManager


class TimedLock:
"""A lock that gets automatically released after timeout_seconds.
"""

def __init__(self, timeout_seconds: float = 60):
self._lock = threading.Lock()
self._time_since_locked: float
self._timeout: float = timeout_seconds

def acquire(self, blocking=True, timeout=-1):
acquired = self._lock.acquire(blocking, timeout)
if acquired:
self._time_since_locked = time.time()
return acquired

def locked(self):
return self._lock.locked()

def release(self):
self._lock.release()
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d')

def timeout(self):
return time.time() - self._time_since_locked > self._timeout
worker_to_ws: Dict[str, Any] = {}

def release_if_timeout(self):
if self.locked() and self.timeout():
self.release()

async def rest_server_handler(websocket):
"""Handles routes of the form: /main. This route is called by the rest-server
whenever a worker needs to be pinged (to ask it to check in). The body of the
message is the worker id to ping. This function sends a message to the worker
with that worker id through an appropriate websocket.
"""
# Got a message from the rest server.
worker_id = await websocket.recv()
logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.")

worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict(
dict
) # Maps worker ID to socket ID to websocket
worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict(
dict
) # Maps worker ID to socket ID to lock
ACK = b'a'
logger = logging.getLogger(__name__)
manager = CodaLabManager()
bundle_model = manager.model()
worker_model = manager.worker_model()
server_secret = os.getenv("CODALAB_SERVER_SECRET")
try:
worker_ws = worker_to_ws[worker_id]
await worker_ws.send(worker_id)
except KeyError:
logger.error(f"Websocket not found for worker: {worker_id}")


async def send_to_worker_handler(server_websocket, worker_id):
"""Handles routes of the form: /send_to_worker/{worker_id}. This route is called by
the rest-server or bundle-manager when either wants to send a message/stream to the worker.
"""
# Authenticate server.
received_secret = await server_websocket.recv()
if received_secret != server_secret:
logger.warning("Server unable to authenticate.")
await server_websocket.close(1008, "Server unable to authenticate.")
return

# Check if any websockets available
if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0:
logger.warning(f"No websockets currently available for worker {worker_id}")
await server_websocket.close(
1011, f"No websockets currently available for worker {worker_id}"
)
return

# Send message from server to worker.
for socket_id, worker_websocket in random.sample(
worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id])
):
if worker_to_lock[worker_id][socket_id].acquire(blocking=False):
data = await server_websocket.recv()
await worker_websocket.send(data)
await server_websocket.send(ACK)
worker_to_lock[worker_id][socket_id].release()
return

logger.warning(f"All websockets for worker {worker_id} are currently busy.")
await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.")


async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None:
"""Handles routes of the form: /worker_connect/{worker_id}/{socket_id}.
This route is called when a worker first connects to the ws-server, creating
a connection that can be used to ask the worker to check-in later.
async def worker_handler(websocket, worker_id):
"""Handles routes of the form: /worker/{id}. This route is called when
a worker first connects to the ws-server, creating a connection that can
be used to ask the worker to check-in later.
"""
# Authenticate worker.
access_token = await websocket.recv()
user_id = worker_model.get_user_id_for_worker(worker_id=worker_id)
authenticated = bundle_model.access_token_exists_for_user(
'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible.
)
logger.error(f"AUTHENTICATED: {authenticated}")
if not authenticated:
logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.")
await websocket.close(
1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate."
)
return

# Establish a connection with worker and keep it alive.
worker_to_ws[worker_id][socket_id] = websocket
worker_to_lock[worker_id][socket_id] = TimedLock()
logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections")
# runs on worker connect
worker_to_ws[worker_id] = websocket
logger.warning(f"Connected to worker {worker_id}!")

while True:
try:
await asyncio.wait_for(websocket.recv(), timeout=60)
worker_to_lock[worker_id][
socket_id
].release_if_timeout() # Failsafe in case not released
except asyncio.futures.TimeoutError:
pass
except websockets.exceptions.ConnectionClosed:
logger.warning(f"Socket connection closed with worker {worker_id}.")
logger.error(f"Socket connection closed with worker {worker_id}.")
break
del worker_to_ws[worker_id][socket_id]
del worker_to_lock[worker_id][socket_id]
logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections")


ROUTES = (
(r'^.*/main$', rest_server_handler),
(r'^.*/worker/(.+)$', worker_handler),
)


async def ws_handler(websocket, *args):
"""Handler for websocket connections. Routes websockets to the appropriate
route handler defined in ROUTES."""
ROUTES = (
(r'^.*/send_to_worker/(.+)$', send_to_worker_handler),
(r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler),
)
logger.info(f"websocket handler, path: {websocket.path}.")
logger.warning(f"websocket handler, path: {websocket.path}.")
for (pattern, handler) in ROUTES:
match = re.match(pattern, websocket.path)
if match:
return await handler(websocket, *match.groups())
return await websocket.close(1011, f"Path {websocket.path} is not valid.")
assert False


async def async_main():
"""Main function that runs the websocket server."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--port', help='Port to run the server on.', type=int, required=False, default=2901
)
parser.add_argument('--port', help='Port to run the server on.', type=int, required=True)
args = parser.parse_args()
logging.debug(f"Running ws-server on 0.0.0.0:{args.port}")
async with websockets.serve(ws_handler, "0.0.0.0", args.port):
Expand Down
9 changes: 1 addition & 8 deletions codalab/lib/codalab_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,6 @@ def ws_server(self):
ws_port = self.config['ws-server']['ws_port']
return f"ws://ws-server:{ws_port}"

@property # type: ignore
@cached
def server_secret(self):
return os.getenv("CODALAB_SERVER_SECRET")

@property # type: ignore
@cached
def worker_socket_dir(self):
Expand Down Expand Up @@ -385,9 +380,7 @@ def model(self):

@cached
def worker_model(self):
return WorkerModel(
self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret
)
return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server)

@cached
def upload_manager(self):
Expand Down
14 changes: 9 additions & 5 deletions codalab/lib/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth):
read_args = {'type': 'get_target_info', 'depth': depth}
self._send_read_message(worker, response_socket_id, target, read_args)
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
result = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
result = self._worker_model.get_json_message(sock, 60)
if result is None: # dead workers are a fact of life now
logging.info('Unable to reach worker, bundle state {}'.format(bundle_state))
raise NotFoundError(
Expand Down Expand Up @@ -365,7 +365,9 @@ def _send_read_message(self, worker, response_socket_id, target, read_args):
'path': target.subpath,
'read_args': read_args,
}
if not self._worker_model.send_json_message(message, worker['worker_id']):
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
Expand All @@ -376,19 +378,21 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message):
'port': port,
'message': message,
}
if not self._worker_model.send_json_message(message, worker['worker_id']):
if not self._worker_model.send_json_message(
worker['socket_id'], worker['worker_id'], message, 60
): # dead workers are a fact of life now
logging.info('Unable to reach worker')

def _get_read_response_stream(self, response_socket_id):
with closing(self._worker_model.start_listening(response_socket_id)) as sock:
header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60)
header_message = self._worker_model.get_json_message(sock, 60)
precondition(header_message is not None, 'Unable to reach worker')
if 'error_code' in header_message:
raise http_error_to_exception(
header_message['error_code'], header_message['error_message']
)

fileobj = self._worker_model.recv_stream(sock, 60)
fileobj = self._worker_model.get_stream(sock, 60)
precondition(fileobj is not None, 'Unable to reach worker')
return fileobj

Expand Down
21 changes: 1 addition & 20 deletions codalab/model/bundle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None):

return OAuth2Token(self, **row)

def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()):
def find_oauth2_token(self, client_id, user_id, expires_after):
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
Expand All @@ -2900,25 +2900,6 @@ def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.

return OAuth2Token(self, **row)

def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool:
"""Check that the provided access_token exists in the database for the provided user_id.
"""
with self.engine.begin() as connection:
row = connection.execute(
select([oauth2_token])
.where(
and_(
oauth2_token.c.client_id == client_id,
oauth2_token.c.user_id == user_id,
oauth2_token.c.access_token == access_token,
oauth2_token.c.expires > datetime.datetime.utcnow(),
)
)
.limit(1)
).fetchone()

return row is not None

def save_oauth2_token(self, token):
with self.engine.begin() as connection:
result = connection.execute(oauth2_token.insert().values(token.columns))
Expand Down
Loading

0 comments on commit bf1dff7

Please sign in to comment.