Skip to content

Commit

Permalink
fix(bzlmod): allow both root module and our module to call cuda.local…
Browse files Browse the repository at this point in the history
…_toolchain (#264)

- fix(bzlmod): allow both root module and our module to call cuda.local_toolchain
- test: add repo integration tests
  • Loading branch information
cloudhan authored Aug 12, 2024
1 parent 07bd546 commit f00f640
Show file tree
Hide file tree
Showing 14 changed files with 183 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ jobs:
- run: cd examples && bazelisk build --jobs=1 //if_cuda:main
- run: cd examples && bazelisk build --jobs=1 //if_cuda:main --enable_cuda=False
- run: bazelisk shutdown
# run some repo integration tests
- run: cd tests/integration && ./test_all.sh

# Use Bazel 6
- run: echo "USE_BAZEL_VERSION=6.4.0" >> $GITHUB_ENV
Expand Down
4 changes: 4 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ bazel_dep(name = "bazel_skylib", version = "1.4.2")
bazel_dep(name = "platforms", version = "0.0.6")

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
name = "local_cuda",
toolkit_path = "",
)
use_repo(cuda, "local_cuda")

register_toolchains(
Expand Down
38 changes: 27 additions & 11 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,35 @@ cuda_toolkit = tag_class(attrs = {
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
})

def _find_modules(module_ctx):
root = None
our_module = None
for mod in module_ctx.modules:
if mod.is_root:
root = mod
if mod.name == "rules_cuda":
our_module = mod
if root == None:
root = our_module
if our_module == None:
fail("Unable to find rules_cuda module")

return root, our_module

def _init(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolchains = root.tags.local_toolchain or rules_cuda.tags.local_toolchain

registrations = {}
for mod in module_ctx.modules:
for toolchain in mod.tags.local_toolchain:
if not mod.is_root:
fail("Only the root module may override the path for the local cuda toolchain")
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for toolchain in toolchains:
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for name, toolkit_path in registrations.items():
local_cuda(name = name, toolkit_path = toolkit_path)

Expand Down
2 changes: 2 additions & 0 deletions examples/if_cuda/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda")

package(default_visibility = ["//visibility:public"])

cuda_library(
name = "kernel",
srcs = ["kernel.cu"],
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/BUILD.to_symlink
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
load("@rules_cuda//cuda:defs.bzl", "cuda_library", "requires_cuda")

cc_library(
name = "use_library",
tags = ["manual"],
deps = ["@local_cuda//:cuda_runtime"],
)

cuda_library(
name = "use_rule",
srcs = ["@rules_cuda_examples//basic:kernel.cu"],
hdrs = ["@rules_cuda_examples//basic:kernel.h"],
tags = ["manual"],
)

cuda_library(
name = "optional_kernel",
srcs = ["@rules_cuda_examples//if_cuda:kernel.cu"],
hdrs = ["@rules_cuda_examples//if_cuda:kernel.h"],
tags = ["manual"],
target_compatible_with = requires_cuda(),
)

cc_binary(
name = "optinally_use_rule",
srcs = ["@rules_cuda_examples//if_cuda:main.cpp"],
defines = [] + select({
"@rules_cuda//cuda:is_enabled": ["CUDA_ENABLED"],
"//conditions:default": ["CUDA_DISABLED"],
}),
tags = ["manual"],
deps = [] + select({
"@rules_cuda//cuda:is_enabled": [":optional_kernel"],
"//conditions:default": [],
}),
)
1 change: 1 addition & 0 deletions tests/integration/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
module(name = "rules_cuda_integration_tests")
1 change: 1 addition & 0 deletions tests/integration/WORKSPACE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
workspace(name = "rules_cuda_integration_tests")
51 changes: 51 additions & 0 deletions tests/integration/test_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash

this_dir=$(realpath $(dirname $0))

set -ev

# toolchain configured by the root module of the user
pushd "$this_dir/toolchain_root"
bazel build //... --@rules_cuda//cuda:enable=False
bazel build //... --@rules_cuda//cuda:enable=True
bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False
bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True
bazel build //:use_library
bazel build //:use_rule
bazel clean && bazel shutdown
popd

# toolchain does not exists
pushd "$this_dir/toolchain_none"
# analysis pass
bazel build //... --@rules_cuda//cuda:enable=False
bazel build //... --@rules_cuda//cuda:enable=True

# force build optional targets
bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False
ERR=$(bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True 2>&1 || true)
if ! [[ $ERR == *"didn't satisfy constraint"*"valid_toolchain_is_configured"* ]]; then exit 1; fi

# use library fails because the library file does not exist
ERR=$(bazel build //:use_library 2>&1 || true)
if ! [[ $ERR =~ "target 'cuda_runtime' not declared in package" ]]; then exit 1; fi
if ! [[ $ERR =~ "ERROR: Analysis of target '//:use_library' failed" ]]; then exit 1; fi

# use rule fails because rules_cuda depends non-existent cuda toolkit
ERR=$(bazel build //:use_rule 2>&1 || true)
if ! [[ $ERR =~ "target 'cuda_runtime' not declared in package" ]]; then exit 1; fi
if ! [[ $ERR =~ "ERROR: Analysis of target '//:use_rule' failed" ]]; then exit 1; fi

bazel clean && bazel shutdown
popd

# toolchain configured by rules_cuda
pushd "$this_dir/toolchain_rules"
bazel build //... --@rules_cuda//cuda:enable=False
bazel build //... --@rules_cuda//cuda:enable=True
bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False
bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True
bazel build //:use_library
bazel build //:use_rule
bazel clean && bazel shutdown
popd
1 change: 1 addition & 0 deletions tests/integration/toolchain_none/BUILD.bazel
20 changes: 20 additions & 0 deletions tests/integration/toolchain_none/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module(name = "toolchain_none")

bazel_dep(name = "rules_cuda", version = "0.0.0")
local_path_override(
module_name = "rules_cuda",
path = "../../..",
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
name = "local_cuda",
toolkit_path = "/nonexistent/cuda/toolkit/path",
)
use_repo(cuda, "local_cuda")

bazel_dep(name = "rules_cuda_examples", version = "0.0.0")
local_path_override(
module_name = "rules_cuda_examples",
path = "../../../examples",
)
1 change: 1 addition & 0 deletions tests/integration/toolchain_root/BUILD.bazel
20 changes: 20 additions & 0 deletions tests/integration/toolchain_root/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module(name = "bzlmod_use_repo_no_toolchain")

bazel_dep(name = "rules_cuda", version = "0.0.0")
local_path_override(
module_name = "rules_cuda",
path = "../../..",
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
name = "local_cuda",
toolkit_path = "",
)
use_repo(cuda, "local_cuda")

bazel_dep(name = "rules_cuda_examples", version = "0.0.0")
local_path_override(
module_name = "rules_cuda_examples",
path = "../../../examples",
)
1 change: 1 addition & 0 deletions tests/integration/toolchain_rules/BUILD.bazel
16 changes: 16 additions & 0 deletions tests/integration/toolchain_rules/MODULE.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module(name = "bzlmod_use_repo")

bazel_dep(name = "rules_cuda", version = "0.0.0")
local_path_override(
module_name = "rules_cuda",
path = "../../..",
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
use_repo(cuda, "local_cuda")

bazel_dep(name = "rules_cuda_examples", version = "0.0.0")
local_path_override(
module_name = "rules_cuda_examples",
path = "../../../examples",
)

0 comments on commit f00f640

Please sign in to comment.