diff --git a/codalab/bin/ws_server.py b/codalab/bin/ws_server.py index 583befc9f..b77bad10c 100644 --- a/codalab/bin/ws_server.py +++ b/codalab/bin/ws_server.py @@ -1,157 +1,75 @@ -# Main entry point to the CodaLab Websocket Server. -# The Websocket Server handles communication between the REST server and workers. +# Main entry point for CodaLab cl-ws-server. import argparse import asyncio -from collections import defaultdict import logging -import os -import random import re -import time from typing import Any, Dict import websockets -import threading -from codalab.lib.codalab_manager import CodaLabManager - - -class TimedLock: - """A lock that gets automatically released after timeout_seconds. - """ - - def __init__(self, timeout_seconds: float = 60): - self._lock = threading.Lock() - self._time_since_locked: float - self._timeout: float = timeout_seconds - - def acquire(self, blocking=True, timeout=-1): - acquired = self._lock.acquire(blocking, timeout) - if acquired: - self._time_since_locked = time.time() - return acquired - - def locked(self): - return self._lock.locked() - - def release(self): - self._lock.release() +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) +logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d') - def timeout(self): - return time.time() - self._time_since_locked > self._timeout +worker_to_ws: Dict[str, Any] = {} - def release_if_timeout(self): - if self.locked() and self.timeout(): - self.release() +async def rest_server_handler(websocket): + """Handles routes of the form: /main. This route is called by the rest-server + whenever a worker needs to be pinged (to ask it to check in). The body of the + message is the worker id to ping. This function sends a message to the worker + with that worker id through an appropriate websocket. + """ + # Got a message from the rest server. + worker_id = await websocket.recv() + logger.warning(f"Got a message from the rest server, to ping worker: {worker_id}.") -worker_to_ws: Dict[str, Dict[str, Any]] = defaultdict( - dict -) # Maps worker ID to socket ID to websocket -worker_to_lock: Dict[str, Dict[str, TimedLock]] = defaultdict( - dict -) # Maps worker ID to socket ID to lock -ACK = b'a' -logger = logging.getLogger(__name__) -manager = CodaLabManager() -bundle_model = manager.model() -worker_model = manager.worker_model() -server_secret = os.getenv("CODALAB_SERVER_SECRET") + try: + worker_ws = worker_to_ws[worker_id] + await worker_ws.send(worker_id) + except KeyError: + logger.error(f"Websocket not found for worker: {worker_id}") -async def send_to_worker_handler(server_websocket, worker_id): - """Handles routes of the form: /send_to_worker/{worker_id}. This route is called by - the rest-server or bundle-manager when either wants to send a message/stream to the worker. - """ - # Authenticate server. - received_secret = await server_websocket.recv() - if received_secret != server_secret: - logger.warning("Server unable to authenticate.") - await server_websocket.close(1008, "Server unable to authenticate.") - return - - # Check if any websockets available - if worker_id not in worker_to_ws or len(worker_to_ws[worker_id]) == 0: - logger.warning(f"No websockets currently available for worker {worker_id}") - await server_websocket.close( - 1011, f"No websockets currently available for worker {worker_id}" - ) - return - - # Send message from server to worker. - for socket_id, worker_websocket in random.sample( - worker_to_ws[worker_id].items(), len(worker_to_ws[worker_id]) - ): - if worker_to_lock[worker_id][socket_id].acquire(blocking=False): - data = await server_websocket.recv() - await worker_websocket.send(data) - await server_websocket.send(ACK) - worker_to_lock[worker_id][socket_id].release() - return - - logger.warning(f"All websockets for worker {worker_id} are currently busy.") - await server_websocket.close(1011, f"All websockets for worker {worker_id} are currently busy.") - - -async def worker_connection_handler(websocket: Any, worker_id: str, socket_id: str) -> None: - """Handles routes of the form: /worker_connect/{worker_id}/{socket_id}. - This route is called when a worker first connects to the ws-server, creating - a connection that can be used to ask the worker to check-in later. +async def worker_handler(websocket, worker_id): + """Handles routes of the form: /worker/{id}. This route is called when + a worker first connects to the ws-server, creating a connection that can + be used to ask the worker to check-in later. """ - # Authenticate worker. - access_token = await websocket.recv() - user_id = worker_model.get_user_id_for_worker(worker_id=worker_id) - authenticated = bundle_model.access_token_exists_for_user( - 'codalab_worker_client', user_id, access_token # TODO: Avoid hard-coding this if possible. - ) - logger.error(f"AUTHENTICATED: {authenticated}") - if not authenticated: - logger.warning(f"Thread {socket_id} for worker {worker_id} unable to authenticate.") - await websocket.close( - 1008, f"Thread {socket_id} for worker {worker_id} unable to authenticate." - ) - return - - # Establish a connection with worker and keep it alive. - worker_to_ws[worker_id][socket_id] = websocket - worker_to_lock[worker_id][socket_id] = TimedLock() - logger.warning(f"Worker {worker_id} connected; has {len(worker_to_ws[worker_id])} connections") + # runs on worker connect + worker_to_ws[worker_id] = websocket + logger.warning(f"Connected to worker {worker_id}!") + while True: try: await asyncio.wait_for(websocket.recv(), timeout=60) - worker_to_lock[worker_id][ - socket_id - ].release_if_timeout() # Failsafe in case not released except asyncio.futures.TimeoutError: pass except websockets.exceptions.ConnectionClosed: - logger.warning(f"Socket connection closed with worker {worker_id}.") + logger.error(f"Socket connection closed with worker {worker_id}.") break - del worker_to_ws[worker_id][socket_id] - del worker_to_lock[worker_id][socket_id] - logger.warning(f"Worker {worker_id} now has {len(worker_to_ws[worker_id])} connections") + + +ROUTES = ( + (r'^.*/main$', rest_server_handler), + (r'^.*/worker/(.+)$', worker_handler), +) async def ws_handler(websocket, *args): """Handler for websocket connections. Routes websockets to the appropriate route handler defined in ROUTES.""" - ROUTES = ( - (r'^.*/send_to_worker/(.+)$', send_to_worker_handler), - (r'^.*/worker_connect/(.+)/(.+)$', worker_connection_handler), - ) - logger.info(f"websocket handler, path: {websocket.path}.") + logger.warning(f"websocket handler, path: {websocket.path}.") for (pattern, handler) in ROUTES: match = re.match(pattern, websocket.path) if match: return await handler(websocket, *match.groups()) - return await websocket.close(1011, f"Path {websocket.path} is not valid.") + assert False async def async_main(): """Main function that runs the websocket server.""" parser = argparse.ArgumentParser() - parser.add_argument( - '--port', help='Port to run the server on.', type=int, required=False, default=2901 - ) + parser.add_argument('--port', help='Port to run the server on.', type=int, required=True) args = parser.parse_args() logging.debug(f"Running ws-server on 0.0.0.0:{args.port}") async with websockets.serve(ws_handler, "0.0.0.0", args.port): diff --git a/codalab/lib/codalab_manager.py b/codalab/lib/codalab_manager.py index ab46d5ebd..71fcd088c 100644 --- a/codalab/lib/codalab_manager.py +++ b/codalab/lib/codalab_manager.py @@ -248,11 +248,6 @@ def ws_server(self): ws_port = self.config['ws-server']['ws_port'] return f"ws://ws-server:{ws_port}" - @property # type: ignore - @cached - def server_secret(self): - return os.getenv("CODALAB_SERVER_SECRET") - @property # type: ignore @cached def worker_socket_dir(self): @@ -385,9 +380,7 @@ def model(self): @cached def worker_model(self): - return WorkerModel( - self.model().engine, self.worker_socket_dir, self.ws_server, self.server_secret - ) + return WorkerModel(self.model().engine, self.worker_socket_dir, self.ws_server) @cached def upload_manager(self): diff --git a/codalab/lib/download_manager.py b/codalab/lib/download_manager.py index 09d617c55..fd77cb9a1 100644 --- a/codalab/lib/download_manager.py +++ b/codalab/lib/download_manager.py @@ -137,7 +137,7 @@ def _get_target_info_within_bundle(self, target, depth): read_args = {'type': 'get_target_info', 'depth': depth} self._send_read_message(worker, response_socket_id, target, read_args) with closing(self._worker_model.start_listening(response_socket_id)) as sock: - result = self._worker_model.recv_json_message_with_unix_socket(sock, 60) + result = self._worker_model.get_json_message(sock, 60) if result is None: # dead workers are a fact of life now logging.info('Unable to reach worker, bundle state {}'.format(bundle_state)) raise NotFoundError( @@ -365,7 +365,9 @@ def _send_read_message(self, worker, response_socket_id, target, read_args): 'path': target.subpath, 'read_args': read_args, } - if not self._worker_model.send_json_message(message, worker['worker_id']): + if not self._worker_model.send_json_message( + worker['socket_id'], worker['worker_id'], message, 60 + ): # dead workers are a fact of life now logging.info('Unable to reach worker') def _send_netcat_message(self, worker, response_socket_id, uuid, port, message): @@ -376,19 +378,21 @@ def _send_netcat_message(self, worker, response_socket_id, uuid, port, message): 'port': port, 'message': message, } - if not self._worker_model.send_json_message(message, worker['worker_id']): + if not self._worker_model.send_json_message( + worker['socket_id'], worker['worker_id'], message, 60 + ): # dead workers are a fact of life now logging.info('Unable to reach worker') def _get_read_response_stream(self, response_socket_id): with closing(self._worker_model.start_listening(response_socket_id)) as sock: - header_message = self._worker_model.recv_json_message_with_unix_socket(sock, 60) + header_message = self._worker_model.get_json_message(sock, 60) precondition(header_message is not None, 'Unable to reach worker') if 'error_code' in header_message: raise http_error_to_exception( header_message['error_code'], header_message['error_message'] ) - fileobj = self._worker_model.recv_stream(sock, 60) + fileobj = self._worker_model.get_stream(sock, 60) precondition(fileobj is not None, 'Unable to reach worker') return fileobj diff --git a/codalab/model/bundle_model.py b/codalab/model/bundle_model.py index 4224ac1d0..73b665085 100644 --- a/codalab/model/bundle_model.py +++ b/codalab/model/bundle_model.py @@ -2881,7 +2881,7 @@ def get_oauth2_token(self, access_token=None, refresh_token=None): return OAuth2Token(self, **row) - def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime.utcnow()): + def find_oauth2_token(self, client_id, user_id, expires_after): with self.engine.begin() as connection: row = connection.execute( select([oauth2_token]) @@ -2900,25 +2900,6 @@ def find_oauth2_token(self, client_id, user_id, expires_after=datetime.datetime. return OAuth2Token(self, **row) - def access_token_exists_for_user(self, client_id: str, user_id: str, access_token: str) -> bool: - """Check that the provided access_token exists in the database for the provided user_id. - """ - with self.engine.begin() as connection: - row = connection.execute( - select([oauth2_token]) - .where( - and_( - oauth2_token.c.client_id == client_id, - oauth2_token.c.user_id == user_id, - oauth2_token.c.access_token == access_token, - oauth2_token.c.expires > datetime.datetime.utcnow(), - ) - ) - .limit(1) - ).fetchone() - - return row is not None - def save_oauth2_token(self, token): with self.engine.begin() as connection: result = connection.execute(oauth2_token.insert().values(token.columns)) diff --git a/codalab/model/worker_model.py b/codalab/model/worker_model.py index 746e66928..6fe1ecbc4 100644 --- a/codalab/model/worker_model.py +++ b/codalab/model/worker_model.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import closing import datetime import json @@ -5,8 +6,7 @@ import os import socket import time -from websockets.sync.client import connect -import traceback +import websockets from sqlalchemy import and_, select @@ -20,8 +20,6 @@ ) logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -logging.basicConfig(format='%(asctime)s %(message)s %(pathname)s %(lineno)d') class WorkerModel(object): @@ -37,13 +35,10 @@ class WorkerModel(object): listen on these sockets for messages and send messages to these sockets. """ - ACK = b'a' - - def __init__(self, engine, socket_dir, ws_server, server_secret): + def __init__(self, engine, socket_dir, ws_server): self._engine = engine self._socket_dir = socket_dir self._ws_server = ws_server - self._server_secret = server_secret def worker_checkin( self, @@ -126,6 +121,7 @@ def worker_checkin( user_id=user_id, worker_id=worker_id, dependencies=blob ) ) + return socket_id @staticmethod @@ -289,7 +285,7 @@ def start_listening(self, socket_id): get_ methods below. as in: with closing(worker_model.start_listening(socket_id)) as sock: - message = worker_model.get_json_message_with_unix_socket(sock, timeout_secs) + message = worker_model.get_json_message(sock, timeout_secs) """ self._cleanup_socket(socket_id) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) @@ -297,7 +293,9 @@ def start_listening(self, socket_id): sock.listen(0) return sock - def recv_stream(self, sock, timeout_secs): + ACK = b'a' + + def get_stream(self, sock, timeout_secs): """ Receives a single message on the given socket and returns a file-like object that can be used for streaming the message data. @@ -318,14 +316,14 @@ def recv_stream(self, sock, timeout_secs): except socket.timeout: return None - def recv_json_message_with_unix_socket(self, sock, timeout_secs): + def get_json_message(self, sock, timeout_secs): """ Receives a single message on the given socket and returns the message data parsed as JSON. If no messages are received within timeout_secs seconds, returns None. """ - fileobj = self.recv_stream(sock, timeout_secs) + fileobj = self.get_stream(sock, timeout_secs) if fileobj is None: return None @@ -365,9 +363,18 @@ def send_stream(self, socket_id, fileobj, timeout_secs): return False - def send_json_message_with_unix_socket( - self, socket_id, worker_id, message, timeout_secs, autoretry=True - ): + def _ping_worker_ws(self, worker_id): + async def ping_ws(): + async with websockets.connect(f"{self._ws_server}/main") as websocket: + await websocket.send(worker_id) + + futures = [ping_ws()] + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(asyncio.wait(futures)) + logging.warn(f"Pinged worker through websockets, worker id: {worker_id}") + + def send_json_message(self, socket_id, worker_id, message, timeout_secs, autoretry=True): """ Sends a JSON message to the given socket, retrying until it is received correctly. @@ -378,6 +385,7 @@ def send_json_message_with_unix_socket( Note, only the worker should call this method with autoretry set to False. See comments below. """ + self._ping_worker_ws(worker_id) start_time = time.time() while time.time() - start_time < timeout_secs: with closing(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) as sock: @@ -401,7 +409,6 @@ def send_json_message_with_unix_socket( success = True except socket.error as e: logging.error(f"socket error when calling send_json_message: {e}") - logging.error(traceback.print_exc()) if not success: # Shouldn't be too expensive just to keep retrying. @@ -444,44 +451,3 @@ def has_reply_permission(self, user_id, worker_id, socket_id): if row: return True return False - - def send_json_message( - self, data: dict, worker_id: str, timeout_secs: int = 5, initial_sleep: float = 0.1 - ): - """ - Send JSON message to the worker. - - :param worker_id: The ID of the worker to send data to - :param data: Data to send to worker. Could be a file or a json message. - :param timeout_secs: Seconds until send fails due to timeout. - :param initial_sleep: Time to sleep before retrying send after a failure. - Note: upon successive failures, exponential backoff is applied. - - :return True if data was sent properly, False otherwise. - """ - start_time = time.time() - sleep_time = initial_sleep - while time.time() - start_time < timeout_secs: - try: - with connect(f"{self._ws_server}/send_to_worker/{worker_id}") as websocket: - websocket.send(self._server_secret) # Authenticate - websocket.send(json.dumps(data).encode()) - ack = websocket.recv() - return ack == self.ACK - except Exception as e: - logger.error(f"Send to worker {worker_id} failed with {e}. Retrying...") - time.sleep(sleep_time) - sleep_time *= 2 # Exponential backoff - return False - - def get_user_id_for_worker(self, worker_id): - """Return the user_id corresponding to the worker with ID worker_id - """ - with self._engine.begin() as conn: - row = conn.execute( - select([cl_worker]).where(cl_worker.c.worker_id == worker_id) - ).fetchone() - - if row is None: - return None - return row.user_id diff --git a/codalab/rest/bundle_actions.py b/codalab/rest/bundle_actions.py index f2eaa941c..dfbfc4d82 100644 --- a/codalab/rest/bundle_actions.py +++ b/codalab/rest/bundle_actions.py @@ -32,7 +32,9 @@ def create_bundle_actions(): # The state updates of bundles in PREPARING, RUNNING, or FINALIZING state will be handled on the worker side. if worker: precondition( - local.worker_model.send_json_message(action, worker['worker_id'], 60), + local.worker_model.send_json_message( + worker['socket_id'], worker['worker_id'], action, 60 + ), 'Unable to reach worker.', ) local.model.update_bundle(bundle, {'metadata': {'actions': new_actions}}) diff --git a/codalab/rest/workers.py b/codalab/rest/workers.py index 9dcd5cf9f..c0a9ba6fc 100644 --- a/codalab/rest/workers.py +++ b/codalab/rest/workers.py @@ -1,6 +1,7 @@ from __future__ import ( absolute_import, ) # Without this line "from worker.worker import VERSION" doesn't work. +from contextlib import closing import http.client import json from datetime import datetime @@ -24,9 +25,10 @@ def checkin(worker_id): Waits for a message for the worker for WAIT_TIME_SECS seconds. Returns the message or None if there isn't one. """ + WAIT_TIME_SECS = 5.0 # Old workers might not have all the fields, so allow subsets to be missing. - local.worker_model.worker_checkin( + socket_id = local.worker_model.worker_checkin( request.user.user_id, worker_id, request.json.get("tag"), @@ -43,6 +45,7 @@ def checkin(worker_id): request.json.get("preemptible", False), ) + messages = [] for run in request.json["runs"]: try: worker_run = BundleCheckinState.from_dict(run) @@ -57,21 +60,22 @@ def checkin(worker_id): 'Kill requested: User time quota exceeded. To apply for more quota, please visit the following link: ' 'https://codalab-worksheets.readthedocs.io/en/latest/FAQ/#how-do-i-request-more-disk-quota-or-time-quota' ) - local.worker_model.send_json_message( - {'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}, worker_id - ) + messages.append({'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}) elif local.model.get_user_disk_quota_left(bundle.owner_id) <= 0: # Then, user has gone over their disk quota and we kill the job. kill_message = ( 'Kill requested: User disk quota exceeded. To apply for more quota, please visit the following link: ' 'https://codalab-worksheets.readthedocs.io/en/latest/FAQ/#how-do-i-request-more-disk-quota-or-time-quota' ) - local.worker_model.send_json_message( - {'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}, worker_id - ) + messages.append({'type': 'kill', 'uuid': bundle.uuid, 'kill_message': kill_message}) except Exception as e: logger.info("Exception in REST checkin: {}".format(e)) + with closing(local.worker_model.start_listening(socket_id)) as sock: + messages.append(local.worker_model.get_json_message(sock, WAIT_TIME_SECS)) + response.content_type = 'application/json' + return json.dumps(messages) + def check_reply_permission(worker_id, socket_id): """ @@ -92,9 +96,7 @@ def reply(worker_id, socket_id): Replies with a single JSON message to the given socket ID. """ check_reply_permission(worker_id, socket_id) - local.worker_model.send_json_message_with_unix_socket( - socket_id, worker_id, request.json, 60, autoretry=False - ) + local.worker_model.send_json_message(socket_id, worker_id, request.json, 60, autoretry=False) @post( @@ -122,9 +124,7 @@ def reply_data(worker_id, socket_id): abort(http.client.BAD_REQUEST, "Header message should be in JSON format.") check_reply_permission(worker_id, socket_id) - local.worker_model.send_json_message_with_unix_socket( - socket_id, worker_id, header_message, 60, autoretry=False - ) + local.worker_model.send_json_message(socket_id, worker_id, header_message, 60, autoretry=False) local.worker_model.send_stream(socket_id, request["wsgi.input"], 60) diff --git a/codalab/server/bundle_manager.py b/codalab/server/bundle_manager.py index e341762e2..642530d9b 100644 --- a/codalab/server/bundle_manager.py +++ b/codalab/server/bundle_manager.py @@ -385,7 +385,10 @@ def _acknowledge_recently_finished_bundles(self, workers): ) self._model.transition_bundle_worker_offline(bundle) elif self._worker_model.send_json_message( - {'type': 'mark_finalized', 'uuid': bundle.uuid}, worker['worker_id'] + worker['socket_id'], + worker['worker_id'], + {'type': 'mark_finalized', 'uuid': bundle.uuid}, + 1, ): logger.info( 'Acknowledged finalization of run bundle {} on worker {}'.format( @@ -395,8 +398,6 @@ def _acknowledge_recently_finished_bundles(self, workers): bundle_location = self._bundle_store.get_bundle_location(bundle.uuid) # TODO(Ashwin): fix this -- bundle location could be linked. self._model.transition_bundle_finished(bundle, bundle_location) - else: - logger.info(f"Bundle {bundle.uuid} could not be finalized.") def _bring_offline_stuck_running_bundles(self, workers): """ @@ -741,17 +742,16 @@ def _try_start_bundle(self, workers, worker, bundle, bundle_resources): remove_path(path) os.mkdir(path) if self._worker_model.send_json_message( - self._construct_run_message(worker['shared_file_system'], bundle, bundle_resources), + worker['socket_id'], worker['worker_id'], + self._construct_run_message(worker['shared_file_system'], bundle, bundle_resources), + 1, ): logger.info( 'Starting run bundle {} on worker {}'.format(bundle.uuid, worker['worker_id']) ) return True else: - logger.info( - f"Bundle {bundle.uuid} could not be started on worker {worker['worker_id']}" - ) self._model.transition_bundle_staged(bundle) workers.restage(bundle.uuid) return False diff --git a/codalab/worker/main.py b/codalab/worker/main.py index 1f6614ead..35b1f76f8 100644 --- a/codalab/worker/main.py +++ b/codalab/worker/main.py @@ -116,12 +116,6 @@ def parse_args(): default=None, help='Limit the amount of memory to a worker in bytes' '(e.g. 3, 3k, 3m, 3g, 3t).', ) - parser.add_argument( - '--num-coroutines', - help='Number of worker threads to have running concurrently waiting for socket messages. Must be a natural number.', - type=int, - default=10, - ) parser.add_argument( '--password-file', help='Path to the file containing the username and ' @@ -397,7 +391,6 @@ def main(): exit_on_exception=args.exit_on_exception, shared_memory_size_gb=args.shared_memory_size_gb, preemptible=args.preemptible, - num_coroutines=args.num_coroutines, bundle_runtime=bundle_runtime_class, ) diff --git a/codalab/worker/worker.py b/codalab/worker/worker.py index f1662f2ad..ca5712fb3 100644 --- a/codalab/worker/worker.py +++ b/codalab/worker/worker.py @@ -13,7 +13,6 @@ from typing import Optional, Set, Dict from types import SimpleNamespace import websockets -import json import psutil @@ -91,8 +90,6 @@ def __init__( exit_on_exception=False, # type: bool shared_memory_size_gb=1, # type: int preemptible=False, # type: bool - num_coroutines=10, # type: int - # Number of threads to have running concurrently waiting for socket messages. MUST be a natural number. ): self.image_manager = image_manager self.dependency_manager = dependency_manager @@ -139,8 +136,6 @@ def __init__( self.ws_server = ws_server - self.num_coroutines = num_coroutines - self.runs = {} # type: Dict[str, RunState] self.docker_network_prefix = docker_network_prefix self.init_docker_networks(docker_network_prefix) @@ -297,80 +292,6 @@ def check_num_runs_stop(self): """ return self.exit_after_num_runs == self.num_runs and len(self.runs) == 0 - def process_message(self, message): - """ - Process messages from the rest server in the worker. - - :param message: (list(dict)) A list of JSON messages for the worker. - :return: None - """ - message = json.loads(message.decode()) - with self._lock: - # Stop processing any new runs received from server - if not message or self.terminate_and_restage or self.terminate: - return - action_type = message['type'] - logger.debug('Received %s message: %s', action_type, message) - if action_type == 'run': - self.initialize_run(message['bundle'], message['resources']) - else: - uuid = message['uuid'] - socket_id = message.get('socket_id', None) - if uuid not in self.runs: - if action_type in ['read', 'netcat']: - self.read_run_missing(socket_id) - return - if action_type == 'kill': - kill_message = 'Kill requested' - if 'kill_message' in message: - kill_message = message['kill_message'] - self.kill(uuid, kill_message) - elif action_type == 'mark_finalized': - self.mark_finalized(uuid) - elif action_type == 'read': - self.read(socket_id, uuid, message['path'], message['read_args']) - elif action_type == 'netcat': - self.netcat(socket_id, uuid, message['port'], message['message']) - elif action_type == 'write': - self.write(uuid, message['subpath'], message['string']) - else: - logger.warning("Unrecognized action type from server: %s", action_type) - - async def recv_messages(self, websocket): - # Authenticate with Websocket Server. - await websocket.send(self.bundle_service._get_access_token()) - - # Loop and keep waiting for messages. - while not self.terminate: - try: - await websocket.send("a") - message = await asyncio.wait_for(websocket.recv(), timeout=10) - self.process_message(message) - except asyncio.TimeoutError: - pass - except websockets.exceptions.ConnectionClosed: - logger.warning("Websocket connection closed, starting a new one...") - break - except Exception: - logger.error(traceback.print_exc()) - - async def listen(self, thread_id): - wss_uri = f"{self.ws_server}/worker_connect/{self.id}/{thread_id}" - while not self.terminate: - logger.info(f"Connecting to {wss_uri}") - try: - async with websockets.connect(f"{wss_uri}", max_queue=1) as websocket: - await self.recv_messages(websocket) - except Exception: - logger.error(f"Error connecting to ws-server: {traceback.print_exc()}") - time.sleep(3) - - def listen_thread_fn(self): - coroutines = [self.listen(i) for i in range(self.num_coroutines)] - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(asyncio.gather(*coroutines)) - def start(self): """Return whether we ran anything.""" logger.info(f"my id is: {self.id}") @@ -380,8 +301,39 @@ def start(self): if not self.shared_file_system: self.dependency_manager.start() - asyncio.new_event_loop() - self.listen_thread = threading.Thread(target=self.listen_thread_fn) + async def listen(self): + logger.warning("Started websocket listening thread") + while not self.terminate: + logger.warning(f"Connecting anew to: {self.ws_server}/worker/{self.id}") + async with websockets.connect( + f"{self.ws_server}/worker/{self.id}", max_queue=1 + ) as websocket: + + async def receive_msg(): + await websocket.send("a") + data = await asyncio.wait_for(websocket.recv(), timeout=10) + logger.warning( + f"Got websocket message, got data: {data}, going to check in now." + ) + self.checkin() + self.last_checkin = time.time() + + while not self.terminate: + try: + await receive_msg() + except asyncio.TimeoutError: + pass + except websockets.exceptions.ConnectionClosed: + logger.warning("Websocket connection closed, starting a new one...") + break + + def listen_thread_fn(self): + futures = [listen(self)] + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(asyncio.wait(futures)) + + self.listen_thread = threading.Thread(target=listen_thread_fn, args=[self]) self.listen_thread.start() while not self.terminate: try: @@ -442,7 +394,8 @@ def cleanup(self): Blocks until cleanup is complete and it is safe to quit """ logger.info("Stopping Worker") - self.listen_thread.join() + if self.listen_thread: + self.listen_thread.join() self.image_manager.stop() if not self.shared_file_system: self.dependency_manager.stop() @@ -559,7 +512,7 @@ def checkin(self): }, ) try: - self.bundle_service.checkin(self.id, request) + response = self.bundle_service.checkin(self.id, request) logger.info('Connected! Successful check in!') self.last_checkin_successful = True except BundleServiceException as ex: @@ -570,6 +523,41 @@ def checkin(self): ) time.sleep(self.CHECKIN_COOLDOWN) self.last_checkin_successful = False + response = None + # Stop processing any new runs received from server + if not response or self.terminate_and_restage or self.terminate: + return + if type(response) is not list: + response = [response] + for action in response: + if not action: + continue + action_type = action['type'] + logger.debug('Received %s message: %s', action_type, action) + if action_type == 'run': + self.initialize_run(action['bundle'], action['resources']) + else: + uuid = action['uuid'] + socket_id = action.get('socket_id', None) + if uuid not in self.runs: + if action_type in ['read', 'netcat']: + self.read_run_missing(socket_id) + return + if action_type == 'kill': + kill_message = 'Kill requested' + if 'kill_message' in action: + kill_message = action['kill_message'] + self.kill(uuid, kill_message) + elif action_type == 'mark_finalized': + self.mark_finalized(uuid) + elif action_type == 'read': + self.read(socket_id, uuid, action['path'], action['read_args']) + elif action_type == 'netcat': + self.netcat(socket_id, uuid, action['port'], action['message']) + elif action_type == 'write': + self.write(uuid, action['subpath'], action['string']) + else: + logger.warning("Unrecognized action type from server: %s", action_type) self.process_runs() def process_runs(self): diff --git a/codalab_service.py b/codalab_service.py index 298cd6ac8..451f69257 100755 --- a/codalab_service.py +++ b/codalab_service.py @@ -246,10 +246,6 @@ def has_callable_default(self): CodalabArg(name='ws_port', help='Port for websocket server', type=int, default=2901), CodalabArg(name='rest_num_processes', help='Number of processes', type=int, default=1), CodalabArg(name='server', help='URL to server (used by external worker to connect to)'), - CodalabArg( - name='server_secret', - help='Secret key used to authenticate the REST server with the Websocket server', - ), CodalabArg( name='shared_file_system', help='Whether worker has access to the bundle mount', type=bool ), diff --git a/docker_config/compose_files/docker-compose.yml b/docker_config/compose_files/docker-compose.yml index cffe9af64..4012a0d06 100644 --- a/docker_config/compose_files/docker-compose.yml +++ b/docker_config/compose_files/docker-compose.yml @@ -28,7 +28,6 @@ x-codalab-env: &codalab-env - CODALAB_EMAIL_USERNAME=${CODALAB_EMAIL_USERNAME} - CODALAB_EMAIL_PASSWORD=${CODALAB_EMAIL_PASSWORD} - CODALAB_SERVER=${CODALAB_SERVER} - - CODALAB_SERVER_SECRET=${CODALAB_SERVER_SECRET} - CODALAB_SHARED_FILE_SYSTEM=${CODALAB_SHARED_FILE_SYSTEM} - CODALAB_LINK_MOUNTS=${CODALAB_LINK_MOUNTS} - CODALAB_BUNDLE_MANAGER_WORKER_TIMEOUT_SECONDS=${CODALAB_BUNDLE_MANAGER_WORKER_TIMEOUT_SECONDS} diff --git a/requirements.txt b/requirements.txt index 735a14a72..6abd213e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,7 +41,7 @@ urllib3==1.26.11 retry==0.9.2 spython==0.1.14 flufl.lock==6.0 -websockets==11.0.3 +websockets==9.1 kubernetes==12.0.1 google-cloud-storage==2.0.0 httpio==0.3.0 diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index ba381995c..42c21a673 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -43,7 +43,6 @@ import time import traceback import requests -import websockets global cl @@ -756,28 +755,6 @@ def test_auth(ctx): os.environ["CODALAB_PASSWORD"] = password check_contains("user: codalab", _run_command([cl, 'status'])) - # Set up websocket authentication tests. - worker_id = 'auth-test-worker' - create_worker(ctx, current_user()[0], worker_id) - codalab_server_secret = os.environ["CODALAB_SERVER_SECRET"] - os.environ["CODALAB_SERVER_SECRET"] = "fake-secret" - worker_model = CodaLabManager().worker_model() # The server secret will be set to "fake-secret" - ws_server_uri = worker_model._ws_server - os.environ["CODALAB_SERVER_SECRET"] = codalab_server_secret - check_equals(worker_model.send_json_message({'a': 1}, 'auth-test-worker', 1), False) - - # Test worker authentication for websocket endpoint. - exception = None - try: - with websockets.sync.client.connect(f"{ws_server_uri}/worker_connect/{worker_id}/15") as ws: - ws.send("fake-access-token") - ws.recv() - except websockets.exceptions.ConnectionClosedError as e: - exception = e - check_contains(exception.reason, f"Thread 15 for worker {worker_id} unable to authenticate.") - check_equals(1008, exception.code) - check_equals(exception.rcvd_then_sent, True) - @TestModule.register('upload1') def test_upload1(ctx):