Skip to content

Commit

Permalink
unified compiler_name_filter and compiler_version_filter in a single …
Browse files Browse the repository at this point in the history
…filter
  • Loading branch information
SimeonEhrig committed Mar 12, 2024
1 parent a2a4a03 commit a35b063
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 202 deletions.
6 changes: 2 additions & 4 deletions bashi/filter_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typeguard import typechecked
from bashi.types import FilterFunction

from bashi.filter_compiler_name import compiler_name_filter
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_compiler import compiler_filter
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter

Expand All @@ -25,8 +24,7 @@ def get_default_filter_chain(
FilterFunction: The filter function chain, which can be directly used in bashi.FilterAdapter
"""
return (
lambda row: compiler_name_filter(row)
and compiler_version_filter(row)
lambda row: compiler_filter(row)
and backend_filter(row)
and software_dependency_filter(row)
and custom_filter_function(row)
Expand Down
52 changes: 40 additions & 12 deletions bashi/filter_compiler_version.py → bashi/filter_compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Filter rules basing on host and device compiler names and versions.
All rules implemented in this filter have an identifier that begins with "v" and follows a number.
Examples: v1, v42, v678 ...
All rules implemented in this filter have an identifier that begins with "c" and follows a number.
Examples: c1, c42, c678 ...
These identifiers are used in the test names, for example, to make it clear which test is testing
which rule.
Expand All @@ -26,18 +26,18 @@ def get_required_parameters() -> List[Parameter]:


@typechecked
def compiler_version_filter_typechecked(
def compiler_filter_typechecked(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
"""Type-checked version of compiler_version_filter(). Type checking has a big performance cost,
which is why the non type-checked version is used for the pairwise generator.
"""Type-checked version of compiler_filter(). Type checking has a big performance cost, which
is why the non type-checked version is used for the pairwise generator.
"""
return compiler_version_filter(row, output)
return compiler_filter(row, output)


# pylint: disable=too-many-branches
def compiler_version_filter(
def compiler_filter(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
) -> bool:
Expand All @@ -52,8 +52,36 @@ def compiler_version_filter(
Returns:
bool: True, if parameter-value-tuple is valid.
"""
# Rule: c1
# NVCC as HOST_COMPILER is not allow
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add NVCC as HOST_COMPILER and filter out afterwards
# this rule is only used by bashi-verify
if HOST_COMPILER in row and row[HOST_COMPILER].name == NVCC:
reason(output, "nvcc is not allowed as host compiler")
return False

# Rule: c2
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == NVCC
and HOST_COMPILER in row
and not row[HOST_COMPILER].name in [GCC, CLANG]
):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False

# Rule: c3
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].name != row[DEVICE_COMPILER].name
):
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False

# Rule: v1
# Rule: c4
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
Expand All @@ -67,7 +95,7 @@ def compiler_version_filter(
# pylint: disable=too-many-nested-blocks
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
if HOST_COMPILER in row and row[HOST_COMPILER].name == GCC:
# Rule: v2
# Rule: c5
# remove all unsupported nvcc gcc version combinations
# define which is the latest supported gcc compiler for a nvcc version

Expand All @@ -87,7 +115,7 @@ def compiler_version_filter(
break

if HOST_COMPILER in row and row[HOST_COMPILER].name == CLANG:
# Rule: v4
# Rule: c7
if row[DEVICE_COMPILER].version >= pkv.parse("11.3") and row[
DEVICE_COMPILER
].version <= pkv.parse("11.5"):
Expand All @@ -97,7 +125,7 @@ def compiler_version_filter(
)
return False

# Rule: v3
# Rule: c6
# remove all unsupported nvcc clang version combinations
# define which is the latest supported clang compiler for a nvcc version

Expand All @@ -116,7 +144,7 @@ def compiler_version_filter(
return False
break

# Rule: v5
# Rule: c8
# clang-cuda 13 and older is not supported
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add the clang-cuda versions and filter it out afterwards
Expand Down
81 changes: 0 additions & 81 deletions bashi/filter_compiler_name.py

This file was deleted.

17 changes: 4 additions & 13 deletions bashi/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ParameterValue, ParameterValueTuple
from bashi.versions import is_supported_version
import bashi.filter_compiler_name
import bashi.filter_compiler_version
import bashi.filter_compiler
import bashi.filter_backend
import bashi.filter_software_dependency

Expand Down Expand Up @@ -244,17 +243,9 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
all_true = 0
all_true += int(
check_single_filter(
bashi.filter_compiler_name.compiler_name_filter_typechecked,
bashi.filter_compiler.compiler_filter,
row,
bashi.filter_compiler_name.get_required_parameters(),
)
)

all_true += int(
check_single_filter(
bashi.filter_compiler_version.compiler_version_filter_typechecked,
row,
bashi.filter_compiler_version.get_required_parameters(),
bashi.filter_compiler.get_required_parameters(),
)
)
all_true += int(
Expand All @@ -273,7 +264,7 @@ def check_filter_chain(row: ParameterValueTuple) -> bool:
)

# each filter add a one, if it was successful
return all_true == 4
return all_true == 3


def main() -> None:
Expand Down
18 changes: 9 additions & 9 deletions tests/test_clang_cuda_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from collections import OrderedDict as OD
from utils_test import parse_param_val as ppv
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.filter_compiler_version import compiler_version_filter_typechecked
from bashi.filter_compiler import compiler_filter_typechecked


class TestClangCudaOldVersions(unittest.TestCase):
def test_valid_clang_cuda_versions_rule_v5(self):
def test_valid_clang_cuda_versions_rule_c8(self):
for clang_cuda_version in [14, 16, 18, 78]:
self.assertTrue(
compiler_version_filter_typechecked(
compiler_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, clang_cuda_version)),
Expand All @@ -22,10 +22,10 @@ def test_valid_clang_cuda_versions_rule_v5(self):
)
)

def test_valid_clang_cuda_versions_multi_row_rule_v5(self):
def test_valid_clang_cuda_versions_multi_row_rule_c8(self):
for clang_cuda_version in [14, 16, 18, 78]:
self.assertTrue(
compiler_version_filter_typechecked(
compiler_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, clang_cuda_version)),
Expand All @@ -37,11 +37,11 @@ def test_valid_clang_cuda_versions_multi_row_rule_v5(self):
)
)

def test_invalid_clang_cuda_versions_rule_v5(self):
def test_invalid_clang_cuda_versions_rule_c8(self):
for clang_cuda_version in [13, 7, 1]:
reason_msg = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
compiler_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, clang_cuda_version)),
Expand All @@ -56,11 +56,11 @@ def test_invalid_clang_cuda_versions_rule_v5(self):
"all clang versions older than 14 are disabled as CUDA Compiler",
)

def test_invalid_clang_cuda_versions_multi_row_rule_v5(self):
def test_invalid_clang_cuda_versions_multi_row_rule_c8(self):
for clang_cuda_version in [13, 7, 1]:
reason_msg = io.StringIO()
self.assertFalse(
compiler_version_filter_typechecked(
compiler_filter_typechecked(
OD(
{
HOST_COMPILER: ppv((CLANG_CUDA, clang_cuda_version)),
Expand Down
7 changes: 3 additions & 4 deletions tests/test_filter_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typeguard import typechecked
from bashi.types import ParameterValueTuple, ParameterValue
from bashi.utils import FilterAdapter
from bashi.filter_compiler_name import compiler_name_filter_typechecked
from bashi.filter_compiler_version import compiler_version_filter
from bashi.filter_compiler import compiler_filter_typechecked
from bashi.filter_backend import backend_filter
from bashi.filter_software_dependency import software_dependency_filter

Expand Down Expand Up @@ -108,11 +107,11 @@ def test_compiler_name_filter(self):
"because the test data should no trigger any rule"
)
self.assertTrue(
FilterAdapter(self.param_map, compiler_name_filter_typechecked)(self.test_row),
FilterAdapter(self.param_map, compiler_filter_typechecked)(self.test_row),
error_msg,
)
self.assertTrue(
FilterAdapter(self.param_map, compiler_version_filter)(self.test_row), error_msg
FilterAdapter(self.param_map, compiler_filter_typechecked)(self.test_row), error_msg
)
self.assertTrue(FilterAdapter(self.param_map, backend_filter)(self.test_row), error_msg)
self.assertTrue(
Expand Down
Loading

0 comments on commit a35b063

Please sign in to comment.