diff --git a/docs/src/user/storage.rst b/docs/src/user/storage.rst index ff379ab88..14e1b5e9e 100644 --- a/docs/src/user/storage.rst +++ b/docs/src/user/storage.rst @@ -153,6 +153,46 @@ simply run the upgrade command. .. _storage_python_apis: +``dump`` Export database content +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``dump`` command allows to export database content to a PickledDB PKL file. + +.. code-block:: sh + + orion db dump -o backup.pkl + +You can also dump a specific experiment. + +.. code-block:: sh + + orion db dump -n exp-name -v exp-version -o backup-exp.pkl + +``load`` Import database content +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``load`` command allows to import database content +from any PickledDB PKL file (including files generated by ``dump`` command). + +You must specify a conflict resolution policy using ``-r/--resolve`` argument +to apply when conflicts are detected during import. Available policies are: + +- ``ignore``, to ignore imported data +- ``overwrite``, to replace old data with imported data +- ``bump``, to bump version of imported data and then make import + +By default, whole PKL file will be imported. + +.. code-block:: sh + + orion db load backup.pkl -r ignore + +You can also import a specific experiment. + +.. code-block:: sh + + orion db load backup.pkl -r overwrite -n exp-name -v exp-version + Python APIs =========== diff --git a/docs/src/user/web_api.rst b/docs/src/user/web_api.rst index 2351cd99a..0f533cc5f 100644 --- a/docs/src/user/web_api.rst +++ b/docs/src/user/web_api.rst @@ -393,7 +393,6 @@ visualize your experiments and their results. :statuscode 404: When the specified experiment doesn't exist in the database. - Benchmarks ---------- The benchmark resource permits the retrieval of in-progress and completed benchmarks. You can @@ -487,6 +486,101 @@ retrieve individual benchmarks as well as a list of all your benchmarks. or assessment, task or algorithms are not part of the existing benchmark configuration. +Database dumping +---------------- + +The database dumping resource allows to dump database content +into a PickledDB and download it as PKL file. + +.. http:get:: /dump + + Return a PKL file containing database content. + + :query name: Optional name of experiment to export. It unspecified, whole database is dumped. + :query version: Optional version of the experiment to retrieve. + If unspecified and name is specified, the **latest** version of the experiment is exported. + If both name and version are unspecified, whole database is dumped. + + :statuscode 404: When an error occurred during dumping. + +Database loading +---------------- + +The database loading resource allow to import data from a PKL file + +.. http:post:: /load + + Import data into database from a PKL file. + This is a POST request, as a file must be uploaded. + Launch an import task in a separate process in backend and return task ID + which may be used to get task progress. + + :query file: PKL file to import + :query resolve: policy to resolve conflicts during import. Either: + + - ``ignore``: ignore imported data on conflict + - ``overwrite``: overwrite ancient data on conflict + - ``bump``: bump version of imported data before insertion on conflict + + :query name: Optional name of experiment to import. If unspecified, whole data from PKL file is imported. + :query version: Optional version of experiment to import. + If unspecified and name is specified, the **latest** version of the experiment is imported. + If both name and version are unspceified, whole data from PKL file is imported. + + **Example response** + + .. sourcecode:: http + + HTTP/1.1 200 OK + Content-Type: text/javascript + + .. code-block:: json + + { + "task": "e453679d-e36b-427a-a14d-58fe5e42ca19" + } + + :>json task: The ID of the running task that are importing data. + + :statuscode 400: When an invalid query parameter is passed in the request. + :statuscode 403: When an import task is already running. + +Import progression +------------------ + +The import progression resource allows to monitor an import task launched by ``/load`` entry. + +.. http:get:: /import-status/:name + + Returns status of a running import task identified by given ``name``. + ``name`` is the task ID returned by ``/load`` entry. + + **Example response** + + .. sourcecode:: http + + HTTP/1.1 200 OK + Content-Type: text/javascript + + .. code-block:: json + + { + "messages": ["latest", "logging", "lines", "from", "import", "process"], + "progress_message": "description of current import step", + "progress_value": 0.889, + "status": "active" + } + + :>json messages: Latest logging lines printed in import process since last call to ``/import-status`` entry. + :>json progress_message: Description of current import process step. + :>json progress_value: Floating value (between 0 and 1 included) representing current import progression. + :>json status: Import process status. Either: + "active": still running + "error": terminated with an error (see latest messages for error info) + "finished": successfully terminated + + :statuscode 400: When an invalid query parameter is passed in the request. + Errors ------ Oríon uses `conventional HTTP response codes `_ diff --git a/src/orion/core/cli/db/dump.py b/src/orion/core/cli/db/dump.py new file mode 100644 index 000000000..5e32ffbe3 --- /dev/null +++ b/src/orion/core/cli/db/dump.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# pylint: disable=,protected-access +""" +Storage export tool +=================== + +Export database content into a file. + +""" +import logging + +from orion.core.cli import base as cli +from orion.core.io import experiment_builder +from orion.storage.backup import dump_database +from orion.storage.base import setup_storage + +logger = logging.getLogger(__name__) + +DESCRIPTION = "Export storage" + + +def add_subparser(parser): + """Add the subparser that needs to be used for this command""" + dump_parser = parser.add_parser("dump", help=DESCRIPTION, description=DESCRIPTION) + + cli.get_basic_args_group(dump_parser) + + dump_parser.add_argument( + "-o", + "--output", + type=str, + default="dump.pkl", + help="Output file path (default: dump.pkl)", + ) + + dump_parser.add_argument( + "-f", + "--force", + action="store_true", + help="Whether to force overwrite if destination file already exists. " + "If specified, delete destination file and recreate a new one from scratch. " + "Otherwise (default), raise an error if destination file already exists.", + ) + + dump_parser.set_defaults(func=main) + + return dump_parser + + +def main(args): + """Script to dump storage""" + config = experiment_builder.get_cmd_config(args) + storage = setup_storage(config.get("storage")) + logger.info(f"Loaded src {storage}") + dump_database( + storage, + args["output"], + name=config.get("name"), + version=config.get("version"), + overwrite=args["force"], + ) diff --git a/src/orion/core/cli/db/load.py b/src/orion/core/cli/db/load.py new file mode 100644 index 000000000..b8e230e13 --- /dev/null +++ b/src/orion/core/cli/db/load.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +""" +Storage import tool +=================== + +Import database content from a file. + +""" +import logging + +from orion.core.cli import base as cli +from orion.core.io import experiment_builder +from orion.storage.backup import load_database +from orion.storage.base import setup_storage + +logger = logging.getLogger(__name__) + +DESCRIPTION = "Import storage" + + +def add_subparser(parser): + """Add the subparser that needs to be used for this command""" + load_parser = parser.add_parser("load", help=DESCRIPTION, description=DESCRIPTION) + + cli.get_basic_args_group(load_parser) + + load_parser.add_argument( + "file", + type=str, + help="File to import", + ) + + load_parser.add_argument( + "-r", + "--resolve", + type=str, + choices=("ignore", "overwrite", "bump"), + help="Strategy to resolve conflicts: " + "'ignore', 'overwrite' or 'bump' " + "(bump version of imported experiment). " + "When overwriting, prior trials will be deleted. " + "If not specified, an exception will be raised on any conflict detected.", + ) + + load_parser.set_defaults(func=main) + + return load_parser + + +def main(args): + """Script to import storage""" + config = experiment_builder.get_cmd_config(args) + storage = setup_storage(config.get("storage")) + logger.info(f"Loaded dst {storage}") + load_database( + storage, + load_host=args["file"], + resolve=args["resolve"], + name=config.get("name"), + version=config.get("version"), + ) diff --git a/src/orion/core/cli/db/upgrade.py b/src/orion/core/cli/db/upgrade.py index 2c2991c96..d2871e3c6 100644 --- a/src/orion/core/cli/db/upgrade.py +++ b/src/orion/core/cli/db/upgrade.py @@ -135,7 +135,7 @@ def upgrade_documents(storage): ) storage.update_experiment(uid=experiment, **experiment) - storage.initialize_algorithm_lock(uid, algorithm) + storage.write_algorithm_lock(uid, algorithm) for trial in storage.fetch_trials(uid=uid): # trial_config = trial.to_dict() diff --git a/src/orion/core/cli/frontend.py b/src/orion/core/cli/frontend.py index 1fedf5415..3e2322949 100644 --- a/src/orion/core/cli/frontend.py +++ b/src/orion/core/cli/frontend.py @@ -17,7 +17,6 @@ from gunicorn.app.base import BaseApplication logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) DESCRIPTION = "Starts Oríon Dashboard" diff --git a/src/orion/core/utils/__init__.py b/src/orion/core/utils/__init__.py index 7f6c33370..48e5179a5 100644 --- a/src/orion/core/utils/__init__.py +++ b/src/orion/core/utils/__init__.py @@ -14,6 +14,7 @@ from contextlib import contextmanager from glob import glob from importlib import import_module +from tempfile import NamedTemporaryFile import pkg_resources @@ -229,3 +230,13 @@ def sigterm_as_interrupt(): yield None signal.signal(signal.SIGTERM, previous) + + +def generate_temporary_file(basename="dump", suffix=".pkl"): + """Generate a temporary file where data could be saved. + + Create an empty file without collision. + Return name of generated file. + """ + with NamedTemporaryFile(prefix=f"{basename}_", suffix=suffix, delete=False) as tf: + return tf.name diff --git a/src/orion/serving/storage_resource.py b/src/orion/serving/storage_resource.py new file mode 100644 index 000000000..54d8bf5bc --- /dev/null +++ b/src/orion/serving/storage_resource.py @@ -0,0 +1,256 @@ +""" +Module responsible for storage import/export REST endpoints +=========================================================== + +Serves all the requests made to storage import/export REST endpoints. + +""" +import json +import logging +import multiprocessing +import os +import traceback +import uuid +from datetime import datetime +from queue import Empty + +import falcon +from falcon import Request, Response + +from orion.core.io.database import DatabaseError +from orion.core.utils import generate_temporary_file +from orion.storage.backup import dump_database, load_database + + +class Notifications: + """Stream handler to collect messages in a shared queue. + + Used instead of stdout in import progress to capture log messages. + """ + + def __init__(self): + """Initialize with a shared queue.""" + self.queue = multiprocessing.Queue() + + def write(self, buf: str): + """Write received data""" + for line in buf.rstrip().splitlines(): + self.queue.put(line) + + def flush(self): + """Placeholder to flush data""" + + +class ImportTask: + """Wrapper to represent an import task. Used to monitor task progress. + + There is two ways to monitor task: + + - either get messages collected in stream handler queue. + Stream handler collects all messages logged in task. + - either regularly check latest message in progress_message. + Progress message only describe the latest step running in task. + + Attributes + ---------- + task_id: str + Used to identify the task in web API + _notifications: Notifications + Stream handler with shared queue to capture task messages + _progress_message: + Latest progress message + _progress_value: + Latest progress (0 <= floating value <= 1) + _completed: + Shared status: 0 for running, -1 for failure, 1 for success + _lock: + Lock to use to prevent concurrent executions + when updating task state. + """ + + # String representation of task status + IMPORT_STATUS = {0: "active", -1: "error", 1: "finished"} + + def __init__(self): + self.task_id = str(uuid.uuid4()) + self._notifications = Notifications() + self._progress_message = multiprocessing.Array("c", 512) + self._progress_value = multiprocessing.Value("d", 0.0) + self._completed = multiprocessing.Value("i", 0) + self._lock = multiprocessing.Lock() + + def set_progress(self, message: str, progress: float): + with self._lock: + self._progress_message.value = message.encode() + self._progress_value.value = progress + + def is_completed(self): + """Return True if task is completed""" + return self._completed.value + + def set_completed(self, success=True, notification=None): + """Set task terminated status + + Parameters + ---------- + success: bool + True if task is successful, False otherwise + notification: str + Optional message to add in notifications + """ + with self._lock: + self._completed.value = 1 if success else -1 + if notification: + self._notifications.write(notification) + + def listen_logging(self): + """Set notifications as logging stream to collect logging messages.""" + # logging.basicConfig() won't do anything if there are already handlers + # for root logger, so we must clear previous handlers first + root_logger = logging.getLogger() + if root_logger.handlers: + for handler in root_logger.handlers: + handler.close() + root_logger.handlers.clear() + # Then set stream and keep previous log level + logging.basicConfig(stream=self._notifications, level=root_logger.level) + + def flush_state(self): + """Return a dictionary with current task state. + + NB: Collect all messages currently in stream handler queue, + so stream handler queue is emptied after this method is called. + """ + latest_messages = [] + with self._lock: + while True: + try: + latest_messages.append(self._notifications.queue.get_nowait()) + except Empty: + break + return { + "messages": latest_messages, + "progress_message": self._progress_message.value.decode(), + "progress_value": self._progress_value.value, + "status": ImportTask.IMPORT_STATUS[self._completed.value], + } + + +def _import_data(task: ImportTask, storage, load_host, resolve, name, version): + """Function to run import task. + + Set stream handler, launch load_database and set task status. + """ + try: + print("Import starting.", task.task_id) + task.listen_logging() + load_database( + storage, + load_host, + resolve, + name, + version, + progress_callback=task.set_progress, + ) + task.set_completed(success=True) + except Exception as exc: + traceback.print_tb(exc.__traceback__) + print("Import error.", exc) + # Add error message to shared messages + task.set_completed(success=False, notification=f"Error: {exc}") + finally: + # Remove imported files + os.unlink(load_host) + lock_file = f"{load_host}.lock" + if os.path.exists(lock_file): + os.unlink(lock_file) + print("Import terminated.") + + +class StorageResource: + """Handle requests for the dump/ REST endpoint""" + + def __init__(self, storage): + self.storage = storage + self.current_task: ImportTask = None + + def on_get_dump(self, req: Request, resp: Response): + """Handle the GET requests for dump/""" + name = req.get_param("name") + version = req.get_param_as_int("version") + dump_host = generate_temporary_file(basename="dump") + download_suffix = "" if name is None else f" {name}" + if download_suffix and version is not None: + download_suffix = f"{download_suffix}.{version}" + try: + dump_database( + self.storage, dump_host, name=name, version=version, overwrite=True + ) + resp.downloadable_as = f"dump{download_suffix if download_suffix else ''} ({datetime.now()}).pkl" + resp.content_type = "application/octet-stream" + with open(dump_host, "rb") as file: + resp.data = file.read() + except DatabaseError as exc: + raise falcon.HTTPNotFound(title=type(exc).__name__, description=str(exc)) + finally: + # Clean dumped files + for path in (dump_host, f"{dump_host}.lock"): + if os.path.exists(path): + os.unlink(path) + + def on_post_load(self, req: Request, resp: Response): + """Handle the POST requests for load/""" + if self.current_task is not None and not self.current_task.is_completed(): + raise falcon.HTTPForbidden(description="An import is already running") + load_host = None + resolve = None + name = None + version = None + for part in req.get_media(): + if part.name == "file": + if part.filename: + load_host = generate_temporary_file(basename="load") + with open(load_host, "wb") as dst: + part.stream.pipe(dst) + elif part.name == "resolve": + resolve = part.get_text().strip() + if resolve not in ("ignore", "overwrite", "bump"): + raise falcon.HTTPInvalidParam( + "Invalid value for resolve", "resolve" + ) + elif part.name == "name": + name = part.get_text().strip() or None + elif part.name == "version": + version = part.get_text().strip() + if version: + try: + version = int(version) + except ValueError: + raise falcon.HTTPInvalidParam( + "Version must be an integer", "version" + ) + if version < 0: + raise falcon.HTTPInvalidParam( + "Version must be a positiver integer", "version" + ) + else: + version = None + else: + raise falcon.HTTPInvalidParam("Unknown parameter", part.name) + if load_host is None: + raise falcon.HTTPInvalidParam("Missing file to import", "file") + if resolve is None: + raise falcon.HTTPInvalidParam("Missing resolve policy", "resolve") + self.current_task = ImportTask() + p = multiprocessing.Process( + target=_import_data, + args=(self.current_task, self.storage, load_host, resolve, name, version), + ) + p.start() + resp.body = json.dumps({"task": self.current_task.task_id}) + + def on_get_import_status(self, req: Request, resp: Response, name: str): + """Handle the GET requests for import-status/""" + if self.current_task is None or self.current_task.task_id != name: + raise falcon.HTTPInvalidParam("Unknown import task", "name") + resp.body = json.dumps(self.current_task.flush_state()) diff --git a/src/orion/serving/webapi.py b/src/orion/serving/webapi.py index bdd1a8964..1ae6c7cfd 100644 --- a/src/orion/serving/webapi.py +++ b/src/orion/serving/webapi.py @@ -16,10 +16,10 @@ from orion.serving.experiments_resource import ExperimentsResource from orion.serving.plots_resources import PlotsResource from orion.serving.runtime import RuntimeResource +from orion.serving.storage_resource import StorageResource from orion.serving.trials_resource import TrialsResource logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) class MyCORSMiddleware(CORSMiddleware): @@ -113,6 +113,7 @@ def __init__(self, storage, config=None): benchmarks_resource = BenchmarksResource(self.storage) trials_resource = TrialsResource(self.storage) plots_resource = PlotsResource(self.storage) + storage_resource = StorageResource(self.storage) # Build routes self.add_route("/", root_resource) @@ -152,6 +153,11 @@ def __init__(self, storage, config=None): self.add_route( "/plots/regret/{experiment_name}", plots_resource, suffix="regret" ) + self.add_route("/dump", storage_resource, suffix="dump") + self.add_route("/load", storage_resource, suffix="load") + self.add_route( + "/import-status/{name}", storage_resource, suffix="import_status" + ) def start(self): """A hook to when a Gunicorn worker calls run().""" diff --git a/src/orion/storage/backup.py b/src/orion/storage/backup.py new file mode 100644 index 000000000..a0f52beb0 --- /dev/null +++ b/src/orion/storage/backup.py @@ -0,0 +1,747 @@ +# pylint: disable=,protected-access,too-many-locals,too-many-branches,too-many-statements +""" +Module responsible for storage export/import +============================================ + +Provide functions to export and import database content. +""" +import logging +import os +import shutil +from typing import Any, Dict, List + +from orion.core.io.database import DatabaseError +from orion.core.io.database.pickleddb import PickledDB +from orion.core.utils import generate_temporary_file +from orion.core.utils.tree import TreeNode +from orion.storage.base import BaseStorageProtocol, setup_storage + +logger = logging.getLogger(__name__) + +COL_EXPERIMENTS = "experiments" +COL_ALGOS = "algo" +COL_BENCHMARKS = "benchmarks" +COL_TRIALS = "trials" + +COLLECTIONS = {"experiments", "algo", "benchmarks", "trials"} +EXPERIMENT_RELATED_COLLECTIONS = {"algo", "trials"} + +STEP_COLLECT_EXPERIMENTS = 0 +STEP_CHECK_BENCHMARKS = 1 +STEP_CHECK_DST_EXPERIMENTS = 2 +STEP_CHECK_SRC_EXPERIMENTS = 3 +STEP_DELETE_OLD_DATA = 4 +STEP_INSERT_NEW_DATA = 5 +STEP_NAMES = [ + "Collect source experiments to load", + "Check benchmarks", + "Check destination experiments", + "Check source experiments", + "Delete data to replace in destination", + "Insert new data in destination", +] + + +def dump_database(storage, dump_host, name=None, version=None, overwrite=False): + """Dump a database + + Parameters + ---------- + storage: BaseStorageProtocol + storage of database to dump + dump_host: + file path to dump into (dumped file will be a pickled file) + name: + (optional) name of experiment to dump (by default, full database is dumped) + version: + (optional) version of experiment to dump. + By default, use the latest version of provided `name`. + overwrite: + (optional) define how to manage destination file if already exists. + If false (default), raise an exception. + If true, delete existing file and create a new one with dumped data. + """ + dump_host = os.path.abspath(dump_host) + + # For pickled databases, make sure src is not dst + if hasattr(storage, "_db"): + orig_db = storage._db + if isinstance(orig_db, PickledDB) and dump_host == os.path.abspath( + orig_db.host + ): + raise DatabaseError("Cannot dump pickleddb to itself.") + + # Temporary output file to be used for dumping. Default is dump_host + tmp_dump_host = dump_host + + if os.path.exists(dump_host): + if overwrite: + # Work on a temporary file, not directly into dump_host. + # dump_host will then be replaced with temporary file + # if no error occurred. + tmp_dump_host = generate_temporary_file() + assert os.path.exists(tmp_dump_host) + assert os.stat(tmp_dump_host).st_size == 0 + logger.info(f"Overwriting previous output at {dump_host}") + else: + raise DatabaseError( + f"Export output already exists (specify `--force` to overwrite) at {dump_host}" + ) + + try: + dst_storage = setup_storage( + {"database": {"host": tmp_dump_host, "type": "pickleddb"}} + ) + logger.info(f"Dump to {dump_host}") + _dump(storage, dst_storage, name, version) + except Exception as exc: + # An exception occurred when dumping. + # If existed, original dump_host has not been modified. + for path in (tmp_dump_host, f"{tmp_dump_host}.lock"): + if os.path.isfile(path): + os.unlink(path) + raise exc + else: + # No error occurred + # Move tmp_dump_host to dump_host if necessary + if tmp_dump_host != dump_host: + # NB: If an OS error occurs here, we can't do anything. + os.unlink(dump_host) + shutil.move(tmp_dump_host, dump_host) + # Cleanup + tmp_lock_host = f"{tmp_dump_host}.lock" + if os.path.isfile(tmp_lock_host): + os.unlink(tmp_lock_host) + + +def load_database( + storage, load_host, resolve=None, name=None, version=None, progress_callback=None +): + """Import data into a database + + Parameters + ---------- + storage: BaseStorageProtocol + storage of destination database to load into + load_host: + file path containing data to import + (should be a pickled file representing a PickledDB) + resolve: + policy to resolve import conflict. Either None, 'ignore', 'overwrite' or 'bump'. + - None will raise an exception on any conflict detected + - 'ignore' will ignore imported data on conflict + - 'overwrite' will overwrite old data in destination database on conflict + - 'bump' will bump imported data version before adding it, + if data with same ID is found in destination + name: + (optional) name of experiment to import (by default, whole file is imported) + version: + (optional) version of experiment to import. + By default, use the latest version of provided `name`. + progress_callback: + Optional callback to report progression. Receives 2 parameters: + - step description (string) + - overall progress (0 <= floating value <= 1) + """ + load_host = os.path.abspath(load_host) + + # For pickled databases, make sure src is not dst + if hasattr(storage, "_db"): + dst_db = storage._db + if isinstance(dst_db, PickledDB) and load_host == os.path.abspath(dst_db.host): + raise DatabaseError("Cannot load pickleddb to itself.") + + src_storage: BaseStorageProtocol = setup_storage( + {"database": {"host": load_host, "type": "pickleddb"}} + ) + logger.info(f"Loaded src {load_host}") + + import_benchmarks = False + _describe_import_progress(STEP_COLLECT_EXPERIMENTS, 0, 1, progress_callback) + if name is None: + import_benchmarks = True + # Retrieve all src experiments for export + experiments = src_storage.fetch_experiments({}) + else: + # Find experiments based on given name and version + query = {"name": name} + if version is not None: + query["version"] = version + experiments = src_storage.fetch_experiments(query) + if not experiments: + raise DatabaseError( + f"No experiment found with query {query}. Nothing to import." + ) + if len(experiments) > 1: + experiments = sorted(experiments, key=lambda d: d["version"])[-1:] + logger.info( + f"Found experiment {experiments[0]['name']}.{experiments[0]['version']}" + ) + _describe_import_progress(STEP_COLLECT_EXPERIMENTS, 1, 1, progress_callback) + + preparation = _prepare_import( + src_storage, + storage, + experiments, + resolve, + import_benchmarks, + progress_callback=progress_callback, + ) + _execute_import(storage, *preparation, progress_callback=progress_callback) + + +def _dump(src_storage, dst_storage, name=None, version=None): + """Dump data from source storage to destination storage. + + Parameters + ---------- + src_storage: BaseStorageProtocol + input storage + dst_storage: BaseStorageProtocol + output storage + name: + (optional) if provided, dump only data related to experiment with this name + version: + (optional) version of experiment to dump + """ + # Get collection names in a set + if name is None: + # Nothing to filter, dump everything + # Dump benchmarks + logger.info("Dumping benchmarks") + for benchmark in src_storage.fetch_benchmark({}): + dst_storage.create_benchmark(benchmark) + # Dump experiments + logger.info("Dumping experiments, algos and trials") + # Dump experiments ordered from parents to children, + # so that we can get new parent IDs from dst + # before writing children. + graph = get_experiment_parent_links(src_storage.fetch_experiments({})) + sorted_experiments = graph.get_sorted_data() + src_to_dst_id = {} + for i, src_exp in enumerate(sorted_experiments): + logger.info( + f"Dumping experiment {i + 1}: {src_exp['name']}.{src_exp['version']}" + ) + _dump_experiment(src_storage, dst_storage, src_exp, src_to_dst_id) + else: + # Get experiments with given name + query = {"name": name} + if version is not None: + query["version"] = version + experiments = src_storage.fetch_experiments(query) + if not experiments: + raise DatabaseError( + f"No experiment found with query {query}. Nothing to dump." + ) + exp_data = sorted(experiments, key=lambda d: d["version"])[-1] + logger.info(f"Found experiment {exp_data['name']}.{exp_data['version']}") + # As we dump only 1 experiment, remove parent links if exist + if exp_data["refers"]: + if exp_data["refers"]["root_id"] is not None: + logger.info("Removing reference to root experiment before dumping") + exp_data["refers"]["root_id"] = None + if exp_data["refers"]["parent_id"] is not None: + logger.info("Removing reference to parent experiment before dumping") + exp_data["refers"]["parent_id"] = None + # Dump selected experiments and related data + logger.info(f"Dumping experiment {name}") + _dump_experiment(src_storage, dst_storage, exp_data, {}) + + +def _dump_experiment(src_storage, dst_storage, src_exp, src_to_dst_id: dict): + """Dump a single experiment and related data from src to dst storage. + + Parameters + ---------- + src_storage: + src storage + dst_storage: + dst storage + src_exp: dict + src experiment + src_to_dst_id: dict + Dictionary mapping experiment ID from src to dst. + Used to set dst parent ID when writing child experiment in dst storage. + Updated with new dst ID corresponding to `src_exp`. + """ + _write_experiment( + src_exp, + algo_lock_info=src_storage.get_algorithm_lock_info(uid=src_exp["_id"]), + trials=src_storage.fetch_trials(uid=src_exp["_id"]), + dst_storage=dst_storage, + src_to_dst_id=src_to_dst_id, + verbose=True, + ) + + +def _prepare_import( + src_storage, + dst_storage, + experiments, + resolve=None, + import_benchmarks=True, + progress_callback=None, +): + """Prepare importation. + + Compute all changes to apply to make import and return changes as dictionaries. + + Parameters + ---------- + src_storage: BaseStorageProtocol + storage to import from + dst_storage: BaseStorageProtocol + storage to import into + experiments: + experiments to import from src_storage into dst_storage + resolve: + resolve policy + import_benchmarks: + if True, benchmarks will be also imported from src_database + progress_callback: + See :func:`load_database` + + Returns + ------- + A couple (queries to delete, data to add) representing + changes to apply to dst_storage to make import + """ + assert resolve is None or resolve in ("ignore", "overwrite", "bump") + + queries_to_delete = {} + data_to_add = {} + + if import_benchmarks: + src_benchmarks = src_storage.fetch_benchmark({}) + for i, src_benchmark in enumerate(src_benchmarks): + _describe_import_progress( + STEP_CHECK_BENCHMARKS, i, len(src_benchmarks), progress_callback + ) + dst_benchmarks = dst_storage.fetch_benchmark( + {"name": src_benchmark["name"]} + ) + if dst_benchmarks: + (dst_benchmark,) = dst_benchmarks + if resolve == "ignore": + logger.info( + f'Ignored benchmark already in dst: {src_benchmark["name"]}' + ) + continue + if resolve == "overwrite": + logger.info( + f'Overwrite benchmark in dst, name: {src_benchmark["name"]}' + ) + queries_to_delete.setdefault(COL_BENCHMARKS, []).append( + {"_id": dst_benchmark["_id"]} + ) + elif resolve == "bump": + raise DatabaseError( + "Can't bump benchmark version, " + "as benchmarks do not currently support versioning." + ) + else: # resolve is None or unknown + raise DatabaseError( + f"Conflict detected without strategy to resolve ({resolve}) " + f"for benchmark {src_benchmark['name']}" + ) + # Delete benchmark database ID so that a new one will be generated on insertion + del src_benchmark["_id"] + data_to_add.setdefault(COL_BENCHMARKS, []).append(src_benchmark) + _describe_import_progress( + STEP_CHECK_BENCHMARKS, + len(src_benchmarks), + len(src_benchmarks), + progress_callback, + ) + + _describe_import_progress(STEP_CHECK_DST_EXPERIMENTS, 0, 1, progress_callback) + all_dst_experiments = dst_storage.fetch_experiments({}) + # Dictionary mapping dst exp name to exp version to list of exps with same name and version + dst_exp_map = {} + last_versions = {} + for dst_exp in all_dst_experiments: + name = dst_exp["name"] + version = dst_exp["version"] + last_versions[name] = max(last_versions.get(name, 0), version) + dst_exp_map.setdefault(name, {}).setdefault(version, []).append(dst_exp) + _describe_import_progress(STEP_CHECK_DST_EXPERIMENTS, 1, 1, progress_callback) + + if len(experiments) == 1: + # As we load only 1 experiment, remove parent links if exist + (exp_data,) = experiments + if exp_data["refers"]: + if exp_data["refers"]["root_id"] is not None: + logger.info("Removing reference to root experiment before loading") + exp_data["refers"]["root_id"] = None + if exp_data["refers"]["parent_id"] is not None: + logger.info("Removing reference to parent experiment before loading") + exp_data["refers"]["parent_id"] = None + else: + # Load experiments ordered from parents to children, + # so that we can get new parent IDs from dst + # before writing children. + graph = get_experiment_parent_links(experiments) + experiments = graph.get_sorted_data() + + for i, experiment in enumerate(experiments): + _describe_import_progress( + STEP_CHECK_SRC_EXPERIMENTS, i, len(experiments), progress_callback + ) + dst_experiments = dst_exp_map.get(experiment["name"], {}).get( + experiment["version"], [] + ) + if dst_experiments: + (dst_experiment,) = dst_experiments + if resolve == "ignore": + logger.info( + f"Ignored experiment already in dst: " + f'{experiment["name"]}.{experiment["version"]}' + ) + continue + if resolve == "overwrite": + # We must remove experiment data in dst + logger.info( + f"Overwrite experiment in dst: " + f'{dst_experiment["name"]}.{dst_experiment["version"]}' + ) + for collection in EXPERIMENT_RELATED_COLLECTIONS: + queries_to_delete.setdefault(collection, []).append( + {"experiment": dst_experiment["_id"]} + ) + queries_to_delete.setdefault(COL_EXPERIMENTS, []).append( + {"_id": dst_experiment["_id"]} + ) + elif resolve == "bump": + old_version = experiment["version"] + new_version = last_versions[experiment["name"]] + 1 + last_versions[experiment["name"]] = new_version + experiment["version"] = new_version + logger.info( + f'Bumped version of src experiment: {experiment["name"]}, ' + f"from {old_version} to {new_version}" + ) + else: # resolve is None or unknown + raise DatabaseError( + f"Conflict detected without strategy to resolve ({resolve}) " + f"for experiment {experiment['name']}.{experiment['version']}" + ) + else: + logger.info( + f'Import experiment {experiment["name"]}.{experiment["version"]}' + ) + + # Get data related to experiment to import. + algo = src_storage.get_algorithm_lock_info(uid=experiment["_id"]) + trials = src_storage.fetch_trials(uid=experiment["_id"]) + # We will use experiment key to link experiment to related data. + exp_key = _get_exp_key(experiment) + # Set data to add + data_to_add.setdefault(COL_EXPERIMENTS, []).append(experiment) + data_to_add.setdefault(COL_ALGOS, {})[exp_key] = algo + data_to_add.setdefault(COL_TRIALS, {})[exp_key] = trials + _describe_import_progress( + STEP_CHECK_SRC_EXPERIMENTS, + len(experiments), + len(experiments), + progress_callback, + ) + + return queries_to_delete, data_to_add + + +def _execute_import( + dst_storage, queries_to_delete, data_to_add, progress_callback=None +): + """Execute import + + Parameters + ---------- + dst_storage: BaseStorageProtocol + destination storage where to apply changes + queries_to_delete: dict + dictionary mapping a collection name to a list of queries to use + to find and delete data + data_to_add: dict + dictionary mapping a collection name to a list of data to add + progress_callback: + See :func:`load_database` + """ + + # Delete data + + total_queries = sum(len(queries) for queries in queries_to_delete.values()) + for collection_name in COLLECTIONS: + queries_to_delete.setdefault(collection_name, ()) + i_query = 0 + for query_delete_benchmark in queries_to_delete[COL_BENCHMARKS]: + logger.info( + f"Deleting from {len(queries_to_delete[COL_BENCHMARKS])} queries into {COL_BENCHMARKS}" + ) + dst_storage.delete_benchmark(query_delete_benchmark) + _describe_import_progress( + STEP_DELETE_OLD_DATA, i_query, total_queries, progress_callback + ) + i_query += 1 + for query_delete_experiment in queries_to_delete[COL_EXPERIMENTS]: + logger.info( + f"Deleting from {len(queries_to_delete[COL_EXPERIMENTS])} queries " + f"into {COL_EXPERIMENTS}" + ) + dst_storage.delete_experiment(uid=query_delete_experiment["_id"]) + _describe_import_progress( + STEP_DELETE_OLD_DATA, i_query, total_queries, progress_callback + ) + i_query += 1 + for query_delete_trials in queries_to_delete[COL_TRIALS]: + logger.info( + f"Deleting from {len(queries_to_delete[COL_TRIALS])} queries into {COL_TRIALS}" + ) + dst_storage.delete_trials(uid=query_delete_trials["experiment"]) + _describe_import_progress( + STEP_DELETE_OLD_DATA, i_query, total_queries, progress_callback + ) + i_query += 1 + for query_delete_algo in queries_to_delete[COL_ALGOS]: + logger.info( + f"Deleting from {len(queries_to_delete[COL_ALGOS])} queries into {COL_ALGOS}" + ) + dst_storage.delete_algorithm_lock(uid=query_delete_algo["experiment"]) + _describe_import_progress( + STEP_DELETE_OLD_DATA, i_query, total_queries, progress_callback + ) + i_query += 1 + + _describe_import_progress( + STEP_DELETE_OLD_DATA, total_queries, total_queries, progress_callback + ) + + # Add data + + nb_data_to_add = len(data_to_add.get(COL_BENCHMARKS, ())) + len( + data_to_add.get(COL_EXPERIMENTS, ()) + ) + i_data = 0 + + for new_benchmark in data_to_add.get(COL_BENCHMARKS, ()): + dst_storage.create_benchmark(new_benchmark) + _describe_import_progress( + STEP_INSERT_NEW_DATA, i_data, nb_data_to_add, progress_callback + ) + i_data += 1 + + src_to_dst_id = {} + for src_exp in data_to_add.get(COL_EXPERIMENTS, ()): + exp_key = _get_exp_key(src_exp) + new_algo = data_to_add[COL_ALGOS][exp_key] + new_trials = data_to_add[COL_TRIALS][exp_key] + _write_experiment( + src_exp, + algo_lock_info=new_algo, + trials=new_trials, + dst_storage=dst_storage, + src_to_dst_id=src_to_dst_id, + verbose=False, + ) + _describe_import_progress( + STEP_INSERT_NEW_DATA, i_data, nb_data_to_add, progress_callback + ) + i_data += 1 + + _describe_import_progress( + STEP_INSERT_NEW_DATA, nb_data_to_add, nb_data_to_add, progress_callback + ) + + +def _write_experiment( + src_exp, algo_lock_info, trials, dst_storage, src_to_dst_id: dict, verbose=False +): + # Remove src experiment database ID + src_id = src_exp.pop("_id") + assert src_id not in src_to_dst_id + + # Update experiment parent ID + old_parent_id = _get_exp_parent_id(src_exp) + if old_parent_id is not None: + _set_exp_parent_id(src_exp, src_to_dst_id[old_parent_id]) + + # Update experiment root ID if different from experiment ID + old_root_id = _get_exp_root_id(src_exp) + if old_root_id is not None: + if old_root_id != src_id: + _set_exp_root_id(src_exp, src_to_dst_id[old_root_id]) + + # Dump experiment and algo + dst_storage.create_experiment( + src_exp, + algo_locked=algo_lock_info.locked, + algo_state=algo_lock_info.state, + algo_heartbeat=algo_lock_info.heartbeat, + ) + if verbose: + logger.info("\tCreated exp") + # Link experiment src ID to dst ID + (dst_exp,) = dst_storage.fetch_experiments( + {"name": src_exp["name"], "version": src_exp["version"]} + ) + src_to_dst_id[src_id] = dst_exp["_id"] + # Update root ID if equals to experiment ID + if old_root_id is not None and old_root_id == src_id: + _set_exp_root_id(src_exp, src_to_dst_id[src_id]) + dst_storage.update_experiment( + uid=src_to_dst_id[src_id], refers=src_exp["refers"] + ) + # Dump trials + trial_old_to_new_id = {} + for trial in trials: + old_id = trial.id + # Set trial parent to new dst exp ID + trial.experiment = src_to_dst_id[trial.experiment] + # Remove src trial database ID, so that new ID will be generated at insertion. + # Trial parents are identified using trial identifier (trial.id) + # which is not related to trial database ID (trial.id_override). + # So, we can safely remove trial database ID. + trial.id_override = None + if trial.parent is not None: + trial.parent = trial_old_to_new_id[trial.parent] + dst_trial = dst_storage.register_trial(trial) + trial_old_to_new_id[old_id] = dst_trial.id + if verbose: + logger.info("\tDumped trials") + + +def _describe_import_progress(step, value, total, callback=None): + print("STEP", step + 1, STEP_NAMES[step], value, total) + if callback: + if total == 0: + value = total = 1 + callback(STEP_NAMES[step], (step + (value / total)) / len(STEP_NAMES)) + + +class _Graph: + """Helper class to build experiments or trials graph.""" + + def __init__(self, key_to_data: dict): + """Initialize + + Parameters + ---------- + key_to_data: + Dictionary mapping key (used as node) to related object. + """ + self.key_to_data: dict = key_to_data + self.key_to_node: Dict[Any, TreeNode] = {} + self.root = TreeNode(None) + + def add_link(self, parent, child): + """Link parent node to child node.""" + if parent not in self.key_to_node: + self.key_to_node[parent] = TreeNode(parent, self.root) + if child in self.key_to_node: + child_node = self.key_to_node[child] + # A node should have at most 1 parent. + assert child_node.parent is self.root + child_node.set_parent(self.key_to_node[parent]) + else: + self.key_to_node[child] = TreeNode(child, self.key_to_node[parent]) + + def _get_sorted_nodes(self) -> List[TreeNode]: + """Return list of sorted nodes from parents to children.""" + # Exclude root node. + return list(self.root)[1:] + + def get_sorted_data(self) -> list: + """Return list of sorted data from parents to children.""" + return [self.key_to_data[node.item] for node in self._get_sorted_nodes()] + + def get_sorted_links(self): + """Return sorted edges (node, child)""" + for node in self._get_sorted_nodes(): + if node.children: + for child_node in node.children: + yield node.item, child_node.item + else: + yield node.item, None + + +def get_experiment_parent_links(experiments: list) -> _Graph: + """Generate experiments graphs based on experiment parents. + + Does not currently check experiment roots. + """ + graph = _Graph({_get_exp_key(exp): exp for exp in experiments}) + exp_id_to_key = {exp["_id"]: _get_exp_key(exp) for exp in experiments} + for exp in experiments: + parent_id = _get_exp_parent_id(exp) + if parent_id is not None: + parent_key = exp_id_to_key[parent_id] + child_key = _get_exp_key(exp) + graph.add_link(parent_key, child_key) + return graph + + +def get_experiment_root_links(experiments: list) -> _Graph: + """Generate experiments graphs based on experiment roots.""" + special_root_key = ("__root__",) + graph = _Graph( + {**{_get_exp_key(exp): exp for exp in experiments}, **{special_root_key: None}} + ) + exp_id_to_key = {exp["_id"]: _get_exp_key(exp) for exp in experiments} + for exp in experiments: + root_id = _get_exp_root_id(exp) + if root_id is not None: + if root_id == exp["_id"]: + # If root is exp, use a special root key + root_key = special_root_key + else: + root_key = exp_id_to_key[root_id] + child_key = _get_exp_key(exp) + graph.add_link(root_key, child_key) + return graph + + +def get_trial_parent_links(trials: list) -> _Graph: + """Generate trials graph based on trial parents. Not yet used.""" + trial_map = {_get_trial_key(trial): trial for trial in trials} + graph = _Graph(trial_map) + for trial in trials: + parent = _get_trial_parent(trial) + if parent is not None: + assert parent in trial_map + graph.add_link(parent, _get_trial_key(trial)) + return graph + + +def _get_trial_key(trial): + """Return trial key, as trial ID""" + return trial["id"] if isinstance(trial, dict) else trial.id + + +def _get_trial_parent(trial): + """Return trial parent""" + return trial["parent"] if isinstance(trial, dict) else trial.parent + + +def _get_exp_key(exp: dict) -> tuple: + """Return experiment key as tuple (name, version)""" + return exp["name"], exp["version"] + + +def _get_exp_parent_id(exp: dict): + """Get experiment parent ID or None if unavailable""" + return exp.get("refers", {}).get("parent_id", None) + + +def _set_exp_parent_id(exp: dict, parent_id): + """Set experiment parent ID""" + exp.setdefault("refers", {})["parent_id"] = parent_id + + +def _get_exp_root_id(exp: dict): + """Get experiment root ID or None if unavailable""" + return exp.get("refers", {}).get("root_id", None) + + +def _set_exp_root_id(exp: dict, parent_id): + """Set experiment root ID""" + exp.setdefault("refers", {})["root_id"] = parent_id diff --git a/src/orion/storage/base.py b/src/orion/storage/base.py index c7a5247ac..dce8a902e 100644 --- a/src/orion/storage/base.py +++ b/src/orion/storage/base.py @@ -150,13 +150,18 @@ class LockedAlgorithmState: Configuration of the locked algorithm. locked: bool Whether the algorithm is locked or not. Default: True + heartbeat: datetime + Current heartbeat of algorithm """ - def __init__(self, state: dict, configuration: dict, locked: bool = True): + def __init__( + self, state: dict, configuration: dict, locked: bool = True, heartbeat=None + ): self._original_state = state self.configuration = configuration self._state = state self.locked = locked + self.heartbeat = heartbeat @property def state(self) -> dict: @@ -186,8 +191,30 @@ def fetch_benchmark(self, query: dict, selection: dict | None = None): """Fetch all benchmarks that match the query""" raise NotImplementedError() - def create_experiment(self, config: ExperimentConfig): - """Insert a new experiment inside the database""" + def delete_benchmark(self, query: dict): + """Delete benchmarks that match given query""" + raise NotImplementedError() + + def create_experiment( + self, + config: ExperimentConfig, + algo_locked: int = 0, + algo_state: dict | None = None, + algo_heartbeat: datetime | None = None, + ): + """Insert a new experiment inside the database + + Parameters + ---------- + config: + experiment config + algo_locked: int + Whether algo is initially locked (1) o not (0, default) + algo_state: dict, optional + Initial algo state + algo_heartbeat: datetime, optional + Initial algo heartbeat. Default to datetime.utcnow(). + """ raise NotImplementedError() def delete_experiment( @@ -524,10 +551,15 @@ def update_heartbeat(self, trial: Trial): """Update trial's heartbeat""" raise NotImplementedError() - def initialize_algorithm_lock( - self, experiment_id: int | str, algorithm_config: dict + def write_algorithm_lock( + self, + experiment_id: int | str, + algorithm_config: dict, + locked: int = 0, + state: dict | None = None, + heartbeat: datetime | None = None, ): - """Initialize algorithm lock for given experiment + """Write algorithm lock for given experiment Parameters ---------- @@ -535,6 +567,12 @@ def initialize_algorithm_lock( ID of the experiment in storage. algorithm_config: dict Configuration of the algorithm. + locked: int + Whether algorithm is locked (1) or not (0, default). + state: dict, optional + Optional algorithm state. + heartbeat: datetime, optional + Algorithm heartbeat. Default to datetime.utcnow(). """ raise NotImplementedError() diff --git a/src/orion/storage/legacy.py b/src/orion/storage/legacy.py index 88945465a..ef619558b 100644 --- a/src/orion/storage/legacy.py +++ b/src/orion/storage/legacy.py @@ -110,11 +110,21 @@ def fetch_benchmark(self, query, selection=None): """Fetch all benchmarks that match the query""" return self._db.read("benchmarks", query, selection) - def create_experiment(self, config): + def delete_benchmark(self, query: dict): + """See :func:`orion.storage.base.BaseStorageProtocol.delete_benchmark`""" + return self._db.remove("benchmarks", query) + + def create_experiment( + self, config, algo_locked=0, algo_state=None, algo_heartbeat=None + ): """See :func:`orion.storage.base.BaseStorageProtocol.create_experiment`""" exp_rval = self._db.write("experiments", data=config, query=None) - self.initialize_algorithm_lock( - experiment_id=config["_id"], algorithm_config=config.get("algorithm", {}) + self.write_algorithm_lock( + experiment_id=config["_id"], + algorithm_config=config.get("algorithm", {}), + locked=algo_locked, + state=algo_state, + heartbeat=algo_heartbeat, ) return exp_rval @@ -350,16 +360,20 @@ def fetch_trials_by_status(self, experiment, status): query = dict(experiment=experiment._id, status=status) return self._fetch_trials(query) - def initialize_algorithm_lock(self, experiment_id, algorithm_config): - """See :func:`orion.storage.base.BaseStorageProtocol.initialize_algorithm_lock`""" + def write_algorithm_lock( + self, experiment_id, algorithm_config, locked=0, state=None, heartbeat=None + ): + """See :func:`orion.storage.base.BaseStorageProtocol.write_algorithm_lock`""" return self._db.write( "algo", { "experiment": experiment_id, "configuration": algorithm_config, - "locked": 0, - "state": None, - "heartbeat": datetime.datetime.utcnow(), + "locked": locked, + "state": None if state is None else pickle.dumps(state), + "heartbeat": datetime.datetime.utcnow() + if heartbeat is None + else heartbeat, }, ) @@ -396,6 +410,7 @@ def get_algorithm_lock_info(self, experiment=None, uid=None): else None, configuration=algo_state_lock["configuration"], locked=algo_state_lock["locked"], + heartbeat=algo_state_lock["heartbeat"], ) def delete_algorithm_lock(self, experiment=None, uid=None): diff --git a/src/orion/storage/track.py b/src/orion/storage/track.py index e3a698829..221f9a110 100644 --- a/src/orion/storage/track.py +++ b/src/orion/storage/track.py @@ -383,7 +383,9 @@ def _get_project(self, name): assert self.project, "Project should have been found" - def create_experiment(self, config): + def create_experiment( + self, config, algo_locked=0, algo_state=None, algo_heartbeat=None + ): """Insert a new experiment inside the database""" self._get_project(config["name"]) @@ -744,7 +746,7 @@ def update_heartbeat(self, trial): trial.storage, heartbeat=to_epoch(datetime.datetime.utcnow()) ) - def _initialize_algorithm_lock(self, experiment_id): + def _write_algorithm_lock(self, experiment_id): raise NotImplementedError return self._db.write( "algo", diff --git a/src/orion/testing/state.py b/src/orion/testing/state.py index c4dcd161d..21c1cf9b9 100644 --- a/src/orion/testing/state.py +++ b/src/orion/testing/state.py @@ -258,11 +258,11 @@ def _set_tables(self): if self._experiments: self.database.write("experiments", self._experiments) for experiment in self._experiments: - self.storage.initialize_algorithm_lock( + self.storage.write_algorithm_lock( experiment["_id"], experiment.get("algorithm") ) # For tests that need a deterministic experiment id. - self.storage.initialize_algorithm_lock( + self.storage.write_algorithm_lock( experiment["name"], experiment.get("algorithm") ) if self._trials: diff --git a/tests/functional/commands/conftest.py b/tests/functional/commands/conftest.py index 1fe757c3b..389fc5446 100644 --- a/tests/functional/commands/conftest.py +++ b/tests/functional/commands/conftest.py @@ -2,7 +2,10 @@ """Common fixtures and utils for unittests and functional tests.""" import copy import os +import pickle import zlib +from collections import Counter +from tempfile import NamedTemporaryFile import pytest import yaml @@ -11,6 +14,13 @@ import orion.core.io.experiment_builder as experiment_builder import orion.core.utils.backward as backward from orion.core.worker.trial import Trial +from orion.storage.backup import ( + _get_exp_key, + dump_database, + get_experiment_parent_links, + get_experiment_root_links, + get_trial_parent_links, +) @pytest.fixture() @@ -27,6 +37,11 @@ def exp_config(): return exp_config +@pytest.fixture +def empty_database(storage): + """Empty database""" + + @pytest.fixture def only_experiments_db(storage, exp_config): """Clean the database and insert only experiments.""" @@ -37,7 +52,11 @@ def only_experiments_db(storage, exp_config): def ensure_deterministic_id(name, storage, version=1, update=None): """Change the id of experiment to its name.""" experiment = storage.fetch_experiments({"name": name, "version": version})[0] + algo_lock_info = storage.get_algorithm_lock_info(uid=experiment["_id"]) + storage.delete_experiment(uid=experiment["_id"]) + storage.delete_algorithm_lock(uid=experiment["_id"]) + _id = zlib.adler32(str((name, version)).encode()) experiment["_id"] = _id @@ -47,7 +66,12 @@ def ensure_deterministic_id(name, storage, version=1, update=None): if update is not None: experiment.update(update) - storage.create_experiment(experiment) + storage.create_experiment( + experiment, + algo_locked=algo_lock_info.locked, + algo_state=algo_lock_info.state, + algo_heartbeat=algo_lock_info.heartbeat, + ) # Experiments combinations fixtures @@ -441,3 +465,417 @@ def three_experiments_same_name_with_trials( orionstate.database.write("trials", trial2.to_dict()) orionstate.database.write("trials", trial3.to_dict()) x_value += 1 + + +@pytest.fixture +def three_experiments_branch_same_name_trials( + three_experiments_branch_same_name, orionstate, storage +): + """Create three experiments, two of them with the same name but different versions and one + with a child, and add trials including children trials. + + Add algorithm state for one experiment. + + NB: It seems 2 experiments are children: + * test_single_exp_child.1 child of test_single_exp.2 + * test_single_exp.2 child of test_single_exp.1 + * test_single_exp.1 has no parent + """ + exp1 = experiment_builder.build(name="test_single_exp", version=1, storage=storage) + exp2 = experiment_builder.build(name="test_single_exp", version=2, storage=storage) + exp3 = experiment_builder.build( + name="test_single_exp_child", version=1, storage=storage + ) + + x = {"name": "/x", "type": "real"} + y = {"name": "/y", "type": "real"} + z = {"name": "/z", "type": "real"} + x_value = 0.0 + for status in Trial.allowed_stati: + x["value"] = x_value + 0.1 # To avoid duplicates + y["value"] = x_value * 10 + z["value"] = x_value * 100 + trial1 = Trial(experiment=exp1.id, params=[x], status=status) + trial2 = Trial(experiment=exp2.id, params=[x, y], status=status) + trial3 = Trial(experiment=exp3.id, params=[x, y, z], status=status) + # Add a child to a trial from exp1 + child = trial1.branch(params={"/x": 1}) + orionstate.database.write("trials", trial1.to_dict()) + orionstate.database.write("trials", trial2.to_dict()) + orionstate.database.write("trials", trial3.to_dict()) + orionstate.database.write("trials", child.to_dict()) + x_value += 1 + # exp1 should have 12 trials (including child trials) + # exp2 and exp3 should have 6 trials each + + # Add some algo data for exp1 + orionstate.database.read_and_write( + collection_name="algo", + query={"experiment": exp1.id}, + data={ + "state": pickle.dumps( + {"my_algo_state": "some_data", "my_other_state_data": "some_other_data"} + ) + }, + ) + + +@pytest.fixture +def three_experiments_branch_same_name_trials_benchmarks( + three_experiments_branch_same_name_trials, orionstate, storage +): + """Create three experiments, two of them with the same name but different versions and one + with a child, and add trials including children trials. + + Add algorithm state for one experiment. + Add benchmarks to database. + """ + # Add benchmarks, copied from db_dashboard_full.pkl + orionstate.database.write( + "benchmarks", + [ + { + "_id": 1, + "algorithms": ["gridsearch", "random"], + "name": "branin_baselines_webapi", + "targets": [ + { + "assess": {"AverageResult": {"repetitions": 10}}, + "task": {"Branin": {"max_trials": 50}}, + } + ], + }, + { + "_id": 2, + "algorithms": [ + "gridsearch", + "random", + {"tpe": {"n_initial_points": 20}}, + ], + "name": "all_algos_webapi", + "targets": [ + { + "assess": {"AverageResult": {"repetitions": 3}}, + "task": { + "Branin": {"max_trials": 10}, + "EggHolder": {"dim": 4, "max_trials": 20}, + "RosenBrock": {"dim": 3, "max_trials": 10}, + }, + } + ], + }, + { + "_id": 3, + "algorithms": ["random", {"tpe": {"n_initial_points": 20}}], + "name": "all_assessments_webapi_2", + "targets": [ + { + "assess": { + "AverageRank": {"repetitions": 3}, + "AverageResult": {"repetitions": 3}, + "ParallelAssessment": { + "executor": "joblib", + "n_workers": (1, 2, 4, 8), + "repetitions": 3, + }, + }, + "task": { + "Branin": {"max_trials": 10}, + "RosenBrock": {"dim": 3, "max_trials": 10}, + }, + } + ], + }, + ], + ) + + +@pytest.fixture +def pkl_experiments(three_experiments_branch_same_name_trials, orionstate, storage): + """Dump three_experiments_branch_same_name_trials to a PKL file""" + with NamedTemporaryFile(prefix="dumped_", suffix=".pkl", delete=False) as tf: + pkl_path = tf.name + dump_database(storage, pkl_path, overwrite=True) + return pkl_path + + +@pytest.fixture +def pkl_experiments_and_benchmarks( + three_experiments_branch_same_name_trials_benchmarks, orionstate, storage +): + """Dump three_experiments_branch_same_name_trials_benchmarks to a PKL file""" + with NamedTemporaryFile(prefix="dumped_", suffix=".pkl", delete=False) as tf: + pkl_path = tf.name + dump_database(storage, pkl_path, overwrite=True) + return pkl_path + + +@pytest.fixture +def other_empty_database(): + """Get an empty database and associated configuration file. + + To be used where we need both global config (e.g. for pkl_* fixtures) + and another config for an empty database. + """ + from orion.storage.base import setup_storage + + with NamedTemporaryFile(prefix="empty_", suffix=".pkl", delete=False) as tf: + pkl_path = tf.name + with NamedTemporaryFile(prefix="orion_config_", suffix=".yaml", delete=False) as tf: + config_content = f""" +storage: + database: + type: 'pickleddb' + host: '{pkl_path}' +""".lstrip() + tf.write(config_content.encode()) + cfg_path = tf.name + storage = setup_storage({"database": {"type": "pickleddb", "host": pkl_path}}) + return storage, cfg_path + + +class _Helpers: + """Helper functions for testing. + + Primarily provided for tests that use fixture (and derived) + `three_experiments_branch_same_name_trials_benchmarks` + """ + + @staticmethod + def check_db(db, nb_exps, nb_algos, nb_trials, nb_benchmarks, nb_child_exps=0): + """Check number of expected data in given database.""" + experiments = db.read("experiments") + assert len(experiments) == nb_exps + assert len(db.read("algo")) == nb_algos + assert len(db.read("trials")) == nb_trials + assert len(db.read("benchmarks")) == nb_benchmarks + + # Check we have expected number of child experiments. + exp_map = {exp["_id"]: exp for exp in experiments} + assert len(exp_map) == nb_exps + child_exps = [] + for exp in experiments: + parent = exp["refers"]["parent_id"] + if parent is not None: + assert parent in exp_map + child_exps.append(exp) + assert len(child_exps) == nb_child_exps + + @staticmethod + def check_exp( + db, + name, + version, + nb_trials, + nb_child_trials=0, + algo_state=None, + trial_links=None, + ): + """Check experiment. + - Check if we found experiment. + - Check if we found exactly 1 algorithm for this experiment. + - Check algo state if algo_state is provided + - Check if we found expected number of trials for this experiment. + - Check if we found expecter number of child trials into experiment trials. + - Check if we found expected trial links if provided. + """ + experiments = db.read("experiments", {"name": name, "version": version}) + assert len(experiments) == 1 + (experiment,) = experiments + algos = db.read("algo", {"experiment": experiment["_id"]}) + trials = db.read("trials", {"experiment": experiment["_id"]}) + assert len(algos) == 1 + assert len(trials) == nb_trials + + if algo_state is not None: + (algo,) = algos + assert algo_state == pickle.loads(algo["state"]) + + trial_map = {trial["id"]: trial for trial in trials} + assert len(trial_map) == nb_trials + child_trials = [] + for trial in trials: + parent = trial["parent"] + if parent is not None: + assert parent in trial_map + child_trials.append(trial) + assert len(child_trials) == nb_child_trials + + if trial_links is not None: + trial_graph = get_trial_parent_links(trials) + given_links = sorted(trial_graph.get_sorted_links()) + trial_links = sorted(trial_links) + assert len(trial_links) == len(given_links) + assert trial_links == given_links + + @staticmethod + def check_empty_db(loaded_db): + """Check that given database is empty""" + _Helpers.check_db( + loaded_db, nb_exps=0, nb_algos=0, nb_trials=0, nb_benchmarks=0 + ) + + @staticmethod + def assert_tested_db_structure(dumped_db, nb_orig_benchmarks=3, nb_duplicated=1): + """Check counts and experiments for database from specific fixture + `three_experiments_branch_same_name_trials[_benchmarks]`. + """ + _Helpers.check_db( + dumped_db, + nb_exps=3 * nb_duplicated, + nb_algos=3 * nb_duplicated, + nb_trials=24 * nb_duplicated, + nb_benchmarks=nb_orig_benchmarks * nb_duplicated, + nb_child_exps=2 * nb_duplicated, + ) + expected_parent_links = [] + expected_root_links = [] + for i in range(nb_duplicated): + _Helpers.check_exp( + dumped_db, "test_single_exp", 1 + 2 * i, nb_trials=12, nb_child_trials=6 + ) + _Helpers.check_exp(dumped_db, "test_single_exp", 2 + 2 * i, nb_trials=6) + _Helpers.check_exp(dumped_db, "test_single_exp_child", 1 + i, nb_trials=6) + expected_parent_links.extend( + [ + (("test_single_exp", 1 + 2 * i), ("test_single_exp", 2 + 2 * i)), + (("test_single_exp", 2 + 2 * i), ("test_single_exp_child", 1 + i)), + (("test_single_exp_child", 1 + i), None), + ] + ) + expected_root_links.extend( + [ + (("__root__",), ("test_single_exp", 1 + 2 * i)), + (("test_single_exp", 1 + 2 * i), ("test_single_exp", 2 + 2 * i)), + (("test_single_exp", 1 + 2 * i), ("test_single_exp_child", 1 + i)), + (("test_single_exp", 2 + 2 * i), None), + (("test_single_exp_child", 1 + i), None), + ] + ) + # Test experiments parent links. + experiments = dumped_db.read("experiments") + parent_graph = get_experiment_parent_links(experiments) + assert sorted(parent_graph.get_sorted_links()) == sorted(expected_parent_links) + # Test experiments root links. + root_graph = get_experiment_root_links(experiments) + root_links = sorted(root_graph.get_sorted_links()) + assert root_links == sorted(expected_root_links) + # Check that experiment with root key (__root__,) + # do have same root ID as experiment ID + key_to_exp = {_get_exp_key(exp): exp for exp in experiments} + nb_verified_identical_roots = 0 + for root_key, exp_key in root_links: + if root_key == ("__root__",): + exp = key_to_exp[exp_key] + assert exp["_id"] == exp["refers"]["root_id"] + nb_verified_identical_roots += 1 + assert nb_verified_identical_roots == 1 * nb_duplicated + + @staticmethod + def assert_tested_trial_status(dumped_db, nb_duplicated=1, counts=None): + """Check that trials have valid status.""" + if counts is None: + counts = { + "new": 9, + "reserved": 3, + "suspended": 3, + "completed": 3, + "interrupted": 3, + "broken": 3, + } + trial_status_count = Counter( + trial["status"] for trial in dumped_db.read("trials") + ) + assert len(trial_status_count) == len(counts) + for status, count in counts.items(): + assert trial_status_count[status] == count * nb_duplicated + + @staticmethod + def check_unique_import( + loaded_db, + name, + version, + nb_trials, + nb_child_trials=0, + nb_versions=1, + algo_state=None, + trial_links=None, + ): + """Check all versions of an experiment in given database""" + _Helpers.check_db( + loaded_db, + nb_exps=1 * nb_versions, + nb_algos=1 * nb_versions, + nb_trials=nb_trials * nb_versions, + nb_benchmarks=0, + ) + for i in range(nb_versions): + _Helpers.check_exp( + loaded_db, + name, + version + i, + nb_trials=nb_trials, + nb_child_trials=nb_child_trials, + algo_state=algo_state, + trial_links=trial_links, + ) + + @staticmethod + def check_unique_import_test_single_expV1(loaded_db, nb_versions=1): + """Check all versions of original experiment test_single_exp.1 in given database""" + _Helpers.check_unique_import( + loaded_db, + "test_single_exp", + 1, + nb_trials=12, + nb_child_trials=6, + nb_versions=nb_versions, + algo_state={ + "my_algo_state": "some_data", + "my_other_state_data": "some_other_data", + }, + trial_links=[ + ( + "9dbe618878008376d0ef47dba77b4175", + "7bc7d88c3f84329ae15667af1fc5eba0", + ), + ("7bc7d88c3f84329ae15667af1fc5eba0", None), + ( + "68e541fa91d9017a50fe534c2e70e34c", + "0caeb769dd8becc1c5064d3638128948", + ), + ("0caeb769dd8becc1c5064d3638128948", None), + ( + "ebd7c227cd7d1911c3b56daa9d02b2c2", + "0e6dce570d2bec70b0c7e26ba6aab617", + ), + ("0e6dce570d2bec70b0c7e26ba6aab617", None), + ( + "26da495bc13561b163e1e67654c913d4", + "7fbcacb8b1a6fd12d57f8b84de009c42", + ), + ("7fbcacb8b1a6fd12d57f8b84de009c42", None), + ( + "284af14179121d0e8df8e7fc856f5920", + "a40d030ff08ebbb7d97ecffaf93fe1f6", + ), + ("a40d030ff08ebbb7d97ecffaf93fe1f6", None), + ( + "938087683a168d4640ee3f72942d2d16", + "44dc1dd034b0dddca891847b8aac31fb", + ), + ("44dc1dd034b0dddca891847b8aac31fb", None), + ], + ) + + @staticmethod + def check_unique_import_test_single_expV2(loaded_db, nb_versions=1): + """Check all versions of original experiment test_single_exp.2 in given database""" + _Helpers.check_unique_import( + loaded_db, "test_single_exp", 2, nb_trials=6, nb_versions=nb_versions + ) + + +@pytest.fixture +def testing_helpers(): + return _Helpers diff --git a/tests/functional/commands/test_db_dump.py b/tests/functional/commands/test_db_dump.py new file mode 100644 index 000000000..a7a633a29 --- /dev/null +++ b/tests/functional/commands/test_db_dump.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python +"""Perform functional tests for db dump.""" + +import os + +import pytest + +import orion.core.cli +from orion.core.io.database.pickleddb import PickledDB +from orion.storage.base import setup_storage + + +def execute(command, assert_code=0): + """Execute orion command and return returncode""" + returncode = orion.core.cli.main(command.split(" ")) + assert returncode == assert_code + + +def clean_dump(dump_path): + """Delete dumped files.""" + for path in (dump_path, f"{dump_path}.lock"): + if os.path.isfile(path): + os.unlink(path) + + +def test_default_storage(three_experiments_branch_same_name): + """Check default storage from three_experiments_branch_same_name""" + storage = setup_storage() + experiments = storage._db.read("experiments") + algos = storage._db.read("algo") + assert len(experiments) == 3 + assert len(algos) == 3 + + +def test_dump_default( + three_experiments_branch_same_name_trials_benchmarks, capsys, testing_helpers +): + """Test dump with default arguments""" + assert not os.path.exists("dump.pkl") + try: + execute("db dump") + assert os.path.isfile("dump.pkl") + dumped_db = PickledDB("dump.pkl") + testing_helpers.assert_tested_db_structure(dumped_db) + finally: + clean_dump("dump.pkl") + + +def test_dump_overwrite( + three_experiments_branch_same_name_trials_benchmarks, + capsys, + testing_helpers, + tmp_path, +): + """Test dump with overwrite argument""" + dump_path = f"{tmp_path}/dump.pkl" + try: + execute(f"db dump -o {dump_path}") + assert os.path.isfile(dump_path) + dumped_db = PickledDB(dump_path) + testing_helpers.assert_tested_db_structure(dumped_db) + + # No overwrite by default. Should fail. + execute(f"db dump -o {dump_path}", assert_code=1) + captured = capsys.readouterr() + assert captured.err.strip().startswith( + "Error: Export output already exists (specify `--force` to overwrite) at" + ) + + # Overwrite. Should pass. + execute(f"db dump --force -o {dump_path}") + assert os.path.isfile(dump_path) + testing_helpers.assert_tested_db_structure(dumped_db) + finally: + clean_dump(dump_path) + + +def test_dump_to_specified_output( + three_experiments_branch_same_name_trials_benchmarks, + capsys, + testing_helpers, + tmp_path, +): + """Test dump to a specified output file""" + dump_path = f"{tmp_path}/test.pkl" + assert not os.path.exists(dump_path) + try: + execute(f"db dump -o {dump_path}") + assert os.path.isfile(dump_path) + dumped_db = PickledDB(dump_path) + testing_helpers.assert_tested_db_structure(dumped_db) + finally: + clean_dump(dump_path) + + +@pytest.mark.parametrize( + "output_already_exists,output_specified,overwrite,error_message", + [ + ( + True, + False, + False, + "Error: Export output already exists (specify `--force` to overwrite)", + ), + ( + True, + False, + True, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ( + True, + True, + False, + "Error: Export output already exists (specify `--force` to overwrite)", + ), + ( + True, + True, + True, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ( + False, + False, + False, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ( + False, + False, + True, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ( + False, + True, + False, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ( + False, + True, + True, + "Error: No experiment found with query {'name': 'unknown-experiment'}. " + "Nothing to dump.", + ), + ], +) +def test_dump_post_clean_on_error( + output_already_exists, output_specified, overwrite, error_message, capsys, tmp_path +): + """Test how dumped file is cleaned if dump fails.""" + + # Prepare a command that will fail (by looking for unknown experiment) + command = ["db", "dump", "-n", "unknown-experiment"] + if output_specified: + output_specified = f"{tmp_path}/test.pkl" + command += ["--output", output_specified] + if overwrite: + command += ["--force"] + + expected_output = output_specified or "dump.pkl" + + # Create expected file if necessary + output_modified_time = None + if output_already_exists: + assert not os.path.exists(expected_output), expected_output + with open(expected_output, "w"): + pass + assert os.path.isfile(expected_output) + output_modified_time = os.stat(expected_output).st_mtime + + # Execute command and expect it to fail + execute(" ".join(command), assert_code=1) + err = capsys.readouterr().err + + # Check output error + assert err.startswith(error_message) + + # Check dump post-clean + if output_already_exists: + # Output should exist after error. + assert os.path.isfile(expected_output) + # Output should have not been modified. + assert output_modified_time == os.stat(expected_output).st_mtime + # Clean files anyway + os.unlink(expected_output) + if os.path.isfile(f"{expected_output}.lock"): + os.unlink(f"{expected_output}.lock") + else: + # Output should not exist after error. + assert not os.path.exists(expected_output) + assert not os.path.exists(f"{expected_output}.lock") + + +def test_dump_unknown_experiment( + three_experiments_branch_same_name_trials_benchmarks, capsys, tmp_path +): + """Test dump unknown experiment""" + dump_path = f"{tmp_path}/dump.pkl" + try: + execute(f"db dump -n i-dont-exist -o {dump_path}", assert_code=1) + captured = capsys.readouterr() + assert captured.err.startswith( + "Error: No experiment found with query {'name': 'i-dont-exist'}. Nothing to dump." + ) + finally: + # Output file is created as soon as dst storage object is created in dump_database() + # So, we still need to delete it here + clean_dump(dump_path) + + +@pytest.mark.parametrize( + "given_version,expected_version,nb_trials,nb_child_trials,algo_state", + [ + (None, 2, 6, 0, None), + ( + 1, + 1, + 12, + 6, + { + "my_algo_state": "some_data", + "my_other_state_data": "some_other_data", + }, + ), + ], +) +def test_dump_experiment_test_single_exp( + three_experiments_branch_same_name_trials_benchmarks, + testing_helpers, + given_version, + expected_version, + nb_trials, + nb_child_trials, + algo_state, + tmp_path, +): + """Test dump experiment test_single_exp""" + dump_path = f"{tmp_path}/dump.pkl" + try: + command = f"db dump -n test_single_exp -o {dump_path}" + if given_version is not None: + command += f" -v {given_version}" + execute(command) + assert os.path.isfile(dump_path) + dumped_db = PickledDB(dump_path) + testing_helpers.check_unique_import( + dumped_db, + "test_single_exp", + expected_version, + nb_trials=nb_trials, + nb_child_trials=nb_child_trials, + algo_state=algo_state, + ) + finally: + clean_dump(dump_path) diff --git a/tests/functional/commands/test_db_load.py b/tests/functional/commands/test_db_load.py new file mode 100644 index 000000000..7a678c59f --- /dev/null +++ b/tests/functional/commands/test_db_load.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python +"""Perform functional tests for db load.""" + +import os + +import pytest + +import orion.core.cli +from orion.core.io.database.pickleddb import PickledDB +from orion.storage.base import setup_storage + + +def execute(command, assert_code=0): + """Execute orion command and return returncode""" + returncode = orion.core.cli.main(command.split(" ")) + assert returncode == assert_code + + +def test_empty_database(empty_database, testing_helpers): + """Test destination database is empty as expected""" + storage = setup_storage() + db = storage._db + with db.locked_database(write=False) as internal_db: + collections = set(internal_db._db.keys()) + assert collections == {"experiments", "algo", "trials", "benchmarks"} + testing_helpers.check_db(db, nb_exps=0, nb_algos=0, nb_trials=0, nb_benchmarks=0) + + +def test_load_all( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load all database""" + assert os.path.isfile(pkl_experiments_and_benchmarks) + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute(f"db load {pkl_experiments_and_benchmarks} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db) + + +def test_load_again_without_resolve( + other_empty_database, pkl_experiments_and_benchmarks, capsys, testing_helpers +): + """Test load all database twice without resolve in second call""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_db( + loaded_db, nb_exps=0, nb_algos=0, nb_trials=0, nb_benchmarks=0 + ) + + execute(f"db load {pkl_experiments_and_benchmarks} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db) + + # Again + execute(f"db load {pkl_experiments_and_benchmarks}", assert_code=1) + captured = capsys.readouterr() + assert ( + captured.err.strip() + == "Error: Conflict detected without strategy to resolve (None) for benchmark branin_baselines_webapi" + ) + # Destination should have not changed + testing_helpers.assert_tested_db_structure(loaded_db) + + +def test_load_ignore( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load all database with --resolve ignore""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute(f"db load {pkl_experiments_and_benchmarks} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db) + testing_helpers.assert_tested_trial_status(loaded_db) + + # Change something in PKL file to check that changes are ignored + src_db = PickledDB(pkl_experiments_and_benchmarks) + testing_helpers.assert_tested_trial_status(src_db) + for trial in src_db.read("trials"): + trial["status"] = "new" + src_db.write("trials", trial, query={"_id": trial["_id"]}) + testing_helpers.assert_tested_db_structure(src_db) + # Trials status checking should fail for PKL file + with pytest.raises(AssertionError): + testing_helpers.assert_tested_trial_status(src_db) + # ... And pass for a specific count + testing_helpers.assert_tested_trial_status(src_db, counts={"new": 24}) + + execute(f"db load {pkl_experiments_and_benchmarks} -r ignore -c {cfg_path}") + # Duplicated data should be ignored, so we must expect same data. + testing_helpers.assert_tested_db_structure(loaded_db) + # Trials status should have not been modified in dst database. + testing_helpers.assert_tested_trial_status(loaded_db) + + +def test_load_overwrite( + other_empty_database, pkl_experiments_and_benchmarks, capsys, testing_helpers +): + """Test load all database with --resolve overwrite""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute(f"db load {pkl_experiments_and_benchmarks} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db) + testing_helpers.assert_tested_trial_status(loaded_db) + + # Change something in PKL file to check that changes are ignored + src_db = PickledDB(pkl_experiments_and_benchmarks) + testing_helpers.assert_tested_trial_status(src_db) + for trial in src_db.read("trials"): + trial["status"] = "new" + src_db.write("trials", trial, query={"_id": trial["_id"]}) + testing_helpers.assert_tested_db_structure(src_db) + # Trials status checking should fail for PKL file + with pytest.raises(AssertionError): + testing_helpers.assert_tested_trial_status(src_db) + # ... And pass for a specific count + testing_helpers.assert_tested_trial_status(src_db, counts={"new": 24}) + + execute(f"db load {pkl_experiments_and_benchmarks} -r overwrite -c {cfg_path}") + # We expect same data structure after overwriting + testing_helpers.assert_tested_db_structure(loaded_db) + # Trial status checking must fail by default + with pytest.raises(AssertionError): + testing_helpers.assert_tested_trial_status(loaded_db) + # ... And pass for specific changes + testing_helpers.assert_tested_trial_status(loaded_db, counts={"new": 24}) + + # Check output to verify progress callback messages + captured = capsys.readouterr() + assert captured.err.strip() == "" + assert ( + captured.out.strip() + == """ +STEP 1 Collect source experiments to load 0 1 +STEP 1 Collect source experiments to load 1 1 +STEP 2 Check benchmarks 0 3 +STEP 2 Check benchmarks 1 3 +STEP 2 Check benchmarks 2 3 +STEP 2 Check benchmarks 3 3 +STEP 3 Check destination experiments 0 1 +STEP 3 Check destination experiments 1 1 +STEP 4 Check source experiments 0 3 +STEP 4 Check source experiments 1 3 +STEP 4 Check source experiments 2 3 +STEP 4 Check source experiments 3 3 +STEP 5 Delete data to replace in destination 0 0 +STEP 6 Insert new data in destination 0 6 +STEP 6 Insert new data in destination 1 6 +STEP 6 Insert new data in destination 2 6 +STEP 6 Insert new data in destination 3 6 +STEP 6 Insert new data in destination 4 6 +STEP 6 Insert new data in destination 5 6 +STEP 6 Insert new data in destination 6 6 +STEP 1 Collect source experiments to load 0 1 +STEP 1 Collect source experiments to load 1 1 +STEP 2 Check benchmarks 0 3 +STEP 2 Check benchmarks 1 3 +STEP 2 Check benchmarks 2 3 +STEP 2 Check benchmarks 3 3 +STEP 3 Check destination experiments 0 1 +STEP 3 Check destination experiments 1 1 +STEP 4 Check source experiments 0 3 +STEP 4 Check source experiments 1 3 +STEP 4 Check source experiments 2 3 +STEP 4 Check source experiments 3 3 +STEP 5 Delete data to replace in destination 0 12 +STEP 5 Delete data to replace in destination 1 12 +STEP 5 Delete data to replace in destination 2 12 +STEP 5 Delete data to replace in destination 3 12 +STEP 5 Delete data to replace in destination 4 12 +STEP 5 Delete data to replace in destination 5 12 +STEP 5 Delete data to replace in destination 6 12 +STEP 5 Delete data to replace in destination 7 12 +STEP 5 Delete data to replace in destination 8 12 +STEP 5 Delete data to replace in destination 9 12 +STEP 5 Delete data to replace in destination 10 12 +STEP 5 Delete data to replace in destination 11 12 +STEP 5 Delete data to replace in destination 12 12 +STEP 6 Insert new data in destination 0 6 +STEP 6 Insert new data in destination 1 6 +STEP 6 Insert new data in destination 2 6 +STEP 6 Insert new data in destination 3 6 +STEP 6 Insert new data in destination 4 6 +STEP 6 Insert new data in destination 5 6 +STEP 6 Insert new data in destination 6 6 +""".strip() + ) + + +def test_load_bump_no_benchmarks( + other_empty_database, pkl_experiments, testing_helpers +): + """Test load all database with --resolve --bump""" + data_source = pkl_experiments + + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute(f"db load {data_source} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db, nb_orig_benchmarks=0) + + execute(f"db load {data_source} -r bump -c {cfg_path}") + # Duplicated data should be bumped, so we must expect twice quantity of data. + testing_helpers.assert_tested_db_structure( + loaded_db, nb_orig_benchmarks=0, nb_duplicated=2 + ) + + execute(f"db load {data_source} -r bump -c {cfg_path}") + # Duplicated data should be bumped, so we must expect thrice quantity of data. + testing_helpers.assert_tested_db_structure( + loaded_db, nb_orig_benchmarks=0, nb_duplicated=3 + ) + + +def test_load_bump_with_benchmarks( + other_empty_database, pkl_experiments_and_benchmarks, capsys, testing_helpers +): + """Test load all database with benchmarks and --resolve --bump""" + data_source = pkl_experiments_and_benchmarks + + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + # First execution should pass, as destination contains nothing. + execute(f"db load {data_source} -c {cfg_path}") + testing_helpers.assert_tested_db_structure(loaded_db) + + # New execution should fail, as benchmarks don't currently support bump. + execute(f"db load {data_source} -r bump", assert_code=1) + captured = capsys.readouterr() + assert ( + captured.err.strip() + == "Error: Can't bump benchmark version, as benchmarks do not currently support versioning." + ) + # Destination should have not changed. + testing_helpers.assert_tested_db_structure(loaded_db) + + +def test_load_one_experiment( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load experiment test_single_exp""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -n test_single_exp -c {cfg_path}" + ) + testing_helpers.check_db( + loaded_db, nb_exps=1, nb_algos=1, nb_trials=6, nb_benchmarks=0 + ) + testing_helpers.check_exp(loaded_db, "test_single_exp", 2, nb_trials=6) + + +def test_load_one_experiment_other_version( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load version 1 of experiment test_single_exp""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + +def test_load_one_experiment_ignore( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load experiment test_single_exp with --resolve ignore""" + + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -r ignore -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + +def test_load_one_experiment_overwrite( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load experiment test_single_exp with --resolve overwrite""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -r overwrite -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + +def test_load_one_experiment_bump( + other_empty_database, pkl_experiments_and_benchmarks, testing_helpers +): + """Test load experiment test_single_exp with --resolve bump""" + storage, cfg_path = other_empty_database + loaded_db = storage._db + testing_helpers.check_empty_db(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -n test_single_exp -v 1 -c {cfg_path}" + ) + testing_helpers.check_unique_import_test_single_expV1(loaded_db) + + execute( + f"db load {pkl_experiments_and_benchmarks} -r bump -n test_single_exp -v 1 -c {cfg_path}" + ) + # Duplicated data should be bumped, so we must expect twice quantity of data. + testing_helpers.check_unique_import_test_single_expV1(loaded_db, nb_versions=2) + + execute( + f"db load {pkl_experiments_and_benchmarks} -r bump -n test_single_exp -v 1 -c {cfg_path}" + ) + # Duplicated data should be bumped, so we must expect thrice quantity of data. + testing_helpers.check_unique_import_test_single_expV1(loaded_db, nb_versions=3) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py new file mode 100644 index 000000000..d32153c9f --- /dev/null +++ b/tests/functional/conftest.py @@ -0,0 +1,25 @@ +"""Common fixtures for functional tests""" + +# Import fixtures from command.conftest +# to be used in test_storage_resource +# Need to also import parent fixtures to make fixtures work +from commands.conftest import ( + one_experiment, + pkl_experiments, + pkl_experiments_and_benchmarks, + testing_helpers, + three_experiments_branch_same_name, + three_experiments_branch_same_name_trials, + three_experiments_branch_same_name_trials_benchmarks, + two_experiments_same_name, +) + +# 'Use' imported fixtures here, to avoid being considered as unused imports by formatting tools +assert one_experiment +assert two_experiments_same_name +assert three_experiments_branch_same_name +assert three_experiments_branch_same_name_trials +assert three_experiments_branch_same_name_trials_benchmarks +assert pkl_experiments +assert pkl_experiments_and_benchmarks +assert testing_helpers diff --git a/tests/functional/serving/test_storage_resource.py b/tests/functional/serving/test_storage_resource.py new file mode 100644 index 000000000..4f5965a08 --- /dev/null +++ b/tests/functional/serving/test_storage_resource.py @@ -0,0 +1,340 @@ +"""Tests for storage resource + +NB: It seems ephemeral db cannot be shared across processes. +It could still be used to test /dump, but we need a database than can be +managed from many processes to test /load, as /load launches import task +in a separate process. + +So, we instead use a PickledDB as destination database for /load testings. +""" +import logging +import os +import random +import string +import time + +import pytest +from falcon import testing + +from orion.core.io.database.pickleddb import PickledDB +from orion.core.utils import generate_temporary_file +from orion.serving.webapi import WebApi +from orion.storage.backup import load_database +from orion.storage.base import setup_storage + + +@pytest.fixture +def ephemeral_loaded(ephemeral_storage, pkl_experiments): + """Load test data in ephemeral storage. To be used before testing /dump requests.""" + load_database(ephemeral_storage.storage, pkl_experiments, resolve="ignore") + + +@pytest.fixture +def ephemeral_loaded_with_benchmarks(ephemeral_storage, pkl_experiments_and_benchmarks): + """Load test data in ephemeral storage. To be used before testing /dump requests.""" + load_database( + ephemeral_storage.storage, pkl_experiments_and_benchmarks, resolve="ignore" + ) + + +class DumpContext: + def __init__(self, client, parameters=None): + self.client = client + self.host = generate_temporary_file() + self.db = None + self.url = "/dump" + ("" if parameters is None else f"?{parameters}") + + def __enter__(self): + response = self.client.simulate_get(self.url) + with open(self.host, "wb") as file: + file.write(response.content) + self.db = PickledDB(self.host) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _clean_dump(self.host) + print("CLEANED DUMP") + + +class LoadContext: + def __init__(self): + # Create empty PKL file as destination database + self.host = generate_temporary_file("test") + # Setup storage and client + pickled_storage = setup_storage( + {"type": "legacy", "database": {"type": "PickledDB", "host": self.host}} + ) + self.pickled_client = testing.TestClient(WebApi(pickled_storage, {})) + self.db = pickled_storage._db + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _clean_dump(self.host) + print("CLEANED LOAD") + + +def _clean_dump(dump_path): + """Delete dumped files.""" + for path in (dump_path, f"{dump_path}.lock"): + if os.path.isfile(path): + os.unlink(path) + + +def _random_string(length: int): + """Generate a random string with given length. Used to generate multipart form data for /load request""" + domain = string.ascii_lowercase + string.digits + return "".join(random.choice(domain) for _ in range(length)) + + +def _gen_multipart_form_for_load(file: str, resolve: str, name="", version=""): + """Generate multipart form body and headers for /load request with uploaded file. + Help (2022/11/08): + https://stackoverflow.com/a/45703104 + https://stackoverflow.com/a/23517227 + """ + import io + + boundary = "----MultiPartFormBoundary" + _random_string(16) + in_boundary = "--" + boundary + out_boundary = in_boundary + "--" + + buff = io.BytesIO() + + buff.write(in_boundary.encode()) + buff.write(b"\r\n") + buff.write(b'Content-Disposition: form-data; name="file"; filename="load.pkl"') + buff.write(b"\r\n") + buff.write(b"Content-Type: application/octet-stream") + buff.write(b"\r\n") + buff.write(b"\r\n") + with open(file, "rb") as data: + buff.write(data.read()) + buff.write(b"\r\n") + + buff.write(in_boundary.encode()) + buff.write(b"\r\n") + buff.write(b'Content-Disposition: form-data; name="resolve"') + buff.write(b"\r\n") + buff.write(b"\r\n") + buff.write(resolve.encode()) + buff.write(b"\r\n") + + buff.write(in_boundary.encode()) + buff.write(b"\r\n") + buff.write(b'Content-Disposition: form-data; name="name"') + buff.write(b"\r\n") + buff.write(b"\r\n") + buff.write(name.encode()) + buff.write(b"\r\n") + + buff.write(in_boundary.encode()) + buff.write(b"\r\n") + buff.write(b'Content-Disposition: form-data; name="version"') + buff.write(b"\r\n") + buff.write(b"\r\n") + buff.write(version.encode()) + buff.write(b"\r\n") + + buff.write(out_boundary.encode()) + buff.write(b"\r\n") + + headers = { + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(buff.tell()), + } + + return buff.getvalue(), headers + + +def test_dump_all(client, ephemeral_loaded_with_benchmarks, testing_helpers): + """Test simple call to /dump""" + with DumpContext(client) as ctx: + testing_helpers.assert_tested_db_structure(ctx.db) + + +def test_dump_one_experiment(client, ephemeral_loaded_with_benchmarks, testing_helpers): + """Test dump only experiment test_single_exp (no version specified)""" + with DumpContext(client, "name=test_single_exp") as ctx: + # We must have dumped version 2 + testing_helpers.check_unique_import_test_single_expV2(ctx.db) + + +def test_dump_one_experiment_other_version( + client, ephemeral_loaded_with_benchmarks, testing_helpers +): + """Test dump version 1 of experiment test_single_exp""" + with DumpContext(client, "name=test_single_exp&version=1") as ctx: + testing_helpers.check_unique_import_test_single_expV1(ctx.db) + + +def test_dump_unknown_experiment(client, ephemeral_loaded_with_benchmarks): + """Test dump unknown experiment""" + response = client.simulate_get("/dump?name=unknown") + assert response.status == "404 Not Found" + assert response.json == { + "title": "DatabaseError", + "description": "No experiment found with query {'name': 'unknown'}. Nothing to dump.", + } + + +def _check_load_and_import_status( + pickled_client, headers, body, finished=True, latest_message=None +): + """Test both /load and /import-status""" + task = pickled_client.simulate_post("/load", headers=headers, body=body).json[ + "task" + ] + # Test /import-status by retrieving all task messages + messages = [] + while True: + progress = pickled_client.simulate_get(f"/import-status/{task}").json + messages.extend(progress["messages"]) + if progress["status"] != "active": + break + time.sleep(0.010) + # Check we have messages + assert messages + for message in messages if latest_message is None else messages[:-1]: + assert message.startswith("INFO:orion.storage.backup") + # Check final task status + if finished: + assert progress["status"] == "finished" + assert progress["progress_value"] == 1.0 + else: + assert progress["status"] == "error" + assert progress["progress_value"] < 1.0 + if latest_message is not None: + assert messages[-1] == latest_message + + +def test_load_all(pkl_experiments_and_benchmarks, testing_helpers, caplog): + """Test both /load and /import-status""" + with caplog.at_level(logging.INFO): + with LoadContext() as ctx: + # Make sure database is empty + testing_helpers.check_empty_db(ctx.db) + + # Generate body and header for request /load + body, headers = _gen_multipart_form_for_load( + pkl_experiments_and_benchmarks, "ignore" + ) + + # Test /load and /import-status 5 times with resolve=ignore + # to check if data are effectively ignored on conflict + for _ in range(5): + _check_load_and_import_status(ctx.pickled_client, headers, body) + # Check expected data in database + # Count should not change as data are ignored on conflict every time + testing_helpers.assert_tested_db_structure(ctx.db) + + +def test_load_one_experiment(pkl_experiments, testing_helpers, caplog): + """Test both /load and /import-status for one experiment""" + with caplog.at_level(logging.INFO): + with LoadContext() as ctx: + # Make sure database is empty + testing_helpers.check_empty_db(ctx.db) + + # Generate body and header for request /load + body, headers = _gen_multipart_form_for_load( + pkl_experiments, "ignore", "test_single_exp" + ) + + _check_load_and_import_status(ctx.pickled_client, headers, body) + + # Check expected data in database + # We must have loaded version 2 + testing_helpers.check_unique_import_test_single_expV2(ctx.db) + + +def test_load_one_experiment_other_version( + pkl_experiments_and_benchmarks, testing_helpers, caplog +): + """Test both /load and /import-status for one experiment with specific version""" + with caplog.at_level(logging.INFO): + with LoadContext() as ctx: + # Make sure database is empty + testing_helpers.check_empty_db(ctx.db) + + # Generate body and header for request /load + body, headers = _gen_multipart_form_for_load( + pkl_experiments_and_benchmarks, "ignore", "test_single_exp", "1" + ) + + _check_load_and_import_status(ctx.pickled_client, headers, body) + + # Check expected data in database + testing_helpers.check_unique_import_test_single_expV1(ctx.db) + + +def test_load_unknown_experiment(pkl_experiments, testing_helpers, caplog): + """Test both /load and /import-status for an unknown experiment""" + with caplog.at_level(logging.INFO): + with LoadContext() as ctx: + # Make sure database is empty + testing_helpers.check_empty_db(ctx.db) + + # Generate body and header for request /load + body, headers = _gen_multipart_form_for_load( + pkl_experiments, "ignore", "unknown" + ) + + _check_load_and_import_status( + ctx.pickled_client, + headers, + body, + finished=False, + latest_message="Error: No experiment found with query {'name': 'unknown'}. Nothing to import.", + ) + + # Check database (must be still empty) + testing_helpers.check_empty_db(ctx.db) + + +@pytest.mark.parametrize( + "log_level,expected_message_prefixes", + [ + (logging.WARNING, []), + ( + logging.INFO, + [ + "INFO:orion.storage.backup:Loaded src /tmp/", + "INFO:orion.storage.backup:Import experiment test_single_exp.1", + "INFO:orion.storage.backup:Import experiment test_single_exp.2", + "INFO:orion.storage.backup:Import experiment test_single_exp_child.1", + ], + ), + ], +) +def test_orion_serve_logging( + pkl_experiments_and_benchmarks, log_level, expected_message_prefixes, caplog +): + with caplog.at_level(log_level): + with LoadContext() as ctx: + # Generate body and header for request /load + body, headers = _gen_multipart_form_for_load( + pkl_experiments_and_benchmarks, "ignore" + ) + # Request /load and get import task ID + task = ctx.pickled_client.simulate_post( + "/load", headers=headers, body=body + ).json["task"] + # Collect task messages using request /import-status + messages = [] + while True: + progress = ctx.pickled_client.simulate_get( + f"/import-status/{task}" + ).json + messages.extend(progress["messages"]) + if progress["status"] != "active": + break + time.sleep(0.010) + + # Check messages + assert len(messages) == len(expected_message_prefixes) + for given_msg, expected_msg_prefix in zip( + messages, expected_message_prefixes + ): + assert given_msg.startswith(expected_msg_prefix)