From cb157a27d671e00b14d536f31b59bd2314372081 Mon Sep 17 00:00:00 2001 From: Edoardo Morassutto Date: Fri, 25 Dec 2020 16:40:49 +0100 Subject: [PATCH] Add support for ephemeral services. Ephemeral services are services that are not fixed in the configuration file, but dynamically added as they connect. This is especially useful in a setup in which cmsWorker/cmsContestWebServer are scaled dynamically, as one might do when configuring CMS for running on cloud services. --- cms/conf.py | 87 ++++++++++++++++++++++++++++ cms/io/web_service.py | 2 + cms/server/contest/server.py | 9 ++- cms/service/EvaluationService.py | 15 ++++- cms/service/Worker.py | 8 +++ cms/service/workerpool.py | 25 +++++++- cms/util.py | 38 ++++++++++-- cmstestsuite/unit_tests/util_test.py | 16 ++++- config/cms.conf.sample | 16 +++++ 9 files changed, 203 insertions(+), 13 deletions(-) diff --git a/cms/conf.py b/cms/conf.py index 96e53ab836..4d7ad170db 100644 --- a/cms/conf.py +++ b/cms/conf.py @@ -22,11 +22,14 @@ # along with this program. If not, see . import errno +import ipaddress import json import logging import os +import socket import sys from collections import namedtuple +from contextlib import closing from .log import set_detailed_logs @@ -44,6 +47,7 @@ class ServiceCoord(namedtuple("ServiceCoord", "name shard")): service (thus identifying it). """ + def __repr__(self): return "%s,%d" % (self.name, self.shard) @@ -53,6 +57,75 @@ class ConfigError(Exception): pass +class EphemeralServiceConfig: + """Configuration of an ephemeral service. An ephemeral service is a + normal service whose shard is chosen depending on its address and + port. The port is assigned inside a range and the address must be + inside the subnet. + """ + EPHEMERAL_SHARD_OFFSET = 10000 + + def __init__(self, subnet, min_port, max_port): + self.subnet = ipaddress.ip_network(subnet) + self.min_port = min_port + self.max_port = max_port + if min_port > max_port: + raise ConfigError("Invalid port range: [%s, %s]" + % (min_port, max_port)) + + def get_shard(self, address, port): + """Get the ephemeral shard for a service given its address and port. + + address (IPv4Address|IPv6Address): address of the service. + port (int): port of the service. + + return (int): shard of the service + """ + if address not in self.subnet: + raise ValueError("The address is not inside the subnet") + host_id = int(address) & int(self.subnet.hostmask) + num_ports = self.max_port - self.min_port + 1 + shard = host_id * num_ports + (port - self.min_port) + return shard + self.EPHEMERAL_SHARD_OFFSET + + def get_address(self, shard): + """Get the address and port of a service given its shard. + + shard (int): shard of the service + + return (Address): address and port of the service + """ + shard -= self.EPHEMERAL_SHARD_OFFSET + num_ports = self.max_port - self.min_port + 1 + port_offset = shard % num_ports + host_id = (shard - port_offset) // num_ports + + port = self.min_port + port_offset + addr = self.subnet.network_address + host_id + if addr not in self.subnet: + raise ValueError("The shard is not valid") + return Address(str(addr), port) + + def find_free_port(self, address): + """Find the first open port. + + address (IPv4Address|IPv6Address): local address to bind to + """ + if address.version == 4: + family = socket.AF_INET + else: + family = socket.AF_INET6 + for port in range(self.min_port, self.max_port+1): + with closing(socket.socket(family, socket.SOCK_STREAM)) as sock: + try: + sock.bind((str(address), port)) + return port + except socket.error: + continue + raise ValueError("No free port found in range [%s, %s] " + "for address %s" % (minport, maxport, address)) + + class AsyncConfig: """This class will contain the configuration for the services. This needs to be populated at the initilization stage. @@ -69,6 +142,7 @@ class AsyncConfig: """ core_services = {} other_services = {} + ephemeral_services = {} # type: dict[str, EphemeralServiceConfig] async_config = AsyncConfig() @@ -81,6 +155,7 @@ class Config: directory for information on the meaning of the fields. """ + def __init__(self): """Default values for configuration, plus decide if this instance is running from the system path or from the source @@ -274,6 +349,18 @@ def _load_unique(self, path): self.async_config.other_services[coord] = Address(*shard) del data["other_services"] + for service_name in data.get("ephemeral_services", {}): + if service_name.startswith("_"): + continue + service = data["ephemeral_services"][service_name] + self.async_config.ephemeral_services[service_name] = \ + EphemeralServiceConfig( + service["subnet"], + service["min_port"], + service["max_port"], + ) + del data["ephemeral_services"] + # Put everything else in self. for key, value in data.items(): setattr(self, key, value) diff --git a/cms/io/web_service.py b/cms/io/web_service.py index 21e1580514..89fce6f570 100644 --- a/cms/io/web_service.py +++ b/cms/io/web_service.py @@ -106,6 +106,8 @@ def __init__(self, listen_port, handlers, parameters, shard=0, if num_proxies_used > 0: self.wsgi_app = ProxyFix(self.wsgi_app, num_proxies_used) + logger.info("%s listening on '%s' at port %d", + type(self).__name__, listen_address, listen_port) self.web_server = WSGIServer((listen_address, listen_port), self) def __call__(self, environ, start_response): diff --git a/cms/server/contest/server.py b/cms/server/contest/server.py index 25a217aa91..9979ff9774 100644 --- a/cms/server/contest/server.py +++ b/cms/server/contest/server.py @@ -45,6 +45,7 @@ from cms.io import WebService from cms.locale import get_translations from cms.server.contest.jinja2_toolbox import CWS_ENVIRONMENT +from cms.util import is_shard_ephemeral from cmscommon.binary import hex_to_bin from .handlers import HANDLERS from .handlers.base import ContestListHandler @@ -73,8 +74,12 @@ def __init__(self, shard, contest_id=None): } try: - listen_address = config.contest_listen_address[shard] - listen_port = config.contest_listen_port[shard] + if is_shard_ephemeral(shard): + index = 0 + else: + index = shard + listen_address = config.contest_listen_address[index] + listen_port = config.contest_listen_port[index] except IndexError: raise ConfigError("Wrong shard number for %s, or missing " "address/port configuration. Please check " diff --git a/cms/service/EvaluationService.py b/cms/service/EvaluationService.py index 7a1b67d517..24f95b5cb0 100644 --- a/cms/service/EvaluationService.py +++ b/cms/service/EvaluationService.py @@ -161,7 +161,8 @@ def enqueue(self, item, priority, timestamp): item_entry = item.to_dict() del item_entry["testcase_codename"] item_entry["multiplicity"] = 1 - entry = {"item": item_entry, "priority": priority, "timestamp": make_timestamp(timestamp)} + entry = {"item": item_entry, "priority": priority, + "timestamp": make_timestamp(timestamp)} self.queue_status_cumulative[key] = entry return success @@ -197,6 +198,11 @@ def _remove_from_cumulative_status(self, queue_entry): if self.queue_status_cumulative[key]["item"]["multiplicity"] == 0: del self.queue_status_cumulative[key] + def add_worker(self, worker_coord): + """Add a new worker to the pool. + """ + self.pool.add_worker(worker_coord, ephemeral=True) + def with_post_finish_lock(func): """Decorator for locking on self.post_finish_lock. @@ -379,6 +385,13 @@ def workers_status(self): """ return self.get_executor().pool.get_status() + @rpc_method + def add_worker(self, coord): + """Register a new worker to the list of workers. + """ + service, shard = coord + self.get_executor().add_worker(ServiceCoord(service, shard)) + def check_workers_timeout(self): """We ask WorkerPool for the unresponsive workers, and we put again their operations in the queue. diff --git a/cms/service/Worker.py b/cms/service/Worker.py index c9c31647d3..e2fc2cb9be 100644 --- a/cms/service/Worker.py +++ b/cms/service/Worker.py @@ -30,6 +30,7 @@ import gevent.lock +from cms import ServiceCoord from cms.db import SessionGen, Contest, enumerate_files from cms.db.filecacher import FileCacher, TombstoneError from cms.grading import JobException @@ -64,6 +65,13 @@ def __init__(self, shard, fake_worker_time=None): self._fake_worker_time = fake_worker_time + self.evaluation_service = self.connect_to( + ServiceCoord("EvaluationService", 0), + on_connect=self.on_es_connection) + + def on_es_connection(self, address): + self.evaluation_service.add_worker(coord=self._my_coord) + @rpc_method def precache_files(self, contest_id): """RPC to ask the worker to precache of files in the contest. diff --git a/cms/service/workerpool.py b/cms/service/workerpool.py index 094cb0bb3b..31261be6a1 100644 --- a/cms/service/workerpool.py +++ b/cms/service/workerpool.py @@ -140,17 +140,20 @@ def wait_for_workers(self): """Wait until a worker might be available.""" self._workers_available_event.wait() - def add_worker(self, worker_coord): + def add_worker(self, worker_coord, ephemeral=False): """Add a new worker to the worker pool. worker_coord (ServiceCoord): the coordinates of the worker. + ephemeral (bool): remove the worker from the pool after the + disconnection. """ shard = worker_coord.shard # Instruct GeventLibrary to connect ES to the Worker. self._worker[shard] = self._service.connect_to( worker_coord, - on_connect=self.on_worker_connected) + on_connect=self.on_worker_connected, + on_disconnect=lambda: self.on_worker_disconnected(worker_coord, ephemeral)) # And we fill all data. self._operations[shard] = WorkerPool.WORKER_INACTIVE @@ -183,6 +186,24 @@ def on_worker_connected(self, worker_coord): # so we wake up the consumers. self._workers_available_event.set() + def on_worker_disconnected(self, worker_coord, ephemeral): + """If the worker is ephemeral, disable and the remove the worker + form the pool. + """ + if not ephemeral: + return + shard = worker_coord.shard + if self._operations[shard] != WorkerPool.WORKER_DISABLED: + # disable the worker and re-enqueue the lost operations + lost_operations = self.disable_worker(shard) + for operation in lost_operations: + logger.info("Operation %s put again in the queue because " + "the worker disconnected.", operation) + priority, timestamp = operation.side_data + self._service.enqueue(operation, priority, timestamp) + del self._worker[shard] + logger.info("Worker %s removed", worker_coord) + def acquire_worker(self, operations): """Tries to assign an operation to an available worker. If no workers are available then this returns None, otherwise this returns diff --git a/cms/util.py b/cms/util.py index 5c5aab78a5..cd0e6f910a 100644 --- a/cms/util.py +++ b/cms/util.py @@ -23,6 +23,7 @@ import argparse import itertools +import ipaddress import logging import netifaces import os @@ -35,6 +36,7 @@ import gevent.socket from cms import ServiceCoord, ConfigError, async_config, config +from cms.conf import EphemeralServiceConfig logger = logging.getLogger(__name__) @@ -136,8 +138,19 @@ def get_safe_shard(service, provided_shard): raise (ValueError): if no safe shard can be returned. """ + addrs = _find_local_addresses() + # Try to assign an ephemeral shard first. This needs to be done before + # autodetecting the shared using the ip since here we cannot detect if + # the service is already running on that port. + if provided_shard is None and service in config.async_config.ephemeral_services: + ephemeral_config = config.async_config.ephemeral_services[service] + for addr in addrs: + addr = ipaddress.ip_address(addr[1]) + if addr in ephemeral_config.subnet: + port = ephemeral_config.find_free_port(addr) + shard = ephemeral_config.get_shard(addr, port) + return shard if provided_shard is None: - addrs = _find_local_addresses() computed_shard = _get_shard_from_addresses(service, addrs) if computed_shard is None: logger.critical("Couldn't autodetect shard number and " @@ -157,6 +170,16 @@ def get_safe_shard(service, provided_shard): return provided_shard +def is_shard_ephemeral(shard): + """Checks if the shard is ephemeral. + + shard (int): the shard to check. + + return (bool): True if the shard is ephemeral. + """ + return shard >= EphemeralServiceConfig.EPHEMERAL_SHARD_OFFSET + + def get_service_address(key): """Give the Address of a ServiceCoord. @@ -164,10 +187,13 @@ def get_service_address(key): returns (Address): listening address of key. """ + service, shard = key if key in async_config.core_services: return async_config.core_services[key] elif key in async_config.other_services: return async_config.other_services[key] + elif service in async_config.ephemeral_services: + return async_config.ephemeral_services[service].get_address(shard) else: raise KeyError("Service not found.") @@ -179,11 +205,11 @@ def get_service_shards(service): returns (int): the number of shards defined in the configuration. """ - for i in itertools.count(): - try: - get_service_address(ServiceCoord(service, i)) - except KeyError: - return i + count = 0 + for services in (async_config.core_services, async_config.other_services): + count += len([0 for s in services if s.name == service]) + + return count def default_argument_parser(description, cls, ask_contest=None): diff --git a/cmstestsuite/unit_tests/util_test.py b/cmstestsuite/unit_tests/util_test.py index 24493824c5..d5060e55f6 100755 --- a/cmstestsuite/unit_tests/util_test.py +++ b/cmstestsuite/unit_tests/util_test.py @@ -24,6 +24,7 @@ import tempfile import unittest from unittest.mock import Mock +from cms.conf import EphemeralServiceConfig import cms.util from cms import Address, ServiceCoord, \ @@ -35,8 +36,10 @@ class FakeAsyncConfig: core_services = { ServiceCoord("Service", 0): Address("0.0.0.0", 0), ServiceCoord("Service", 1): Address("0.0.0.1", 1), - } + } other_services = {} + ephemeral_services = {"Service": + EphemeralServiceConfig("1.0.0.0/8", 1, 1000)} def _set_up_async_config(restore=False): @@ -66,6 +69,7 @@ def _set_up_ip_addresses(addresses=None, restore=False): class TestGetSafeShard(unittest.TestCase): """Test the function cms.util.get_safe_shard.""" + def setUp(self): """Set up the default mocks.""" _set_up_async_config() @@ -109,6 +113,7 @@ class TestGetServiceAddress(unittest.TestCase): """Test the function cms.util.get_service_address. """ + def setUp(self): """Set up the default mocks.""" _set_up_async_config() @@ -128,7 +133,7 @@ def test_success(self): def test_shard_not_present(self): """Test failure when the shard of the service is invalid.""" - with self.assertRaises(KeyError): + with self.assertRaises(ValueError): get_service_address(ServiceCoord("Service", 2)) def test_service_not_present(self): @@ -136,11 +141,17 @@ def test_service_not_present(self): with self.assertRaises(KeyError): get_service_address(ServiceCoord("ServiceNotPresent", 0)) + def test_ephemeral(self): + """Test ephemeral service case.""" + self.assertEqual(get_service_address(ServiceCoord( + "Service", EphemeralServiceConfig.EPHEMERAL_SHARD_OFFSET + 1000)), Address("1.0.0.1", 1)) + class TestGetServiceShards(unittest.TestCase): """Test the function cms.util.get_service_shards. """ + def setUp(self): """Set up the default mocks.""" _set_up_async_config() @@ -159,6 +170,7 @@ class TestRmtree(unittest.TestCase): """Test the function cms.util.rmtree. """ + def setUp(self): """Set up temporary directory.""" self.tmpdir = tempfile.mkdtemp() diff --git a/config/cms.conf.sample b/config/cms.conf.sample index be43efc5af..e33d4363d6 100644 --- a/config/cms.conf.sample +++ b/config/cms.conf.sample @@ -52,6 +52,20 @@ "TestFileCacher": [["localhost", 27501]] }, + "ephemeral_services": + { + "Worker": { + "subnet": "127.0.0.0/8", + "min_port": 26000, + "max_port": 26999 + }, + "ContestWebServer": { + "subnet": "127.0.0.0/8", + "min_port": 21000, + "max_port": 21000 + } + }, + "_section": "Database", @@ -106,6 +120,8 @@ "_help": "in core_services. If you access them through a proxy (acting", "_help": "as a load balancer) running on the same host you could put", "_help": "127.0.0.1 here for additional security.", + "_help": "When using ephemeral services only the first address and port", + "_help": "are used", "contest_listen_address": [""], "contest_listen_port": [8888],