Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hermetic ctk with deliverable #285

Merged
merged 8 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions cuda/extensions.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,51 @@
"""Entry point for extensions used by bzlmod."""

load("//cuda/private:repositories.bzl", "local_cuda")
load("//cuda/private:compat.bzl", "components_mapping_compat")
load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda")

cuda_component_tag = tag_class(attrs = {
"name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"),
"component_name": attr.string(doc = "Short name of the component defined in registry."),
"integrity": attr.string(
doc = "Expected checksum in Subresource Integrity format of the file downloaded. " +
"This must match the checksum of the file downloaded.",
),
"sha256": attr.string(
doc = "The expected SHA-256 of the file downloaded. This must match the SHA-256 of the file downloaded.",
),
"strip_prefix": attr.string(
doc = "A directory prefix to strip from the extracted files. " +
"Many archives contain a top-level directory that contains all of the useful files in archive.",
),
"url": attr.string(
doc = "A URL to a file that will be made available to Bazel. " +
"This must be a file, http or https URL." +
"Redirections are followed. Authentication is not supported. " +
"More flexibility can be achieved by the urls parameter that allows " +
"to specify alternative URLs to fetch from.",
),
"urls": attr.string_list(
doc = "A list of URLs to a file that will be made available to Bazel. " +
"Each entry must be a file, http or https URL. " +
"Redirections are followed. Authentication is not supported. " +
"URLs are tried in order until one succeeds, so you should list local mirrors first. " +
"If all downloads fail, the rule will fail.",
),
})

cuda_toolkit_tag = tag_class(attrs = {
"name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."),
"name": attr.string(mandatory = True, doc = "Name for the toolchain repository", default = "local_cuda"),
"toolkit_path": attr.string(
doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path.",
),
"components_mapping": components_mapping_compat.attr(
doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " +
"Only the repo part of the label is usefull",
),
"version": attr.string(doc = "cuda toolkit version. Required for deliverable toolkit only."),
"nvcc_version": attr.string(
doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.",
),
})

def _find_modules(module_ctx):
Expand All @@ -25,10 +66,20 @@ def _find_modules(module_ctx):
def _module_tag_to_dict(t):
return {attr: getattr(t, attr) for attr in dir(t)}

def _init(module_ctx):
def _impl(module_ctx):
# Toolchain configuration is only allowed in the root module, or in rules_cuda.
root, rules_cuda = _find_modules(module_ctx)
toolkits = root.tags.toolkit or rules_cuda.tags.toolkit
components = None
toolkits = None
if root.tags.toolkit:
components = root.tags.component
toolkits = root.tags.toolkit
else:
components = rules_cuda.tags.component
toolkits = rules_cuda.tags.toolkit

for component in components:
cuda_component(**_module_tag_to_dict(component))

registrations = {}
for toolkit in toolkits:
Expand All @@ -43,6 +94,9 @@ def _init(module_ctx):
local_cuda(**_module_tag_to_dict(toolkit))

toolchain = module_extension(
implementation = _init,
tag_classes = {"toolkit": cuda_toolkit_tag},
implementation = _impl,
tag_classes = {
"component": cuda_component_tag,
"toolkit": cuda_toolkit_tag,
},
)
26 changes: 26 additions & 0 deletions cuda/private/compat.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
_is_attr_string_keyed_label_dict_available = getattr(attr, "string_keyed_label_dict", None) != None
_is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@")

def _attr(*args, **kwargs):
"""Compatibility layer for attr.string_keyed_label_dict(...)"""
if _is_attr_string_keyed_label_dict_available:
return attr.string_keyed_label_dict(*args, **kwargs)
else:
return attr.string_dict(*args, **kwargs)

def _repo_str(repo_str_or_repo_label):
"""Get mapped repo as string.

Args:
repo_str_or_repo_label: `"@repo"` or `Label("@repo")` """
if type(repo_str_or_repo_label) == "Label":
canonical_repo_name = repo_str_or_repo_label.repo_name
repo_str = ("@@{}" if _is_bzlmod_enabled else "@{}").format(canonical_repo_name)
return repo_str
else:
return repo_str_or_repo_label

components_mapping_compat = struct(
attr = _attr,
repo_str = _repo_str,
)
Loading
Loading