Skip to content

Commit

Permalink
implement the FilterAdapter
Browse files Browse the repository at this point in the history
The FilterAdapter provides a nice interface of allpairspy filter rules.
  • Loading branch information
SimeonEhrig committed Jan 24, 2024
1 parent 6d0f9fe commit abbf873
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 0 deletions.
73 changes: 73 additions & 0 deletions bashi/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Different helper functions for bashi"""

from typing import Dict, Callable, Tuple, List
from collections import OrderedDict
from packaging.version import Version
from typeguard import typechecked


class FilterAdapter:
"""
An adapter for the filter functions used by allpairspy to provide a better filter function
interface.
Independent of the type of `parameter` (in the bashi naming convention:
parameter-value-matrix type) used as an argument of AllPairs.__init__(), allpairspy always
passes the same row type to the filter function: List of parameter-values.
Therefore, the parameter name is encoded in the position in the row list. This makes it
much more difficult to write filter rules.
The FilterAdapter transforms the list of parameter values into a parameter-value-tuple, which
has the type OrderedDict[str, Tuple[str, Version]].
This user writes a filter rule function with the expected line type
OrderedDict[str, Tuple[str, Version]], creates a FunctionAdapter object with the functor as a
parameter and passes the FunctionAdapter object to AllPairs.__init__().
filter function example:
def filter_function(row: OrderedDict[str, Tuple[str, Version]]):
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER][NAME] == NVCC
and row[DEVICE_COMPILER][VERSION] < pkv.parse("12.0")
):
return False
return True
"""

@typechecked
def __init__(
self,
param_map: Dict[int, str],
filter_func: Callable[[OrderedDict[str, Tuple[str, Version]]], bool],
):
"""Create a new FilterAdapter, see class doc string.
Args:
param_map (Dict[int, str]): The param_map maps the index position of a parameter to the
parameter name. Assuming the parameter-value-matrix has the following keys:
["param1", "param2", "param3"], the param_map should look like this:
{0: "param1", 1 : "param2", 2 : "param3"}.
filter_func (Callable[[OrderedDict[str, Tuple[str, Version]]], bool]): The filter
function used by allpairspy, see class doc string.
"""
self.param_map = param_map
self.filter_func = filter_func

def __call__(self, row: List[Tuple[str, Version]]) -> bool:
"""The expected interface of allpairspy filter rule.
Transform the type of row from List[Tuple[str, Version]] to
[OrderedDict[str, Tuple[str, Version]]].
Args:
row (List[Tuple[str, Version]]): the parameter-value-tuple
Returns:
bool: Returns True, if the parameter-value-tuple is valid
"""
ordered_row: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
for index, param_name in enumerate(row):
ordered_row[self.param_map[index]] = param_name
return self.filter_func(ordered_row)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ Issues = "https://github.com/alpaka-group/bashi/issues"
command_line = "-m unittest discover -s tests/"
branch = true
source = ["bashi"]

[tool.black]
line-length = 100
162 changes: 162 additions & 0 deletions tests/test_filter_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# pylint: disable=missing-docstring
import unittest
from typing import Tuple, Dict, List
from collections import OrderedDict
from packaging.version import Version
import packaging.version as pkv
from typeguard import typechecked
from bashi.utils import FilterAdapter


class TestFilterAdapterDataSet1(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
cls.param_val_tuple["param1"] = ("param-val-name1", pkv.parse("1"))
cls.param_val_tuple["param2"] = ("param-val-name2", pkv.parse("2"))
cls.param_val_tuple["param3"] = ("param-val-name3", pkv.parse("3"))

cls.param_map: Dict[int, str] = {}
for index, param_name in enumerate(cls.param_val_tuple.keys()):
cls.param_map[index] = param_name

cls.test_row: List[Tuple[str, Version]] = []
for param_val in cls.param_val_tuple.values():
cls.test_row.append(param_val)

# use typechecked to do a deep type check
# isinstance() only verify the "outer" data type, which is OrderedDict
# isinstance() does not verify the key and value type
def test_function_type(self):
@typechecked
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
if len(row.keys()) < 1:
raise AssertionError("There is no element in row.")

# typechecked does not check the types of Tuple, therefore I "unwrap" it
@typechecked
def check_param_value_type(_: Tuple[str, Version]):
pass

check_param_value_type(next(iter(row.values())))

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_function_length(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
if len(row) != 3:
raise AssertionError(f"Size of test_row is {len(row)}. Expected is 3.")

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_function_row_order(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
excepted_param_order = ["param1", "param2", "param3"]
if len(excepted_param_order) != len(row):
raise AssertionError(
"excepted_key_order and row has not the same length.\n"
f"{len(excepted_param_order)} != {len(row)}"
)

for index, param in enumerate(row.keys()):
if excepted_param_order[index] != param:
raise AssertionError(
f"The {index}. parameter is not the expected "
f"parameter: {excepted_param_order[index]}"
)

expected_param_value_order = [
("param-val-name1", pkv.parse("1")),
("param-val-name2", pkv.parse("2")),
("param-val-name3", pkv.parse("3")),
]

for index, param_value in enumerate(row.values()):
expected_value_name = expected_param_value_order[index][0]
expected_value_version = expected_param_value_order[index][1]
if (
expected_value_name != param_value[0]
or expected_value_version != param_value[1]
):
raise AssertionError(
f"The {index}. parameter-value is not the expected parameter-value\n"
f"Get: {param_value}\n"
f"Expected: {expected_param_value_order[index]}"
)

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

def test_lambda(self):
filter_adapter = FilterAdapter(self.param_map, lambda row: len(row) == 3)
self.assertTrue(filter_adapter(self.test_row), "row has not the length of 3")


# do a complex test with a different data set
class TestFilterAdapterDataSet2(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.param_val_tuple: OrderedDict[str, Tuple[str, Version]] = OrderedDict()
cls.param_val_tuple["param6b"] = ("param-val-name1", pkv.parse("3.21.2"))
cls.param_val_tuple["param231a"] = ("param-val-name67asd", pkv.parse("2.4"))
cls.param_val_tuple["param234s"] = ("param-val-678", pkv.parse("3"))
cls.param_val_tuple["foo"] = ("foo", pkv.parse("12.3"))
cls.param_val_tuple["bar"] = ("bar", pkv.parse("3"))

cls.param_map: Dict[int, str] = {}
for index, param_name in enumerate(cls.param_val_tuple.keys()):
cls.param_map[index] = param_name

cls.test_row: List[Tuple[str, Version]] = []
for param_val in cls.param_val_tuple.values():
cls.test_row.append(param_val)

def test_function_row_lenght_order(self):
def filter_function(row: OrderedDict[str, Tuple[str, Version]]) -> bool:
excepted_param_order = ["param6b", "param231a", "param234s", "foo", "bar"]
if len(excepted_param_order) != len(row):
raise AssertionError(
"excepted_key_order and row has not the same length.\n"
f"{len(excepted_param_order)} != {len(row)}"
)

for index, param in enumerate(row.keys()):
if excepted_param_order[index] != param:
raise AssertionError(
f"The {index}. parameter is not the expected "
f"parameter: {excepted_param_order[index]}"
)

expected_param_value_order = [
("param-val-name1", pkv.parse("3.21.2")),
("param-val-name67asd", pkv.parse("2.4")),
("param-val-678", pkv.parse("3")),
("foo", pkv.parse("12.3")),
("bar", pkv.parse("3")),
]

for index, param_value in enumerate(row.values()):
expected_value_name = expected_param_value_order[index][0]
expected_value_version = expected_param_value_order[index][1]
if (
expected_value_name != param_value[0]
or expected_value_version != param_value[1]
):
raise AssertionError(
f"The {index}. parameter-value is not the expected parameter-value\n"
f"Get: {param_value}\n"
f"Expected: {expected_param_value_order[index]}"
)

return True

filter_adapter = FilterAdapter(self.param_map, filter_function)
self.assertTrue(filter_adapter(self.test_row))

0 comments on commit abbf873

Please sign in to comment.