From 65a2e9229b3b2629a440b9cd609018c5f2030ac3 Mon Sep 17 00:00:00 2001 From: Gertjan Franken Date: Sun, 13 Oct 2024 13:00:47 +0200 Subject: [PATCH] Improve binary sequence algorithm (#31) * WIP: start implementation of more efficient sequence algorithm * WIP: add fixes and improvements * Fix tests * Fix issue where bgb search would not stop * Fix out of bounds * Update tooltip --- bci/database/mongo/mongodb.py | 194 +++++++-------- bci/distribution/worker_manager.py | 36 +-- bci/evaluations/{ => collectors}/collector.py | 7 +- bci/evaluations/custom/custom_evaluation.py | 21 +- bci/evaluations/evaluation_framework.py | 5 +- bci/evaluations/logic.py | 16 +- bci/evaluations/outcome_checker.py | 13 +- bci/master.py | 138 ++++------- bci/search_strategy/bgb_search.py | 86 +++++++ bci/search_strategy/bgb_sequence.py | 78 ++++++ bci/search_strategy/composite_search.py | 92 +------- bci/search_strategy/n_ary_search.py | 81 ------- bci/search_strategy/n_ary_sequence.py | 59 ----- bci/search_strategy/sequence_elem.py | 45 ---- bci/search_strategy/sequence_strategy.py | 119 ++++++---- bci/util.py | 18 +- bci/version_control/factory.py | 158 +++++++------ .../revision_parser/chromium_parser.py | 23 +- bci/version_control/revision_parser/parser.py | 6 +- bci/version_control/states/revisions/base.py | 156 +++++------- .../states/revisions/chromium.py | 46 ++-- .../states/revisions/firefox.py | 48 ++-- bci/version_control/states/state.py | 131 +++++----- bci/version_control/states/versions/base.py | 51 ++-- .../states/versions/chromium.py | 21 +- .../states/versions/firefox.py | 14 +- bci/web/clients.py | 40 ++-- bci/web/vue/src/App.vue | 14 +- bci/web/vue/src/components/tooltip.vue | 6 +- test/http_collector/test_collector.py | 2 +- .../test_biggest_gap_bisection_search.py | 81 +++++++ .../test_biggest_gap_bisection_sequence.py | 42 ++++ test/sequence/test_composite_search.py | 115 +++++---- test/sequence/test_search_strategy.py | 170 ------------- test/sequence/test_sequence_strategy.py | 223 ++++++++---------- 35 files changed, 1106 insertions(+), 1249 deletions(-) rename bci/evaluations/{ => collectors}/collector.py (92%) create mode 100644 bci/search_strategy/bgb_search.py create mode 100644 bci/search_strategy/bgb_sequence.py delete mode 100644 bci/search_strategy/n_ary_search.py delete mode 100644 bci/search_strategy/n_ary_sequence.py delete mode 100644 bci/search_strategy/sequence_elem.py create mode 100644 test/sequence/test_biggest_gap_bisection_search.py create mode 100644 test/sequence/test_biggest_gap_bisection_sequence.py delete mode 100644 test/sequence/test_search_strategy.py diff --git a/bci/database/mongo/mongodb.py b/bci/database/mongo/mongodb.py index 7918776..3df380e 100644 --- a/bci/database/mongo/mongodb.py +++ b/bci/database/mongo/mongodb.py @@ -9,10 +9,17 @@ from pymongo.collection import Collection from pymongo.errors import ServerSelectionTimeoutError -from bci.evaluations.logic import (DatabaseConnectionParameters, - PlotParameters, TestParameters, TestResult, - WorkerParameters) -from bci.version_control.states.state import State +from bci.evaluations.logic import ( + DatabaseConnectionParameters, + EvaluationParameters, + PlotParameters, + StateResult, + TestParameters, + TestResult, + WorkerParameters, +) +from bci.evaluations.outcome_checker import OutcomeChecker +from bci.version_control.states.state import State, StateCondition logger = logging.getLogger(__name__) @@ -25,8 +32,8 @@ class MongoDB(ABC): instance = None binary_availability_collection_names = { - "chromium": "chromium_binary_availability", - "firefox": "firefox_central_binary_availability" + 'chromium': 'chromium_binary_availability', + 'firefox': 'firefox_central_binary_availability', } def __init__(self): @@ -51,14 +58,15 @@ def connect(db_connection_params: DatabaseConnectionParameters): password=db_connection_params.password, authsource=db_connection_params.database_name, retryWrites=False, - serverSelectionTimeoutMS=10000) + serverSelectionTimeoutMS=10000, + ) # Force connection to check whether MongoDB server is reachable try: CLIENT.server_info() DB = CLIENT[db_connection_params.database_name] - logger.info("Connected to database!") + logger.info('Connected to database!') except ServerSelectionTimeoutError as e: - logger.info("A timeout occurred while attempting to establish connection.", exc_info=True) + logger.info('A timeout occurred while attempting to establish connection.', exc_info=True) raise ServerException from e # Initialize collections @@ -73,16 +81,13 @@ def disconnect(): @staticmethod def __initialize_collections(): - for collection_name in [ - 'chromium_binary_availability', - 'firefox_central_binary_availability' - ]: + for collection_name in ['chromium_binary_availability', 'firefox_central_binary_availability']: if collection_name not in DB.list_collection_names(): DB.create_collection(collection_name) def get_collection(self, name: str): if name not in DB.list_collection_names(): - logger.info(f'Collection \'{name}\' does not exist, creating it...') + logger.info(f"Collection '{name}' does not exist, creating it...") DB.create_collection(name) return DB[name] @@ -102,18 +107,18 @@ def store_result(self, result: TestResult): 'mech_group': result.params.mech_group, 'results': result.data, 'dirty': result.is_dirty, - 'ts': str(datetime.now(timezone.utc).replace(microsecond=0)) + 'ts': str(datetime.now(timezone.utc).replace(microsecond=0)), } if result.driver_version: - document["driver_version"] = result.driver_version + document['driver_version'] = result.driver_version - if browser_config.browser_name == "firefox": + if browser_config.browser_name == 'firefox': build_id = self.get_build_id_firefox(result.params.state) if build_id is None: - document["artisanal"] = True - document["build_id"] = "artisanal" + document['artisanal'] = True + document['build_id'] = 'artisanal' else: - document["build_id"] = build_id + document['build_id'] = build_id collection.insert_one(document) @@ -123,10 +128,7 @@ def get_result(self, params: TestParameters) -> TestResult: document = collection.find_one(query) if document: return params.create_test_result_with( - document['browser_version'], - document['binary_origin'], - document['results'], - document['dirty'] + document['browser_version'], document['binary_origin'], document['results'], document['dirty'] ) else: logger.error(f'Could not find document for query {query}') @@ -143,24 +145,66 @@ def has_all_results(self, params: WorkerParameters) -> bool: return False return True + def get_evaluated_states( + self, params: EvaluationParameters, boundary_states: tuple[State, State], outcome_checker: OutcomeChecker + ) -> list[State]: + collection = self.get_collection(params.database_collection) + query = { + 'browser_config': params.browser_configuration.browser_setting, + 'mech_group': params.evaluation_range.mech_groups[0], # TODO: fix this + 'state.browser_name': params.browser_configuration.browser_name, + 'results': {'$exists': True}, + 'state.type': 'version' if params.evaluation_range.only_release_revisions else 'revision', + 'state.revision_number': { + '$gte': boundary_states[0].revision_nb, + '$lte': boundary_states[1].revision_nb, + }, + } + if params.browser_configuration.extensions: + query['extensions'] = { + '$size': len(params.browser_configuration.extensions), + '$all': params.browser_configuration.extensions, + } + else: + query['extensions'] = [] + if params.browser_configuration.cli_options: + query['cli_options'] = { + '$size': len(params.browser_configuration.cli_options), + '$all': params.browser_configuration.cli_options, + } + else: + query['cli_options'] = [] + cursor = collection.find(query) + states = [] + for doc in cursor: + state = State.from_dict(doc['state']) + state.result = StateResult.from_dict(doc['results'], is_dirty=doc['dirty']) + state.outcome = outcome_checker.get_outcome(state.result) + if doc['dirty']: + state.condition = StateCondition.FAILED + else: + state.condition = StateCondition.COMPLETED + states.append(state) + return states + def __to_query(self, params: TestParameters) -> dict: query = { 'state': params.state.to_dict(), 'browser_automation': params.evaluation_configuration.automation, 'browser_config': params.browser_configuration.browser_setting, - 'mech_group': params.mech_group + 'mech_group': params.mech_group, } if len(params.browser_configuration.extensions) > 0: query['extensions'] = { '$size': len(params.browser_configuration.extensions), - '$all': params.browser_configuration.extensions + '$all': params.browser_configuration.extensions, } else: query['extensions'] = [] if len(params.browser_configuration.cli_options) > 0: query['cli_options'] = { '$size': len(params.browser_configuration.cli_options), - '$all': params.browser_configuration.cli_options + '$all': params.browser_configuration.cli_options, } else: query['cli_options'] = [] @@ -184,34 +228,30 @@ def get_binary_availability_collection(browser_name: str): @staticmethod def has_binary_available_online(browser: str, state: State): collection = MongoDB.get_binary_availability_collection(browser) - document = collection.find_one({'state': state.to_dict(make_complete=False)}) + document = collection.find_one({'state': state.to_dict()}) if document is None: return None - return document["binary_online"] + return document['binary_online'] @staticmethod def get_stored_binary_availability(browser): collection = MongoDB.get_binary_availability_collection(browser) result = collection.find( + {'binary_online': True}, { - "binary_online": True + '_id': False, + 'state': True, }, - { - "_id": False, - "state": True, - } ) - if browser == "firefox": + if browser == 'firefox': result.sort('build_id', -1) return result @staticmethod - def get_complete_state_dict_from_binary_availability_cache(state: State): + def get_complete_state_dict_from_binary_availability_cache(state: State) -> dict: collection = MongoDB.get_binary_availability_collection(state.browser_name) # We have to flatten the state dictionary to ignore missing attributes. - state_dict = { - 'state': state.to_dict(make_complete=False) - } + state_dict = {'state': state.to_dict()} query = flatten(state_dict, reducer='dot') document = collection.find_one(query) if document is None: @@ -222,100 +262,68 @@ def get_complete_state_dict_from_binary_availability_cache(state: State): def store_binary_availability_online_cache(browser: str, state: State, binary_online: bool, url: str = None): collection = MongoDB.get_binary_availability_collection(browser) collection.update_one( + {'state': state.to_dict()}, { - 'state': state.to_dict() - }, - { - "$set": - { + '$set': { 'state': state.to_dict(), 'binary_online': binary_online, 'url': url, - 'ts': str(datetime.now(timezone.utc).replace(microsecond=0)) + 'ts': str(datetime.now(timezone.utc).replace(microsecond=0)), } }, - upsert=True + upsert=True, ) @staticmethod def get_build_id_firefox(state: State): - collection = MongoDB.get_binary_availability_collection("firefox") - - result = collection.find_one({ - "state": state.to_dict() - }, { - "_id": False, - "build_id": 1 - }) + collection = MongoDB.get_binary_availability_collection('firefox') + + result = collection.find_one({'state': state.to_dict()}, {'_id': False, 'build_id': 1}) # Result can only be None if the binary associated with the state_id is artisanal: # This state_id will not be included in the binary_availability_collection and not have a build_id. if result is None or len(result) == 0: return None - return result["build_id"] + return result['build_id'] def get_documents_for_plotting(self, params: PlotParameters, releases: bool = False): collection = self.get_collection(params.database_collection) query = { 'mech_group': params.mech_group, 'browser_config': params.browser_config, - 'state.type': 'version' if releases else 'revision' - } - query['extensions'] = { - '$size': len(params.extensions) if params.extensions else 0 + 'state.type': 'version' if releases else 'revision', } + query['extensions'] = {'$size': len(params.extensions) if params.extensions else 0} if params.extensions: query['extensions']['$all'] = params.extensions - query['cli_options'] = { - '$size': len(params.cli_options) if params.cli_options else 0 - } + query['cli_options'] = {'$size': len(params.cli_options) if params.cli_options else 0} if params.cli_options: query['cli_options']['$all'] = params.cli_options if params.revision_number_range: query['state.revision_number'] = { '$gte': params.revision_number_range[0], - '$lte': params.revision_number_range[1] + '$lte': params.revision_number_range[1], } elif params.major_version_range: query['padded_browser_version'] = { '$gte': str(params.major_version_range[0]).zfill(4), - '$lte': str(params.major_version_range[1] + 1).zfill(4) + '$lte': str(params.major_version_range[1] + 1).zfill(4), } - docs = collection.aggregate([ - { - '$match': query - }, - { - '$project': { - '_id': False, - 'state': True, - 'browser_version': True, - 'dirty': True, - 'results': True - } - }, - { - '$sort': { - 'rev_nb': 1 - } - } - ]) + docs = collection.aggregate( + [ + {'$match': query}, + {'$project': {'_id': False, 'state': True, 'browser_version': True, 'dirty': True, 'results': True}}, + {'$sort': {'rev_nb': 1}}, + ] + ) return list(docs) @staticmethod def get_info() -> dict: if CLIENT and CLIENT.address: - return { - 'type': 'mongo', - 'host': CLIENT.address[0], - 'connected': True - } + return {'type': 'mongo', 'host': CLIENT.address[0], 'connected': True} else: - return { - 'type': 'mongo', - 'host': None, - 'connected': False - } + return {'type': 'mongo', 'host': None, 'connected': False} class ServerException(Exception): diff --git a/bci/distribution/worker_manager.py b/bci/distribution/worker_manager.py index feb35f7..778deca 100644 --- a/bci/distribution/worker_manager.py +++ b/bci/distribution/worker_manager.py @@ -3,7 +3,6 @@ import threading import time from queue import Queue -from typing import Callable import docker import docker.errors @@ -11,8 +10,9 @@ from bci import worker from bci.configuration import Global from bci.evaluations.logic import WorkerParameters +from bci.web.clients import Clients -logger = logging.getLogger('bci') +logger = logging.getLogger(__name__) class WorkerManager: @@ -27,19 +27,15 @@ def __init__(self, max_nb_of_containers: int) -> None: self.container_id_pool.put(i) self.client = docker.from_env() - def start_test(self, params: WorkerParameters, cb: Callable, blocking_wait=True) -> None: + def start_test(self, params: WorkerParameters, blocking_wait=True) -> None: if self.max_nb_of_containers != 1: - return self.__run_container(params, cb, blocking_wait) + return self.__run_container(params, blocking_wait) # Single container mode worker.run(params) - cb() - def __run_container(self, params: WorkerParameters, cb: Callable, blocking_wait=True) -> None: - while ( - blocking_wait - and self.get_nb_of_running_worker_containers() >= self.max_nb_of_containers - ): + def __run_container(self, params: WorkerParameters, blocking_wait=True) -> None: + while blocking_wait and self.get_nb_of_running_worker_containers() >= self.max_nb_of_containers: time.sleep(5) container_id = self.container_id_pool.get() container_name = f'bh_worker_{container_id}' @@ -54,7 +50,7 @@ def start_container_thread(): ignore_removed=True, filters={ 'name': f'^/{container_name}$' # The exact name has to match - } + }, ) # Break loop if no container with same name is active if not active_containers: @@ -76,8 +72,10 @@ def start_container_thread(): command=[params.serialize()], volumes=[ os.path.join(os.getenv('HOST_PWD'), 'config') + ':/app/config:ro', - os.path.join(os.getenv('HOST_PWD'), 'browser/binaries/chromium/artisanal') + ':/app/browser/binaries/chromium/artisanal:rw', - os.path.join(os.getenv('HOST_PWD'), 'browser/binaries/firefox/artisanal') + ':/app/browser/binaries/firefox/artisanal:rw', + os.path.join(os.getenv('HOST_PWD'), 'browser/binaries/chromium/artisanal') + + ':/app/browser/binaries/chromium/artisanal:rw', + os.path.join(os.getenv('HOST_PWD'), 'browser/binaries/firefox/artisanal') + + ':/app/browser/binaries/firefox/artisanal:rw', os.path.join(os.getenv('HOST_PWD'), 'experiments') + ':/app/experiments:ro', os.path.join(os.getenv('HOST_PWD'), 'browser/extensions') + ':/app/browser/extensions:ro', os.path.join(os.getenv('HOST_PWD'), 'logs') + ':/app/logs:rw', @@ -85,18 +83,20 @@ def start_container_thread(): '/dev/shm:/dev/shm', ], ) - logger.debug(f'Container \'{container_name}\' finished experiments with parameters \'{repr(params)}\'') - cb() + logger.debug(f"Container '{container_name}' finished experiments with parameters '{repr(params)}'") + Clients.push_results_to_all() except docker.errors.APIError: - logger.error(f'Could not run container \'{container_name}\' or container was unexpectedly removed', exc_info=True) + logger.error( + f"Could not run container '{container_name}' or container was unexpectedly removed", exc_info=True + ) finally: self.container_id_pool.put(container_id) thread = threading.Thread(target=start_container_thread) thread.start() - logger.info(f'Container \'{container_name}\' started experiments for \'{params.state}\'') + logger.info(f"Container '{container_name}' started experiments for '{params.state}'") # To avoid race-condition where more than max containers are started - time.sleep(5) + time.sleep(3) def get_nb_of_running_worker_containers(self): return len(self.get_runnning_containers()) diff --git a/bci/evaluations/collector.py b/bci/evaluations/collectors/collector.py similarity index 92% rename from bci/evaluations/collector.py rename to bci/evaluations/collectors/collector.py index c9e3278..e59f97f 100644 --- a/bci/evaluations/collector.py +++ b/bci/evaluations/collectors/collector.py @@ -1,11 +1,11 @@ +import logging from abc import abstractmethod from enum import Enum -import logging from bci.evaluations.collectors.base import BaseCollector -from .collectors.requests import RequestCollector -from .collectors.logs import LogCollector +from .logs import LogCollector +from .requests import RequestCollector logger = logging.getLogger(__name__) @@ -16,7 +16,6 @@ class Type(Enum): class Collector: - def __init__(self, types: list[Type]) -> None: self.collectors: list[BaseCollector] = [] if Type.REQUESTS in types: diff --git a/bci/evaluations/custom/custom_evaluation.py b/bci/evaluations/custom/custom_evaluation.py index e85237c..338dcb0 100644 --- a/bci/evaluations/custom/custom_evaluation.py +++ b/bci/evaluations/custom/custom_evaluation.py @@ -1,14 +1,13 @@ import logging import os import textwrap -from unittest import TestResult from bci.browser.configuration.browser import Browser from bci.configuration import Global -from bci.evaluations.collector import Collector, Type +from bci.evaluations.collectors.collector import Collector, Type from bci.evaluations.custom.custom_mongodb import CustomMongoDB from bci.evaluations.evaluation_framework import EvaluationFramework -from bci.evaluations.logic import TestParameters +from bci.evaluations.logic import TestParameters, TestResult from bci.web.clients import Clients logger = logging.getLogger(__name__) @@ -98,17 +97,17 @@ def perform_specific_evaluation(self, browser: Browser, params: TestParameters) is_dirty = True finally: collector.stop() - data = collector.collect_results() + results = collector.collect_results() if not is_dirty: # New way to perform sanity check - if [var_entry for var_entry in data['req_vars'] if var_entry['var'] == 'sanity_check' and var_entry['val'] == 'OK']: + if [var_entry for var_entry in results['req_vars'] if var_entry['var'] == 'sanity_check' and var_entry['val'] == 'OK']: pass # Old way for backwards compatibility - elif [request for request in data['requests'] if 'report/?leak=baseline' in request['url']]: + elif [request for request in results['requests'] if 'report/?leak=baseline' in request['url']]: pass else: is_dirty = True - return params.create_test_result_with(browser_version, binary_origin, data, is_dirty) + return params.create_test_result_with(browser_version, binary_origin, results, is_dirty) def get_mech_groups(self, project: str) -> list[tuple[str, bool]]: if project not in self.tests_per_project: @@ -122,14 +121,14 @@ def get_projects(self) -> list[str]: def get_poc_structure(self, project: str, poc: str) -> dict: return self.dir_tree[project][poc] - def get_poc_file(self, project: str, poc: str, domain: str, path: str, file: str) -> str: - file_path = os.path.join(Global.custom_page_folder, project, poc, domain, path, file) + def get_poc_file(self, project: str, poc: str, domain: str, path: str, file_name: str) -> str: + file_path = os.path.join(Global.custom_page_folder, project, poc, domain, path, file_name) if os.path.isfile(file_path): with open(file_path) as file: return file.read() - def update_poc_file(self, project: str, poc: str, domain: str, path: str, file: str, content: str) -> bool: - file_path = os.path.join(Global.custom_page_folder, project, poc, domain, path, file) + def update_poc_file(self, project: str, poc: str, domain: str, path: str, file_name: str, content: str) -> bool: + file_path = os.path.join(Global.custom_page_folder, project, poc, domain, path, file_name) if os.path.isfile(file_path): if content == '': logger.warning('Attempt to save empty file ignored') diff --git a/bci/evaluations/evaluation_framework.py b/bci/evaluations/evaluation_framework.py index 018900c..e33dbb5 100644 --- a/bci/evaluations/evaluation_framework.py +++ b/bci/evaluations/evaluation_framework.py @@ -7,6 +7,7 @@ from bci.configuration import Global from bci.database.mongo.mongodb import MongoDB from bci.evaluations.logic import TestParameters, TestResult, WorkerParameters +from bci.version_control.states.state import StateCondition logger = logging.getLogger(__name__) @@ -39,12 +40,10 @@ def evaluate(self, worker_params: WorkerParameters): try: browser.pre_test_setup() result = self.perform_specific_evaluation(browser, test_params) - - state.set_evaluation_outcome(result) self.db_class.get_instance().store_result(result) logger.info(f'Test finalized: {test_params}') except Exception as e: - state.set_evaluation_error(str(e)) + state.condition = StateCondition.FAILED logger.error("An error occurred during evaluation", exc_info=True) traceback.print_exc() finally: diff --git a/bci/evaluations/logic.py b/bci/evaluations/logic.py index 25d51f6..0188cea 100644 --- a/bci/evaluations/logic.py +++ b/bci/evaluations/logic.py @@ -8,9 +8,9 @@ import bci.browser.cli_options.chromium as cli_options_chromium import bci.browser.cli_options.firefox as cli_options_firefox -from bci.version_control.states.state import State +from bci.version_control.states.state import State, StateResult -logger = logging.getLogger('bci') +logger = logging.getLogger(__name__) @dataclass(frozen=True) @@ -99,8 +99,8 @@ def from_dict(data: dict) -> EvaluationConfiguration: @dataclass(frozen=True) class EvaluationRange: mech_groups: list[str] - major_version_range: tuple[int] | None = None - revision_number_range: tuple[int] | None = None + major_version_range: tuple[int, int] | None = None + revision_number_range: tuple[int, int] | None = None only_release_revisions: bool = False def __post_init__(self): @@ -251,12 +251,8 @@ def padded_browser_version(self): padded_version.append('0' * (padding_target - len(sub)) + sub) return ".".join(padded_version) - @property - def reproduced(self): - entry_if_reproduced = {'var': 'reproduced', 'val': 'OK'} - reproduced_in_req_vars = [entry for entry in self.data['req_vars'] if entry == entry_if_reproduced] != [] - reproduced_in_log_vars = [entry for entry in self.data['log_vars'] if entry == entry_if_reproduced] != [] - return reproduced_in_req_vars or reproduced_in_log_vars + def get_state_result(self) -> StateResult: + return StateResult.from_dict(self.data, self.is_dirty) @dataclass(frozen=True) diff --git a/bci/evaluations/outcome_checker.py b/bci/evaluations/outcome_checker.py index 54c93db..98fc859 100644 --- a/bci/evaluations/outcome_checker.py +++ b/bci/evaluations/outcome_checker.py @@ -1,7 +1,7 @@ import re -from abc import abstractmethod -from bci.evaluations.logic import SequenceConfiguration, TestResult +from bci.evaluations.logic import SequenceConfiguration +from bci.version_control.states.state import StateResult class OutcomeChecker: @@ -9,8 +9,7 @@ class OutcomeChecker: def __init__(self, sequence_config: SequenceConfiguration): self.sequence_config = sequence_config - @abstractmethod - def get_outcome(self, result: TestResult) -> bool: + def get_outcome(self, result: StateResult) -> bool | None: ''' Returns the outcome of the test result. @@ -24,12 +23,12 @@ def get_outcome(self, result: TestResult) -> bool: return True # Backwards compatibility if self.sequence_config.target_mech_id: - return self.get_outcome_for_proxy(result) + return self.__get_outcome_for_proxy(result) - def get_outcome_for_proxy(self, result: TestResult) -> bool | None: + def __get_outcome_for_proxy(self, result: StateResult) -> bool | None: target_mech_id = self.sequence_config.target_mech_id target_cookie = self.sequence_config.target_cookie_name - requests = result.data.get('requests') + requests = result.requests if requests is None: return None # DISCLAIMER: diff --git a/bci/master.py b/bci/master.py index 2f12751..a842610 100644 --- a/bci/master.py +++ b/bci/master.py @@ -9,31 +9,23 @@ from bci.evaluations.logic import ( DatabaseConnectionParameters, EvaluationParameters, - SequenceConfiguration, - WorkerParameters, ) from bci.evaluations.outcome_checker import OutcomeChecker from bci.evaluations.samesite.samesite_evaluation import SameSiteEvaluationFramework from bci.evaluations.xsleaks.evaluation import XSLeaksEvaluation +from bci.search_strategy.bgb_search import BiggestGapBisectionSearch +from bci.search_strategy.bgb_sequence import BiggestGapBisectionSequence from bci.search_strategy.composite_search import CompositeSearch -from bci.search_strategy.n_ary_search import NArySearch -from bci.search_strategy.n_ary_sequence import NArySequence, SequenceFinished -from bci.search_strategy.sequence_strategy import SequenceStrategy -from bci.version_control import factory -from bci.version_control.states.state import State +from bci.search_strategy.sequence_strategy import SequenceFinished, SequenceStrategy +from bci.version_control.factory import StateFactory from bci.web.clients import Clients logger = logging.getLogger(__name__) class Master: - def __init__(self): - self.state = { - 'is_running': False, - 'reason': 'init', - 'status': 'idle' - } + self.state = {'is_running': False, 'reason': 'init', 'status': 'idle'} self.stop_gracefully = False self.stop_forcefully = False @@ -50,20 +42,16 @@ def __init__(self): self.db_connection_params = Global.get_database_connection_params() self.connect_to_database(self.db_connection_params) self.inititialize_available_evaluation_frameworks() - logger.info("BugHog is ready!") + logger.info('BugHog is ready!') def connect_to_database(self, db_connection_params: DatabaseConnectionParameters): try: MongoDB.connect(db_connection_params) except ServerException: - logger.error("Could not connect to database.", exc_info=True) + logger.error('Could not connect to database.', exc_info=True) def run(self, eval_params: EvaluationParameters): - self.state = { - 'is_running': True, - 'reason': 'user', - 'status': 'running' - } + self.state = {'is_running': True, 'reason': 'user', 'status': 'running'} self.stop_gracefully = False self.stop_forcefully = False @@ -74,94 +62,74 @@ def run(self, eval_params: EvaluationParameters): evaluation_range = eval_params.evaluation_range sequence_config = eval_params.sequence_configuration - logger.info(f'Running experiments for {browser_config.browser_name} ({", ".join(evaluation_range.mech_groups)})') - self.evaluation_framework = self.get_specific_evaluation_framework( - evaluation_config.project + logger.info( + f'Running experiments for {browser_config.browser_name} ({", ".join(evaluation_range.mech_groups)})' ) + self.evaluation_framework = self.get_specific_evaluation_framework(evaluation_config.project) self.worker_manager = WorkerManager(sequence_config.nb_of_containers) try: - state_list = factory.create_state_collection(browser_config, evaluation_range) - - search_strategy = self.parse_search_strategy(sequence_config, state_list) - - outcome_checker = OutcomeChecker(sequence_config) - - # The state_lineage is put into self.evaluation as a means to check on the process through front-end - # self.evaluations.append(state_list) + search_strategy = self.create_sequence_strategy(eval_params) try: - current_state = search_strategy.next() while (self.stop_gracefully or self.stop_forcefully) is False: - worker_params = eval_params.create_worker_params_for(current_state, self.db_connection_params) - - # Callback function for sequence strategy - update_outcome = self.get_update_outcome_cb(search_strategy, worker_params, sequence_config, outcome_checker) + # Update search strategy with new potentially new results + current_state = search_strategy.next() - # Check whether state is already evaluated - if self.evaluation_framework.has_all_results(worker_params): - logger.info(f"'{current_state}' already evaluated.") - update_outcome() - current_state = search_strategy.next() - continue + # Prepare worker parameters + worker_params = eval_params.create_worker_params_for(current_state, self.db_connection_params) # Start worker to perform evaluation - self.worker_manager.start_test(worker_params, update_outcome) + self.worker_manager.start_test(worker_params) - current_state = search_strategy.next() except SequenceFinished: - logger.debug("Last experiment has started") + logger.debug('Last experiment has started') self.state['reason'] = 'finished' except Exception as e: - logger.critical("A critical error occurred", exc_info=True) + logger.critical('A critical error occurred', exc_info=True) raise e finally: # Gracefully exit if self.stop_gracefully: - logger.info("Gracefully stopping experiment queue due to user end signal...") + logger.info('Gracefully stopping experiment queue due to user end signal...') self.state['reason'] = 'user' if self.stop_forcefully: - logger.info("Forcefully stopping experiment queue due to user end signal...") + logger.info('Forcefully stopping experiment queue due to user end signal...') self.state['reason'] = 'user' self.worker_manager.forcefully_stop_all_running_containers() else: - logger.info("Gracefully stopping experiment queue since last experiment started.") + logger.info('Gracefully stopping experiment queue since last experiment started.') # MongoDB.disconnect() - logger.info("Waiting for remaining experiments to stop...") + logger.info('Waiting for remaining experiments to stop...') self.worker_manager.wait_until_all_evaluations_are_done() - logger.info("BugHog has finished the evaluation!") + logger.info('BugHog has finished the evaluation!') self.state['is_running'] = False self.state['status'] = 'idle' Clients.push_info_to_all('is_running', 'state') - @staticmethod - def get_update_outcome_cb(search_strategy: SequenceStrategy, worker_params: WorkerParameters, sequence_config: SequenceConfiguration, checker: OutcomeChecker) -> None: - def cb(): - if sequence_config.target_mech_id is not None and len(worker_params.mech_groups) == 1: - result = MongoDB.get_instance().get_result(worker_params.create_test_params_for(worker_params.mech_groups[0])) - outcome = checker.get_outcome(result) - search_strategy.update_outcome(worker_params.state, outcome) - # Just push results update to all clients. Could be more efficient, but would complicate things... - Clients.push_results_to_all() - return cb - def inititialize_available_evaluation_frameworks(self): - self.available_evaluation_frameworks["samesite"] = SameSiteEvaluationFramework() - self.available_evaluation_frameworks["custom"] = CustomEvaluationFramework() - self.available_evaluation_frameworks["xsleaks"] = XSLeaksEvaluation() + self.available_evaluation_frameworks['samesite'] = SameSiteEvaluationFramework() + self.available_evaluation_frameworks['custom'] = CustomEvaluationFramework() + self.available_evaluation_frameworks['xsleaks'] = XSLeaksEvaluation() @staticmethod - def parse_search_strategy(sequence_config: SequenceConfiguration, state_list: list[State]): + def create_sequence_strategy(eval_params: EvaluationParameters) -> SequenceStrategy: + sequence_config = eval_params.sequence_configuration search_strategy = sequence_config.search_strategy sequence_limit = sequence_config.sequence_limit - if search_strategy == "bin_seq": - return NArySequence(state_list, 2, limit=sequence_limit) - if search_strategy == "bin_search": - return NArySearch(state_list, 2) - if search_strategy == "comp_search": - return CompositeSearch(state_list, 2, sequence_limit, NArySequence, NArySearch) - raise AttributeError("Unknown search strategy option '%s'" % search_strategy) + outcome_checker = OutcomeChecker(sequence_config) + state_factory = StateFactory(eval_params, outcome_checker) + + if search_strategy == 'bgb_sequence': + strategy = BiggestGapBisectionSequence(state_factory, sequence_limit) + elif search_strategy == 'bgb_search': + strategy = BiggestGapBisectionSearch(state_factory) + elif search_strategy == 'comp_search': + strategy = CompositeSearch(state_factory, sequence_limit) + else: + raise AttributeError("Unknown search strategy option '%s'" % search_strategy) + return strategy def get_specific_evaluation_framework(self, evaluation_name: str) -> EvaluationFramework: # TODO: we always use 'custom', in which evaluation_name is a project @@ -173,36 +141,28 @@ def get_specific_evaluation_framework(self, evaluation_name: str) -> EvaluationF def activate_stop_gracefully(self): if self.evaluation_framework: self.stop_gracefully = True - self.state = { - 'is_running': True, - 'reason': 'user', - 'status': 'waiting_to_stop' - } + self.state = {'is_running': True, 'reason': 'user', 'status': 'waiting_to_stop'} Clients.push_info_to_all('state') self.evaluation_framework.stop_gracefully() - logger.info("Received user signal to gracefully stop.") + logger.info('Received user signal to gracefully stop.') else: - logger.info("Received user signal to gracefully stop, but no evaluation is running.") + logger.info('Received user signal to gracefully stop, but no evaluation is running.') def activate_stop_forcefully(self): if self.evaluation_framework: self.stop_forcefully = True - self.state = { - 'is_running': True, - 'reason': 'user', - 'status': 'waiting_to_stop' - } + self.state = {'is_running': True, 'reason': 'user', 'status': 'waiting_to_stop'} Clients.push_info_to_all('state') self.evaluation_framework.stop_gracefully() if self.worker_manager: self.worker_manager.forcefully_stop_all_running_containers() - logger.info("Received user signal to forcefully stop.") + logger.info('Received user signal to forcefully stop.') else: - logger.info("Received user signal to forcefully stop, but no evaluation is running.") + logger.info('Received user signal to forcefully stop, but no evaluation is running.') def stop_bughog(self): - logger.info("Stopping all running BugHog containers...") + logger.info('Stopping all running BugHog containers...') self.activate_stop_forcefully() mongodb_container.stop() - logger.info("Stopping BugHog core...") + logger.info('Stopping BugHog core...') exit(0) diff --git a/bci/search_strategy/bgb_search.py b/bci/search_strategy/bgb_search.py new file mode 100644 index 0000000..45043dc --- /dev/null +++ b/bci/search_strategy/bgb_search.py @@ -0,0 +1,86 @@ +import logging +from typing import Optional + +from bci.search_strategy.bgb_sequence import BiggestGapBisectionSequence +from bci.search_strategy.sequence_strategy import SequenceFinished +from bci.version_control.factory import StateFactory +from bci.version_control.states.state import State + +logger = logging.getLogger(__name__) + + +class BiggestGapBisectionSearch(BiggestGapBisectionSequence): + """ + This search strategy will split the biggest gap between two states in half and return the state in the middle. + It will only consider states where the non-None outcome differs. + It stops when there are no more states to evaluate between two states with different outcomes. + """ + + def __init__(self, state_factory: StateFactory) -> None: + """ + Initializes the search strategy. + + :param state_factory: The factory to create new states. + """ + super().__init__(state_factory, 0) + + def next(self) -> State: + """ + Returns the next state to evaluate. + """ + # Fetch all evaluated states + self._fetch_evaluated_states() + + if self._limit and self._limit <= len(self._completed_states): + raise SequenceFinished() + + if self._lower_state not in self._completed_states: + self._add_state(self._lower_state) + return self._lower_state + if self._upper_state not in self._completed_states: + self._add_state(self._upper_state) + return self._upper_state + + while next_pair := self.__get_next_pair_to_split(): + splitter_state = self._find_best_splitter_state(next_pair[0], next_pair[1]) + if splitter_state is None: + self._unavailability_gap_pairs.add(next_pair) + if splitter_state: + logger.debug(f'Splitting [{next_pair[0].index}]--/{splitter_state.index}/--[{next_pair[1].index}]') + self._add_state(splitter_state) + return splitter_state + raise SequenceFinished() + + def __get_next_pair_to_split(self) -> Optional[tuple[State, State]]: + """ + Returns the next pair of states to split. + """ + # Make pairwise list of states and remove pairs with the same outcome + states = self._completed_states + pairs = [(state1, state2) for state1, state2 in zip(states, states[1:]) if state1.outcome != state2.outcome] + # Remove the first and last pair if they have a first and last state with a None outcome, respectively + if pairs[0][0].outcome is None: + pairs = pairs[1:] + if pairs[-1][1].outcome is None: + pairs = pairs[:-1] + # Remove all pairs that have already been identified as unavailability gaps + pairs = [pair for pair in pairs if pair not in self._unavailability_gap_pairs] + # Remove any pair where the same None-outcome state is present in a pair where the sibling states have the same outcome + pairs_with_failed = [pair for pair in pairs if pair[0].outcome is None or pair[1].outcome is None] + for i in range(0, len(pairs_with_failed), 2): + if i + 1 >= len(pairs_with_failed): + break + first_pair = pairs_with_failed[i] + second_pair = pairs_with_failed[i + 1] + if first_pair[0].outcome == second_pair[1].outcome: + pairs.remove(first_pair) + pairs.remove(second_pair) + + if not pairs: + return None + # Sort pairs to prioritize pairs with bigger gaps. + # This way, we refrain from pinpointing pair-by-pair, making the search more efficient. + # E.g., when the splitter of the first gap is being evaluated, we can already evaluate the + # splitter of the second gap with having to wait for the first gap to be fully evaluated. + pairs.sort(key=lambda pair: pair[1].index - pair[0].index, reverse=True) + return pairs[0] diff --git a/bci/search_strategy/bgb_sequence.py b/bci/search_strategy/bgb_sequence.py new file mode 100644 index 0000000..b9a88d1 --- /dev/null +++ b/bci/search_strategy/bgb_sequence.py @@ -0,0 +1,78 @@ +import logging +from typing import Optional + +from bci.search_strategy.sequence_strategy import SequenceFinished, SequenceStrategy +from bci.version_control.factory import StateFactory +from bci.version_control.states.state import State + +logger = logging.getLogger(__name__) + + +class BiggestGapBisectionSequence(SequenceStrategy): + """ + This sequence strategy will split the biggest gap between two states in half and return the state in the middle. + """ + + def __init__(self, state_factory: StateFactory, limit: int) -> None: + """ + Initializes the sequence strategy. + + :param state_factory: The factory to create new states. + :param limit: The maximum number of states to evaluate. 0 means no limit. + """ + super().__init__(state_factory, limit) + self._unavailability_gap_pairs: set[tuple[State, State]] = set() + """Tuples in this list are **strict** boundaries of ranges without any available binaries.""" + + def next(self) -> State: + """ + Returns the next state to evaluate. + """ + # Fetch all evaluated states on the first call + if not self._completed_states: + self._fetch_evaluated_states() + + if self._limit and self._limit <= len(self._completed_states): + raise SequenceFinished() + + if self._lower_state not in self._completed_states: + self._add_state(self._lower_state) + return self._lower_state + if self._upper_state not in self._completed_states: + self._add_state(self._upper_state) + return self._upper_state + + pairs = list(zip(self._completed_states, self._completed_states[1:])) + while pairs: + furthest_pair = max(pairs, key=lambda x: x[1].index - x[0].index) + splitter_state = self._find_best_splitter_state(furthest_pair[0], furthest_pair[1]) + if splitter_state is None: + self._unavailability_gap_pairs.add(furthest_pair) + elif splitter_state: + logger.debug( + f"Splitting [{furthest_pair[0].index}]--/{splitter_state.index}/--[{furthest_pair[1].index}]" + ) + self._add_state(splitter_state) + return splitter_state + pairs.remove(furthest_pair) + raise SequenceFinished() + + def _find_best_splitter_state(self, first_state: State, last_state: State) -> Optional[State]: + """ + Returns the most suitable state that splits the gap between the two states. + The state should be as close as possible to the middle of the gap and should have an available binary. + """ + if first_state.index + 1 == last_state.index: + return None + best_splitter_index = first_state.index + (last_state.index - first_state.index) // 2 + target_state = self._state_factory.create_state(best_splitter_index) + return self._find_closest_state_with_available_binary(target_state, (first_state, last_state)) + + def _state_is_in_unavailability_gap(self, state: State) -> bool: + """ + Returns True if the state is in a gap between two states without any available binaries. + """ + for pair in self._unavailability_gap_pairs: + if pair[0].index < state.index < pair[1].index: + return True + return False diff --git a/bci/search_strategy/composite_search.py b/bci/search_strategy/composite_search.py index 60fbc23..f59a7c1 100644 --- a/bci/search_strategy/composite_search.py +++ b/bci/search_strategy/composite_search.py @@ -1,89 +1,23 @@ -from bci.search_strategy.sequence_strategy import SequenceStrategy -from bci.search_strategy.n_ary_sequence import NArySequence, SequenceFinished -from bci.search_strategy.n_ary_search import NArySearch -from bci.search_strategy.sequence_elem import SequenceElem, ElemState +from bci.search_strategy.bgb_search import BiggestGapBisectionSearch +from bci.search_strategy.bgb_sequence import BiggestGapBisectionSequence +from bci.search_strategy.sequence_strategy import SequenceFinished +from bci.version_control.factory import StateFactory from bci.version_control.states.state import State -class CompositeSearch(SequenceStrategy): - def __init__( - self, - values: list[State], - n: int, - sequence_limit: int, - sequence_strategy_class: NArySequence.__class__, - search_strategy_class: NArySearch.__class__) -> None: - super().__init__(values) - self.n = n - self.sequence_strategy = sequence_strategy_class(values, n, limit=sequence_limit) - self.search_strategies = [] - self.search_strategy_class = search_strategy_class +class CompositeSearch(): + def __init__(self, state_factory: StateFactory, sequence_limit: int) -> None: + self.sequence_strategy = BiggestGapBisectionSequence(state_factory, limit=sequence_limit) + self.search_strategy = BiggestGapBisectionSearch(state_factory) self.sequence_strategy_finished = False def next(self) -> State: + # First we use the sequence strategy to select the next state if not self.sequence_strategy_finished: - next_elem = self.next_in_sequence_strategy() - if next_elem is not None: - return next_elem - return self.next_in_search_strategy() - - def next_in_sequence_strategy(self) -> State: - try: - return self.sequence_strategy.next() - except SequenceFinished: - self.sequence_strategy_finished = True - self.prepare_search_strategies() - return None - - def next_in_search_strategy(self) -> State: - while True: - if not self.search_strategies: - raise SequenceFinished() - search_strategy = self.search_strategies[0] try: - return search_strategy.next() + return self.sequence_strategy.next() except SequenceFinished: - del self.search_strategies[0] - - def get_active_strategy(self) -> SequenceStrategy: - ''' - Returns the currently active sequence/search strategy. - Returns None if all sequence/search strategies are finished. - ''' - if not self.sequence_strategy_finished: - return self.sequence_strategy - elif self.search_strategies: - return self.search_strategies[0] + self.sequence_strategy_finished = True + return self.search_strategy.next() else: - return None - - def update_outcome(self, elem: State, outcome: bool) -> None: - if active_strategy := self.get_active_strategy(): - active_strategy.update_outcome(elem, outcome) - # We only update the outcome of this object too if we are still using the sequence strategy - # because the elem lists need to be synced up until the search strategies are prepared. - # Not very clean, but does the job for now. - if not self.sequence_strategy_finished: - super().update_outcome(elem, outcome) - - def prepare_search_strategies(self): - shift_index_pairs = self.find_all_shift_index_pairs() - self.search_strategies = [self.search_strategy_class( - self.sequence_strategy.values[left_shift_index:right_shift_index+1], - self.n, - prior_elems=self.get_elems_slice(left_shift_index, right_shift_index+1)) - for left_shift_index, right_shift_index in shift_index_pairs] - - def get_elems_slice(self, start: int, end: int) -> list[SequenceElem]: - return [elem.get_deep_copy(index=i) for i, elem in enumerate(self._elems[start:end])] - - def find_all_shift_index_pairs(self) -> list[tuple[int, int]]: - # Filter out all errors and unevaluated elements - filtered_elems = [elem for elem in self._elems if elem.state not in [ElemState.ERROR, ElemState.INITIALIZED]] - filtered_elems_outcomes = [elem.outcome for elem in filtered_elems] - # Get start indexes of shift in outcome - shift_indexes = [i for i in range(0, len(filtered_elems_outcomes) - 1) if filtered_elems_outcomes[i] != filtered_elems_outcomes[i+1]] - # Convert to index pairs for original value list - shift_elem_pairs = [(filtered_elems[shift_index], filtered_elems[shift_index + 1]) for shift_index in shift_indexes if shift_index + 1 < len(filtered_elems)] - shift_index_pairs = [(left_shift_elem.index, right_shift_elem.index) for left_shift_elem, right_shift_elem in shift_elem_pairs] - return shift_index_pairs + return self.search_strategy.next() diff --git a/bci/search_strategy/n_ary_search.py b/bci/search_strategy/n_ary_search.py deleted file mode 100644 index 38e0279..0000000 --- a/bci/search_strategy/n_ary_search.py +++ /dev/null @@ -1,81 +0,0 @@ -from bisect import insort - -from bci.search_strategy.sequence_elem import SequenceElem -from bci.search_strategy.n_ary_sequence import NArySequence, SequenceFinished, ElemState -from bci.version_control.states.state import State - - -class NArySearch(NArySequence): - - def __init__(self, values: list[State], n: int, prior_elems: list[SequenceElem] = None) -> None: - super().__init__(values, n, prior_elems=prior_elems) - self.lower_bound = 0 - """ - Lower boundary, only indexes equal or higher should be evaluated. - """ - self.upper_bound = len(values) - """ - Strict upper boundary, only indexes strictly lower should be evaluated. - """ - self.outcomes: list[tuple[int, bool]] = [] - if prior_elems: - for elem in prior_elems: - if elem.outcome is not None: - self.update_boundaries(elem.value, elem.outcome) - - def update_outcome(self, value: State, outcome: bool) -> None: - super().update_outcome(value, outcome) - self.update_boundaries(value, outcome) - - def update_boundaries(self, value: State, outcome: bool) -> None: - if outcome is None: - return - new_index = self._elem_info[value].index - insort(self.outcomes, (new_index, outcome), key=lambda x: x[0]) - if len(self.outcomes) < 3: - return - index0, outcome0 = self.outcomes[0] - index1, outcome1 = self.outcomes[1] - index2, outcome2 = self.outcomes[2] - if outcome0 != outcome1: - del self.outcomes[2] - self.lower_bound = index0 - self.upper_bound = index1 - elif outcome1 != outcome2: - del self.outcomes[0] - self.lower_bound = index1 - self.upper_bound = index2 - lower_value = self._elems[self.lower_bound].value - upper_value = self._elems[self.upper_bound - 1].value - self.logger.info(f"Boundaries updated: {lower_value} <= x <= {upper_value}") - - def next(self) -> State: - while True: - while self.index_queue.empty(): - if self.range_queue.empty(): - raise SequenceFinished() - (lower_index, upper_index) = self.range_queue.get() - if lower_index >= self.upper_bound or upper_index < self.lower_bound: - # The range is completely out of bounds, so we just discard it - continue - if lower_index < self.lower_bound: - # The range is partly out of bounds, so we truncate it - # (possible because closest available elem instead of exact elem) - lower_index = self.lower_bound - if upper_index > self.upper_bound: - # Same as above - upper_index = self.upper_bound - new_indexes, new_ranges = self.divide_range(lower_index, upper_index, self.n) - for new_index in new_indexes: - self.index_queue.put(new_index) - for new_range in new_ranges: - self.range_queue.put(new_range) - index = self.index_queue.get() - # Only use index if it's within the active bounds - # Could not be the case if the index was added to the queue after the bounds were updated - if self.lower_bound <= index < self.upper_bound: - # Get closest available elem and check whether it is not yet evaluated - closest_available_elem = self.find_closest_available_elem(index) - if closest_available_elem.state == ElemState.INITIALIZED: - closest_available_elem.state = ElemState.IN_PROGRESS - return closest_available_elem.value diff --git a/bci/search_strategy/n_ary_sequence.py b/bci/search_strategy/n_ary_sequence.py deleted file mode 100644 index 6de0f12..0000000 --- a/bci/search_strategy/n_ary_sequence.py +++ /dev/null @@ -1,59 +0,0 @@ -import math -from queue import Queue -from bci.search_strategy.sequence_elem import ElemState, SequenceElem -from bci.search_strategy.sequence_strategy import SequenceStrategy, SequenceFinished -from bci.version_control.states.state import State - - -class NArySequence(SequenceStrategy): - - def __init__(self, values: list[State], n: int, limit=float('inf'), prior_elems: list[SequenceElem] = None) -> None: - super().__init__(values, prior_elems=prior_elems) - self.n = n - first_index = 0 - last_index = len(self._elems) - 1 - self.index_queue = Queue() - self.index_queue.put(first_index) - self.index_queue.put(last_index) - self.range_queue = Queue() - self.range_queue.put((first_index + 1, last_index)) - self.limit = limit - self.nb_of_started_evaluations = 0 - - def next(self) -> State: - while True: - if self.limit <= self.nb_of_started_evaluations: - raise SequenceFinished() - while self.index_queue.empty(): - if self.range_queue.empty(): - raise SequenceFinished() - (lower_index, higher_index) = self.range_queue.get() - new_indexes, new_ranges = self.divide_range(lower_index, higher_index, self.n) - for new_index in new_indexes: - self.index_queue.put(new_index) - for new_range in new_ranges: - self.range_queue.put(new_range) - target_elem = self.index_queue.get() - closest_available_elem = self.find_closest_available_elem(target_elem) - self.logger.debug(f"Next state should be {repr(target_elem)}, but {repr(closest_available_elem)} is closest available") - if closest_available_elem.state == ElemState.INITIALIZED: - closest_available_elem.state = ElemState.IN_PROGRESS - self.nb_of_started_evaluations += 1 - return closest_available_elem.value - - @staticmethod - def divide_range(lower_index, higher_index, n): - if lower_index == higher_index: - return [], [] - if higher_index - lower_index + 1 <= n: - return list(range(lower_index, higher_index)), [] - step = math.ceil((higher_index - lower_index) / n) - if lower_index + step * n <= higher_index: - indexes = list(range(lower_index, higher_index + 1, step)) - else: - indexes = list(range(lower_index, higher_index + 1, step)) + [higher_index] - ranges = [] - ranges.append((indexes[0], indexes[1])) - for i in range(1, len(indexes) - 1): - ranges.append((indexes[i] + 1, indexes[i + 1])) - return indexes[1:-1], ranges diff --git a/bci/search_strategy/sequence_elem.py b/bci/search_strategy/sequence_elem.py deleted file mode 100644 index 156714f..0000000 --- a/bci/search_strategy/sequence_elem.py +++ /dev/null @@ -1,45 +0,0 @@ -from enum import Enum - -import bci.browser.binary.factory as binary_factory -from bci.version_control.states.state import State - - -class ElemState(Enum): - INITIALIZED = 0 - UNAVAILABLE = 1 - IN_PROGRESS = 2 - ERROR = 3 - DONE = 4 - - -class SequenceElem: - - def __init__(self, index: int, value: State, state: ElemState = ElemState.INITIALIZED, outcome: bool = None) -> None: - self.value = value - self.index = index - if state == ElemState.DONE and outcome is None: - raise AttributeError("Every sequence element that has been evaluated should have an outcome") - self.state = state - self.outcome = outcome - - def is_available(self) -> bool: - binary = binary_factory.get_binary(self.value) - return binary.is_available() - - def update_outcome(self, outcome: bool): - if self.state == ElemState.DONE: - raise AttributeError(f"Outcome was already set to DONE for {repr(self)}") - if outcome is None: - self.state = ElemState.ERROR - else: - self.state = ElemState.DONE - self.outcome = outcome - - def get_deep_copy(self, index=None): - if index is not None: - return SequenceElem(index, self.value, state=self.state, outcome=self.outcome) - else: - return SequenceElem(self.index, self.value, state=self.state, outcome=self.outcome) - - def __repr__(self) -> str: - return f"{str(self.value)}: {self.state}" diff --git a/bci/search_strategy/sequence_strategy.py b/bci/search_strategy/sequence_strategy.py index 74cc175..777e490 100644 --- a/bci/search_strategy/sequence_strategy.py +++ b/bci/search_strategy/sequence_strategy.py @@ -1,65 +1,94 @@ import logging from abc import abstractmethod from threading import Thread -import bci.browser.binary.factory as binary_factory -from bci.search_strategy.sequence_elem import SequenceElem +from typing import Optional + +from bci.version_control.factory import StateFactory from bci.version_control.states.state import State +logger = logging.getLogger(__name__) + class SequenceStrategy: - def __init__(self, values: list[State], prior_elems: list[SequenceElem] = None) -> None: - self.logger = logging.getLogger(__name__) - if prior_elems and len(values) != len(prior_elems): - raise AttributeError(f"List of values and list of elems should be of equal length ({len(values)} != {len(prior_elems)})") - self.values = values - if prior_elems: - self._elems = prior_elems - else: - self._elems = [SequenceElem(index, value) for index, value in enumerate(values)] - self._elem_info = { - elem.value: elem - for elem in self._elems - } - - def update_outcome(self, elem: State, outcome: bool) -> None: - self._elem_info[elem].update_outcome(outcome) + def __init__(self, state_factory: StateFactory, limit) -> None: + """ + Initializes the sequence strategy. - def is_available(self, state: State) -> bool: - return binary_factory.binary_is_available(state) + :param state_factory: The factory to create new states. + :param limit: The maximum number of states to evaluate. 0 means no limit. + """ + self._state_factory = state_factory + self._limit = limit + self._lower_state, self._upper_state = self.__create_available_boundary_states() + self._completed_states = [] @abstractmethod def next(self) -> State: pass - def find_closest_available_elem(self, target_index: int) -> SequenceElem: - diff = 0 - while True: - potential_indexes = set(index for index in [ - target_index + diff, - target_index + diff + 1, - target_index - diff, - target_index - diff - 1, - ] if 0 <= index < len(self._elems)) - - if not potential_indexes: - raise AttributeError(f"Could not find closest available build state for '{target_index}'") + def is_available(self, state: State) -> bool: + return state.has_available_binary() + + def _add_state(self, elem: State) -> None: + """ + Adds an element to the list of evaluated states and sorts the list. + """ + self._completed_states.append(elem) + self._completed_states.sort(key=lambda x: x.index) + + def _fetch_evaluated_states(self) -> None: + """ + Fetches all evaluated states from the database and stores them in the list of evaluated states. + """ + fetched_states = self._state_factory.create_evaluated_states() + for state in self._completed_states: + if state not in fetched_states: + fetched_states.append(state) + fetched_states.sort(key=lambda x: x.index) + self._completed_states = fetched_states + + def __create_available_boundary_states(self) -> tuple[State, State]: + first_state, last_state = self._state_factory.boundary_states + available_first_state = self._find_closest_state_with_available_binary(first_state, (first_state, last_state)) + available_last_state = self._find_closest_state_with_available_binary(last_state, (first_state, last_state)) + if available_first_state is None or available_last_state is None: + raise AttributeError( + f"Could not find boundary states for '{self._lower_state.index}' and '{self._upper_state.index}'" + ) + return available_first_state, available_last_state + + def _find_closest_state_with_available_binary(self, target: State, boundaries: tuple[State, State]) -> State | None: + """ + Finds the closest state with an available binary **strictly** within the given boundaries. + """ + if target.has_available_binary(): + return target + + def index_has_available_binary(index: int) -> Optional[State]: + state = self._state_factory.create_state(index) + if state.has_available_binary(): + return state + else: + return None + + diff = 1 + first_state, last_state = boundaries + best_splitter_index = target.index + while (best_splitter_index - diff - 1) > first_state.index or (best_splitter_index + diff + 1) < last_state.index: threads = [] - for index in potential_indexes: - thread = ThreadWithReturnValue(target=lambda x: x if self._elems[x].is_available() else None, args=(index,)) - thread.start() - threads.append(thread) + for offset in (-diff, diff, - 1 - diff, 1 + diff): + target_index = best_splitter_index + offset + if first_state.index < target_index < last_state.index: + thread = ThreadWithReturnValue(target=index_has_available_binary, args=(target_index,)) + thread.start() + threads.append(thread) - results = [] for thread in threads: - result = thread.join() - if result is not None: - results.append(result) - # If valid results are found, return the one closest to target - if results: - results = sorted(results, key=lambda x: abs(x - target_index)) - return self._elems[results[0]] - # Otherwise re-iterate + state = thread.join() + if state: + return state diff += 2 + return None class SequenceFinished(Exception): diff --git a/bci/util.py b/bci/util.py index 81516c8..68f0952 100644 --- a/bci/util.py +++ b/bci/util.py @@ -64,7 +64,7 @@ def read_web_report(file_name): report_folder = "/reports" path = os.path.join(report_folder, file_name) if not os.path.isfile(path): - raise AttributeError("Could not find report at '%s'" % path) + raise PageNotFound("Could not find report at '%s'" % path) with open(path, "r") as file: return json.load(file) @@ -72,16 +72,16 @@ def read_web_report(file_name): def request_html(url: str): LOGGER.debug(f"Requesting {url}") resp = requests.get(url, timeout=60) - if resp.status_code != 200: - raise AttributeError(f"Could not connect to url '{url}'") + if resp.status_code >= 400: + raise PageNotFound(f"Could not connect to url '{url}'") return resp.content def request_json(url: str): LOGGER.debug(f"Requesting {url}") resp = requests.get(url, timeout=60) - if resp.status_code != 200: - raise AttributeError(f"Could not connect to url '{url}'") + if resp.status_code >= 400: + raise PageNotFound(f"Could not connect to url '{url}'") LOGGER.debug(f"Request completed") return resp.json() @@ -89,7 +89,11 @@ def request_json(url: str): def request_final_url(url: str) -> str: LOGGER.debug(f"Requesting {url}") resp = requests.get(url, timeout=60) - if resp.status_code != 200: - raise AttributeError(f"Could not connect to url '{url}'") + if resp.status_code >= 400: + raise PageNotFound(f"Could not connect to url '{url}'") LOGGER.debug(f"Request completed") return resp.url + + +class PageNotFound(Exception): + pass diff --git a/bci/version_control/factory.py b/bci/version_control/factory.py index f203fd5..444ce70 100644 --- a/bci/version_control/factory.py +++ b/bci/version_control/factory.py @@ -1,81 +1,93 @@ -import re +from __future__ import annotations -import bci.version_control.repository.online.chromium as chromium_repo -import bci.version_control.repository.online.firefox as firefox_repo - -from bci.evaluations.logic import BrowserConfiguration, EvaluationRange -from bci.version_control.repository.repository import Repository +from bci.database.mongo.mongodb import MongoDB +from bci.evaluations.logic import EvaluationParameters +from bci.evaluations.outcome_checker import OutcomeChecker from bci.version_control.states.revisions.chromium import ChromiumRevision from bci.version_control.states.revisions.firefox import FirefoxRevision from bci.version_control.states.state import State +from bci.version_control.states.versions.base import BaseVersion from bci.version_control.states.versions.chromium import ChromiumVersion from bci.version_control.states.versions.firefox import FirefoxVersion -def create_state_collection(browser_config: BrowserConfiguration, eval_range: EvaluationRange) -> list[State]: - if eval_range.only_release_revisions: - return __create_version_collection(browser_config, eval_range) - else: - return __create_revision_collection(browser_config, eval_range) - - -def __create_version_collection(browser_config: BrowserConfiguration, eval_range: EvaluationRange) -> list[State]: - if not eval_range.major_version_range: - raise ValueError('A major version range is required for creating a version collection') - lower_version = eval_range.major_version_range[0] - upper_version = eval_range.major_version_range[1] - - match browser_config.browser_name: - case 'chromium': - state_class = ChromiumVersion - case 'firefox': - state_class = FirefoxVersion - case _: - raise ValueError(f'Unknown browser name: {browser_config.browser_name}') - - return [ - state_class(version) - for version in range(lower_version, upper_version + 1) - ] - - -def __create_revision_collection(browser_config: BrowserConfiguration, eval_range: EvaluationRange) -> list[State]: - if eval_range.major_version_range: - repo = __get_repo(browser_config) - lower_revision_nb = repo.get_release_revision_number(eval_range.major_version_range[0]) - upper_revision_nb = repo.get_release_revision_number(eval_range.major_version_range[1]) - else: - lower_revision_nb, upper_revision_nb = eval_range.revision_number_range - - match browser_config.browser_name: - case 'chromium': - state_class = ChromiumRevision - case 'firefox': - state_class = FirefoxRevision - case _: - raise ValueError(f'Unknown browser name: {browser_config.browser_name}') - - return [ - state_class(revision_number=rev_nb) - for rev_nb in range(lower_revision_nb, upper_revision_nb + 1) - ] - - -def __get_short_version(version: str) -> int: - if '.' not in version: - return int(version) - if re.match(r'^[0-9]+$', version): - return int(version) - if re.match(r'^[0-9]+(\.[0-9]+)+$', version): - return int(version.split(".")[0]) - raise AttributeError(f'Could not convert version \'{version}\' to short version') - - -def __get_repo(browser_config: BrowserConfiguration) -> Repository: - match browser_config.browser_name: - case 'chromium': - return chromium_repo - case 'firefox': - return firefox_repo - case _: - raise ValueError(f'Unknown browser name: {browser_config.browser_name}') +class StateFactory: + def __init__(self, eval_params: EvaluationParameters, outcome_checker: OutcomeChecker) -> None: + """ + Create a state factory object with the given evaluation parameters and boundary indices. + + :param eval_params: The evaluation parameters. + """ + self.__eval_params = eval_params + self.__outcome_checker = outcome_checker + self.boundary_states = self.__create_boundary_states() + + def create_state(self, index: int) -> State: + """ + Create a state object associated with the given index. + The given index represents: + - A major version number if `self.eval_params.evaluation_range.major_version_range` is True. + - A revision number otherwise. + + :param index: The index of the state. + """ + eval_range = self.__eval_params.evaluation_range + if eval_range.only_release_revisions: + return self.__create_version_state(index) + else: + return self.__create_revision_state(index) + + def __create_boundary_states(self) -> tuple[State, State]: + """ + Create the boundary state objects for the evaluation range. + """ + eval_range = self.__eval_params.evaluation_range + if eval_range.major_version_range: + first_state = self.__create_version_state(eval_range.major_version_range[0]) + last_state = self.__create_version_state(eval_range.major_version_range[1]) + if not eval_range.only_release_revisions: + first_state = first_state.convert_to_revision() + last_state = last_state.convert_to_revision() + return first_state, last_state + elif eval_range.revision_number_range: + if eval_range.only_release_revisions: + raise ValueError('Release revisions are not allowed in this evaluation range') + return ( + self.__create_revision_state(eval_range.revision_number_range[0]), + self.__create_revision_state(eval_range.revision_number_range[1]), + ) + else: + raise ValueError('No evaluation range specified') + + def create_evaluated_states(self) -> list[State]: + """ + Create evaluated state objects within the evaluation range where the result is fetched from the database. + """ + db = MongoDB.get_instance() + return db.get_evaluated_states(self.__eval_params, self.boundary_states, self.__outcome_checker) + + def __create_version_state(self, index: int) -> BaseVersion: + """ + Create a version state object associated with the given index. + """ + browser_config = self.__eval_params.browser_configuration + match browser_config.browser_name: + case 'chromium': + return ChromiumVersion(index) + case 'firefox': + return FirefoxVersion(index) + case _: + raise ValueError(f'Unknown browser name: {browser_config.browser_name}') + + def __create_revision_state(self, index: int) -> State: + """ + Create a revision state object associated with the given index. + """ + browser_config = self.__eval_params.browser_configuration + match browser_config.browser_name: + case 'chromium': + return ChromiumRevision(revision_nb=index) + case 'firefox': + return FirefoxRevision(revision_nb=index) + case _: + raise ValueError(f'Unknown browser name: {browser_config.browser_name}') diff --git a/bci/version_control/revision_parser/chromium_parser.py b/bci/version_control/revision_parser/chromium_parser.py index 027d54c..df643b4 100644 --- a/bci/version_control/revision_parser/chromium_parser.py +++ b/bci/version_control/revision_parser/chromium_parser.py @@ -1,29 +1,34 @@ import logging import re from typing import Optional -from bci.version_control.revision_parser.parser import RevisionParser -from bci.util import request_html, request_final_url +from bci.util import PageNotFound, request_final_url, request_html +from bci.version_control.revision_parser.parser import RevisionParser REV_ID_BASE_URL = 'https://chromium.googlesource.com/chromium/src/+/' REV_NUMBER_BASE_URL = 'http://crrev.com/' +logger = logging.getLogger(__name__) -class ChromiumRevisionParser(RevisionParser): - def get_rev_id(self, rev_number: int) -> str: - final_url = request_final_url(f'{REV_NUMBER_BASE_URL}{rev_number}') +class ChromiumRevisionParser(RevisionParser): + def get_revision_id(self, revision_nb: int) -> Optional[str]: + try: + final_url = request_final_url(f'{REV_NUMBER_BASE_URL}{revision_nb}') + except PageNotFound: + logger.warning(f"Could not find revision id for revision number '{revision_nb}'") + return None rev_id = final_url[-40:] assert re.match(r'[a-z0-9]{40}', rev_id) return rev_id - def get_rev_number(self, rev_id: str) -> int: - url = f'{REV_ID_BASE_URL}{rev_id}' + def get_revision_nb(self, revision_id: str) -> int: + url = f'{REV_ID_BASE_URL}{revision_id}' html = request_html(url).decode() rev_number = self.__parse_revision_number(html) if rev_number is None: - logging.getLogger('bci').error(f'Could not parse revision number on \'{url}\'') - raise AttributeError(f'Could not parse revision number on \'{url}\'') + logging.getLogger('bci').error(f"Could not parse revision number on '{url}'") + raise AttributeError(f"Could not parse revision number on '{url}'") assert re.match(r'[0-9]{1,7}', rev_number) return int(rev_number) diff --git a/bci/version_control/revision_parser/parser.py b/bci/version_control/revision_parser/parser.py index 26551f2..e3e2595 100644 --- a/bci/version_control/revision_parser/parser.py +++ b/bci/version_control/revision_parser/parser.py @@ -1,12 +1,12 @@ from abc import abstractmethod +from typing import Optional class RevisionParser: - @abstractmethod - def get_rev_id(self, rev_nb: int): + def get_revision_id(self, revision_nb: int) -> Optional[str]: pass @abstractmethod - def get_rev_number(self, rev_id: str): + def get_revision_nb(self, revision_id: str) -> Optional[int]: pass diff --git a/bci/version_control/states/revisions/base.py b/bci/version_control/states/revisions/base.py index c4e6eaf..fc96fd9 100644 --- a/bci/version_control/states/revisions/base.py +++ b/bci/version_control/states/revisions/base.py @@ -1,135 +1,97 @@ +import logging import re from abc import abstractmethod +from typing import Optional from bci.version_control.states.state import State +logger = logging.getLogger(__name__) -class BaseRevision(State): - def __init__(self, revision_id: str = None, revision_number: int = None, parents=None, children=None): +class BaseRevision(State): + def __init__(self, revision_id: Optional[str] = None, revision_nb: Optional[int] = None): super().__init__() - self._revision_id = None - self._revision_number = None - if revision_id is None and revision_number is None: - raise Exception('A state must be initiliazed with either a revision id or revision number') - if revision_id is not None: - self.revision_id = revision_id - if revision_number is not None: - self.revision_number = revision_number - self.parents = [] if parents is None else parents - self.children = [] if children is None else children - self.result = [] - self.evaluation_target = False + if revision_id is None and revision_nb is None: + raise AttributeError('A state must be initiliazed with either a revision id or revision number') + + self._revision_id = revision_id + self._revision_nb = revision_nb + self._fetch_missing_data() + + if self._revision_id is not None and not self._is_valid_revision_id(self._revision_id): + raise AttributeError(f"Invalid revision id '{self._revision_id}' for state '{self}'") + + if self._revision_nb is not None and not self._is_valid_revision_number(self._revision_nb): + raise AttributeError(f"Invalid revision number '{self._revision_nb}' for state '{self}'") @property @abstractmethod - def browser_name(self): + def browser_name(self) -> str: pass @property - def name(self): - return f'{self.revision_number}' + def name(self) -> str: + return f'{self._revision_nb}' + + @property + def index(self) -> int: + return self._revision_nb + + @property + def revision_nb(self) -> int: + return self._revision_nb - def to_dict(self, make_complete: bool = True) -> dict: - ''' + def to_dict(self) -> dict: + """ Returns a dictionary representation of the state. - If complete is True, any missing information will be fetched. - For example, only the revision id might be known, but not the revision number. - ''' - if make_complete: - return { - 'type': 'revision', - 'browser_name': self.browser_name, - 'revision_id': self.revision_id, - 'revision_number': self.revision_number - } - else: - state_dict = { - 'type': 'revision', - 'browser_name': self.browser_name - } - if self._revision_id is not None: - state_dict['revision_id'] = self._revision_id - if self._revision_number is not None: - state_dict['revision_number'] = self._revision_number - return state_dict + """ + state_dict = {'type': 'revision', 'browser_name': self.browser_name} + if self._revision_id: + state_dict['revision_id'] = self._revision_id + if self._revision_nb: + state_dict['revision_number'] = self._revision_nb + return state_dict @staticmethod def from_dict(data: dict) -> State: - from bci.version_control.states.revisions.chromium import \ - ChromiumRevision - from bci.version_control.states.revisions.firefox import \ - FirefoxRevision + from bci.version_control.states.revisions.chromium import ChromiumRevision + from bci.version_control.states.revisions.firefox import FirefoxRevision + match data['browser_name']: case 'chromium': - return ChromiumRevision( - revision_id=data['revision_id'], revision_number=data['revision_number'] - ) + state = ChromiumRevision(revision_id=data.get('revision_id', None), revision_nb=data['revision_number']) case 'firefox': - return FirefoxRevision( - revision_id=data['revision_id'], revision_number=data['revision_number'] - ) + state = FirefoxRevision(revision_id=data.get('revision_id', None), revision_nb=data['revision_number']) case _: raise Exception(f'Unknown browser: {data["browser_name"]}') + return state def _has_revision_id(self) -> bool: return self._revision_id is not None def _has_revision_number(self) -> bool: - return self._revision_number is not None - - @abstractmethod - def _fetch_revision_id(self) -> str: - pass + return self._revision_nb is not None @abstractmethod - def _fetch_revision_number(self) -> int: + def _fetch_missing_data(self): pass - @property - def revision_id(self) -> str: - if self._revision_id is None: - self.revision_id = self._fetch_revision_id() - return self._revision_id - - @revision_id.setter - def revision_id(self, value: str): - assert value is not None - assert re.match(r'[a-z0-9]{40}', value), f'\'{value}\' is not a valid revision id' - self._revision_id = value + def _is_valid_revision_id(self, revision_id: str) -> bool: + """ + Checks if a revision id is valid. + A valid revision id is a 40 character long string containing only lowercase letters and numbers. + """ + return re.match(r'[a-z0-9]{40}', revision_id) is not None - @property - def revision_number(self) -> int: - if self._revision_number is None: - self.revision_number = self._fetch_revision_number() - return self._revision_number - - @revision_number.setter - def revision_number(self, value: int): - assert value is not None - assert re.match(r'[0-9]{1,7}', str(value)), f'\'{value}\' is not a valid revision number' - self._revision_number = value - - def add_parent(self, new_parent): - if not self.is_parent(new_parent): - self.parents.append(new_parent) - if not new_parent.is_child(self): - new_parent.add_child(self) - - def add_child(self, new_child): - if not self.is_child(new_child): - self.children.append(new_child) - if not new_child.is_parent(self): - new_child.add_parent(self) - - def is_parent(self, parent): - return parent in self.parents - - def is_child(self, child): - return child in self.children + def _is_valid_revision_number(self, revision_number: int) -> bool: + """ + Checks if a revision number is valid. + A valid revision number is a positive integer. + """ + return re.match(r'[0-9]{1,7}', str(revision_number)) is not None def __str__(self): - return f'RevisionState(id: {self._revision_id}, number: {self._revision_number})' + return f'RevisionState(number: {self._revision_nb}, id: {self._revision_id})' def __repr__(self): - return f'RevisionState(id: {self._revision_id}, number: {self._revision_number})' + return f'RevisionState(number: {self._revision_nb}, id: {self._revision_id})' diff --git a/bci/version_control/states/revisions/chromium.py b/bci/version_control/states/revisions/chromium.py index e586e73..0b61d34 100644 --- a/bci/version_control/states/revisions/chromium.py +++ b/bci/version_control/states/revisions/chromium.py @@ -1,39 +1,53 @@ +from typing import Optional + import requests + +from bci.database.mongo.mongodb import MongoDB from bci.version_control.revision_parser.chromium_parser import ChromiumRevisionParser from bci.version_control.states.revisions.base import BaseRevision -from bci.database.mongo.mongodb import MongoDB PARSER = ChromiumRevisionParser() class ChromiumRevision(BaseRevision): - - def __init__(self, revision_id: str = None, revision_number: int = None, parents=None, children=None): - super().__init__(revision_id, revision_number, parents=parents, children=children) + def __init__(self, revision_id: Optional[str] = None, revision_nb: Optional[int] = None): + super().__init__(revision_id, revision_nb) @property def browser_name(self): return 'chromium' - def has_online_binary(self): + def has_online_binary(self) -> bool: cached_binary_available_online = MongoDB.has_binary_available_online('chromium', self) if cached_binary_available_online is not None: return cached_binary_available_online - url = f'https://www.googleapis.com/storage/v1/b/chromium-browser-snapshots/o/Linux_x64%2F{self._revision_number}%2Fchrome-linux.zip' + url = f'https://www.googleapis.com/storage/v1/b/chromium-browser-snapshots/o/Linux_x64%2F{self._revision_nb}%2Fchrome-linux.zip' req = requests.get(url) has_binary_online = req.status_code == 200 MongoDB.store_binary_availability_online_cache('chromium', self, has_binary_online) return has_binary_online def get_online_binary_url(self): - return "https://www.googleapis.com/download/storage/v1/b/chromium-browser-snapshots/o/%s%%2F%s%%2Fchrome-%s.zip?alt=media" % ('Linux_x64', self._revision_number, 'linux') - - def _fetch_revision_id(self) -> str: - if state := MongoDB.get_complete_state_dict_from_binary_availability_cache(self): - return state['revision_id'] - return PARSER.get_rev_id(self._revision_number) - - def _fetch_revision_number(self) -> int: + return ( + 'https://www.googleapis.com/download/storage/v1/b/chromium-browser-snapshots/o/%s%%2F%s%%2Fchrome-%s.zip?alt=media' + % ('Linux_x64', self._revision_nb, 'linux') + ) + + def _fetch_missing_data(self) -> None: + """ + States are initialized with either a revision id or revision number. + This method attempts to fetch other data to complete this state object. + """ + # First check if the missing data is available in the cache + if self._revision_id and self._revision_nb: + return if state := MongoDB.get_complete_state_dict_from_binary_availability_cache(self): - return state['revision_number'] - return PARSER.get_rev_number(self._revision_id) + if self._revision_id is None: + self._revision_id = state.get('revision_id', None) + if self._revision_nb is None: + self._revision_nb = state.get('revision_number', None) + # If not, fetch the missing data from the parser + if self._revision_id is None: + self._revision_id = PARSER.get_revision_id(self._revision_nb) + if self._revision_nb is None: + self._revision_nb = PARSER.get_revision_nb(self._revision_id) diff --git a/bci/version_control/states/revisions/firefox.py b/bci/version_control/states/revisions/firefox.py index f9b948a..7f5d531 100644 --- a/bci/version_control/states/revisions/firefox.py +++ b/bci/version_control/states/revisions/firefox.py @@ -1,41 +1,43 @@ +from typing import Optional + from bci.util import request_json from bci.version_control.states.revisions.base import BaseRevision -BINARY_AVAILABILITY_URL = "https://distrinet.pages.gitlab.kuleuven.be/users/gertjan-franken/bughog-revision-metadata/firefox_binary_availability.json" -REVISION_NUMBER_MAPPING_URL = "https://distrinet.pages.gitlab.kuleuven.be/users/gertjan-franken/bughog-revision-metadata/firefox_revision_nb_to_id.json" +BINARY_AVAILABILITY_URL = 'https://distrinet.pages.gitlab.kuleuven.be/users/gertjan-franken/bughog-revision-metadata/firefox_binary_availability.json' +REVISION_NUMBER_MAPPING_URL = 'https://distrinet.pages.gitlab.kuleuven.be/users/gertjan-franken/bughog-revision-metadata/firefox_revision_nb_to_id.json' -BINARY_AVAILABILITY_MAPPING = request_json(BINARY_AVAILABILITY_URL)["data"] -REVISION_NUMBER_MAPPING = request_json(REVISION_NUMBER_MAPPING_URL)["data"] +BINARY_AVAILABILITY_MAPPING = request_json(BINARY_AVAILABILITY_URL)['data'] +REVISION_NUMBER_MAPPING = request_json(REVISION_NUMBER_MAPPING_URL)['data'] class FirefoxRevision(BaseRevision): - - def __init__(self, revision_id: str = None, revision_number: str = None, parents=None, children=None, version: int = None): - super().__init__(revision_id=revision_id, revision_number=revision_number, parents=parents, children=children) - self.version = version + def __init__( + self, revision_id: Optional[str] = None, revision_nb: Optional[int] = None, major_version: Optional[int] = None + ): + super().__init__(revision_id=revision_id, revision_nb=revision_nb) + self.major_version = major_version @property - def browser_name(self): + def browser_name(self) -> str: return 'firefox' def has_online_binary(self) -> bool: if self._revision_id: return self._revision_id in BINARY_AVAILABILITY_MAPPING - if self._revision_number: - return str(self._revision_number) in REVISION_NUMBER_MAPPING + if self._revision_nb: + return str(self._revision_nb) in REVISION_NUMBER_MAPPING + raise AttributeError('Cannot check binary availability without a revision id or revision number') def get_online_binary_url(self) -> str: - binary_base_url = BINARY_AVAILABILITY_MAPPING[self.revision_id]["files_url"] - app_version = BINARY_AVAILABILITY_MAPPING[self.revision_id]["app_version"] - binary_url = f"{binary_base_url}firefox-{app_version}.en-US.linux-x86_64.tar.bz2" + binary_base_url = BINARY_AVAILABILITY_MAPPING[self._revision_id]['files_url'] + app_version = BINARY_AVAILABILITY_MAPPING[self._revision_id]['app_version'] + binary_url = f'{binary_base_url}firefox-{app_version}.en-US.linux-x86_64.tar.bz2' return binary_url - def _fetch_revision_id(self) -> str: - return REVISION_NUMBER_MAPPING.get(str(self._revision_number), None) - - def _fetch_revision_number(self) -> int: - binary_data = BINARY_AVAILABILITY_MAPPING.get(self._revision_id, None) - if binary_data is not None: - return binary_data.get('revision_number') - else: - return None + def _fetch_missing_data(self): + if self._revision_id is None: + self._revision_id = REVISION_NUMBER_MAPPING.get(str(self._revision_nb), None) + if self._revision_nb is None: + binary_data = BINARY_AVAILABILITY_MAPPING.get(self._revision_id, None) + if binary_data is not None: + self._revision_nb = binary_data.get('revision_number') diff --git a/bci/version_control/states/state.py b/bci/version_control/states/state.py index 31dc10a..92a52b7 100644 --- a/bci/version_control/states/state.py +++ b/bci/version_control/states/state.py @@ -1,34 +1,84 @@ from __future__ import annotations -from abc import abstractmethod, abstractproperty +from abc import abstractmethod +from dataclasses import dataclass +from enum import Enum + + +class StateCondition(Enum): + """ + The condition of a state. + """ + + # This state has been evaluated and the result is available. + COMPLETED = 0 + # The evaluation of this state has failed. + FAILED = 1 + # The evaluation of this state is in progress. + IN_PROGRESS = 2 + # The evaluation of this state has not started yet. + PENDING = 3 + # This state is not available. + UNAVAILABLE = 4 + + +@dataclass(frozen=True) +class StateResult: + requests: list[dict[str, str]] + request_vars: list[dict[str, str]] + log_vars: list[dict[str, str]] + is_dirty: bool + @property + def reproduced(self): + entry_if_reproduced = {'var': 'reproduced', 'val': 'OK'} + reproduced_in_req_vars = [entry for entry in self.request_vars if entry == entry_if_reproduced] != [] + reproduced_in_log_vars = [entry for entry in self.log_vars if entry == entry_if_reproduced] != [] + return reproduced_in_req_vars or reproduced_in_log_vars -class EvaluationResult: - BuildUnavailable = "build unavailable" - Error = "error" - Positive = "positive" - Negative = "negative" - Undefined = "undefined" + @staticmethod + def from_dict(data: dict, is_dirty: bool = False) -> StateResult: + return StateResult(data['requests'], data['req_vars'], data['log_vars'], is_dirty) class State: + def __init__(self): + self.condition = StateCondition.PENDING + self.result: StateResult + self.outcome: bool | None = None + + @property + @abstractmethod + def name(self) -> str: + pass + + @property + @abstractmethod + def browser_name(self) -> str: + pass - @abstractproperty - def name(self): + @property + @abstractmethod + def index(self) -> int: + """ + The index of the element in the sequence. + """ pass - @abstractproperty - def browser_name(self): + @property + @abstractmethod + def revision_nb(self) -> int: pass @abstractmethod - def to_dict(self): + def to_dict(self) -> dict: pass @staticmethod def from_dict(data: dict) -> State: from bci.version_control.states.revisions.base import BaseRevision from bci.version_control.states.versions.base import BaseVersion + match data['type']: case 'revision': return BaseRevision.from_dict(data) @@ -45,49 +95,22 @@ def has_online_binary(self) -> bool: def get_online_binary_url(self) -> str: pass - def is_evaluation_target(self): - return self.evaluation_target - - def set_as_evaluation_target(self): - self.evaluation_target = True - - def set_evaluation_outcome(self, outcome: bool): - if outcome: - self.result = EvaluationResult.Positive + def has_available_binary(self) -> bool: + if self.condition == StateCondition.UNAVAILABLE: + return False else: - self.result = EvaluationResult.Negative + has_available_binary = self.has_online_binary() + if not has_available_binary: + self.condition = StateCondition.UNAVAILABLE + return has_available_binary - def set_evaluation_build_unavailable(self): - self.result = EvaluationResult.BuildUnavailable + def __repr__(self) -> str: + return f'State(index={self.index})' - def set_evaluation_error(self, error_message): - self.result = error_message + def __eq__(self, other: object) -> bool: + if not isinstance(other, State): + return False + return self.index == other.index - @property - def build_unavailable(self): - return self.result == EvaluationResult.BuildUnavailable - - @property - def result_undefined(self): - return len(self.result) == 0 - - # @classmethod - # def create_state_list(cls, evaluation_targets, revision_numbers) -> list: - # states = [] - # ancestor_state = cls(revision_number=revision_numbers[0]) - # descendant_state = cls(revision_number=revision_numbers[len(revision_numbers) - 1]) - - # states.append(ancestor_state) - # prev_state = ancestor_state - # for i in range(1, len(revision_numbers) - 1): - # revision_number = revision_numbers[i] - # curr_state = cls(revision_number=revision_number) - # curr_state.add_parent(prev_state) - # states.append(curr_state) - # prev_state = curr_state - # if evaluation_targets is None or revision_number in evaluation_targets: - # curr_state.set_as_evaluation_target() - - # descendant_state.add_parent(prev_state) - # states.append(descendant_state) - # return states + def __hash__(self) -> int: + return hash((self.index, self.browser_name)) diff --git a/bci/version_control/states/versions/base.py b/bci/version_control/states/versions/base.py index 45d3636..328cd46 100644 --- a/bci/version_control/states/versions/base.py +++ b/bci/version_control/states/versions/base.py @@ -1,60 +1,69 @@ -from abc import abstractmethod, abstractproperty +from abc import abstractmethod from bci.version_control.states.state import State class BaseVersion(State): - def __init__(self, major_version: int): super().__init__() self.major_version = major_version - self._rev_nb = self._get_rev_nb() - self._rev_id = self._get_rev_id() + self._revision_nb = self._get_rev_nb() + self._revision_id = self._get_rev_id() @abstractmethod - def _get_rev_nb(self): + def _get_rev_nb(self) -> int: pass @abstractmethod - def _get_rev_id(self): + def _get_rev_id(self) -> str: pass @property - def name(self): + def name(self) -> str: return f'v_{self.major_version}' - @abstractproperty - def browser_name(self): + @property + @abstractmethod + def browser_name(self) -> str: pass + @property + def index(self) -> int: + return self.major_version + + @property + def revision_nb(self) -> int: + return self._revision_nb + def to_dict(self, make_complete: bool = True) -> dict: return { 'type': 'version', 'browser_name': self.browser_name, 'major_version': self.major_version, - 'revision_id': self._rev_id, - 'revision_number': self._rev_nb + 'revision_id': self._revision_id, + 'revision_number': self._revision_nb, } @staticmethod def from_dict(data: dict) -> State: - from bci.version_control.states.versions.chromium import \ - ChromiumVersion + from bci.version_control.states.versions.chromium import ChromiumVersion from bci.version_control.states.versions.firefox import FirefoxVersion + match data['browser_name']: case 'chromium': - return ChromiumVersion( - major_version=data['major_version'] - ) + state = ChromiumVersion(major_version=data['major_version']) case 'firefox': - return FirefoxVersion( - major_version=data['major_version'] - ) + state = FirefoxVersion(major_version=data['major_version']) case _: raise Exception(f'Unknown browser: {data["browser_name"]}') + return state + + @abstractmethod + def convert_to_revision(self) -> State: + pass def __str__(self): - return f'VersionState(version: {self.major_version}, rev: {self._rev_nb})' + return f'VersionState(version: {self.major_version}, rev: {self._revision_nb})' def __repr__(self): - return f'VersionState(version: {self.major_version}, rev: {self._rev_nb})' + return f'VersionState(version: {self.major_version}, rev: {self._revision_nb})' diff --git a/bci/version_control/states/versions/chromium.py b/bci/version_control/states/versions/chromium.py index 6814b3d..682cfd7 100644 --- a/bci/version_control/states/versions/chromium.py +++ b/bci/version_control/states/versions/chromium.py @@ -1,18 +1,19 @@ import requests -from bci.version_control.repository.online.chromium import get_release_revision_number, get_release_revision_id -from bci.version_control.states.versions.base import BaseVersion + from bci.database.mongo.mongodb import MongoDB +from bci.version_control.repository.online.chromium import get_release_revision_id, get_release_revision_number +from bci.version_control.states.revisions.chromium import ChromiumRevision +from bci.version_control.states.versions.base import BaseVersion class ChromiumVersion(BaseVersion): - def __init__(self, major_version: int): super().__init__(major_version) - def _get_rev_nb(self): + def _get_rev_nb(self) -> int: return get_release_revision_number(self.major_version) - def _get_rev_id(self): + def _get_rev_id(self) -> str: return get_release_revision_id(self.major_version) @property @@ -23,11 +24,17 @@ def has_online_binary(self): cached_binary_available_online = MongoDB.has_binary_available_online('chromium', self) if cached_binary_available_online is not None: return cached_binary_available_online - url = f'https://www.googleapis.com/storage/v1/b/chromium-browser-snapshots/o/Linux_x64%2F{self._rev_nb}%2Fchrome-linux.zip' + url = f'https://www.googleapis.com/storage/v1/b/chromium-browser-snapshots/o/Linux_x64%2F{self._revision_nb}%2Fchrome-linux.zip' req = requests.get(url) has_binary_online = req.status_code == 200 MongoDB.store_binary_availability_online_cache('chromium', self, has_binary_online) return has_binary_online def get_online_binary_url(self): - return "https://www.googleapis.com/download/storage/v1/b/chromium-browser-snapshots/o/%s%%2F%s%%2Fchrome-%s.zip?alt=media" % ('Linux_x64', self._rev_nb, 'linux') + return ( + 'https://www.googleapis.com/download/storage/v1/b/chromium-browser-snapshots/o/%s%%2F%s%%2Fchrome-%s.zip?alt=media' + % ('Linux_x64', self._revision_nb, 'linux') + ) + + def convert_to_revision(self) -> ChromiumRevision: + return ChromiumRevision(revision_nb=self._revision_nb) diff --git a/bci/version_control/states/versions/firefox.py b/bci/version_control/states/versions/firefox.py index ba719a4..a33c14c 100644 --- a/bci/version_control/states/versions/firefox.py +++ b/bci/version_control/states/versions/firefox.py @@ -1,4 +1,5 @@ from bci.version_control.repository.online.firefox import get_release_revision_number, get_release_revision_id +from bci.version_control.states.revisions.firefox import FirefoxRevision from bci.version_control.states.versions.base import BaseVersion @@ -7,18 +8,21 @@ class FirefoxVersion(BaseVersion): def __init__(self, major_version: int): super().__init__(major_version) - def _get_rev_nb(self): + def _get_rev_nb(self) -> int: return get_release_revision_number(self.major_version) def _get_rev_id(self): return get_release_revision_id(self.major_version) @property - def browser_name(self): + def browser_name(self) -> str: return 'firefox' - def has_online_binary(self): - return f'https://www.googleapis.com/storage/v1/b/chromium-browser-snapshots/o/Linux_x64%2F{self._rev_nb}%2Fchrome-linux.zip' + def has_online_binary(self) -> bool: + return True - def get_online_binary_url(self): + def get_online_binary_url(self) -> str: return f'https://ftp.mozilla.org/pub/firefox/releases/{self.major_version}.0/linux-x86_64/en-US/firefox-{self.major_version}.0.tar.bz2' + + def convert_to_revision(self) -> FirefoxRevision: + return FirefoxRevision(revision_nb=self._revision_nb) diff --git a/bci/web/clients.py b/bci/web/clients.py index 54e21fd..eee5d24 100644 --- a/bci/web/clients.py +++ b/bci/web/clients.py @@ -6,7 +6,7 @@ class Clients: __semaphore = threading.Semaphore() - __clients: dict[Server] = {} + __clients: dict[Server, dict | None] = {} @staticmethod def add_client(ws_client: Server): @@ -16,9 +16,7 @@ def add_client(ws_client: Server): @staticmethod def __remove_disconnected_clients(): with Clients.__semaphore: - Clients.__clients = { - k: v for k, v in Clients.__clients.items() if k.connected - } + Clients.__clients = {k: v for k, v in Clients.__clients.items() if k.connected} @staticmethod def associate_params(ws_client: Server, params: dict): @@ -35,7 +33,7 @@ def associate_project(ws_client: Server, project: str): with Clients.__semaphore: if not (params := Clients.__clients.get(ws_client, None)): params = {} - params["project"] = project + params['project'] = project Clients.__clients[ws_client] = params Clients.push_experiments(ws_client) @@ -48,10 +46,10 @@ def push_results(ws_client: Server): ws_client.send( json.dumps( { - "update": { - "plot_data": { - "revision_data": revision_data, - "version_data": version_data, + 'update': { + 'plot_data': { + 'revision_data': revision_data, + 'version_data': version_data, } } } @@ -65,21 +63,21 @@ def push_results_to_all(): Clients.push_results(ws_client) @staticmethod - def push_info(ws_client: Server, *requested_vars: list[str]): + def push_info(ws_client: Server, *requested_vars: str): from bci.main import Main as bci_api update = {} - all = not requested_vars or "all" in requested_vars - if "db_info" in requested_vars or all: - update["db_info"] = bci_api.get_database_info() - if "logs" in requested_vars or all: - update["logs"] = bci_api.get_logs() - if "state" in requested_vars or all: - update["state"] = bci_api.get_state() - ws_client.send(json.dumps({"update": update})) + all = not requested_vars or 'all' in requested_vars + if 'db_info' in requested_vars or all: + update['db_info'] = bci_api.get_database_info() + if 'logs' in requested_vars or all: + update['logs'] = bci_api.get_logs() + if 'state' in requested_vars or all: + update['state'] = bci_api.get_state() + ws_client.send(json.dumps({'update': update})) @staticmethod - def push_info_to_all(*requested_vars: list[str]): + def push_info_to_all(*requested_vars: str): Clients.__remove_disconnected_clients() for ws_client in Clients.__clients.keys(): Clients.push_info(ws_client, *requested_vars) @@ -88,10 +86,10 @@ def push_info_to_all(*requested_vars: list[str]): def push_experiments(ws_client: Server): from bci.main import Main as bci_api - project = Clients.__clients[ws_client].get("project", None) + project = Clients.__clients[ws_client].get('project', None) if project: experiments = bci_api.get_mech_groups_of_evaluation_framework('custom', project) - ws_client.send(json.dumps({"update": {"experiments": experiments}})) + ws_client.send(json.dumps({'update': {'experiments': experiments}})) @staticmethod def push_experiments_to_all(): diff --git a/bci/web/vue/src/App.vue b/bci/web/vue/src/App.vue index b672aa3..09be7e3 100644 --- a/bci/web/vue/src/App.vue +++ b/bci/web/vue/src/App.vue @@ -615,16 +615,16 @@ export default {
- - + value="bgb_sequence" :disabled="this.eval_params.only_release_revisions"> + +
- - - + + +
diff --git a/bci/web/vue/src/components/tooltip.vue b/bci/web/vue/src/components/tooltip.vue index 2ee94b1..d2dabfd 100644 --- a/bci/web/vue/src/components/tooltip.vue +++ b/bci/web/vue/src/components/tooltip.vue @@ -3,11 +3,11 @@ data() { return { tooltips: { - "bin_seq": { + "bgb_sequence": { "tooltip": "Binaries are selected uniformly over the specified evaluation range. Experiment outcomes do not influence the next binary to be evaluated." }, - "bin_search": { - "tooltip": "Perform a search to identify either an introducing or fixing revision. This should only be performed within a range where one shift in reproducibility has been observed." + "bgb_search": { + "tooltip": "Perform a search to identify introducing and fixing revision." }, "comp_search": { "tooltip": "Combines the two strategies above. First, binaries are selected uniformly over the evaluation range, until the sequence limit is reached. Then, for each shift in reproducibility that can be observed, a search is conducted to identify the introducing or fixing binary." diff --git a/test/http_collector/test_collector.py b/test/http_collector/test_collector.py index 57bb5b6..2cc477f 100644 --- a/test/http_collector/test_collector.py +++ b/test/http_collector/test_collector.py @@ -3,7 +3,7 @@ import requests -from bci.evaluations.collector import Collector, Type +from bci.evaluations.collectors.collector import Collector, Type class TestCollector(unittest.TestCase): diff --git a/test/sequence/test_biggest_gap_bisection_search.py b/test/sequence/test_biggest_gap_bisection_search.py new file mode 100644 index 0000000..7c110d3 --- /dev/null +++ b/test/sequence/test_biggest_gap_bisection_search.py @@ -0,0 +1,81 @@ +import unittest + +from bci.search_strategy.bgb_search import BiggestGapBisectionSearch +from bci.search_strategy.sequence_strategy import SequenceFinished +from test.sequence.test_sequence_strategy import TestSequenceStrategy as helper + + +class TestBiggestGapBisectionSearch(unittest.TestCase): + + def test_sbg_search_always_available_search(self): + state_factory = helper.create_state_factory( + helper.always_has_binary, + outcome_func=lambda x: True if x < 50 else False) + sequence = BiggestGapBisectionSearch(state_factory) + index_sequence = [sequence.next().index for _ in range(8)] + assert index_sequence == [0, 99, 49, 74, 61, 55, 52, 50] + self.assertRaises(SequenceFinished, sequence.next) + + def test_sbg_search_even_available_search(self): + state_factory = helper.create_state_factory( + helper.only_has_binaries_for_even, + outcome_func=lambda x: True if x < 35 else False) + sequence = BiggestGapBisectionSearch(state_factory) + + assert sequence.next().index == 0 + assert [state.index for state in sequence._completed_states] == [0] + assert sequence._unavailability_gap_pairs == set() + + while True: + try: + sequence.next() + except SequenceFinished: + break + + assert ([state.index for state in sequence._completed_states] + == [0, 24, 30, 32, 34, 36, 48, 98]) + + self.assertRaises(SequenceFinished, sequence.next) + assert {(first.index, last.index) for (first, last) in sequence._unavailability_gap_pairs} == {(34, 36)} + + + def test_sbg_search_few_available_search(self): + state_factory = helper.create_state_factory( + helper.has_very_few_binaries, + outcome_func=lambda x: True if x < 35 else False) + sequence = BiggestGapBisectionSearch(state_factory) + + assert sequence.next().index == 0 + assert [state.index for state in sequence._completed_states] == [0] + assert sequence._unavailability_gap_pairs == set() + + assert sequence.next().index == 99 + assert [state.index for state in sequence._completed_states] == [0, 99] + + assert sequence.next().index == 44 + assert [state.index for state in sequence._completed_states] == [0, 44, 99] + + assert sequence.next().index == 22 + assert [state.index for state in sequence._completed_states] == [0, 22, 44, 99] + + assert sequence.next().index == 33 + assert [state.index for state in sequence._completed_states] == [0, 22, 33, 44, 99] + + self.assertRaises(SequenceFinished, sequence.next) + assert {(first.index, last.index) for (first, last) in sequence._unavailability_gap_pairs} == {(33, 44)} + + def test_sbg_search_few_available_search_complex(self): + state_factory = helper.create_state_factory( + helper.only_has_binaries_for_even, + evaluated_indexes=[0, 12, 22, 34, 44, 56, 66, 78, 88, 98], + outcome_func=lambda x: True if x < 35 or 66 < x else False) + sequence = BiggestGapBisectionSearch(state_factory) + + while True: + try: + sequence.next() + except SequenceFinished: + break + + assert ([state.index for state in sequence._completed_states] + == [0, 12, 22, 34, 36, 38, 44, 56, 66, 68, 72, 78, 88, 98]) diff --git a/test/sequence/test_biggest_gap_bisection_sequence.py b/test/sequence/test_biggest_gap_bisection_sequence.py new file mode 100644 index 0000000..c476ecd --- /dev/null +++ b/test/sequence/test_biggest_gap_bisection_sequence.py @@ -0,0 +1,42 @@ +import unittest + +from bci.search_strategy.bgb_sequence import BiggestGapBisectionSequence +from bci.search_strategy.sequence_strategy import SequenceFinished +from test.sequence.test_sequence_strategy import TestSequenceStrategy as helper + + +class TestBiggestGapBisectionSequence(unittest.TestCase): + + def test_sbg_sequence_always_available(self): + state_factory = helper.create_state_factory(helper.always_has_binary) + sequence = BiggestGapBisectionSequence(state_factory, 12) + index_sequence = [sequence.next().index for _ in range(12)] + assert index_sequence == [0, 99, 49, 74, 24, 36, 61, 86, 12, 42, 67, 92] + self.assertRaises(SequenceFinished, sequence.next) + + def test_sbg_sequence_even_available(self): + state_factory = helper.create_state_factory(helper.only_has_binaries_for_even) + sequence = BiggestGapBisectionSequence(state_factory, 12) + index_sequence = [sequence.next().index for _ in range(12)] + assert index_sequence == [0, 98, 48, 72, 24, 84, 12, 36, 60, 90, 6, 18] + + def test_sbg_sequence_almost_none_available(self): + state_factory = helper.create_state_factory(helper.has_very_few_binaries) + sequence = BiggestGapBisectionSequence(state_factory, 10) + index_sequence = [sequence.next().index for _ in range(10)] + assert index_sequence == [0, 99, 44, 66, 22, 77, 11, 33, 55, 88] + self.assertRaises(SequenceFinished, sequence.next) + + def test_sbg_sequence_sparse_first_half_avaiable(self): + state_factory = helper.create_state_factory(helper.has_very_few_binaries_in_first_half) + sequence = BiggestGapBisectionSequence(state_factory, 17) + index_sequence = [sequence.next().index for _ in range(17)] + assert index_sequence == [0, 99, 50, 22, 74, 44, 86, 62, 92, 56, 68, 80, 95, 53, 59, 65, 71] + + def test_sbg_sequence_always_available_with_evaluated_states(self): + state_factory = helper.create_state_factory(helper.always_has_binary, evaluated_indexes=[49, 61]) + sequence = BiggestGapBisectionSequence(state_factory, 17) + index_sequence = [sequence.next().index for _ in range(15)] + print(index_sequence) + assert index_sequence == [0, 99, 24, 80, 36, 12, 70, 89, 42, 6, 18, 30, 55, 75, 94] + self.assertRaises(SequenceFinished, sequence.next) diff --git a/test/sequence/test_composite_search.py b/test/sequence/test_composite_search.py index 407d819..5a7d247 100644 --- a/test/sequence/test_composite_search.py +++ b/test/sequence/test_composite_search.py @@ -1,72 +1,71 @@ import unittest -from unittest.mock import patch + from bci.search_strategy.composite_search import CompositeSearch -from bci.search_strategy.n_ary_sequence import NArySequence -from bci.search_strategy.n_ary_search import NArySearch from bci.search_strategy.sequence_strategy import SequenceFinished +from test.sequence.test_sequence_strategy import TestSequenceStrategy as helper class TestCompositeSearch(unittest.TestCase): - @staticmethod - def always_true(_): - return True + def test_binary_sequence_always_available_composite(self): + state_factory = helper.create_state_factory( + helper.always_has_binary, + outcome_func=lambda x: True if x < 50 else False) + sequence = CompositeSearch(state_factory, 10) + + # Sequence + index_sequence = [sequence.next().index for _ in range(10)] + assert index_sequence == [0, 99, 49, 74, 24, 36, 61, 86, 12, 42] + + # Simulate that the previous part of the evaluation has been completed + state_factory = helper.create_state_factory( + helper.always_has_binary, + outcome_func=lambda x: True if x < 50 else False, + evaluated_indexes=[0, 99, 49, 74, 24, 36, 61, 86, 12, 42] + ) + sequence.search_strategy._state_factory = state_factory + + # Sequence + index_sequence = [sequence.next().index for _ in range(3)] + assert index_sequence == [55, 52, 50] + + self.assertRaises(SequenceFinished, sequence.next) - @staticmethod - def only_even(x): - return x % 2 == 0 + def test_binary_sequence_always_available_composite_two_shifts(self): + state_factory = helper.create_state_factory( + helper.always_has_binary, + outcome_func=lambda x: True if x < 33 or 81 < x else False) + sequence = CompositeSearch(state_factory, 10) - def test_find_all_shift_index_pairs(self): - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - def outcome(x) -> bool: - return x < 22 or x > 60 - values = list(range(100)) - seq = CompositeSearch(values, 2, 10, NArySequence, NArySearch) - seq.is_available = self.always_true - expected_elem_sequence = [0, 99, 50, 26, 75, 14, 39, 63, 88, 8] - elem_sequence = [] - for _ in range(10): - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - assert expected_elem_sequence == elem_sequence - shift_index_pairs = seq.find_all_shift_index_pairs() - assert shift_index_pairs == [(14, 26), (50, 63)] + # Sequence + index_sequence = [sequence.next().index for _ in range(10)] + assert index_sequence == [0, 99, 49, 74, 24, 36, 61, 86, 12, 42] - expected_elem_search = [21, 24, 23, 22, 57, 61, 60] - elem_search = [] - for _ in range(len(expected_elem_search)): - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_search.append(elem) - assert seq.sequence_strategy_finished - assert expected_elem_search == elem_search - self.assertRaises(SequenceFinished, seq.next) + # Simulate that the previous part of the evaluation has been completed + state_factory = helper.create_state_factory( + helper.always_has_binary, + outcome_func=lambda x: True if x < 33 or 81 < x else False, + evaluated_indexes=[0, 99, 49, 74, 24, 36, 61, 86, 12, 42] + ) + sequence.search_strategy._state_factory = state_factory - def test_composite_search(self): - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - def outcome(x) -> bool: - return x < 22 or x > 60 + while True: + try: + print(sequence.next()) + except SequenceFinished: + break - values = list(range(100)) - seq = CompositeSearch(values, 2, 10, NArySequence, NArySearch) - seq.is_available = self.always_true - expected_sequence_part = [0, 99, 50, 26, 75, 14, 39, 63, 88, 8] - expected_search_part = [21, 24, 23, 22, 57, 61, 60] + evaluated_indexes = [state.index for state in sequence.search_strategy._completed_states] - actual_sequence_part = [] - for _ in range(10): - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - actual_sequence_part.append(elem) - assert expected_sequence_part == actual_sequence_part + assert sequence.sequence_strategy_finished + assert 32 in evaluated_indexes + assert 33 in evaluated_indexes + assert 81 in evaluated_indexes + assert 82 in evaluated_indexes - actual_search_part = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - actual_search_part.append(elem) - except SequenceFinished: - break - assert expected_search_part == actual_search_part + assert 1 not in evaluated_indexes + assert 13 not in evaluated_indexes + assert 37 not in evaluated_indexes + assert 50 not in evaluated_indexes + assert 62 not in evaluated_indexes + assert 87 not in evaluated_indexes diff --git a/test/sequence/test_search_strategy.py b/test/sequence/test_search_strategy.py deleted file mode 100644 index abc586a..0000000 --- a/test/sequence/test_search_strategy.py +++ /dev/null @@ -1,170 +0,0 @@ -import unittest -from unittest.mock import patch -from bci.search_strategy.n_ary_search import NArySearch -from bci.search_strategy.sequence_strategy import SequenceFinished - - -class TestSearchStrategy(unittest.TestCase): - @staticmethod - def always_true(_): - return True - - @staticmethod - def only_even(x): - return x.value % 2 == 0 - - @staticmethod - def one_in_15(x): - return x.value % 15 == 0 - - def test_binary_search(self): - def outcome(x) -> bool: - return x < 22 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - values = list(range(100)) - seq = NArySearch(values, 2) - expected_elem_sequence = [0, 99, 50, 26, 14, 21, 24, 23, 22] - elem_sequence = [] - for _ in expected_elem_sequence: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - assert expected_elem_sequence == elem_sequence - self.assertRaises(SequenceFinished, seq.next) - - def test_binary_search_only_even(self): - def outcome(x) -> bool: - return x < 22 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.only_even): - values = list(range(100)) - seq = NArySearch(values, 2) - expected_elem_sequence = [0, 98, 50, 26, 14, 20, 24, 22] - elem_sequence = [] - for _ in expected_elem_sequence: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - assert expected_elem_sequence == elem_sequence - self.assertRaises(SequenceFinished, seq.next) - - def test_3ary_search_only_even(self): - def outcome(x) -> bool: - return x < 15 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.only_even): - values = list(range(100)) - seq = NArySearch(values, 3) - expected_elem_sequence = [0, 98, 34, 12, 24, 16, 14] - elem_sequence = [] - for _ in expected_elem_sequence: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - assert expected_elem_sequence == elem_sequence - self.assertRaises(SequenceFinished, seq.next) - - def test_observer_edge_case1(self): - def is_available(x): - return x.value in [766907, 766912, 766922] - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', is_available): - values = list(range(766907, 766923)) - seq = NArySearch(values, 16) - assert seq.next() == 766907 - seq.update_outcome(766907, False) - assert seq.next() == 766922 - seq.update_outcome(766922, True) - assert seq.next() == 766912 - self.assertRaises(SequenceFinished, seq.next) - - def test_observer_edge_case2(self): - def outcome(x) -> bool: - return x < 454750 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.one_in_15): - values = list(range(454462, 455227)) - seq = NArySearch(values, 8) - elem_sequence = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - except SequenceFinished: - break - assert 454740 in elem_sequence[-2:] - assert 454725 in elem_sequence[-2:] - - def test_correct_ending_2(self): - def outcome(x) -> bool: - return x < 561011 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.one_in_15): - values = list(range(560417, 562154)) - seq = NArySearch(values, 2) - elem_sequence = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - except SequenceFinished: - break - assert 561015 in elem_sequence[-2:] - assert 561000 in elem_sequence[-2:] - - def test_correct_ending_4(self): - def outcome(x) -> bool: - return x < 561011 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.one_in_15): - values = list(range(560417, 562154)) - seq = NArySearch(values, 4) - elem_sequence = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - except SequenceFinished: - break - assert 561015 in elem_sequence[-2:] - assert 561000 in elem_sequence[-2:] - - def test_correct_ending_8(self): - def outcome(x) -> bool: - return x < 561011 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.one_in_15): - values = list(range(560417, 562154)) - seq = NArySearch(values, 8) - elem_sequence = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - except SequenceFinished: - break - assert 561015 in elem_sequence[-2:] - assert 561000 in elem_sequence[-2:] - - def test_correct_ending_16(self): - def outcome(x) -> bool: - return x < 561011 - - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.one_in_15): - values = list(range(560417, 562154)) - seq = NArySearch(values, 16) - elem_sequence = [] - while True: - try: - elem = seq.next() - seq.update_outcome(elem, outcome(elem)) - elem_sequence.append(elem) - except SequenceFinished: - break - assert 561015 in elem_sequence[-2:] - assert 561000 in elem_sequence[-2:] diff --git a/test/sequence/test_sequence_strategy.py b/test/sequence/test_sequence_strategy.py index 4418114..92a7e1e 100644 --- a/test/sequence/test_sequence_strategy.py +++ b/test/sequence/test_sequence_strategy.py @@ -1,139 +1,102 @@ import unittest -from unittest.mock import patch -from bci.search_strategy.n_ary_sequence import NArySequence, SequenceFinished +from typing import Callable +from unittest.mock import MagicMock + +from bci.evaluations.logic import EvaluationConfiguration, EvaluationRange +from bci.evaluations.outcome_checker import OutcomeChecker +from bci.search_strategy.sequence_strategy import SequenceStrategy +from bci.version_control.factory import StateFactory +from bci.version_control.states.state import State class TestSequenceStrategy(unittest.TestCase): + ''' + Helper functions to create states and state factories for testing. + ''' + + @staticmethod + def get_states(indexes: list[int], is_available, outcome_func) -> list[State]: + return [TestSequenceStrategy.create_state(index, is_available, outcome_func) for index in indexes] + @staticmethod - def always_true(_): + def create_state_factory( + is_available: Callable, + evaluated_indexes: list[int] = None, + outcome_func: Callable = None) -> StateFactory: + eval_params = MagicMock(spec=EvaluationConfiguration) + eval_params.evaluation_range = MagicMock(spec=EvaluationRange) + eval_params.evaluation_range.major_version_range = [0, 99] + + factory = MagicMock(spec=StateFactory) + factory.__eval_params = eval_params + factory.__outcome_checker = TestSequenceStrategy.create_outcome_checker(outcome_func) + factory.create_state = lambda index: TestSequenceStrategy.create_state(index, is_available, outcome_func) + first_state = TestSequenceStrategy.create_state(0, is_available, outcome_func) + last_state = TestSequenceStrategy.create_state(99, is_available, outcome_func) + factory.boundary_states = (first_state, last_state) + + if evaluated_indexes: + factory.create_evaluated_states = lambda: TestSequenceStrategy.get_states(evaluated_indexes, lambda _: True, outcome_func) + else: + factory.create_evaluated_states = lambda: [] + return factory + + @staticmethod + def create_state(index, is_available: Callable, outcome_func: Callable) -> State: + state = MagicMock(spec=State) + state.index = index + state.has_available_binary = lambda: is_available(index) + state.outcome = outcome_func(index) if outcome_func else None + state.__eq__ = State.__eq__ + state.__repr__ = State.__repr__ + return state + + @staticmethod + def create_outcome_checker(outcome_func: Callable) -> OutcomeChecker: + if outcome_func: + outcome_checker = MagicMock() + outcome_checker.get_outcome = outcome_func + return outcome_checker + else: + return None + + @staticmethod + def always_has_binary(index) -> bool: return True @staticmethod - def only_even(x): - return x.value % 2 == 0 - - def test_binary_sequence(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 2) - assert seq.next() == 0 - assert seq.next() == 99 - assert seq.next() == 50 - - def test_binary_sequence_ending(self): - values = list(range(10)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 2) - assert seq.next() == 0 - assert seq.next() == 9 - assert seq.next() == 5 - assert seq.next() == 3 - assert seq.next() == 8 - assert seq.next() == 2 - assert seq.next() == 4 - assert seq.next() == 7 - assert seq.next() == 1 - assert seq.next() == 6 - self.assertRaises(SequenceFinished, seq.next) - - def test_binary_sequence_ending_only_even_available(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.only_even): - seq = NArySequence(values, 2) - outputted_values = set() - for _ in range(50): - n = seq.next() - assert n % 2 == 0 - assert n not in outputted_values - outputted_values.add(n) - self.assertRaises(SequenceFinished, seq.next) - - def test_3ary_sequence(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 3) - assert seq.next() == 0 - assert seq.next() == 99 - assert seq.next() == 34 - assert seq.next() == 67 - assert seq.next() == 12 - assert seq.next() == 23 - assert seq.next() == 46 - assert seq.next() == 57 - assert seq.next() == 79 - assert seq.next() == 90 - - def test_3nary_sequence_ending(self): - values = list(range(10)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 3) - assert seq.next() == 0 - assert seq.next() == 9 - assert seq.next() == 4 - assert seq.next() == 7 - assert seq.next() == 2 - assert seq.next() == 3 - assert seq.next() == 5 - assert seq.next() == 6 - assert seq.next() == 8 - assert seq.next() == 1 - self.assertRaises(SequenceFinished, seq.next) - - def test_3nary_sequence_ending_only_even_available(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.only_even): - seq = NArySequence(values, 3) - outputted_values = set() - for _ in range(50): - n = seq.next() - assert n % 2 == 0 - assert n not in outputted_values - outputted_values.add(n) - self.assertRaises(SequenceFinished, seq.next) - - def test_4ary_sequence(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 4) - assert seq.next() == 0 - assert seq.next() == 99 - assert seq.next() == 26 - assert seq.next() == 51 - assert seq.next() == 76 - - def test_4nary_sequence_ending(self): - values = list(range(10)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 4) - assert seq.next() == 0 - assert seq.next() == 9 - assert seq.next() == 3 - assert seq.next() == 5 - assert seq.next() == 7 - assert seq.next() == 1 - assert seq.next() == 2 - assert seq.next() == 4 - assert seq.next() == 6 - assert seq.next() == 8 - self.assertRaises(SequenceFinished, seq.next) - - def test_4nary_sequence_ending_only_even_available(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.only_even): - seq = NArySequence(values, 4) - outputted_values = set() - for _ in range(50): - n = seq.next() - assert n % 2 == 0 - assert n not in outputted_values - outputted_values.add(n) - self.assertRaises(SequenceFinished, seq.next) - - def test_limit(self): - values = list(range(100)) - with patch('bci.search_strategy.sequence_elem.SequenceElem.is_available', self.always_true): - seq = NArySequence(values, 2, limit=37) - for _ in range(37): - seq.next() - self.assertRaises(SequenceFinished, seq.next) + def only_has_binaries_for_even(index) -> bool: + return index % 2 == 0 + + @staticmethod + def has_very_few_binaries(index) -> bool: + return index % 11 == 0 + + @staticmethod + def has_very_few_binaries_in_first_half(index) -> bool: + if index < 50: + return index % 22 == 0 + return True + + ''' + Actual tests + ''' + + def test_find_closest_state_with_available_binary_1(self): + state_factory = TestSequenceStrategy.create_state_factory(TestSequenceStrategy.always_has_binary) + sequence_strategy = SequenceStrategy(state_factory, 0) + state = sequence_strategy._find_closest_state_with_available_binary(state_factory.create_state(5), (state_factory.create_state(0), state_factory.create_state(10))) + assert state.index == 5 + + def test_find_closest_state_with_available_binary_2(self): + state_factory = TestSequenceStrategy.create_state_factory(TestSequenceStrategy.only_has_binaries_for_even) + sequence_strategy = SequenceStrategy(state_factory, 0) + state = sequence_strategy._find_closest_state_with_available_binary(state_factory.create_state(5), (state_factory.create_state(0), state_factory.create_state(10))) + assert state.index == 4 + + def test_find_closest_state_with_available_binary_3(self): + state_factory = TestSequenceStrategy.create_state_factory(TestSequenceStrategy.only_has_binaries_for_even) + sequence_strategy = SequenceStrategy(state_factory, 0) + state = sequence_strategy._find_closest_state_with_available_binary(state_factory.create_state(1), (state_factory.create_state(0), state_factory.create_state(2))) + assert state is None