diff --git a/codalab/bin/ws_server.py b/codalab/bin/ws_server.py index b77bad10c..583befc9f 100644 --- a/codalab/bin/ws_server.py +++ b/codalab/bin/ws_server.py @@ -1,75 +1,157 @@ -# Main entry point for CodaLab cl-ws-server. +# Main entry point to the CodaLab Websocket Server. +# The Websocket Server handles communication between the REST server and workers. import argparse import asyncio +from collections import defaultdict import logging +import os +import random import re +import time from typing import Any, Dict import websockets +import threading -logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) -logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d') - -worker_to_ws: Dict[str, Any] = {} +from codalab.lib.codalab_manager import CodaLabManager -async def rest_server_handler(websocket): - """Handles routes of the form: /main. This route is called by the rest-server - whenever a worker needs to be pinged (to ask it to check in). The body of the - message is the worker id to ping. This function sends a message to the worker - with that worker id through an appropriate websocket. +class TimedLock: + """A lock that gets automatically released after timeout_seconds. """ - # Got a message from the rest server. - worker_id = await websocket.recv() - logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.") - try: - worker_ws = worker_to_ws[worker_id] - await worker_ws.send(worker_id) - except KeyError: - logger.error(f"Websocket not found for worker: {worker_id}") + def __init__(self, timeout_seconds: float = 60): + self._lock = threading.Lock() + self._time_since_locked: float + self._timeout: float = timeout_seconds + def acquire(self, blocking=True, timeout=-1): + acquired = self._lock.acquire(blocking, timeout) + if acquired: + self._time_since_locked = time.time() + return acquired + + def locked(self): + return self._lock.locked() + + def release(self): + self._lock.release() + + def timeout(self): + return time.time() - self._time_since_locked > self._timeout + + def release_if_timeout(self): + if self.locked() and self.timeout(): + self.release() -async def worker_handler(websocket, worker_id): - """Handles routes of the form: /worker/{id}. This route is called when - a worker first connects to the ws-server, creating a connection that can - be used to ask the worker to check-in later. - """ - # runs on worker connect - worker_to_ws[worker_id] = websocket - logger.warning(f"Connected to worker {worker_id}!") +worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict( + dict +) # Maps worker ID to socket ID to websocket +worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict( + dict +) # Maps worker ID to socket ID to lock +ACK = b'a' +logger = logging.getLogger(__name__) +manager = CodaLabManager() +bundle_model = manager.model() +worker_model = manager.worker_model() +server_secret = os.getenv("CODALAB_SERVER_SECRET") + + +async def send_to_worker_handler(server_websocket, worker_id): + """Handles routes of the form: /send_to_worker/{worker_id}. This route is called by + the rest-server or bundle-manager when either wants to send a message/stream to the worker. + """ + # Authenticate server. + received_secret = await server_websocket.recv() + if received_secret != server_secret: + logger.warning("Server unable to authenticate.") + await server_websocket.close(1008, "Server unable to authenticate.") + return + + # Check if any websockets available + if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0: + logger.warning(f"No websockets currently available for worker {worker_id}") + await server_websocket.close( + 1011, f"No websockets currently available for worker {worker_id}" + ) + return + + # Send message from server to worker. + for socket_id, worker_websocket in random.sample( + worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id]) + ): + if worker_to_lock[worker_id][socket_id].acquire(blocking=False): + data = await server_websocket.recv() + await worker_websocket.send(data) + await server_websocket.send(ACK) + worker_to_lock[worker_id][socket_id].release() + return + + logger.warning(f"All websockets for worker {worker_id} are currently busy.") + await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.") + + +async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None: + """Handles routes of the form: /worker_connect/{worker_id}/{socket_id}. + This route is called when a worker first connects to the ws-server, creating + a connection that can be used to ask the worker to check-in later. + """ + # Authenticate worker. + access_token = await websocket.recv() + user_id = worker_model.get_user_id_for_worker(worker_id=worker_id) + authenticated = bundle_model.access_token_exists_for_user( + 'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible. + ) + logger.error(f"AUTHENTICATED: {authenticated}") + if not authenticated: + logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.") + await websocket.close( + 1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate." + ) + return + + # Establish a connection with worker and keep it alive. + worker_to_ws[worker_id][socket_id] = websocket + worker_to_lock[worker_id][socket_id] = TimedLock() + logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections") while True: try: await asyncio.wait_for(websocket.recv(), timeout=60) + worker_to_lock[worker_id][ + socket_id + ].release_if_timeout() # Failsafe in case not released except asyncio.futures.TimeoutError: pass except websockets.exceptions.ConnectionClosed: - logger.error(f"Socket connection closed with worker {worker_id}.") + logger.warning(f"Socket connection closed with worker {worker_id}.") break - - -ROUTES = ( - (r'^.*/main$', rest_server_handler), - (r'^.*/worker/(.+)$', worker_handler), -) + del worker_to_ws[worker_id][socket_id] + del worker_to_lock[worker_id][socket_id] + logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections") async def ws_handler(websocket, *args): """Handler for websocket connections. Routes websockets to the appropriate route handler defined in ROUTES.""" - logger.warning(f"websocket handler, path: {websocket.path}.") + ROUTES = ( + (r'^.*/send_to_worker/(.+)$', send_to_worker_handler), + (r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler), + ) + logger.info(f"websocket handler, path: {websocket.path}.") for (pattern, handler) in ROUTES: match = re.match(pattern, websocket.path) if match: return await handler(websocket, *match.groups()) - assert False + return await websocket.close(1011, f"Path {websocket.path} is not valid.") async def async_main(): """Main function that runs the websocket server.""" parser = argparse.ArgumentParser() - parser.add_argument('--port', help='Port to run the server on.', type=int, required=True) + parser.add_argument( + '--port', help='Port to run the server on.', type=int, required=False, default=2901 + ) args = parser.parse_args() logging.debug(f"Running ws-server on 0.0.0.0:{args.port}") async with websockets.serve(ws_handler, "0.0.0.0", args.port): diff --git a/codalab/lib/codalab_manager.py b/codalab/lib/codalab_manager.py index 71fcd088c..ab46d5ebd 100644 --- a/codalab/lib/codalab_manager.py +++ b/codalab/lib/codalab_manager.py @@ -248,6 +248,11 @@ def ws_server(self): ws_port = self.config['ws-server']['ws_port'] return f"ws://ws-server:{ws_port}" + @property # type: ignore + @cached + def server_secret(self): + return os.getenv("CODALAB_SERVER_SECRET") + @property # type: ignore @cached def worker_socket_dir(self): @@ -380,7 +385,9 @@ def model(self): @cached def worker_model(self): - return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server) + return WorkerModel( + self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret + ) @cached def upload_manager(self): diff --git a/codalab/lib/download_manager.py b/codalab/lib/download_manager.py index fd77cb9a1..09d617c55 100644 --- a/codalab/lib/download_manager.py +++ b/codalab/lib/download_manager.py @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth): read_args = {'type': 'get_target_info', 'depth': depth} self._send_read_message(worker, response_socket_id, target, read_args) with closing(self._worker_model.start_listening(response_socket_id)) as sock: - result = self._worker_model.get_json_message(sock, 60) + result = self._worker_model.recv_json_message_with_unix_socket(sock, 60) if result is None: # dead workers are a fact of life now logging.info('Unable to reach worker, bundle state {}'.format(bundle_state)) raise NotFoundError( @@ -365,9 +365,7 @@ def _send_read_message(self, worker, response_socket_id, target, read_args): 'path': target.subpath, 'read_args': read_args, } - if not self._worker_model.send_json_message( - worker['socket_id'], worker['worker_id'], message, 60 - ): # dead workers are a fact of life now + if not self._worker_model.send_json_message(message, worker['worker_id']): logging.info('Unable to reach worker') def _send_netcat_message(self, worker, response_socket_id, uuid, port, message): @@ -378,21 +376,19 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message): 'port': port, 'message': message, } - if not self._worker_model.send_json_message( - worker['socket_id'], worker['worker_id'], message, 60 - ): # dead workers are a fact of life now + if not self._worker_model.send_json_message(message, worker['worker_id']): logging.info('Unable to reach worker') def _get_read_response_stream(self, response_socket_id): with closing(self._worker_model.start_listening(response_socket_id)) as sock: - header_message = self._worker_model.get_json_message(sock, 60) + header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60) precondition(header_message is not None, 'Unable to reach worker') if 'error_code' in header_message: raise http_error_to_exception( header_message['error_code'], header_message['error_message'] ) - fileobj = self._worker_model.get_stream(sock, 60) + fileobj = self._worker_model.recv_stream(sock, 60) precondition(fileobj is not None, 'Unable to reach worker') return fileobj diff --git a/codalab/model/bundle_model.py b/codalab/model/bundle_model.py index 73b665085..4224ac1d0 100644 --- a/codalab/model/bundle_model.py +++ b/codalab/model/bundle_model.py @@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None): return OAuth2Token(self, **row) - def find_oauth2_token(self, client_id, user_id, expires_after): + def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()): with self.engine.begin() as connection: row = connection.execute( select([oauth2_token]) @@ -2900,6 +2900,25 @@ def find_oauth2_token(self, client_id, user_id, expires_after): return OAuth2Token(self, **row) + def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool: + """Check that the provided access_token exists in the database for the provided user_id. + """ + with self.engine.begin() as connection: + row = connection.execute( + select([oauth2_token]) + .where( + and_( + oauth2_token.c.client_id == client_id, + oauth2_token.c.user_id == user_id, + oauth2_token.c.access_token == access_token, + oauth2_token.c.expires > datetime.datetime.utcnow(), + ) + ) + .limit(1) + ).fetchone() + + return row is not None + def save_oauth2_token(self, token): with self.engine.begin() as connection: result = connection.execute(oauth2_token.insert().values(token.columns)) diff --git a/codalab/model/worker_model.py b/codalab/model/worker_model.py index 6fe1ecbc4..746e66928 100644 --- a/codalab/model/worker_model.py +++ b/codalab/model/worker_model.py @@ -1,4 +1,3 @@ -import asyncio from contextlib import closing import datetime import json @@ -6,7 +5,8 @@ import os import socket import time -import websockets +from websockets.sync.client import connect +import traceback from sqlalchemy import and_, select @@ -20,6 +20,8 @@ ) logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d') class WorkerModel(object): @@ -35,10 +37,13 @@ class WorkerModel(object): listen on these sockets for messages and send messages to these sockets. """ - def __init__(self, engine, socket_dir, ws_server): + ACK = b'a' + + def __init__(self, engine, socket_dir, ws_server, server_secret): self._engine = engine self._socket_dir = socket_dir self._ws_server = ws_server + self._server_secret = server_secret def worker_checkin( self, @@ -121,7 +126,6 @@ def worker_checkin( user_id=user_id, worker_id=worker_id, dependencies=blob ) ) - return socket_id @staticmethod @@ -285,7 +289,7 @@ def start_listening(self, socket_id): get_ methods below. as in: with closing(worker_model.start_listening(socket_id)) as sock: - message = worker_model.get_json_message(sock, timeout_secs) + message = worker_model.get_json_message_with_unix_socket(sock, timeout_secs) """ self._cleanup_socket(socket_id) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -293,9 +297,7 @@ def start_listening(self, socket_id): sock.listen(0) return sock - ACK = b'a' - - def get_stream(self, sock, timeout_secs): + def recv_stream(self, sock, timeout_secs): """ Receives a single message on the given socket and returns a file-like object that can be used for streaming the message data. @@ -316,14 +318,14 @@ def get_stream(self, sock, timeout_secs): except socket.timeout: return None - def get_json_message(self, sock, timeout_secs): + def recv_json_message_with_unix_socket(self, sock, timeout_secs): """ Receives a single message on the given socket and returns the message data parsed as JSON. If no messages are received within timeout_secs seconds, returns None. """ - fileobj = self.get_stream(sock, timeout_secs) + fileobj = self.recv_stream(sock, timeout_secs) if fileobj is None: return None @@ -363,18 +365,9 @@ def send_stream(self, socket_id, fileobj, timeout_secs): return False - def _ping_worker_ws(self, worker_id): - async def ping_ws(): - async with websockets.connect(f"{self._ws_server}/main") as websocket: - await websocket.send(worker_id) - - futures = [ping_ws()] - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(asyncio.wait(futures)) - logging.warn(f"Pinged worker through websockets, worker id: {worker_id}") - - def send_json_message(self, socket_id, worker_id, message, timeout_secs, autoretry=True): + def send_json_message_with_unix_socket( + self, socket_id, worker_id, message, timeout_secs, autoretry=True + ): """ Sends a JSON message to the given socket, retrying until it is received correctly. @@ -385,7 +378,6 @@ def send_json_message(self, socket_id, worker_id, message, timeout_secs, autoret Note, only the worker should call this method with autoretry set to False. See comments below. """ - self._ping_worker_ws(worker_id) start_time = time.time() while time.time() - start_time < timeout_secs: with closing(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) as sock: @@ -409,6 +401,7 @@ def send_json_message(self, socket_id, worker_id, message, timeout_secs, autoret success = True except socket.error as e: logging.error(f"socket error when calling send_json_message: {e}") + logging.error(traceback.print_exc()) if not success: # Shouldn't be too expensive just to keep retrying. @@ -451,3 +444,44 @@ def has_reply_permission(self, user_id, worker_id, socket_id): if row: return True return False + + def send_json_message( + self, data: dict, worker_id: str, timeout_secs: int = 5, initial_sleep: float = 0.1 + ): + """ + Send JSON message to the worker. + + :param worker_id: The ID of the worker to send data to + :param data: Data to send to worker. Could be a file or a json message. + :param timeout_secs: Seconds until send fails due to timeout. + :param initial_sleep: Time to sleep before retrying send after a failure. + Note: upon successive failures, exponential backoff is applied. + + :return True if data was sent properly, False otherwise. + """ + start_time = time.time() + sleep_time = initial_sleep + while time.time() - start_time < timeout_secs: + try: + with connect(f"{self._ws_server}/send_to_worker/{worker_id}") as websocket: + websocket.send(self._server_secret) # Authenticate + websocket.send(json.dumps(data).encode()) + ack = websocket.recv() + return ack == self.ACK + except Exception as e: + logger.error(f"Send to worker {worker_id} failed with {e}. Retrying...") + time.sleep(sleep_time) + sleep_time *= 2 # Exponential backoff + return False + + def get_user_id_for_worker(self, worker_id): + """Return the user_id corresponding to the worker with ID worker_id + """ + with self._engine.begin() as conn: + row = conn.execute( + select([cl_worker]).where(cl_worker.c.worker_id == worker_id) + ).fetchone() + + if row is None: + return None + return row.user_id diff --git a/codalab/rest/bundle_actions.py b/codalab/rest/bundle_actions.py index dfbfc4d82..f2eaa941c 100644 --- a/codalab/rest/bundle_actions.py +++ b/codalab/rest/bundle_actions.py @@ -32,9 +32,7 @@ def create_bundle_actions(): # The state updates of bundles in PREPARING, RUNNING, or FINALIZING state will be handled on the worker side. if worker: precondition( - local.worker_model.send_json_message( - worker['socket_id'], worker['worker_id'], action, 60 - ), + local.worker_model.send_json_message(action, worker['worker_id'], 60), 'Unable to reach worker.', ) local.model.update_bundle(bundle, {'metadata': {'actions': new_actions}}) diff --git a/codalab/rest/workers.py b/codalab/rest/workers.py index c0a9ba6fc..9dcd5cf9f 100644 --- a/codalab/rest/workers.py +++ b/codalab/rest/workers.py @@ -1,7 +1,6 @@ from __future__ import ( absolute_import, ) # Without this line "from worker.worker import VERSION" doesn't work. -from contextlib import closing import http.client import json from datetime import datetime @@ -25,10 +24,9 @@ def checkin(worker_id): Waits for a message for the worker for WAIT_TIME_SECS seconds. Returns the message or None if there isn't one. """ - WAIT_TIME_SECS = 5.0 # Old workers might not have all the fields, so allow subsets to be missing. - socket_id = local.worker_model.worker_checkin( + local.worker_model.worker_checkin( request.user.user_id, worker_id, request.json.get("tag"), @@ -45,7 +43,6 @@ def checkin(worker_id): request.json.get("preemptible", False), ) - messages = [] for run in request.json["runs"]: try: worker_run = BundleCheckinState.from_dict(run) @@ -60,22 +57,21 @@ def checkin(worker_id): 'Kill requested: User time quota exceeded. To apply for more quota, please visit the following link: ' 'https://codalab-worksheets.readthedocs.io/en/latest/FAQ/#how-do-i-request-more-disk-quota-or-time-quota' ) - messages.append({'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}) + local.worker_model.send_json_message( + {'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}, worker_id + ) elif local.model.get_user_disk_quota_left(bundle.owner_id) <= 0: # Then, user has gone over their disk quota and we kill the job. kill_message = ( 'Kill requested: User disk quota exceeded. To apply for more quota, please visit the following link: ' 'https://codalab-worksheets.readthedocs.io/en/latest/FAQ/#how-do-i-request-more-disk-quota-or-time-quota' ) - messages.append({'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}) + local.worker_model.send_json_message( + {'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}, worker_id + ) except Exception as e: logger.info("Exception in REST checkin: {}".format(e)) - with closing(local.worker_model.start_listening(socket_id)) as sock: - messages.append(local.worker_model.get_json_message(sock, WAIT_TIME_SECS)) - response.content_type = 'application/json' - return json.dumps(messages) - def check_reply_permission(worker_id, socket_id): """ @@ -96,7 +92,9 @@ def reply(worker_id, socket_id): Replies with a single JSON message to the given socket ID. """ check_reply_permission(worker_id, socket_id) - local.worker_model.send_json_message(socket_id, worker_id, request.json, 60, autoretry=False) + local.worker_model.send_json_message_with_unix_socket( + socket_id, worker_id, request.json, 60, autoretry=False + ) @post( @@ -124,7 +122,9 @@ def reply_data(worker_id, socket_id): abort(http.client.BAD_REQUEST, "Header message should be in JSON format.") check_reply_permission(worker_id, socket_id) - local.worker_model.send_json_message(socket_id, worker_id, header_message, 60, autoretry=False) + local.worker_model.send_json_message_with_unix_socket( + socket_id, worker_id, header_message, 60, autoretry=False + ) local.worker_model.send_stream(socket_id, request["wsgi.input"], 60) diff --git a/codalab/server/bundle_manager.py b/codalab/server/bundle_manager.py index 642530d9b..e341762e2 100644 --- a/codalab/server/bundle_manager.py +++ b/codalab/server/bundle_manager.py @@ -385,10 +385,7 @@ def _acknowledge_recently_finished_bundles(self, workers): ) self._model.transition_bundle_worker_offline(bundle) elif self._worker_model.send_json_message( - worker['socket_id'], - worker['worker_id'], - {'type': 'mark_finalized', 'uuid': bundle.uuid}, - 1, + {'type': 'mark_finalized', 'uuid': bundle.uuid}, worker['worker_id'] ): logger.info( 'Acknowledged finalization of run bundle {} on worker {}'.format( @@ -398,6 +395,8 @@ def _acknowledge_recently_finished_bundles(self, workers): bundle_location = self._bundle_store.get_bundle_location(bundle.uuid) # TODO(Ashwin): fix this -- bundle location could be linked. self._model.transition_bundle_finished(bundle, bundle_location) + else: + logger.info(f"Bundle {bundle.uuid} could not be finalized.") def _bring_offline_stuck_running_bundles(self, workers): """ @@ -742,16 +741,17 @@ def _try_start_bundle(self, workers, worker, bundle, bundle_resources): remove_path(path) os.mkdir(path) if self._worker_model.send_json_message( - worker['socket_id'], - worker['worker_id'], self._construct_run_message(worker['shared_file_system'], bundle, bundle_resources), - 1, + worker['worker_id'], ): logger.info( 'Starting run bundle {} on worker {}'.format(bundle.uuid, worker['worker_id']) ) return True else: + logger.info( + f"Bundle {bundle.uuid} could not be started on worker {worker['worker_id']}" + ) self._model.transition_bundle_staged(bundle) workers.restage(bundle.uuid) return False diff --git a/codalab/worker/main.py b/codalab/worker/main.py index 35b1f76f8..1f6614ead 100644 --- a/codalab/worker/main.py +++ b/codalab/worker/main.py @@ -116,6 +116,12 @@ def parse_args(): default=None, help='Limit the amount of memory to a worker in bytes' '(e.g. 3, 3k, 3m, 3g, 3t).', ) + parser.add_argument( + '--num-coroutines', + help='Number of worker threads to have running concurrently waiting for socket messages. Must be a natural number.', + type=int, + default=10, + ) parser.add_argument( '--password-file', help='Path to the file containing the username and ' @@ -391,6 +397,7 @@ def main(): exit_on_exception=args.exit_on_exception, shared_memory_size_gb=args.shared_memory_size_gb, preemptible=args.preemptible, + num_coroutines=args.num_coroutines, bundle_runtime=bundle_runtime_class, ) diff --git a/codalab/worker/worker.py b/codalab/worker/worker.py index ca5712fb3..f1662f2ad 100644 --- a/codalab/worker/worker.py +++ b/codalab/worker/worker.py @@ -13,6 +13,7 @@ from typing import Optional, Set, Dict from types import SimpleNamespace import websockets +import json import psutil @@ -90,6 +91,8 @@ def __init__( exit_on_exception=False, # type: bool shared_memory_size_gb=1, # type: int preemptible=False, # type: bool + num_coroutines=10, # type: int + # Number of threads to have running concurrently waiting for socket messages. MUST be a natural number. ): self.image_manager = image_manager self.dependency_manager = dependency_manager @@ -136,6 +139,8 @@ def __init__( self.ws_server = ws_server + self.num_coroutines = num_coroutines + self.runs = {} # type: Dict[str, RunState] self.docker_network_prefix = docker_network_prefix self.init_docker_networks(docker_network_prefix) @@ -292,6 +297,80 @@ def check_num_runs_stop(self): """ return self.exit_after_num_runs == self.num_runs and len(self.runs) == 0 + def process_message(self, message): + """ + Process messages from the rest server in the worker. + + :param message: (list(dict)) A list of JSON messages for the worker. + :return: None + """ + message = json.loads(message.decode()) + with self._lock: + # Stop processing any new runs received from server + if not message or self.terminate_and_restage or self.terminate: + return + action_type = message['type'] + logger.debug('Received %s message: %s', action_type, message) + if action_type == 'run': + self.initialize_run(message['bundle'], message['resources']) + else: + uuid = message['uuid'] + socket_id = message.get('socket_id', None) + if uuid not in self.runs: + if action_type in ['read', 'netcat']: + self.read_run_missing(socket_id) + return + if action_type == 'kill': + kill_message = 'Kill requested' + if 'kill_message' in message: + kill_message = message['kill_message'] + self.kill(uuid, kill_message) + elif action_type == 'mark_finalized': + self.mark_finalized(uuid) + elif action_type == 'read': + self.read(socket_id, uuid, message['path'], message['read_args']) + elif action_type == 'netcat': + self.netcat(socket_id, uuid, message['port'], message['message']) + elif action_type == 'write': + self.write(uuid, message['subpath'], message['string']) + else: + logger.warning("Unrecognized action type from server: %s", action_type) + + async def recv_messages(self, websocket): + # Authenticate with Websocket Server. + await websocket.send(self.bundle_service._get_access_token()) + + # Loop and keep waiting for messages. + while not self.terminate: + try: + await websocket.send("a") + message = await asyncio.wait_for(websocket.recv(), timeout=10) + self.process_message(message) + except asyncio.TimeoutError: + pass + except websockets.exceptions.ConnectionClosed: + logger.warning("Websocket connection closed, starting a new one...") + break + except Exception: + logger.error(traceback.print_exc()) + + async def listen(self, thread_id): + wss_uri = f"{self.ws_server}/worker_connect/{self.id}/{thread_id}" + while not self.terminate: + logger.info(f"Connecting to {wss_uri}") + try: + async with websockets.connect(f"{wss_uri}", max_queue=1) as websocket: + await self.recv_messages(websocket) + except Exception: + logger.error(f"Error connecting to ws-server: {traceback.print_exc()}") + time.sleep(3) + + def listen_thread_fn(self): + coroutines = [self.listen(i) for i in range(self.num_coroutines)] + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(asyncio.gather(*coroutines)) + def start(self): """Return whether we ran anything.""" logger.info(f"my id is: {self.id}") @@ -301,39 +380,8 @@ def start(self): if not self.shared_file_system: self.dependency_manager.start() - async def listen(self): - logger.warning("Started websocket listening thread") - while not self.terminate: - logger.warning(f"Connecting anew to: {self.ws_server}/worker/{self.id}") - async with websockets.connect( - f"{self.ws_server}/worker/{self.id}", max_queue=1 - ) as websocket: - - async def receive_msg(): - await websocket.send("a") - data = await asyncio.wait_for(websocket.recv(), timeout=10) - logger.warning( - f"Got websocket message, got data: {data}, going to check in now." - ) - self.checkin() - self.last_checkin = time.time() - - while not self.terminate: - try: - await receive_msg() - except asyncio.TimeoutError: - pass - except websockets.exceptions.ConnectionClosed: - logger.warning("Websocket connection closed, starting a new one...") - break - - def listen_thread_fn(self): - futures = [listen(self)] - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(asyncio.wait(futures)) - - self.listen_thread = threading.Thread(target=listen_thread_fn, args=[self]) + asyncio.new_event_loop() + self.listen_thread = threading.Thread(target=self.listen_thread_fn) self.listen_thread.start() while not self.terminate: try: @@ -394,8 +442,7 @@ def cleanup(self): Blocks until cleanup is complete and it is safe to quit """ logger.info("Stopping Worker") - if self.listen_thread: - self.listen_thread.join() + self.listen_thread.join() self.image_manager.stop() if not self.shared_file_system: self.dependency_manager.stop() @@ -512,7 +559,7 @@ def checkin(self): }, ) try: - response = self.bundle_service.checkin(self.id, request) + self.bundle_service.checkin(self.id, request) logger.info('Connected! Successful check in!') self.last_checkin_successful = True except BundleServiceException as ex: @@ -523,41 +570,6 @@ def checkin(self): ) time.sleep(self.CHECKIN_COOLDOWN) self.last_checkin_successful = False - response = None - # Stop processing any new runs received from server - if not response or self.terminate_and_restage or self.terminate: - return - if type(response) is not list: - response = [response] - for action in response: - if not action: - continue - action_type = action['type'] - logger.debug('Received %s message: %s', action_type, action) - if action_type == 'run': - self.initialize_run(action['bundle'], action['resources']) - else: - uuid = action['uuid'] - socket_id = action.get('socket_id', None) - if uuid not in self.runs: - if action_type in ['read', 'netcat']: - self.read_run_missing(socket_id) - return - if action_type == 'kill': - kill_message = 'Kill requested' - if 'kill_message' in action: - kill_message = action['kill_message'] - self.kill(uuid, kill_message) - elif action_type == 'mark_finalized': - self.mark_finalized(uuid) - elif action_type == 'read': - self.read(socket_id, uuid, action['path'], action['read_args']) - elif action_type == 'netcat': - self.netcat(socket_id, uuid, action['port'], action['message']) - elif action_type == 'write': - self.write(uuid, action['subpath'], action['string']) - else: - logger.warning("Unrecognized action type from server: %s", action_type) self.process_runs() def process_runs(self): diff --git a/codalab_service.py b/codalab_service.py index 451f69257..298cd6ac8 100755 --- a/codalab_service.py +++ b/codalab_service.py @@ -246,6 +246,10 @@ def has_callable_default(self): CodalabArg(name='ws_port', help='Port for websocket server', type=int, default=2901), CodalabArg(name='rest_num_processes', help='Number of processes', type=int, default=1), CodalabArg(name='server', help='URL to server (used by external worker to connect to)'), + CodalabArg( + name='server_secret', + help='Secret key used to authenticate the REST server with the Websocket server', + ), CodalabArg( name='shared_file_system', help='Whether worker has access to the bundle mount', type=bool ), diff --git a/docker_config/compose_files/docker-compose.yml b/docker_config/compose_files/docker-compose.yml index 4012a0d06..cffe9af64 100644 --- a/docker_config/compose_files/docker-compose.yml +++ b/docker_config/compose_files/docker-compose.yml @@ -28,6 +28,7 @@ x-codalab-env: &codalab-env - CODALAB_EMAIL_USERNAME=${CODALAB_EMAIL_USERNAME} - CODALAB_EMAIL_PASSWORD=${CODALAB_EMAIL_PASSWORD} - CODALAB_SERVER=${CODALAB_SERVER} + - CODALAB_SERVER_SECRET=${CODALAB_SERVER_SECRET} - CODALAB_SHARED_FILE_SYSTEM=${CODALAB_SHARED_FILE_SYSTEM} - CODALAB_LINK_MOUNTS=${CODALAB_LINK_MOUNTS} - CODALAB_BUNDLE_MANAGER_WORKER_TIMEOUT_SECONDS=${CODALAB_BUNDLE_MANAGER_WORKER_TIMEOUT_SECONDS} diff --git a/requirements.txt b/requirements.txt index 6abd213e5..735a14a72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,7 +41,7 @@ urllib3==1.26.11 retry==0.9.2 spython==0.1.14 flufl.lock==6.0 -websockets==9.1 +websockets==11.0.3 kubernetes==12.0.1 google-cloud-storage==2.0.0 httpio==0.3.0 diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 42c21a673..ba381995c 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -43,6 +43,7 @@ import time import traceback import requests +import websockets global cl @@ -755,6 +756,28 @@ def test_auth(ctx): os.environ["CODALAB_PASSWORD"] = password check_contains("user: codalab", _run_command([cl, 'status'])) + # Set up websocket authentication tests. + worker_id = 'auth-test-worker' + create_worker(ctx, current_user()[0], worker_id) + codalab_server_secret = os.environ["CODALAB_SERVER_SECRET"] + os.environ["CODALAB_SERVER_SECRET"] = "fake-secret" + worker_model = CodaLabManager().worker_model() # The server secret will be set to "fake-secret" + ws_server_uri = worker_model._ws_server + os.environ["CODALAB_SERVER_SECRET"] = codalab_server_secret + check_equals(worker_model.send_json_message({'a': 1}, 'auth-test-worker', 1), False) + + # Test worker authentication for websocket endpoint. + exception = None + try: + with websockets.sync.client.connect(f"{ws_server_uri}/worker_connect/{worker_id}/15") as ws: + ws.send("fake-access-token") + ws.recv() + except websockets.exceptions.ConnectionClosedError as e: + exception = e + check_contains(exception.reason, f"Thread 15 for worker {worker_id} unable to authenticate.") + check_equals(1008, exception.code) + check_equals(exception.rcvd_then_sent, True) + @TestModule.register('upload1') def test_upload1(ctx):