Skip to content

Commit

Permalink
add filter rule that nvcc version only supports up to a specific gcc …
Browse files Browse the repository at this point in the history
…version
  • Loading branch information
SimeonEhrig committed Feb 14, 2024
1 parent 539638f commit 677b9d6
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 1 deletion.
24 changes: 24 additions & 0 deletions bashi/filter_compiler_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typeguard import typechecked
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import Parameter, ParameterValueTuple
from bashi.versions import NVCC_GCC_MAX_VERSION
from bashi.utils import reason


Expand Down Expand Up @@ -60,4 +61,27 @@ def compiler_version_filter(
reason(output, "host and device compiler version must be the same (except for nvcc)")
return False

# now idea, how remove nested blocks without hitting the performance
# pylint: disable=too-many-nested-blocks
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
if HOST_COMPILER in row and row[HOST_COMPILER].name == GCC:
# Rule: v2
# remove all unsupported nvcc gcc version combinations
# define which is the latest supported gcc compiler for a nvcc version

# if a nvcc version is not supported by bashi, assume that the version supports the
# latest gcc compiler version
if row[DEVICE_COMPILER].version <= NVCC_GCC_MAX_VERSION[0].nvcc:
# check the maximum supported gcc version for the given nvcc version
for comb in NVCC_GCC_MAX_VERSION:
if row[DEVICE_COMPILER].version >= comb.nvcc:
if row[HOST_COMPILER].version > comb.host:
reason(
output,
f"nvcc {row[DEVICE_COMPILER].version} "
f"does not support gcc {row[HOST_COMPILER].version}",
)
return False
break

return True
21 changes: 20 additions & 1 deletion bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ParameterValueSingle,
ParameterValueTuple,
)
from bashi.versions import COMPILERS, VERSIONS
from bashi.versions import COMPILERS, VERSIONS, NVCC_GCC_MAX_VERSION
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import


Expand Down Expand Up @@ -316,6 +316,7 @@ def reason(output: Optional[IO[str]], msg: str):
)


# pylint: disable=too-many-branches
@typechecked
def get_expected_bashi_parameter_value_pairs(
parameter_matrix: ParameterValueMatrix,
Expand Down Expand Up @@ -383,4 +384,22 @@ def get_expected_bashi_parameter_value_pairs(
all_versions=False,
)

# remove all gcc version, which are to new for a specific nvcc version
nvcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[NVCC]]
nvcc_versions.sort()
gcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[GCC]]
gcc_versions.sort()
for nvcc_version in nvcc_versions:
for max_nvcc_gcc_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_gcc_version.nvcc:
for gcc_version in gcc_versions:
if gcc_version > max_nvcc_gcc_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER, GCC, gcc_version, DEVICE_COMPILER, NVCC, nvcc_version
),
parameter_value_pairs=param_val_pair_list,
)
break

return param_val_pair_list
43 changes: 43 additions & 0 deletions bashi/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ValueName, ValueVersion, ParameterValue, ParameterValueMatrix


class NvccHostSupport:
"""Contains a nvcc version and host compiler version. Does automatically parse the input strings
to package.version.Version.
Provides comparision operators for sorting.
"""

def __init__(self, nvcc_version: str, host_version: str):
self.nvcc = pkv.parse(nvcc_version)
self.host = pkv.parse(host_version)

def __lt__(self, other: "NvccHostSupport") -> bool:
return self.nvcc < other.nvcc

def __eq__(self, other: object) -> bool:
if not isinstance(other, NvccHostSupport):
raise TypeError("does not support other types than NvccHostSupport")
return self.nvcc == other.nvcc and self.host == other.host

def __str__(self) -> str:
return f"nvcc {str(self.nvcc)} + host version {self.host}"


VERSIONS: Dict[str, List[Union[str, int, float]]] = {
GCC: [6, 7, 8, 9, 10, 11, 12, 13],
CLANG: [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
Expand All @@ -26,6 +50,7 @@
12.0,
12.1,
12.2,
12.3,
],
HIPCC: [5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0],
ICPX: ["2023.1.0", "2023.2.0"],
Expand All @@ -45,6 +70,24 @@
CXX_STANDARD: [17, 20],
}

# define the maximum supported gcc version for a specific nvcc version
# the latest supported nvcc version must be added, even if the supported gcc version does not
# increase
# e.g.:
# NvccHostSupport("12.3", "12"),
# NvccHostSupport("12.0", "12"),
# NvccHostSupport("11.4", "11"),
NVCC_GCC_MAX_VERSION: List[NvccHostSupport] = [
NvccHostSupport("12.3", "12"),
NvccHostSupport("12.0", "12"),
NvccHostSupport("11.4", "11"),
NvccHostSupport("11.1", "10"),
NvccHostSupport("11.0", "9"),
NvccHostSupport("10.1", "8"),
NvccHostSupport("10.0", "7"),
]
NVCC_GCC_MAX_VERSION.sort(reverse=True)


def get_parameter_value_matrix() -> ParameterValueMatrix:
"""Generates a parameter-value-matrix from all supported compilers, softwares and compilation
Expand Down
Loading

0 comments on commit 677b9d6

Please sign in to comment.