Skip to content

Commit

Permalink
replace allpairspy with covertable library
Browse files Browse the repository at this point in the history
- Both library provides an pair-wise generator algorithm.
- There is a know bug in the allpairspy library, which causes that not all valid parameter-value-pairs are generated, if filter rules are used. The bug exists since two years.
  • Loading branch information
SimeonEhrig committed Feb 7, 2024
1 parent 5bb3749 commit b323da5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
34 changes: 18 additions & 16 deletions bashi/generator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""Functions to generate the combination-list"""

from typing import Dict
from typing import Dict, List
from collections import OrderedDict

from allpairspy import AllPairs
from covertable import make

from bashi.types import (
Parameter,
ParameterValue,
ParameterValueMatrix,
FilterFunction,
Combination,
CombinationList,
)
from bashi.utils import get_default_filter_chain, FilterAdapter
from bashi.utils import get_default_filter_chain


def generate_combination_list(
Expand All @@ -32,20 +33,21 @@ def generate_combination_list(
"""
filter_chain = get_default_filter_chain(custom_filter)

param_map: Dict[int, Parameter] = {}
for index, key in enumerate(parameter_value_matrix.keys()):
param_map[index] = key
filter_adapter = FilterAdapter(param_map, filter_chain)

comb_list: CombinationList = []

# convert List[Pair] to CombinationList
for all_pair in AllPairs( # type: ignore
parameters=parameter_value_matrix, n=2, filter_func=filter_adapter
):
comb: Combination = OrderedDict()
for index, param in enumerate(all_pair._fields): # type: ignore
comb[param] = all_pair[index] # type: ignore
comb_list.append(comb)
all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=parameter_value_matrix,
length=2,
pre_filter=filter_chain,
) # type: ignore

# convert List[Dict[Parameter, ParameterValue]] to CombinationList
for all_pair in all_pairs:
tmp_comb: Combination = OrderedDict()
# covertable does not keep the ordering of the parameters
# therefore we sort it
for param in parameter_value_matrix.keys():
tmp_comb[param] = all_pair[param]
comb_list.append(tmp_comb)

return comb_list
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers= [
"Operating System :: OS Independent",
]
dependencies = [
"allpairspy == 2.5.1",
"covertable == 2.1.0",
"typeguard",
"packaging"
]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_expected_parameter_value_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import io
import packaging.version as pkv

# allpairspy has no type hints
from allpairspy import AllPairs # type: ignore
from utils_test import parse_param_val, parse_param_vals, parse_expected_val_pairs
from covertable import make
from bashi.types import (
Parameter,
ParameterValue,
Expand Down Expand Up @@ -488,15 +487,16 @@ def test_check_parameter_value_pair_in_combination_list_complete_list_plus_wrong

self.assertEqual(output_wrong_many_pairs_list, expected_output_many_wrong_pairs_list)

def test_unrestricted_allpairspy_generator(self):
def test_unrestricted_covertable_generator(self):
comb_list: CombinationList = []
# pylance shows a warning, because it cannot determine the concrete type of a namedtuple,
# which is returned by AllPairs
for all_pair in AllPairs(parameters=self.param_matrix): # type: ignore
comb: Combination = OrderedDict()
for index, param in enumerate(all_pair._fields): # type: ignore
comb[param] = all_pair[index] # type: ignore
comb_list.append(comb)
all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=self.param_matrix
) # type: ignore

for all_pair in all_pairs:
comb_list.append(OrderedDict(all_pair))

self.assertTrue(
check_parameter_value_pair_in_combination_list(comb_list, self.expected_param_val_pairs)
Expand Down
19 changes: 5 additions & 14 deletions tests/test_generate_combination_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def test_generator_without_custom_filter(self):
)
)

# TODO(SimeonEhrig): remove expectedFailure, if this PR was merged:
# https://github.com/thombashi/allpairspy/pull/10
@unittest.expectedFailure
def test_generator_with_custom_filter(self):
def custom_filter(row: ParameterValueTuple) -> bool:
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
Expand Down Expand Up @@ -168,7 +165,6 @@ def custom_filter(row: ParameterValueTuple) -> bool:

missing_combinations = io.StringIO()

# because of a bug in the allpairspy library, valid pairs are missing.
try:
self.assertTrue(
check_parameter_value_pair_in_combination_list(
Expand All @@ -177,7 +173,7 @@ def custom_filter(row: ParameterValueTuple) -> bool:
)
except AssertionError as e:
# remove comment to print missing, valid pairs
# print(f"\n{missing_combinations.getvalue()}")
print(f"\n{missing_combinations.getvalue()}")
raise e


Expand All @@ -194,14 +190,8 @@ def test_generator_without_custom_filter(self):
check_parameter_value_pair_in_combination_list(comb_list, expected_param_val_pairs)
)

# TODO(SimeonEhrig): remove expectedFailure, if this PR was merged:
# https://github.com/thombashi/allpairspy/pull/10
@unittest.expectedFailure
def test_generator_with_custom_filter(self):
def custom_filter(row: ParameterValueTuple) -> bool:
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
return False

if (
CMAKE in row
and row[CMAKE].version == pkv.parse("3.23")
Expand Down Expand Up @@ -236,7 +226,8 @@ def custom_filter(row: ParameterValueTuple) -> bool:
)
except AssertionError as e:
# remove comment to display missing combinations
# missing_combinations_str = missing_combinations.getvalue()
# print(f"\nnumber of missing combinations: {len(missing_combinations_str.split('\n'))}")
# print(f"\n{missing_combinations_str}")
missing_combinations_str = missing_combinations.getvalue()
print(f"\n{missing_combinations_str}")
number_of_combs = len(missing_combinations_str.split("\n"))
print(f"\nnumber of missing combinations: {number_of_combs}")
raise e

0 comments on commit b323da5

Please sign in to comment.