Skip to content

Commit

Permalink
refactored filter rules
Browse files Browse the repository at this point in the history
- nvcc host compiler
- same host and device compiler name and version
- clang-cuda older than version 14
  • Loading branch information
SimeonEhrig committed Mar 12, 2024
1 parent dfdafff commit ee930fb
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 71 deletions.
61 changes: 26 additions & 35 deletions bashi/filter_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,35 +62,25 @@ def compiler_filter(
reason(output, "nvcc is not allowed as host compiler")
return False

# Rule: c2
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == NVCC
and HOST_COMPILER in row
and not row[HOST_COMPILER].name in [GCC, CLANG]
):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False

# Rule: c3
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].name != row[DEVICE_COMPILER].name
):
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False
if HOST_COMPILER in row and DEVICE_COMPILER in row:
if NVCC in (row[HOST_COMPILER].name, row[DEVICE_COMPILER].name):
# Rule: c2
if row[HOST_COMPILER].name not in (GCC, CLANG):
reason(output, "only gcc and clang are allowed as nvcc host compiler")
return False
else:
# Rule: c3
if row[HOST_COMPILER].name != row[DEVICE_COMPILER].name:
reason(output, "host and device compiler name must be the same (except for nvcc)")
return False

# Rule: c4
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name != NVCC
and HOST_COMPILER in row
and row[HOST_COMPILER].version != row[DEVICE_COMPILER].version
):
reason(output, "host and device compiler version must be the same (except for nvcc)")
return False
# Rule: c4
if row[HOST_COMPILER].version != row[DEVICE_COMPILER].version:
reason(
output,
"host and device compiler version must be the same (except for nvcc)",
)
return False

# now idea, how remove nested blocks without hitting the performance
# pylint: disable=too-many-nested-blocks
Expand Down Expand Up @@ -150,12 +140,13 @@ def compiler_filter(
# this rule will be never used, because of an implementation detail of the covertable library
# it is not possible to add the clang-cuda versions and filter it out afterwards
# this rule is only used by bashi-verify
if (
DEVICE_COMPILER in row
and row[DEVICE_COMPILER].name == CLANG_CUDA
and row[DEVICE_COMPILER].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False
for compiler in (HOST_COMPILER, DEVICE_COMPILER):
if (
compiler in row
and row[compiler].name == CLANG_CUDA
and row[compiler].version < pkv.parse("14")
):
reason(output, "all clang versions older than 14 are disabled as CUDA Compiler")
return False

return True
38 changes: 2 additions & 36 deletions bashi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from typing import Dict, List
from collections import OrderedDict
import copy
import packaging.version as pkv

from covertable import make # type: ignore

Expand Down Expand Up @@ -34,44 +32,12 @@ def generate_combination_list(
Returns:
CombinationList: combination-list
"""
# use local version to do not modify parameter_value_matrix
local_param_val_mat = copy.deepcopy(parameter_value_matrix)

filter_chain = get_default_filter_chain(custom_filter)

def host_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: n1
# remove nvcc as host compiler
if param_val.name == NVCC:
return False
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

def device_compiler_filter(param_val: ParameterValue) -> bool:
# Rule: v5
# remove clang-cuda older than 14
if param_val.name == CLANG_CUDA and param_val.version < pkv.parse("14"):
return False

return True

pre_filters = {HOST_COMPILER: host_compiler_filter, DEVICE_COMPILER: device_compiler_filter}

# some filter rules requires that specific parameter-values are already removed from the
# parameter-value-matrix
# otherwise the covertable library throws an error
for param, filter_func in pre_filters.items():
if param in local_param_val_mat:
local_param_val_mat[param] = list(filter(filter_func, local_param_val_mat[param]))

comb_list: CombinationList = []

all_pairs: List[Dict[Parameter, ParameterValue]] = make(
factors=local_param_val_mat,
factors=parameter_value_matrix,
length=2,
pre_filter=filter_chain,
) # type: ignore
Expand All @@ -81,7 +47,7 @@ def device_compiler_filter(param_val: ParameterValue) -> bool:
tmp_comb: Combination = OrderedDict({})
# covertable does not keep the ordering of the parameters
# therefore we sort it
for param in local_param_val_mat.keys():
for param in parameter_value_matrix.keys():
tmp_comb[param] = all_pair[param]
comb_list.append(tmp_comb)

Expand Down

0 comments on commit ee930fb

Please sign in to comment.