From bae92eedc767b56cd9fbaa3fbe617890c77dbf40 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 28 Feb 2024 21:06:41 +0100 Subject: [PATCH] Adjusted docstrings, check for matching keys in test_io_data_dict and io_data_dict --- src/sensai/evaluation/eval_util.py | 42 +++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/sensai/evaluation/eval_util.py b/src/sensai/evaluation/eval_util.py index 666fb9a6..9fab0d69 100644 --- a/src/sensai/evaluation/eval_util.py +++ b/src/sensai/evaluation/eval_util.py @@ -577,20 +577,33 @@ def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "da meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None, evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None, cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None, - test_io_data_dict: Optional[Dict[str, InputOutputData]] = None): + test_io_data_dict: Optional[Dict[str, InputOutputData | None]] = None): """ - :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models + :param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models. + For evaluation or cross-validation, these datasets will usually be split according to the rules + specified by `evaluator_params or `cross_validator_params`. An exception is the case where + explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data + sets, the io_data will not be split for evaluation, but the test_io_data will be used instead. :param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames :param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps from a column name to a value and which is to be used to extend the result data frames containing per-dataset results :param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False) :param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True) - :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation. Datasets under the same - keys as in io_data_dict will be used for evaluation of the models that were trained on the respective io_data_dict. - The keys don't need to be the same as in io_data_dict: unused keys are ignored, and for missing keys the test_io_data is None. + :param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None. + Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict. + If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g. + if you want to use test data sets for some entries, and splitting of the io_data for others). + If not None, cross-validation cannot be used when calling ``compare_models``. """ + if test_io_data_dict is not None: + missing_keys = set(io_data_dict).difference(test_io_data_dict) + if len(missing_keys) > 0: + raise ValueError( + "If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys." + f"However, found missing_keys: {missing_keys}") self.io_data_dict = io_data_dict - self.test_io_data_dict = test_io_data_dict or {} + self.test_io_data_dict = test_io_data_dict + self.key_name = key_name self.evaluator_params = evaluator_params self.cross_validator_params = cross_validator_params @@ -617,25 +630,34 @@ def compare_models(self, """ :param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result in a model with a fixed model name (otherwise results cannot be correctly aggregated) - :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation + :param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation. + This can only be used if the instance's ``test_io_data_dict`` is None. :param result_writer: a writer with which to store results; if None, results are not stored :param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each dataset in a subdirectory named according to the name of the dataset + :param write_csvs: whether to write metrics table to CSV files :param column_name_for_model_ranking: column name to use for ranking models :param rank_max: if true, use max for ranking, else min :param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from all EvalStats objects. Note that for classification, this is only possible if all individual experiments use the same set of class labels. :param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the - datasets (applies only if resultWriter is not None) + datasets (applies only if result_writer is not None) :param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only if resultWriter is not None). Note that for classification, this is only possible if all individual experiments use the same set of class labels. + :param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if + create_metric_distribution_plots is True and result_writer is not None. + :param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that + distribution_plots_cdf is True. :param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the comparison. :return: an object containing the full comparison results """ + if self.test_io_data_dict and use_cross_validation: + raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified") + all_results_df = pd.DataFrame() eval_stats_by_model_name = defaultdict(list) results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list) @@ -664,7 +686,7 @@ def compare_models(self, else: raise ValueError("The models have to be either all regression models or all classification, not a mixture") - test_io_data = self.test_io_data_dict.get(key) + test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params, cross_validator_params=self.cross_validator_params, test_io_data=test_io_data) @@ -924,7 +946,7 @@ def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_c :param result_writer: the result writer :param cdf: whether to additionally plot, for each distribution, the cumulative distribution function - :param cdf_complementary: whether to plot the complementary cdf, provided that ``cdf`` is True + :param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True """ for modelName in self.get_model_names(): eval_stats_collection = self.get_eval_stats_collection(modelName)