From b323da5d4cb8269b92907abb6fc715f58367b69f Mon Sep 17 00:00:00 2001 From: Simeon Ehrig Date: Wed, 7 Feb 2024 13:57:51 +0100 Subject: [PATCH] replace allpairspy with covertable library - Both library provides an pair-wise generator algorithm. - There is a know bug in the allpairspy library, which causes that not all valid parameter-value-pairs are generated, if filter rules are used. The bug exists since two years. --- bashi/generator.py | 34 +++++++++++--------- pyproject.toml | 2 +- tests/test_expected_parameter_value_pairs.py | 16 ++++----- tests/test_generate_combination_list.py | 19 +++-------- 4 files changed, 32 insertions(+), 39 deletions(-) diff --git a/bashi/generator.py b/bashi/generator.py index 9202ebf..0f3b55f 100644 --- a/bashi/generator.py +++ b/bashi/generator.py @@ -1,18 +1,19 @@ """Functions to generate the combination-list""" -from typing import Dict +from typing import Dict, List from collections import OrderedDict -from allpairspy import AllPairs +from covertable import make from bashi.types import ( Parameter, + ParameterValue, ParameterValueMatrix, FilterFunction, Combination, CombinationList, ) -from bashi.utils import get_default_filter_chain, FilterAdapter +from bashi.utils import get_default_filter_chain def generate_combination_list( @@ -32,20 +33,21 @@ def generate_combination_list( """ filter_chain = get_default_filter_chain(custom_filter) - param_map: Dict[int, Parameter] = {} - for index, key in enumerate(parameter_value_matrix.keys()): - param_map[index] = key - filter_adapter = FilterAdapter(param_map, filter_chain) - comb_list: CombinationList = [] - # convert List[Pair] to CombinationList - for all_pair in AllPairs( # type: ignore - parameters=parameter_value_matrix, n=2, filter_func=filter_adapter - ): - comb: Combination = OrderedDict() - for index, param in enumerate(all_pair._fields): # type: ignore - comb[param] = all_pair[index] # type: ignore - comb_list.append(comb) + all_pairs: List[Dict[Parameter, ParameterValue]] = make( + factors=parameter_value_matrix, + length=2, + pre_filter=filter_chain, + ) # type: ignore + + # convert List[Dict[Parameter, ParameterValue]] to CombinationList + for all_pair in all_pairs: + tmp_comb: Combination = OrderedDict() + # covertable does not keep the ordering of the parameters + # therefore we sort it + for param in parameter_value_matrix.keys(): + tmp_comb[param] = all_pair[param] + comb_list.append(tmp_comb) return comb_list diff --git a/pyproject.toml b/pyproject.toml index 5f61a27..22cf465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers= [ "Operating System :: OS Independent", ] dependencies = [ - "allpairspy == 2.5.1", + "covertable == 2.1.0", "typeguard", "packaging" ] diff --git a/tests/test_expected_parameter_value_pairs.py b/tests/test_expected_parameter_value_pairs.py index 3b1b6d3..ec3cd29 100644 --- a/tests/test_expected_parameter_value_pairs.py +++ b/tests/test_expected_parameter_value_pairs.py @@ -5,9 +5,8 @@ import io import packaging.version as pkv -# allpairspy has no type hints -from allpairspy import AllPairs # type: ignore from utils_test import parse_param_val, parse_param_vals, parse_expected_val_pairs +from covertable import make from bashi.types import ( Parameter, ParameterValue, @@ -488,15 +487,16 @@ def test_check_parameter_value_pair_in_combination_list_complete_list_plus_wrong self.assertEqual(output_wrong_many_pairs_list, expected_output_many_wrong_pairs_list) - def test_unrestricted_allpairspy_generator(self): + def test_unrestricted_covertable_generator(self): comb_list: CombinationList = [] # pylance shows a warning, because it cannot determine the concrete type of a namedtuple, # which is returned by AllPairs - for all_pair in AllPairs(parameters=self.param_matrix): # type: ignore - comb: Combination = OrderedDict() - for index, param in enumerate(all_pair._fields): # type: ignore - comb[param] = all_pair[index] # type: ignore - comb_list.append(comb) + all_pairs: List[Dict[Parameter, ParameterValue]] = make( + factors=self.param_matrix + ) # type: ignore + + for all_pair in all_pairs: + comb_list.append(OrderedDict(all_pair)) self.assertTrue( check_parameter_value_pair_in_combination_list(comb_list, self.expected_param_val_pairs) diff --git a/tests/test_generate_combination_list.py b/tests/test_generate_combination_list.py index 988f7b5..29b3b66 100644 --- a/tests/test_generate_combination_list.py +++ b/tests/test_generate_combination_list.py @@ -50,9 +50,6 @@ def test_generator_without_custom_filter(self): ) ) - # TODO(SimeonEhrig): remove expectedFailure, if this PR was merged: - # https://github.com/thombashi/allpairspy/pull/10 - @unittest.expectedFailure def test_generator_with_custom_filter(self): def custom_filter(row: ParameterValueTuple) -> bool: if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC: @@ -168,7 +165,6 @@ def custom_filter(row: ParameterValueTuple) -> bool: missing_combinations = io.StringIO() - # because of a bug in the allpairspy library, valid pairs are missing. try: self.assertTrue( check_parameter_value_pair_in_combination_list( @@ -177,7 +173,7 @@ def custom_filter(row: ParameterValueTuple) -> bool: ) except AssertionError as e: # remove comment to print missing, valid pairs - # print(f"\n{missing_combinations.getvalue()}") + print(f"\n{missing_combinations.getvalue()}") raise e @@ -194,14 +190,8 @@ def test_generator_without_custom_filter(self): check_parameter_value_pair_in_combination_list(comb_list, expected_param_val_pairs) ) - # TODO(SimeonEhrig): remove expectedFailure, if this PR was merged: - # https://github.com/thombashi/allpairspy/pull/10 - @unittest.expectedFailure def test_generator_with_custom_filter(self): def custom_filter(row: ParameterValueTuple) -> bool: - if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC: - return False - if ( CMAKE in row and row[CMAKE].version == pkv.parse("3.23") @@ -236,7 +226,8 @@ def custom_filter(row: ParameterValueTuple) -> bool: ) except AssertionError as e: # remove comment to display missing combinations - # missing_combinations_str = missing_combinations.getvalue() - # print(f"\nnumber of missing combinations: {len(missing_combinations_str.split('\n'))}") - # print(f"\n{missing_combinations_str}") + missing_combinations_str = missing_combinations.getvalue() + print(f"\n{missing_combinations_str}") + number_of_combs = len(missing_combinations_str.split("\n")) + print(f"\nnumber of missing combinations: {number_of_combs}") raise e