Skip to content

Commit

Permalink
Merge pull request #15 from SimeonEhrig/RuleNvccGccSupport
Browse files Browse the repository at this point in the history
filter rules which restic the supported gcc and clang host compiler versions for specific nvcc versions
  • Loading branch information
SimeonEhrig authored Feb 14, 2024
2 parents b454dc3 + dfd0d72 commit 209a5df
Show file tree
Hide file tree
Showing 5 changed files with 735 additions and 4 deletions.
56 changes: 56 additions & 0 deletions bashi/filter_compiler_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
"""

from typing import Optional, IO, List
import packaging.version as pkv
from typeguard import typechecked
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import Parameter, ParameterValueTuple
from bashi.versions import NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.utils import reason


Expand All @@ -34,6 +36,7 @@ def compiler_version_filter_typechecked(
return compiler_version_filter(row, output)


# pylint: disable=too-many-branches
def compiler_version_filter(
row: ParameterValueTuple,
output: Optional[IO[str]] = None,
Expand All @@ -60,4 +63,57 @@ def compiler_version_filter(
reason(output, "host and device compiler version must be the same (except for nvcc)")
return False

# now idea, how remove nested blocks without hitting the performance
# pylint: disable=too-many-nested-blocks
if DEVICE_COMPILER in row and row[DEVICE_COMPILER].name == NVCC:
if HOST_COMPILER in row and row[HOST_COMPILER].name == GCC:
# Rule: v2
# remove all unsupported nvcc gcc version combinations
# define which is the latest supported gcc compiler for a nvcc version

# if a nvcc version is not supported by bashi, assume that the version supports the
# latest gcc compiler version
if row[DEVICE_COMPILER].version <= NVCC_GCC_MAX_VERSION[0].nvcc:
# check the maximum supported gcc version for the given nvcc version
for nvcc_gcc_comb in NVCC_GCC_MAX_VERSION:
if row[DEVICE_COMPILER].version >= nvcc_gcc_comb.nvcc:
if row[HOST_COMPILER].version > nvcc_gcc_comb.host:
reason(
output,
f"nvcc {row[DEVICE_COMPILER].version} "
f"does not support gcc {row[HOST_COMPILER].version}",
)
return False
break

if HOST_COMPILER in row and row[HOST_COMPILER].name == CLANG:
# Rule: v4
if row[DEVICE_COMPILER].version >= pkv.parse("11.3") and row[
DEVICE_COMPILER
].version <= pkv.parse("11.5"):
reason(
output,
"clang as host compiler is disabled for nvcc 11.3 to 11.5",
)
return False

# Rule: v3
# remove all unsupported nvcc clang version combinations
# define which is the latest supported clang compiler for a nvcc version

# if a nvcc version is not supported by bashi, assume that the version supports the
# latest clang compiler version
if row[DEVICE_COMPILER].version <= NVCC_CLANG_MAX_VERSION[0].nvcc:
# check the maximum supported gcc version for the given nvcc version
for nvcc_clang_comb in NVCC_CLANG_MAX_VERSION:
if row[DEVICE_COMPILER].version >= nvcc_clang_comb.nvcc:
if row[HOST_COMPILER].version > nvcc_clang_comb.host:
reason(
output,
f"nvcc {row[DEVICE_COMPILER].version} "
f"does not support clang {row[HOST_COMPILER].version}",
)
return False
break

return True
63 changes: 62 additions & 1 deletion bashi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
ParameterValueSingle,
ParameterValueTuple,
)
from bashi.versions import COMPILERS, VERSIONS
from bashi.versions import COMPILERS, VERSIONS, NVCC_GCC_MAX_VERSION, NVCC_CLANG_MAX_VERSION
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import


Expand Down Expand Up @@ -316,6 +316,7 @@ def reason(output: Optional[IO[str]], msg: str):
)


# pylint: disable=too-many-branches
@typechecked
def get_expected_bashi_parameter_value_pairs(
parameter_matrix: ParameterValueMatrix,
Expand Down Expand Up @@ -383,4 +384,64 @@ def get_expected_bashi_parameter_value_pairs(
all_versions=False,
)

# remove all gcc version, which are to new for a specific nvcc version
nvcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[NVCC]]
nvcc_versions.sort()
gcc_versions = [packaging.version.parse(str(v)) for v in VERSIONS[GCC]]
gcc_versions.sort()
for nvcc_version in nvcc_versions:
for max_nvcc_clang_version in NVCC_GCC_MAX_VERSION:
if nvcc_version >= max_nvcc_clang_version.nvcc:
for clang_version in gcc_versions:
if clang_version > max_nvcc_clang_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
GCC,
clang_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
),
parameter_value_pairs=param_val_pair_list,
)
break

clang_versions = [packaging.version.parse(str(v)) for v in VERSIONS[CLANG]]
clang_versions.sort()

# remove all clang version, which are to new for a specific nvcc version
for nvcc_version in nvcc_versions:
for max_nvcc_clang_version in NVCC_CLANG_MAX_VERSION:
if nvcc_version >= max_nvcc_clang_version.nvcc:
for clang_version in clang_versions:
if clang_version > max_nvcc_clang_version.host:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
CLANG,
clang_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
),
parameter_value_pairs=param_val_pair_list,
)
break

# remove all pairs, where clang is host-compiler for nvcc 11.3, 11.4 and 11.5 as device compiler
for nvcc_version in [packaging.version.parse(str(v)) for v in [11.3, 11.4, 11.5]]:
for clang_version in clang_versions:
remove_parameter_value_pair(
to_remove=create_parameter_value_pair(
HOST_COMPILER,
CLANG,
clang_version,
DEVICE_COMPILER,
NVCC,
nvcc_version,
),
parameter_value_pairs=param_val_pair_list,
)

return param_val_pair_list
65 changes: 65 additions & 0 deletions bashi/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
from bashi.globals import * # pylint: disable=wildcard-import,unused-wildcard-import
from bashi.types import ValueName, ValueVersion, ParameterValue, ParameterValueMatrix


class NvccHostSupport:
"""Contains a nvcc version and host compiler version. Does automatically parse the input strings
to package.version.Version.
Provides comparision operators for sorting.
"""

def __init__(self, nvcc_version: str, host_version: str):
self.nvcc = pkv.parse(nvcc_version)
self.host = pkv.parse(host_version)

def __lt__(self, other: "NvccHostSupport") -> bool:
return self.nvcc < other.nvcc

def __eq__(self, other: object) -> bool:
if not isinstance(other, NvccHostSupport):
raise TypeError("does not support other types than NvccHostSupport")
return self.nvcc == other.nvcc and self.host == other.host

def __str__(self) -> str:
return f"nvcc {str(self.nvcc)} + host version {self.host}"


VERSIONS: Dict[str, List[Union[str, int, float]]] = {
GCC: [6, 7, 8, 9, 10, 11, 12, 13],
CLANG: [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17],
Expand All @@ -26,6 +50,7 @@
12.0,
12.1,
12.2,
12.3,
],
HIPCC: [5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 6.0],
ICPX: ["2023.1.0", "2023.2.0"],
Expand All @@ -45,6 +70,46 @@
CXX_STANDARD: [17, 20],
}

# define the maximum supported gcc version for a specific nvcc version
# the latest supported nvcc version must be added, even if the supported gcc version does not
# increase
# e.g.:
# NvccHostSupport("12.3", "12"),
# NvccHostSupport("12.0", "12"),
# NvccHostSupport("11.4", "11"),
NVCC_GCC_MAX_VERSION: List[NvccHostSupport] = [
NvccHostSupport("12.3", "12"),
NvccHostSupport("12.0", "12"),
NvccHostSupport("11.4", "11"),
NvccHostSupport("11.1", "10"),
NvccHostSupport("11.0", "9"),
NvccHostSupport("10.1", "8"),
NvccHostSupport("10.0", "7"),
]
NVCC_GCC_MAX_VERSION.sort(reverse=True)

# define the maximum supported clang version for a specific nvcc version
# the latest supported nvcc version must be added, even if the supported clang version does not
# increase
# e.g.:
# NvccHostSupport("12.3", "16"),
# NvccHostSupport("12.2", "15"),
# NvccHostSupport("12.1", "15"),
NVCC_CLANG_MAX_VERSION: List[NvccHostSupport] = [
NvccHostSupport("12.3", "16"),
NvccHostSupport("12.2", "15"),
NvccHostSupport("12.1", "15"),
NvccHostSupport("12.0", "14"),
NvccHostSupport("11.6", "13"),
NvccHostSupport("11.4", "12"),
NvccHostSupport("11.2", "11"),
NvccHostSupport("11.1", "10"),
NvccHostSupport("11.0", "9"),
NvccHostSupport("10.1", "8"),
NvccHostSupport("10.0", "6"),
]
NVCC_CLANG_MAX_VERSION.sort(reverse=True)


def get_parameter_value_matrix() -> ParameterValueMatrix:
"""Generates a parameter-value-matrix from all supported compilers, softwares and compilation
Expand Down
12 changes: 9 additions & 3 deletions example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import List
import os
import sys
from bashi.generator import generate_combination_list
from bashi.utils import (
get_expected_bashi_parameter_value_pairs,
Expand Down Expand Up @@ -225,8 +226,13 @@ def create_yaml(combination_list: CombinationList):
parameter_value_matrix=param_matrix, custom_filter=custom_filter
)

print("verify combination-list")
verify(comb_list, param_matrix)

create_yaml(comb_list)
print(f"number of combinations: {len(comb_list)}")

print("verify combination-list")
if verify(comb_list, param_matrix):
print("verification passed")
sys.exit(0)

print("verification failed")
sys.exit(1)
Loading

0 comments on commit 209a5df

Please sign in to comment.