diff --git a/cuda/extensions.bzl b/cuda/extensions.bzl index edb80a48..35368f4b 100644 --- a/cuda/extensions.bzl +++ b/cuda/extensions.bzl @@ -1,10 +1,51 @@ """Entry point for extensions used by bzlmod.""" -load("//cuda/private:repositories.bzl", "local_cuda") +load("//cuda/private:compat.bzl", "components_mapping_compat") +load("//cuda/private:repositories.bzl", "cuda_component", "local_cuda") + +cuda_component_tag = tag_class(attrs = { + "name": attr.string(mandatory = True, doc = "Repo name for the deliverable cuda_component"), + "component_name": attr.string(doc = "Short name of the component defined in registry."), + "integrity": attr.string( + doc = "Expected checksum in Subresource Integrity format of the file downloaded. " + + "This must match the checksum of the file downloaded.", + ), + "sha256": attr.string( + doc = "The expected SHA-256 of the file downloaded. This must match the SHA-256 of the file downloaded.", + ), + "strip_prefix": attr.string( + doc = "A directory prefix to strip from the extracted files. " + + "Many archives contain a top-level directory that contains all of the useful files in archive.", + ), + "url": attr.string( + doc = "A URL to a file that will be made available to Bazel. " + + "This must be a file, http or https URL." + + "Redirections are followed. Authentication is not supported. " + + "More flexibility can be achieved by the urls parameter that allows " + + "to specify alternative URLs to fetch from.", + ), + "urls": attr.string_list( + doc = "A list of URLs to a file that will be made available to Bazel. " + + "Each entry must be a file, http or https URL. " + + "Redirections are followed. Authentication is not supported. " + + "URLs are tried in order until one succeeds, so you should list local mirrors first. " + + "If all downloads fail, the rule will fail.", + ), +}) cuda_toolkit_tag = tag_class(attrs = { - "name": attr.string(doc = "Name for the toolchain repository", default = "local_cuda"), - "toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."), + "name": attr.string(mandatory = True, doc = "Name for the toolchain repository", default = "local_cuda"), + "toolkit_path": attr.string( + doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path.", + ), + "components_mapping": components_mapping_compat.attr( + doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " + + "Only the repo part of the label is usefull", + ), + "version": attr.string(doc = "cuda toolkit version. Required for deliverable toolkit only."), + "nvcc_version": attr.string( + doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.", + ), }) def _find_modules(module_ctx): @@ -25,10 +66,20 @@ def _find_modules(module_ctx): def _module_tag_to_dict(t): return {attr: getattr(t, attr) for attr in dir(t)} -def _init(module_ctx): +def _impl(module_ctx): # Toolchain configuration is only allowed in the root module, or in rules_cuda. root, rules_cuda = _find_modules(module_ctx) - toolkits = root.tags.toolkit or rules_cuda.tags.toolkit + components = None + toolkits = None + if root.tags.toolkit: + components = root.tags.component + toolkits = root.tags.toolkit + else: + components = rules_cuda.tags.component + toolkits = rules_cuda.tags.toolkit + + for component in components: + cuda_component(**_module_tag_to_dict(component)) registrations = {} for toolkit in toolkits: @@ -43,6 +94,9 @@ def _init(module_ctx): local_cuda(**_module_tag_to_dict(toolkit)) toolchain = module_extension( - implementation = _init, - tag_classes = {"toolkit": cuda_toolkit_tag}, + implementation = _impl, + tag_classes = { + "component": cuda_component_tag, + "toolkit": cuda_toolkit_tag, + }, ) diff --git a/cuda/private/compat.bzl b/cuda/private/compat.bzl new file mode 100644 index 00000000..e2f39837 --- /dev/null +++ b/cuda/private/compat.bzl @@ -0,0 +1,26 @@ +_is_attr_string_keyed_label_dict_available = getattr(attr, "string_keyed_label_dict", None) != None +_is_bzlmod_enabled = str(Label("//:invalid")).startswith("@@") + +def _attr(*args, **kwargs): + """Compatibility layer for attr.string_keyed_label_dict(...)""" + if _is_attr_string_keyed_label_dict_available: + return attr.string_keyed_label_dict(*args, **kwargs) + else: + return attr.string_dict(*args, **kwargs) + +def _repo_str(repo_str_or_repo_label): + """Get mapped repo as string. + + Args: + repo_str_or_repo_label: `"@repo"` or `Label("@repo")` """ + if type(repo_str_or_repo_label) == "Label": + canonical_repo_name = repo_str_or_repo_label.repo_name + repo_str = ("@@{}" if _is_bzlmod_enabled else "@{}").format(canonical_repo_name) + return repo_str + else: + return repo_str_or_repo_label + +components_mapping_compat = struct( + attr = _attr, + repo_str = _repo_str, +) diff --git a/cuda/private/repositories.bzl b/cuda/private/repositories.bzl index 8325729b..68590c3a 100644 --- a/cuda/private/repositories.bzl +++ b/cuda/private/repositories.bzl @@ -2,20 +2,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") +load("//cuda/private:compat.bzl", "components_mapping_compat") load("//cuda/private:template_helper.bzl", "template_helper") +load("//cuda/private:templates/registry.bzl", "FULL_COMPONENT_NAME", "REGISTRY") load("//cuda/private:toolchain.bzl", "register_detected_cuda_toolchains") -def _to_forward_slash(s): - return s.replace("\\", "/") - def _is_linux(ctx): return ctx.os.name.startswith("linux") def _is_windows(ctx): return ctx.os.name.lower().startswith("windows") -def _get_nvcc_version(repository_ctx, cuda_path): - result = repository_ctx.execute([cuda_path + "/bin/nvcc", "--version"]) +def _get_nvcc_version(repository_ctx, nvcc_root): + result = repository_ctx.execute([nvcc_root + "/bin/nvcc", "--version"]) if result.return_code != 0: return [-1, -1] for line in [line for line in result.stdout.split("\n") if ", release " in line]: @@ -27,21 +26,7 @@ def _get_nvcc_version(repository_ctx, cuda_path): return version[:2] return [-1, -1] -def detect_cuda_toolkit(repository_ctx): - """Detect CUDA Toolkit. - - The path to CUDA Toolkit is determined as: - - the value of `toolkit_path` passed to local_cuda as an attribute - - taken from `CUDA_PATH` environment variable or - - determined through 'which ptxas' or - - defaults to '/usr/local/cuda' - - Args: - repository_ctx: repository_ctx - - Returns: - A struct contains the information of CUDA Toolkit. - """ +def _detect_local_cuda_toolkit(repository_ctx): cuda_path = repository_ctx.attr.toolkit_path if cuda_path == "": cuda_path = repository_ctx.os.environ.get("CUDA_PATH", None) @@ -101,6 +86,67 @@ def detect_cuda_toolkit(repository_ctx): fatbinary_label = fatbinary, ) +def _detect_deliverable_cuda_toolkit(repository_ctx): + # NOTE: component nvcc contains some headers that will be used. + required_components = ["cccl", "cudart", "nvcc"] + for rc in required_components: + if rc not in repository_ctx.attr.components_mapping: + fail('component "{}" is required.'.format(rc)) + + nvcc_repo = components_mapping_compat.repo_str(repository_ctx.attr.components_mapping["nvcc"]) + + bin_ext = ".exe" if _is_windows(repository_ctx) else "" + nvcc = "{}//:nvcc/bin/nvcc{}".format(nvcc_repo, bin_ext) + nvlink = "{}//:nvcc/bin/nvlink{}".format(nvcc_repo, bin_ext) + link_stub = "{}//:nvcc/bin/crt/link.stub".format(nvcc_repo) + bin2c = "{}//:nvcc/bin/bin2c{}".format(nvcc_repo, bin_ext) + fatbinary = "{}//:nvcc/bin/fatbinary{}".format(nvcc_repo, bin_ext) + + cuda_version_str = repository_ctx.attr.version + if cuda_version_str == None or cuda_version_str == "": + fail("attr version is required.") + + nvcc_version_str = repository_ctx.attr.nvcc_version + if nvcc_version_str == None or nvcc_version_str == "": + nvcc_version_str = cuda_version_str + + cuda_version_major, cuda_version_minor = cuda_version_str.split(".")[:2] + nvcc_version_major, nvcc_version_minor = nvcc_version_str.split(".")[:2] + + return struct( + path = None, # scattered components + version_major = cuda_version_major, + version_minor = cuda_version_minor, + nvcc_version_major = nvcc_version_major, + nvcc_version_minor = nvcc_version_minor, + nvcc_label = nvcc, + nvlink_label = nvlink, + link_stub_label = link_stub, + bin2c_label = bin2c, + fatbinary_label = fatbinary, + ) + +def detect_cuda_toolkit(repository_ctx): + """Detect CUDA Toolkit. + + The path to CUDA Toolkit is determined as: + - use nvcc component from deliverable + - the value of `toolkit_path` passed to local_cuda as an attribute + - taken from `CUDA_PATH` environment variable or + - determined through 'which ptxas' or + - defaults to '/usr/local/cuda' + + Args: + repository_ctx: repository_ctx + + Returns: + A struct contains the information of CUDA Toolkit. + """ + if repository_ctx.attr.components_mapping != {}: + return _detect_deliverable_cuda_toolkit(repository_ctx) + else: + return _detect_local_cuda_toolkit(repository_ctx) + def config_cuda_toolkit_and_nvcc(repository_ctx, cuda): """Generate `@local_cuda//BUILD` and `@local_cuda//defs.bzl` and `@local_cuda//toolchain/BUILD` @@ -109,30 +155,40 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda): cuda: The struct returned from detect_cuda_toolkit """ - # True: locally installed cuda toolkit - # False: hermatic cuda toolkit (components) + # True: locally installed cuda toolkit (@local_cuda with full install of local CTK) + # False: hermatic cuda toolkit (@local_cuda with alias of components) # None: cuda toolkit is not presented - is_local_cuda = None - if cuda.path != None: + is_local_ctk = None + + if len(repository_ctx.attr.components_mapping) != 0: + is_local_ctk = False + + if is_local_ctk == None and cuda.path != None: # When using a special cuda toolkit path install, need to manually fix up the lib64 links if cuda.path == "/usr/lib/nvidia-cuda-toolkit": repository_ctx.symlink(cuda.path + "/bin", "cuda/bin") repository_ctx.symlink("/usr/lib/x86_64-linux-gnu", "cuda/lib64") else: repository_ctx.symlink(cuda.path, "cuda") - is_local_cuda = True + is_local_ctk = True # Generate @local_cuda//BUILD - if is_local_cuda == None: + if is_local_ctk == None: repository_ctx.symlink(Label("//cuda/private:templates/BUILD.local_cuda_disabled"), "BUILD") - elif is_local_cuda: + elif is_local_ctk: libpath = "lib64" if _is_linux(repository_ctx) else "lib" template_helper.generate_build(repository_ctx, libpath) else: - fail("hermatic cuda toolchain is not implemented") + template_helper.generate_build( + repository_ctx, + libpath = "lib", + components = repository_ctx.attr.components_mapping, + is_local_cuda = True, + is_deliverable = True, + ) # Generate @local_cuda//defs.bzl - template_helper.generate_defs_bzl(repository_ctx, is_local_cuda) + template_helper.generate_defs_bzl(repository_ctx, is_local_ctk == True) # Generate @local_cuda//toolchain/BUILD template_helper.generate_toolchain_build(repository_ctx, cuda) @@ -191,13 +247,96 @@ def _local_cuda_impl(repository_ctx): local_cuda = repository_rule( implementation = _local_cuda_impl, - attrs = {"toolkit_path": attr.string(mandatory = False)}, + attrs = { + "toolkit_path": attr.string(doc = "Path to the CUDA SDK, if empty the environment variable CUDA_PATH will be used to deduce this path."), + "components_mapping": components_mapping_compat.attr( + doc = "A mapping from component names to component repos of a deliverable CUDA Toolkit. " + + "Only the repo part of the label is usefull", + ), + "version": attr.string(doc = "cuda toolkit version. Required for deliverable toolkit only."), + "nvcc_version": attr.string( + doc = "nvcc version. Required for deliverable toolkit only. Fallback to version if omitted.", + ), + }, configure = True, local = True, environ = ["CUDA_PATH", "PATH", "CUDA_CLANG_PATH", "BAZEL_LLVM"], # remotable = True, ) +def _cuda_component_impl(repository_ctx): + component_name = None + if repository_ctx.attr.component_name: + component_name = repository_ctx.attr.component_name + if component_name not in REGISTRY: + fail("invalid component '{}', available: {}".format(component_name, repr(REGISTRY.keys()))) + else: + component_name = repository_ctx.name[len("local_cuda_"):] + if component_name not in REGISTRY: + fail("invalid derived component '{}', available: {}, ".format(component_name, repr(REGISTRY.keys())) + + " if derivation result is unexpected, please specify `component_name` attribute manually") + + if not repository_ctx.attr.url and not repository_ctx.attr.urls: + fail("either attribute `url` or `urls` must be filled") + if repository_ctx.attr.url and repository_ctx.attr.urls: + fail("attributes `url` and `urls` cannot be used at the same time") + + repository_ctx.download_and_extract( + url = repository_ctx.attr.url or repository_ctx.attr.urls, + output = component_name, + integrity = repository_ctx.attr.integrity, + sha256 = repository_ctx.attr.sha256, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + + template_helper.generate_build( + repository_ctx, + libpath = "lib", + components = {component_name: repository_ctx.name}, + is_local_cuda = False, + is_deliverable = True, + ) + +cuda_component = repository_rule( + implementation = _cuda_component_impl, + attrs = { + "component_name": attr.string(doc = "Short name of the component defined in registry."), + "integrity": attr.string( + doc = "Expected checksum in Subresource Integrity format of the file downloaded. " + + "This must match the checksum of the file downloaded.", + ), + "sha256": attr.string( + doc = "The expected SHA-256 of the file downloaded. This must match the SHA-256 of the file downloaded.", + ), + "strip_prefix": attr.string( + doc = "A directory prefix to strip from the extracted files. " + + "Many archives contain a top-level directory that contains all of the useful files in archive.", + ), + "url": attr.string( + doc = "A URL to a file that will be made available to Bazel. " + + "This must be a file, http or https URL." + + "Redirections are followed. Authentication is not supported. " + + "More flexibility can be achieved by the urls parameter that allows " + + "to specify alternative URLs to fetch from.", + ), + "urls": attr.string_list( + doc = "A list of URLs to a file that will be made available to Bazel. " + + "Each entry must be a file, http or https URL. " + + "Redirections are followed. Authentication is not supported. " + + "URLs are tried in order until one succeeds, so you should list local mirrors first. " + + "If all downloads fail, the rule will fail.", + ), + }, +) + +def default_components_mapping(components): + """Create a default components_mapping from list of component names. + + Args: + components: list of string, a list of component names. + """ + return {c: "@local_cuda_" + c for c in components} + def rules_cuda_dependencies(): """Populate the dependencies for rules_cuda. This will setup other bazel rules as workspace dependencies""" maybe( @@ -220,17 +359,23 @@ def rules_cuda_dependencies(): ], ) -def rules_cuda_toolchains(toolkit_path = None, register_toolchains = False): +def rules_cuda_toolchains(toolkit_path = None, components_mapping = None, version = None, nvcc_version = None, register_toolchains = False): """Populate the local_cuda repo. Args: toolkit_path: Optionally specify the path to CUDA toolkit. If not specified, it will be detected automatically. + components_mapping: dict mapping from component_name to its corresponding cuda_component's repo_name + version: str for cuda toolkit version. Required for deliverable toolkit only. + nvcc_version: str for nvcc version. Required for deliverable toolkit only. Fallback to version if omitted. register_toolchains: Register the toolchains if enabled. """ local_cuda( name = "local_cuda", toolkit_path = toolkit_path, + components_mapping = components_mapping, + version = version, + nvcc_version = nvcc_version, ) if register_toolchains: diff --git a/cuda/private/template_helper.bzl b/cuda/private/template_helper.bzl index 72c7dca4..9590b7ba 100644 --- a/cuda/private/template_helper.bzl +++ b/cuda/private/template_helper.bzl @@ -1,3 +1,4 @@ +load("//cuda/private:compat.bzl", "components_mapping_compat") load("//cuda/private:templates/registry.bzl", "REGISTRY") def _to_forward_slash(s): @@ -9,35 +10,82 @@ def _is_linux(ctx): def _is_windows(ctx): return ctx.os.name.lower().startswith("windows") -def _generate_build(repository_ctx, libpath): +def _generate_local_cuda_build_impl(repository_ctx, libpath, components, is_local_cuda, is_deliverable): # stitch template fragment fragments = [ Label("//cuda/private:templates/BUILD.local_cuda_shared"), Label("//cuda/private:templates/BUILD.local_cuda_headers"), Label("//cuda/private:templates/BUILD.local_cuda_build_setting"), ] - fragments.extend([Label("//cuda/private:templates/BUILD.{}".format(c)) for c in REGISTRY if len(REGISTRY[c]) > 0]) + if is_local_cuda and not is_deliverable: # generate `@local_cuda//BUILD` for local host CTK + fragments.extend([Label("//cuda/private:templates/BUILD.{}".format(c)) for c in components]) + elif is_local_cuda and is_deliverable: # generate `@local_cuda//BUILD` for CTK with deliverables + pass + elif not is_local_cuda and is_deliverable: # generate `@local_cuda_//BUILD` for a deliverable + if len(components) != 1: + fail("one deliverable at a time") + fragments.append(Label("//cuda/private:templates/BUILD.{}".format(components.keys()[0]))) + else: + fail("unreachable") template_content = [] for frag in fragments: template_content.append("# Generated from fragment " + str(frag)) template_content.append(repository_ctx.read(frag)) + if is_local_cuda and is_deliverable: # generate `@local_cuda//BUILD` for CTK with deliverables + for comp in components: + for target in REGISTRY[comp]: + repo = components_mapping_compat.repo_str(components[comp]) + line = 'alias(name = "{target}", actual = "{repo}//:{target}")'.format(target = target, repo = repo) + template_content.append(line) + + # add an empty line to separate aliased targets from different components + template_content.append("") + template_content = "\n".join(template_content) template_path = repository_ctx.path("BUILD.tpl") repository_ctx.file(template_path, content = template_content, executable = False) substitutions = { - "%{component_name}": "cuda", + "%{component_name}": "cuda" if is_local_cuda else components.keys()[0], "%{libpath}": libpath, } repository_ctx.template("BUILD", template_path, substitutions = substitutions, executable = False) -def _generate_defs_bzl(repository_ctx, is_local_cuda): +def _generate_build(repository_ctx, libpath, components = None, is_local_cuda = True, is_deliverable = False): + """Generate `@local_cuda//BUILD` + + Notes: + - is_local_cuda==False and is_deliverable==False is an error + - is_local_cuda==True and is_deliverable==False generate `@local_cuda//BUILD` for local host CTK + - is_local_cuda==True and is_deliverable==True generate `@local_cuda//BUILD` for CTK with deliverables + - is_local_cuda==False and is_deliverable==True generate `@local_cuda_//BUILD` for a deliverable + generates `@local_cuda//BUILD` + + Args: + repository_ctx: repository_ctx + libpath: substitution of %{libpath} + components: dict[str, str], the components of CTK to be included, mappeed to the repo names for the components + is_local_cuda: See Notes, True for @local_cuda generation, False for @local_cuda_ generation. + is_deliverable: See Notes + """ + + if is_local_cuda and not is_deliverable: + if components == None: + components = [c for c in REGISTRY if len(REGISTRY[c]) > 0] + else: + for c in components: + if c not in REGISTRY: + fail("{} is not a valid component") + + _generate_local_cuda_build_impl(repository_ctx, libpath, components, is_local_cuda, is_deliverable) + +def _generate_defs_bzl(repository_ctx, is_local_ctk): tpl_label = Label("//cuda/private:templates/defs.bzl.tpl") substitutions = { - "%{is_local_cuda}": str(is_local_cuda), + "%{is_local_ctk}": str(is_local_ctk), } repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False) diff --git a/cuda/private/templates/BUILD.cudart b/cuda/private/templates/BUILD.cudart index c9fa6d5e..09bcfeb7 100644 --- a/cuda/private/templates/BUILD.cudart +++ b/cuda/private/templates/BUILD.cudart @@ -38,7 +38,7 @@ cc_library( "-lpthread", "-lrt", ]), - deps = [ + deps = additional_header_deps("cudart") + [ ":%{component_name}_headers", ] + if_linux([ # devrt is required for jit linking when rdc is enabled @@ -66,7 +66,7 @@ cc_library( "-lpthread", "-lrt", ]), - deps = [":cudadevrt_a"], + deps = additional_header_deps("cudart") + [":cudadevrt_a"], # FIXME: # visibility = ["@rules_cuda//cuda:__pkg__"], ) diff --git a/cuda/private/templates/BUILD.local_cuda_shared b/cuda/private/templates/BUILD.local_cuda_shared index af548866..e1c1ff26 100644 --- a/cuda/private/templates/BUILD.local_cuda_shared +++ b/cuda/private/templates/BUILD.local_cuda_shared @@ -1,4 +1,5 @@ load("@bazel_skylib//rules:common_settings.bzl", "bool_setting") # @unused +load("@local_cuda//:defs.bzl", "additional_header_deps", "if_local_cuda_toolkit") # @unused load("@rules_cuda//cuda:defs.bzl", "cc_import_versioned_sos", "if_linux", "if_windows") # @unused package( diff --git a/cuda/private/templates/BUILD.nvprof b/cuda/private/templates/BUILD.nvprof deleted file mode 100644 index e69de29b..00000000 diff --git a/cuda/private/templates/README.md b/cuda/private/templates/README.md new file mode 100644 index 00000000..3764247b --- /dev/null +++ b/cuda/private/templates/README.md @@ -0,0 +1,86 @@ +## Template files + +- `BUILD.local_cuda_shared`: For `local_cuda` repo (CTK + toolchain) or `local_cuda_%{component_name}` +- `BUILD.local_cuda_headers`: For `local_cuda` repo (CTK + toolchain) or `local_cuda_%{component_name}` headers +- `BUILD.local_cuda_build_setting`: For `local_cuda` repo (CTK + toolchain) build_setting +- `BUILD.local_cuda_disabled`: For creating a dummy local configuration. +- `BUILD.local_toolchain_disabled`: For creating a dummy local toolchain. +- `BUILD.local_toolchain_clang`: For Clang device compilation toolchain. +- `BUILD.local_toolchain_nvcc`: For NVCC device compilation toolchain. +- `BUILD.local_toolchain_nvcc_msvc`: For NVCC device compilation with (MSVC as host compiler) toolchain. +- Otherwise, each `BUILD.*` corresponds to a component in CUDA Toolkit. + +## Repository organization + +We organize the generated repo as follows, for both `local_cuda` and `local_cuda_` + +``` + # bazel unconditionally creates a directory for us +├── %{component_name}/ # cuda for local ctk, component name otherwise +│ ├── include/ # +│ └── %{libpath}/ # lib or lib64, platform dependent +├── defs.bzl # generated +├── BUILD # generated from BUILD.local_cuda and one/all of the component(s) +└── WORKSPACE # generated +``` + +If the repo is `local_cuda`, we additionally generate toolchain config as follows + +``` + +└── toolchain/ + ├── BUILD # the default nvcc toolchain + ├── clang/ # the optional clang toolchain + │ └── BUILD # + └── disabled/ # the fallback toolchain + └── BUILD # +``` + +## How are component repositories and `@local_cuda` connected? + +The `registry.bzl` file holds mappings from our (`rules_cuda`) components name to various things. + +The registry serve the following purpose: + +1. maps our component names to full component names used `redistrib.json` file. + + This is purely for looking up the json files. + +2. maps our component names to target names to be exposed under `@local_cuda` repo. + + To expose those targets, we use a `components_mapping` attr from our component names to labels of component + repository (for example, `@local_cuda_nvcc`) as follows + +```starlark +# in registry.bzl +... + "cudart": ["cuda", "cuda_runtime", "cuda_runtime_static"], +... + +# in WORKSPACE.bazel +cuda_component( + name = "local_cuda_cudart_v12.6.77", + component_name = "cudart", + ... +) + +local_cuda( + name = "local_cuda", + components_mapping = {"cudart": "@local_cuda_cudart_v12.6.77"}, + ... +) +``` + +This basically means the component `cudart` has `cuda`, `cuda_runtime` and `cuda_runtime_static` targets defined. + +- In locally installed CTK, we setup the targets in `@local_cuda` directly. +- In a deliverable CTK, we setup the targets in `@local_cuda_cudart_v12.6.77` repo. And alias all targets to + `@local_cuda` as follows + +```starlark +alias(name = "cuda", actual = "@local_cuda_cudart_v12.6.77//:cuda") +alias(name = "cuda_runtime", actual = "@local_cuda_cudart_v12.6.77//:cuda_runtime") +alias(name = "cuda_runtime_static", actual = "@local_cuda_cudart_v12.6.77//:cuda_runtime_static") +``` + +`cuda_component` is in charge of setting up the repo `@local_cuda_cudart_v12.6.77`. diff --git a/cuda/private/templates/defs.bzl.tpl b/cuda/private/templates/defs.bzl.tpl index 012f03ec..84f388ba 100644 --- a/cuda/private/templates/defs.bzl.tpl +++ b/cuda/private/templates/defs.bzl.tpl @@ -1,6 +1,18 @@ -def if_local_cuda(if_true, if_false = []): - is_local_cuda = %{is_local_cuda} - if is_local_cuda: +def if_local_cuda_toolkit(if_true, if_false = []): + is_local_ctk = %{is_local_ctk} + if is_local_ctk: return if_true else: return if_false + +def if_deliverable_cuda_toolkit(if_true, if_false = []): + return if_local_cuda_toolkit(if_false, if_true) + +def additional_header_deps(component_name): + if component_name == "cudart": + return if_deliverable_cuda_toolkit([ + "@local_cuda//:nvcc_headers", + "@local_cuda//:cccl_headers", + ]) + + return [] diff --git a/cuda/private/templates/registry.bzl b/cuda/private/templates/registry.bzl index fc884439..06a43a23 100644 --- a/cuda/private/templates/registry.bzl +++ b/cuda/private/templates/registry.bzl @@ -1,8 +1,8 @@ -# map component name to consumable targets +# map short component name to consumable targets REGISTRY = { "cudart": ["cuda", "cuda_runtime", "cuda_runtime_static"], - "nvcc": ["compiler_deps", "nvptxcompiler"], - "cccl": ["cub", "thrust"], + "nvcc": ["compiler_deps", "nvptxcompiler", "nvcc_headers"], + "cccl": ["cub", "thrust", "cccl_headers"], "cublas": ["cublas"], "cufft": ["cufft", "cufft_static"], "cufile": [], @@ -19,3 +19,24 @@ REGISTRY = { "nvrtc": ["nvrtc"], "nvtx": ["nvtx"], } + +# map short component name to full component name +FULL_COMPONENT_NAME = { + "cudart": "cuda_cudart", + "nvcc": "cuda_nvcc", + "cccl": "cuda_cccl", + "cublas": "libcublas", + "cufft": "libcufft", + "cufile": "libcufile", + "cupti": "libcupti", + "curand": "libcurand", + "cusolver": "libcusolver", + "cusparse": "libcusparse", + "npp": "libnpp", + "nvidia_fs": "nvidia_fs", + "nvjitlink": "libnvjitlink", + "nvjpeg": "libnvjpeg", + "nvml": "cuda_nvml_dev", + "nvrtc": "cuda_nvrtc", + "nvtx": "cuda_nvtx", +} diff --git a/cuda/repositories.bzl b/cuda/repositories.bzl index 1e5fca0c..9635d0db 100644 --- a/cuda/repositories.bzl +++ b/cuda/repositories.bzl @@ -1,12 +1,19 @@ load( "//cuda/private:repositories.bzl", + _cuda_component = "cuda_component", + _default_components_mapping = "default_components_mapping", _local_cuda = "local_cuda", _rules_cuda_dependencies = "rules_cuda_dependencies", _rules_cuda_toolchains = "rules_cuda_toolchains", ) load("//cuda/private:toolchain.bzl", _register_detected_cuda_toolchains = "register_detected_cuda_toolchains") +# rules +cuda_component = _cuda_component local_cuda = _local_cuda + +# macros rules_cuda_dependencies = _rules_cuda_dependencies rules_cuda_toolchains = _rules_cuda_toolchains register_detected_cuda_toolchains = _register_detected_cuda_toolchains +default_components_mapping = _default_components_mapping diff --git a/tests/integration/test_all.sh b/tests/integration/test_all.sh index 446dd240..48c0f8a8 100755 --- a/tests/integration/test_all.sh +++ b/tests/integration/test_all.sh @@ -49,3 +49,14 @@ pushd "$this_dir/toolchain_rules" bazel build //:use_rule bazel clean && bazel shutdown popd + +# toolchain configured with deliverables +pushd "$this_dir/toolchain_redist" + bazel build //... --@rules_cuda//cuda:enable=False + bazel build //... --@rules_cuda//cuda:enable=True + bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=False + bazel build //:optinally_use_rule --@rules_cuda//cuda:enable=True + bazel build //:use_library + bazel build //:use_rule + bazel clean && bazel shutdown +popd diff --git a/tests/integration/toolchain_redist/BUILD.bazel b/tests/integration/toolchain_redist/BUILD.bazel new file mode 120000 index 00000000..60a18fe7 --- /dev/null +++ b/tests/integration/toolchain_redist/BUILD.bazel @@ -0,0 +1 @@ +../BUILD.to_symlink \ No newline at end of file diff --git a/tests/integration/toolchain_redist/MODULE.bazel b/tests/integration/toolchain_redist/MODULE.bazel new file mode 100644 index 00000000..7c42b10e --- /dev/null +++ b/tests/integration/toolchain_redist/MODULE.bazel @@ -0,0 +1,52 @@ +module(name = "bzlmod_use_repo_no_toolchain") + +bazel_dep(name = "rules_cuda", version = "0.0.0") +local_path_override( + module_name = "rules_cuda", + path = "../../..", +) + +cuda = use_extension("@rules_cuda//cuda:extensions.bzl", "toolchain") +cuda.component( + name = "local_cuda_cccl", + component_name = "cccl", + sha256 = "9c3145ef01f73e50c0f5fcf923f0899c847f487c529817daa8f8b1a3ecf20925", + strip_prefix = "cuda_cccl-linux-x86_64-12.6.77-archive", + urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/linux-x86_64/cuda_cccl-linux-x86_64-12.6.77-archive.tar.xz"], +) +cuda.component( + name = "local_cuda_cudart", + component_name = "cudart", + sha256 = "f74689258a60fd9c5bdfa7679458527a55e22442691ba678dcfaeffbf4391ef9", + strip_prefix = "cuda_cudart-linux-x86_64-12.6.77-archive", + urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/linux-x86_64/cuda_cudart-linux-x86_64-12.6.77-archive.tar.xz"], +) +cuda.component( + name = "local_cuda_nvcc", + component_name = "nvcc", + sha256 = "840deff234d9bef20d6856439c49881cb4f29423b214f9ecd2fa59b7ac323817", + strip_prefix = "cuda_nvcc-linux-x86_64-12.6.85-archive", + urls = ["https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/linux-x86_64/cuda_nvcc-linux-x86_64-12.6.85-archive.tar.xz"], +) +cuda.toolkit( + name = "local_cuda", + components_mapping = { + "cccl": "@local_cuda_cccl", + "cudart": "@local_cuda_cudart", + "nvcc": "@local_cuda_nvcc", + }, + version = "12.6", +) +use_repo( + cuda, + "local_cuda", + "local_cuda_cccl", + "local_cuda_cudart", + "local_cuda_nvcc", +) + +bazel_dep(name = "rules_cuda_examples", version = "0.0.0") +local_path_override( + module_name = "rules_cuda_examples", + path = "../../../examples", +)