Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
make nvrtc compile cuda-c to cubin directly (#1313)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu authored Mar 29, 2023
1 parent d53c64d commit f26d24d
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 18 deletions.
4 changes: 2 additions & 2 deletions cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code)
auto ptx = compiler(source_code);
CHECK(!ptx.empty());

// TODO(Superjomn) Whether to support multiple CUDA modules?
cuda_module_.reset(new CUDAModule(ptx, CUDAModule::Kind::PTX));
cuda_module_.reset(
new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));

RuntimeSymbols symbols;

Expand Down
44 changes: 33 additions & 11 deletions cinn/backends/nvrtc/nvrtc_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,28 @@
#include "cinn/common/common.h"
#include "cinn/utils/string.h"

DECLARE_bool(nvrtc_compile_to_cubin);

namespace cinn {
namespace backends {
namespace nvrtc {

std::string Compiler::operator()(const std::string& code, bool include_headers) {
return CompilePTX(code, include_headers);
return CompileCudaSource(code, include_headers);
}

Compiler::Compiler() {
if (FLAGS_nvrtc_compile_to_cubin) {
#if CUDA_VERSION >= 11010
compile_to_cubin_ = true;
#endif
}
VLOG(4) << "FLAGS_nvrtc_compile_to_cubin: " << FLAGS_nvrtc_compile_to_cubin
<< ", compile_to_cubin_: " << compile_to_cubin_;
}

bool Compiler::compile_to_cubin() { return compile_to_cubin_; }

std::vector<std::string> Compiler::FindCUDAIncludePaths() {
const std::string delimiter = "/";
std::string cuda_include_path;
Expand All @@ -56,7 +70,7 @@ std::vector<std::string> Compiler::FindCUDAIncludePaths() {

std::vector<std::string> Compiler::FindCINNRuntimeIncludePaths() { return {Context::Global().runtime_include_dir()}; }

std::string Compiler::CompilePTX(const std::string& code, bool include_headers) {
std::string Compiler::CompileCudaSource(const std::string& code, bool include_headers) {
const auto& header_gen = JitSafeHeaderGenerator::GetInstance();
std::vector<std::string> compile_options;
std::vector<const char*> param_cstrings{};
Expand All @@ -72,8 +86,11 @@ std::string Compiler::CompilePTX(const std::string& code, bool include_headers)
LOG(WARNING) << "cannot detect compute capability from your device, "
<< "fall back to compute_30.";
}

compile_options.push_back("-arch=compute_" + cc);
if (compile_to_cubin_) {
compile_options.push_back("-arch=sm_" + cc);
} else {
compile_options.push_back("-arch=compute_" + cc);
}
compile_options.push_back("-std=c++14");
compile_options.push_back("-default-device");

Expand Down Expand Up @@ -107,15 +124,20 @@ std::string Compiler::CompilePTX(const std::string& code, bool include_headers)
CHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
}

size_t ptx_size;
NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
size_t size;
std::string data;
if (compile_to_cubin_) {
NVRTC_CALL(nvrtcGetCUBINSize(prog, &size));
data.resize(size);
NVRTC_CALL(nvrtcGetCUBIN(prog, &data[0]));
} else {
NVRTC_CALL(nvrtcGetPTXSize(prog, &size));
data.resize(size);
NVRTC_CALL(nvrtcGetPTX(prog, &data[0]));
}

std::string ptx;
ptx.resize(ptx_size);
NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
NVRTC_CALL(nvrtcDestroyProgram(&prog));

return ptx;
return data;
}

} // namespace nvrtc
Expand Down
18 changes: 15 additions & 3 deletions cinn/backends/nvrtc/nvrtc_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace nvrtc {
*/
class Compiler {
public:
Compiler();

/**
* Compile the \p code and get PTX string.
* @param code The CUDA source code.
Expand All @@ -39,6 +41,11 @@ class Compiler {
*/
std::string operator()(const std::string& code, bool include_headers = true);

/** Compile into cubin or not
* @return Compile into cubin or not.
*/
bool compile_to_cubin();

private:
/**
* Get the directories of CUDA's header files.
Expand All @@ -53,11 +60,16 @@ class Compiler {
std::vector<std::string> FindCINNRuntimeIncludePaths();

/**
* Compile CUDA source code and get PTX.
* Compile CUDA source code and get PTX or CUBIN.
* @param code source code string.
* @return PTX string.
* @return PTX or CUBIN string.
*/
std::string CompileCudaSource(const std::string& code, bool include_headers);

/**
* whether to compile the source code into cubin, only works with cuda version > 11.1
*/
std::string CompilePTX(const std::string& code, bool include_headers);
bool compile_to_cubin_{false};
};

} // namespace nvrtc
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ std::vector<Expr> OpLowerer::IRElementwiseCompute(poly::StageMap& stages,

std::vector<Expr> ast_exprs;
for (auto& node : sub_group->nodes) {
VLOG(4) << "Lower op: " << node->op()->name;
auto node_data = GetNodeData(node);
CHECK_EQ(GetAllNodeData(node).size(), 1U);
std::vector<common::CINNValue> cinn_inputs;
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
CHECK(!ptx.empty());

// load cumodule
cumodule.reset(new CUDAModule(ptx, CUDAModule::Kind::PTX));
cumodule.reset(new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
// register kernel
backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) {
Expand Down
3 changes: 2 additions & 1 deletion cinn/runtime/cuda/cuda_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ namespace cuda {
class CUDAModule {
public:
enum class Kind {
PTX = 0,
PTX = 0,
CUBIN = 1,
};

CUDAModule(const std::string& data, Kind kind);
Expand Down
4 changes: 4 additions & 0 deletions cinn/runtime/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ DEFINE_bool(cinn_use_dense_merge_pass,
BoolFromEnv("FLAGS_cinn_use_dense_merge_pass", false),
"Whether use dense merge pass.");
DEFINE_bool(nvrtc_compile_to_cubin,
BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false),
"Whether nvrtc compile cuda source into cubin instead of ptx (only works after cuda-11.1).");

// FLAGS for performance analysis and accuracy debug
DEFINE_bool(cinn_sync_run,
BoolFromEnv("FLAGS_cinn_sync_run", false),
Expand Down

0 comments on commit f26d24d

Please sign in to comment.