Skip to content

Commit

Permalink
refactor!: change for hermetic with minor user interface breaking (#300)
Browse files Browse the repository at this point in the history
* refactor!: rules_cuda_dependencies only pull in bazel deps for rules_cuda and split the local_cuda part into newly added rules_cuda_toolchains

* refactor!: rename local_toolchain as toolkit for bzlmod
  • Loading branch information
cloudhan authored Feb 22, 2025
1 parent 458eccf commit fe377e2
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 38 deletions.
6 changes: 3 additions & 3 deletions .github/release_notes.template
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ http_archive(
urls = ["https://github.com/bazel-contrib/rules_cuda/releases/download/{version}/rules_cuda-{version}.tar.gz"],
)

load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")
load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains")
rules_cuda_dependencies()
register_detected_cuda_toolchains()
```
rules_cuda_toolchains(register_toolchains = True)
```
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.4.2")
bazel_dep(name = "platforms", version = "0.0.6")

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ http_archive(
strip_prefix = "rules_cuda-{git_commit_hash}",
urls = ["https://github.com/bazel-contrib/rules_cuda/archive/{git_commit_hash}.tar.gz"],
)
load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")
load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains")
rules_cuda_dependencies()
register_detected_cuda_toolchains()
rules_cuda_toolchains(register_toolchains = True)
```

**NOTE**: the use of `register_detected_cuda_toolchains` depends on the environment variable `CUDA_PATH`. You must also
ensure the host compiler is available. On Windows, this means that you will also need to set the environment variable
**NOTE**: `rules_cuda_toolchains` implicitly calls to `register_detected_cuda_toolchains`, and the use of
`register_detected_cuda_toolchains` depends on the environment variable `CUDA_PATH`. You must also ensure the
host compiler is available. On Windows, this means that you will also need to set the environment variable
`BAZEL_VC` properly.

[`detect_cuda_toolkit`](https://github.com/bazel-contrib/rules_cuda/blob/5633f0c0f7/cuda/private/repositories.bzl#L28-L58)
Expand All @@ -47,7 +48,7 @@ archive_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ local_repository(
path = "examples",
)

load("//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")
load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains")

rules_cuda_dependencies()

register_detected_cuda_toolchains()
rules_cuda_toolchains(register_toolchains = True)

load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")

Expand Down
25 changes: 14 additions & 11 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

load("//cuda/private:repositories.bzl", "local_cuda")

cuda_toolkit = tag_class(attrs = {
cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
})
Expand All @@ -22,24 +22,27 @@ def _find_modules(module_ctx):

return root, our_module

def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _init(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolchains = root.tags.local_toolchain or rules_cuda.tags.local_toolchain
toolkits = root.tags.toolkit or rules_cuda.tags.toolkit

registrations = {}
for toolchain in toolchains:
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
for toolkit in toolkits:
if toolkit.name in registrations.keys():
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
# No problem to register a matching toolkit twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for name, toolkit_path in registrations.items():
local_cuda(name = name, toolkit_path = toolkit_path)
registrations[toolkit.name] = toolkit
for _, toolkit in registrations.items():
local_cuda(**_module_tag_to_dict(toolkit))

toolchain = module_extension(
implementation = _init,
tag_classes = {"local_toolchain": cuda_toolkit},
tag_classes = {"toolkit": cuda_toolkit_tag},
)
25 changes: 18 additions & 7 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("//cuda/private:template_helper.bzl", "template_helper")
load("//cuda/private:toolchain.bzl", "register_detected_cuda_toolchains")

def _to_forward_slash(s):
return s.replace("\\", "/")
Expand Down Expand Up @@ -197,12 +198,8 @@ local_cuda = repository_rule(
# remotable = True,
)

def rules_cuda_dependencies(toolkit_path = None):
"""Populate the dependencies for rules_cuda. This will setup workspace dependencies (other bazel rules) and local toolchains.
Args:
toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically.
"""
def rules_cuda_dependencies():
"""Populate the dependencies for rules_cuda. This will setup other bazel rules as workspace dependencies"""
maybe(
name = "bazel_skylib",
repo_rule = http_archive,
Expand All @@ -223,4 +220,18 @@ def rules_cuda_dependencies(toolkit_path = None):
],
)

local_cuda(name = "local_cuda", toolkit_path = toolkit_path)
def rules_cuda_toolchains(toolkit_path = None, register_toolchains = False):
"""Populate the local_cuda repo.
Args:
toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically.
register_toolchains: Register the toolchains if enabled.
"""

local_cuda(
name = "local_cuda",
toolkit_path = toolkit_path,
)

if register_toolchains:
register_detected_cuda_toolchains()
10 changes: 8 additions & 2 deletions cuda/repositories.bzl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
load("//cuda/private:repositories.bzl", _local_cuda = "local_cuda", _rules_cuda_dependencies = "rules_cuda_dependencies")
load(
"//cuda/private:repositories.bzl",
_local_cuda = "local_cuda",
_rules_cuda_dependencies = "rules_cuda_dependencies",
_rules_cuda_toolchains = "rules_cuda_toolchains",
)
load("//cuda/private:toolchain.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains")

rules_cuda_dependencies = _rules_cuda_dependencies
local_cuda = _local_cuda
rules_cuda_dependencies = _rules_cuda_dependencies
rules_cuda_toolchains = _rules_cuda_toolchains
register_detected_cuda_toolchains = _register_detected_cuda_toolchains
2 changes: 1 addition & 1 deletion docs/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
8 changes: 7 additions & 1 deletion docs/user_docs.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
load("@rules_cuda//cuda:defs.bzl", _cuda_binary = "cuda_binary", _cuda_library = "cuda_library", _cuda_objects = "cuda_objects", _cuda_test = "cuda_test")
load("@rules_cuda//cuda:repositories.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains", _rules_cuda_dependencies = "rules_cuda_dependencies")
load(
"@rules_cuda//cuda:repositories.bzl",
_register_detected_cuda_toolchains = "register_detected_cuda_toolchains",
_rules_cuda_dependencies = "rules_cuda_dependencies",
_rules_cuda_toolchains = "rules_cuda_toolchains",
)
load("@rules_cuda//cuda/private:rules/flags.bzl", _cuda_archs_flag = "cuda_archs_flag")

cuda_library = _cuda_library
Expand All @@ -12,3 +17,4 @@ cuda_archs = _cuda_archs_flag

register_detected_cuda_toolchains = _register_detected_cuda_toolchains
rules_cuda_dependencies = _rules_cuda_dependencies
rules_cuda_toolchains = _rules_cuda_toolchains
2 changes: 1 addition & 1 deletion examples/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
4 changes: 2 additions & 2 deletions examples/WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ local_repository(
# If you want to have a different version of some dependency,
# you should fetch it *before* calling this.

load("@rules_cuda//cuda:repositories.bzl", "register_detected_cuda_toolchains", "rules_cuda_dependencies")
load("@rules_cuda//cuda:repositories.bzl", "rules_cuda_dependencies", "rules_cuda_toolchains")

rules_cuda_dependencies()

register_detected_cuda_toolchains()
rules_cuda_toolchains(register_toolchains = True)

#################################
# Dependencies for nccl example #
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/toolchain_none/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "/nonexistent/cuda/toolkit/path",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/toolchain_root/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down

0 comments on commit fe377e2

Please sign in to comment.