Skip to content

Commit

Permalink
Adjusted docstrings, check for matching keys in test_io_data_dict and…
Browse files Browse the repository at this point in the history
… io_data_dict
  • Loading branch information
Michael Panchenko committed Feb 28, 2024
1 parent 41e2082 commit bae92ee
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions src/sensai/evaluation/eval_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,20 +577,33 @@ def __init__(self, io_data_dict: Dict[str, InputOutputData], key_name: str = "da
meta_data_dict: Optional[Dict[str, Dict[str, Any]]] = None,
evaluator_params: Optional[Union[RegressionEvaluatorParams, ClassificationEvaluatorParams, Dict[str, Any]]] = None,
cross_validator_params: Optional[Union[VectorModelCrossValidatorParams, Dict[str, Any]]] = None,
test_io_data_dict: Optional[Dict[str, InputOutputData]] = None):
test_io_data_dict: Optional[Dict[str, InputOutputData | None]] = None):
"""
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models
:param io_data_dict: a dictionary mapping from names to the data sets with which to evaluate models.
For evaluation or cross-validation, these datasets will usually be split according to the rules
specified by `evaluator_params or `cross_validator_params`. An exception is the case where
explicit test data sets are specified by passing `test_io_data_dict`. Then, for these data
sets, the io_data will not be split for evaluation, but the test_io_data will be used instead.
:param key_name: a name for the key value used in inputOutputDataDict, which will be used as a column name in result data frames
:param meta_data_dict: a dictionary which maps from a name (same keys as in inputOutputDataDict) to a dictionary, which maps
from a column name to a value and which is to be used to extend the result data frames containing per-dataset results
:param evaluator_params: parameters to use for the instantiation of evaluators (relevant if useCrossValidation==False)
:param cross_validator_params: parameters to use for the instantiation of cross-validators (relevant if useCrossValidation==True)
:param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation. Datasets under the same
keys as in io_data_dict will be used for evaluation of the models that were trained on the respective io_data_dict.
The keys don't need to be the same as in io_data_dict: unused keys are ignored, and for missing keys the test_io_data is None.
:param test_io_data_dict: a dictionary mapping from names to the test data sets to use for evaluation or to None.
Entries with non-None values will be used for evaluation of the models that were trained on the respective io_data_dict.
If passed, the keys need to be a superset of io_data_dict's keys (note that the values may be None, e.g.
if you want to use test data sets for some entries, and splitting of the io_data for others).
If not None, cross-validation cannot be used when calling ``compare_models``.
"""
if test_io_data_dict is not None:
missing_keys = set(io_data_dict).difference(test_io_data_dict)
if len(missing_keys) > 0:
raise ValueError(
"If test_io_data_dict is passed, its keys must be a superset of the io_data_dict's keys."
f"However, found missing_keys: {missing_keys}")
self.io_data_dict = io_data_dict
self.test_io_data_dict = test_io_data_dict or {}
self.test_io_data_dict = test_io_data_dict

self.key_name = key_name
self.evaluator_params = evaluator_params
self.cross_validator_params = cross_validator_params
Expand All @@ -617,25 +630,34 @@ def compare_models(self,
"""
:param model_factories: a sequence of factory functions for the creation of models to evaluate; every factory must result
in a model with a fixed model name (otherwise results cannot be correctly aggregated)
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation
:param use_cross_validation: whether to use cross-validation (rather than a single split) for model evaluation.
This can only be used if the instance's ``test_io_data_dict`` is None.
:param result_writer: a writer with which to store results; if None, results are not stored
:param write_per_dataset_results: whether to use resultWriter (if not None) in order to generate detailed results for each
dataset in a subdirectory named according to the name of the dataset
:param write_csvs: whether to write metrics table to CSV files
:param column_name_for_model_ranking: column name to use for ranking models
:param rank_max: if true, use max for ranking, else min
:param add_combined_eval_stats: whether to also report, for each model, evaluation metrics on the combined set data points from
all EvalStats objects.
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param create_metric_distribution_plots: whether to create, for each model, plots of the distribution of each metric across the
datasets (applies only if resultWriter is not None)
datasets (applies only if result_writer is not None)
:param create_combined_eval_stats_plots: whether to combine, for each type of model, the EvalStats objects from the individual
experiments into a single objects that holds all results and use it to create plots reflecting the overall result (applies only
if resultWriter is not None).
Note that for classification, this is only possible if all individual experiments use the same set of class labels.
:param distribution_plots_cdf: whether to create CDF plots for the metric distributions. Applies only if
create_metric_distribution_plots is True and result_writer is not None.
:param distribution_plots_cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that
distribution_plots_cdf is True.
:param visitors: visitors which may process individual results. Plots generated by visitors are created/collected at the end of the
comparison.
:return: an object containing the full comparison results
"""
if self.test_io_data_dict and use_cross_validation:
raise ValueError("Cannot use cross-validation when `test_io_data_dict` is specified")

all_results_df = pd.DataFrame()
eval_stats_by_model_name = defaultdict(list)
results_by_model_name: Dict[str, List[ModelComparisonData.Result]] = defaultdict(list)
Expand Down Expand Up @@ -664,7 +686,7 @@ def compare_models(self,
else:
raise ValueError("The models have to be either all regression models or all classification, not a mixture")

test_io_data = self.test_io_data_dict.get(key)
test_io_data = self.test_io_data_dict[key] if self.test_io_data_dict is not None else None
ev = create_evaluation_util(inputOutputData, is_regression=is_regression, evaluator_params=self.evaluator_params,
cross_validator_params=self.cross_validator_params, test_io_data=test_io_data)

Expand Down Expand Up @@ -924,7 +946,7 @@ def create_distribution_plots(self, result_writer: ResultWriter, cdf=True, cdf_c
:param result_writer: the result writer
:param cdf: whether to additionally plot, for each distribution, the cumulative distribution function
:param cdf_complementary: whether to plot the complementary cdf, provided that ``cdf`` is True
:param cdf_complementary: whether to plot the complementary cdf instead of the regular cdf, provided that ``cdf`` is True
"""
for modelName in self.get_model_names():
eval_stats_collection = self.get_eval_stats_collection(modelName)
Expand Down

0 comments on commit bae92ee

Please sign in to comment.