Skip to content

Commit

Permalink
refactor: rename local_toolchain as toolkit for bzlmod
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 30, 2024
1 parent a04a94f commit faa178c
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ bazel_dep(name = "bazel_skylib", version = "1.4.2")
bazel_dep(name = "platforms", version = "0.0.6")

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ archive_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
25 changes: 14 additions & 11 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

load("//cuda/private:repositories.bzl", "local_cuda")

cuda_toolkit = tag_class(attrs = {
cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
})
Expand All @@ -22,24 +22,27 @@ def _find_modules(module_ctx):

return root, our_module

def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _init(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolchains = root.tags.local_toolchain or rules_cuda.tags.local_toolchain
toolkits = root.tags.toolkit or rules_cuda.tags.toolkit

registrations = {}
for toolchain in toolchains:
if toolchain.name in registrations.keys():
if toolchain.toolkit_path == registrations[toolchain.name]:
# No problem to register a matching toolchain twice
for toolkit in toolkits:
if toolkit.name in registrations.keys():
if toolkit.toolkit_path == registrations[toolkit.name].toolkit_path:
# No problem to register a matching toolkit twice
continue
fail("Multiple conflicting toolchains declared for name {} ({} and {}".format(toolchain.name, toolchain.toolkit_path, registrations[toolchain.name]))
fail("Multiple conflicting toolkits declared for name {} ({} and {}".format(toolkit.name, toolkit.toolkit_path, registrations[toolkit.name].toolkit_path))
else:
registrations[toolchain.name] = toolchain.toolkit_path
for name, toolkit_path in registrations.items():
local_cuda(name = name, toolkit_path = toolkit_path)
registrations[toolkit.name] = toolkit
for _, toolkit in registrations.items():
local_cuda(**_module_tag_to_dict(toolkit))

toolchain = module_extension(
implementation = _init,
tag_classes = {"local_toolchain": cuda_toolkit},
tag_classes = {"toolkit": cuda_toolkit_tag},
)
2 changes: 1 addition & 1 deletion docs/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
2 changes: 1 addition & 1 deletion examples/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/toolchain_none/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "/nonexistent/cuda/toolkit/path",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/toolchain_root/MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ local_path_override(
)

cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain")
cuda.local_toolchain(
cuda.toolkit(
name = "local_cuda",
toolkit_path = "",
)
Expand Down

0 comments on commit faa178c

Please sign in to comment.