-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The FilterAdapter provides a nice interface of allpairspy filter rules.
- Loading branch information
1 parent
6d0f9fe
commit abbf873
Showing
3 changed files
with
238 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Different helper functions for bashi""" | ||
|
||
from typing import Dict, Callable, Tuple, List | ||
from collections import OrderedDict | ||
from packaging.version import Version | ||
from typeguard import typechecked | ||
|
||
|
||
class FilterAdapter: | ||
""" | ||
An adapter for the filter functions used by allpairspy to provide a better filter function | ||
interface. | ||
Independent of the type of `parameter` (in the bashi naming convention: | ||
parameter-value-matrix type) used as an argument of AllPairs.__init__(), allpairspy always | ||
passes the same row type to the filter function: List of parameter-values. | ||
Therefore, the parameter name is encoded in the position in the row list. This makes it | ||
much more difficult to write filter rules. | ||
The FilterAdapter transforms the list of parameter values into a parameter-value-tuple, which | ||
has the type OrderedDict[str, Tuple[str, Version]]. | ||
This user writes a filter rule function with the expected line type | ||
OrderedDict[str, Tuple[str, Version]], creates a FunctionAdapter object with the functor as a | ||
parameter and passes the FunctionAdapter object to AllPairs.__init__(). | ||
filter function example: | ||
def filter_function(row: OrderedDict[str, Tuple[str, Version]]): | ||
if ( | ||
DEVICE_COMPILER in row | ||
and row[DEVICE_COMPILER][NAME] == NVCC | ||
and row[DEVICE_COMPILER][VERSION] < pkv.parse("12.0") | ||
): | ||
return False | ||
return True | ||
""" | ||
|
||
@typechecked | ||
def __init__( | ||
self, | ||
param_map: Dict[int, str], | ||
filter_func: Callable[[OrderedDict[str, Tuple[str, Version]]], bool], | ||
): | ||
"""Create a new FilterAdapter, see class doc string. | ||
Args: | ||
param_map (Dict[int, str]): The param_map maps the index position of a parameter to the | ||
parameter name. Assuming the parameter-value-matrix has the following keys: | ||
["param1", "param2", "param3"], the param_map should look like this: | ||
{0: "param1", 1 : "param2", 2 : "param3"}. | ||
filter_func (Callable[[OrderedDict[str, Tuple[str, Version]]], bool]): The filter | ||
function used by allpairspy, see class doc string. | ||
""" | ||
self.param_map = param_map | ||
self.filter_func = filter_func | ||
|
||
def __call__(self, row: List[Tuple[str, Version]]) -> bool: | ||
"""The expected interface of allpairspy filter rule. | ||
Transform the type of row from List[Tuple[str, Version]] to | ||
[OrderedDict[str, Tuple[str, Version]]]. | ||
Args: | ||
row (List[Tuple[str, Version]]): the parameter-value-tuple | ||
Returns: | ||
bool: Returns True, if the parameter-value-tuple is valid | ||
""" | ||
ordered_row: OrderedDict[str, Tuple[str, Version]] = OrderedDict() | ||
for index, param_name in enumerate(row): | ||
ordered_row[self.param_map[index]] = param_name | ||
return self.filter_func(ordered_row) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# pylint: disable=missing-docstring | ||
import unittest | ||
from typing import Tuple, Dict, List | ||
from collections import OrderedDict | ||
from packaging.version import Version | ||
import packaging.version as pkv | ||
from typeguard import typechecked | ||
from bashi.utils import FilterAdapter | ||
|
||
|
||
class TestFilterAdapterDataSet1(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict() | ||
cls.param_val_tuple["param1"] = ("param-val-name1", pkv.parse("1")) | ||
cls.param_val_tuple["param2"] = ("param-val-name2", pkv.parse("2")) | ||
cls.param_val_tuple["param3"] = ("param-val-name3", pkv.parse("3")) | ||
|
||
cls.param_map: Dict[int, str] = {} | ||
for index, param_name in enumerate(cls.param_val_tuple.keys()): | ||
cls.param_map[index] = param_name | ||
|
||
cls.test_row: List[Tuple[str, Version]] = [] | ||
for param_val in cls.param_val_tuple.values(): | ||
cls.test_row.append(param_val) | ||
|
||
# use typechecked to do a deep type check | ||
# isinstance() only verify the "outer" data type, which is OrderedDict | ||
# isinstance() does not verify the key and value type | ||
def test_function_type(self): | ||
@typechecked | ||
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool: | ||
if len(row.keys()) < 1: | ||
raise AssertionError("There is no element in row.") | ||
|
||
# typechecked does not check the types of Tuple, therefore I "unwrap" it | ||
@typechecked | ||
def check_param_value_type(_: Tuple[str, Version]): | ||
pass | ||
|
||
check_param_value_type(next(iter(row.values()))) | ||
|
||
return True | ||
|
||
filter_adapter = FilterAdapter(self.param_map, filter_function) | ||
self.assertTrue(filter_adapter(self.test_row)) | ||
|
||
def test_function_length(self): | ||
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool: | ||
if len(row) != 3: | ||
raise AssertionError(f"Size of test_row is {len(row)}. Expected is 3.") | ||
|
||
return True | ||
|
||
filter_adapter = FilterAdapter(self.param_map, filter_function) | ||
self.assertTrue(filter_adapter(self.test_row)) | ||
|
||
def test_function_row_order(self): | ||
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool: | ||
excepted_param_order = ["param1", "param2", "param3"] | ||
if len(excepted_param_order) != len(row): | ||
raise AssertionError( | ||
"excepted_key_order and row has not the same length.\n" | ||
f"{len(excepted_param_order)} != {len(row)}" | ||
) | ||
|
||
for index, param in enumerate(row.keys()): | ||
if excepted_param_order[index] != param: | ||
raise AssertionError( | ||
f"The {index}. parameter is not the expected " | ||
f"parameter: {excepted_param_order[index]}" | ||
) | ||
|
||
expected_param_value_order = [ | ||
("param-val-name1", pkv.parse("1")), | ||
("param-val-name2", pkv.parse("2")), | ||
("param-val-name3", pkv.parse("3")), | ||
] | ||
|
||
for index, param_value in enumerate(row.values()): | ||
expected_value_name = expected_param_value_order[index][0] | ||
expected_value_version = expected_param_value_order[index][1] | ||
if ( | ||
expected_value_name != param_value[0] | ||
or expected_value_version != param_value[1] | ||
): | ||
raise AssertionError( | ||
f"The {index}. parameter-value is not the expected parameter-value\n" | ||
f"Get: {param_value}\n" | ||
f"Expected: {expected_param_value_order[index]}" | ||
) | ||
|
||
return True | ||
|
||
filter_adapter = FilterAdapter(self.param_map, filter_function) | ||
self.assertTrue(filter_adapter(self.test_row)) | ||
|
||
def test_lambda(self): | ||
filter_adapter = FilterAdapter(self.param_map, lambda row: len(row) == 3) | ||
self.assertTrue(filter_adapter(self.test_row), "row has not the length of 3") | ||
|
||
|
||
# do a complex test with a different data set | ||
class TestFilterAdapterDataSet2(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict() | ||
cls.param_val_tuple["param6b"] = ("param-val-name1", pkv.parse("3.21.2")) | ||
cls.param_val_tuple["param231a"] = ("param-val-name67asd", pkv.parse("2.4")) | ||
cls.param_val_tuple["param234s"] = ("param-val-678", pkv.parse("3")) | ||
cls.param_val_tuple["foo"] = ("foo", pkv.parse("12.3")) | ||
cls.param_val_tuple["bar"] = ("bar", pkv.parse("3")) | ||
|
||
cls.param_map: Dict[int, str] = {} | ||
for index, param_name in enumerate(cls.param_val_tuple.keys()): | ||
cls.param_map[index] = param_name | ||
|
||
cls.test_row: List[Tuple[str, Version]] = [] | ||
for param_val in cls.param_val_tuple.values(): | ||
cls.test_row.append(param_val) | ||
|
||
def test_function_row_lenght_order(self): | ||
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool: | ||
excepted_param_order = ["param6b", "param231a", "param234s", "foo", "bar"] | ||
if len(excepted_param_order) != len(row): | ||
raise AssertionError( | ||
"excepted_key_order and row has not the same length.\n" | ||
f"{len(excepted_param_order)} != {len(row)}" | ||
) | ||
|
||
for index, param in enumerate(row.keys()): | ||
if excepted_param_order[index] != param: | ||
raise AssertionError( | ||
f"The {index}. parameter is not the expected " | ||
f"parameter: {excepted_param_order[index]}" | ||
) | ||
|
||
expected_param_value_order = [ | ||
("param-val-name1", pkv.parse("3.21.2")), | ||
("param-val-name67asd", pkv.parse("2.4")), | ||
("param-val-678", pkv.parse("3")), | ||
("foo", pkv.parse("12.3")), | ||
("bar", pkv.parse("3")), | ||
] | ||
|
||
for index, param_value in enumerate(row.values()): | ||
expected_value_name = expected_param_value_order[index][0] | ||
expected_value_version = expected_param_value_order[index][1] | ||
if ( | ||
expected_value_name != param_value[0] | ||
or expected_value_version != param_value[1] | ||
): | ||
raise AssertionError( | ||
f"The {index}. parameter-value is not the expected parameter-value\n" | ||
f"Get: {param_value}\n" | ||
f"Expected: {expected_param_value_order[index]}" | ||
) | ||
|
||
return True | ||
|
||
filter_adapter = FilterAdapter(self.param_map, filter_function) | ||
self.assertTrue(filter_adapter(self.test_row)) |