Skip to content

Commit

Permalink
improve debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 14, 2025
1 parent 64a1c28 commit 84f97d3
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 20 deletions.
142 changes: 122 additions & 20 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
std::string cubinFeatures, size_t cuLaunchKernelPtr,
size_t cuModuleLoadDataPtr,
size_t cuModuleGetFunctionPtr, bool compileLaunch,
bool run_init, enzymexla::KernelCallOp kernelCallOp) {
bool run_init, enzymexla::KernelCallOp kernelCallOp,
bool debug, size_t cuResultHandlerPtr) {

OpBuilder builder(op);

Expand Down Expand Up @@ -533,6 +534,7 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
idx, i32, ptrty, ptrty, ptrty};

auto launch_ty = LLVM::LLVMFunctionType::get(i32, cutys);
auto curesult_handler_ty = LLVM::LLVMFunctionType::get(voidty, {i32});
LLVM::LLVMFuncOp launch =
builder.create<LLVM::LLVMFuncOp>(loc, "cuLaunchKernel", launch_ty);

Expand All @@ -557,24 +559,61 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
loc, "nv_func_init", LLVM::LLVMFunctionType::get(ptrty, {}, false),
LLVM::Linkage::External);

auto printfunc = builder.create<LLVM::LLVMFuncOp>(
loc, "printf", LLVM::LLVMFunctionType::get(ptrty, {ptrty, ptrty}, false),
LLVM::Linkage::External);
printfunc.setVisibility(SymbolTable::Visibility::Private);

LLVM::GlobalOp printStrFunc;
{
std::string opstr;
llvm::raw_string_ostream ss(opstr);
LLVM::LLVMFuncOp printfunc = nullptr;
LLVM::LLVMFuncOp putfunc = nullptr;

if (debug) {
printfunc = builder.create<LLVM::LLVMFuncOp>(
loc, "printf",
LLVM::LLVMFunctionType::get(ptrty, {ptrty, ptrty}, false),
LLVM::Linkage::External);
printfunc.setVisibility(SymbolTable::Visibility::Private);
putfunc = builder.create<LLVM::LLVMFuncOp>(
loc, "puts", LLVM::LLVMFunctionType::get(voidty, {ptrty}, false),
LLVM::Linkage::External);
putfunc.setVisibility(SymbolTable::Visibility::Private);
}

ss << kernelCallOp;
std::string value = "launch Kernel result = %d\n modstr=" + modstr + "\n" + opstr + "\n\n";
LLVM::GlobalOp loadModuleStr = nullptr;
if (debug) {
std::string value = "load Module result = %d\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrFunc = builder.create<LLVM::GlobalOp>(
loadModuleStr = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strmod",
builder.getStringAttr(value + '\0'));
}
LLVM::GlobalOp loadFuncStr = nullptr;
if (debug) {
std::string value = "load Func result = %d\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
loadFuncStr = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strfunc",
builder.getStringAttr(value + '\0'));
}
LLVM::GlobalOp launchKernelStr = nullptr;
if (debug) {
std::string value = "launch Kernel result = %d\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
launchKernelStr = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strlaunch", builder.getStringAttr(value + '\0'));
}
LLVM::GlobalOp modOpStr = nullptr;
if (debug) {
std::string opstr;
llvm::raw_string_ostream ss(opstr);

ss << kernelCallOp;
std::string value = "modstr=" + modstr + "\n" + opstr + "\n\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
modOpStr = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strmlirmod", builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp binary = nullptr;
submod.walk([&](gpu::BinaryOp op) {
Expand Down Expand Up @@ -618,15 +657,36 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
SmallVector<mlir::Value> modargs = {modptr->getResult(0),
addr_modbin->getResult(0)};

mlir::Value loadModRes = nullptr;
if (cuModuleLoadDataPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleLoadDataPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
modargs.insert(modargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, modload_ty, modargs);
loadModRes = builder.create<LLVM::CallOp>(loc, modload_ty, modargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, modload, modargs);
loadModRes =
builder.create<LLVM::CallOp>(loc, modload, modargs)->getResult(0);
}

if (debug) {
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, loadModuleStr)
->getResult(0),
builder.create<LLVM::IntToPtrOp>(loc, ptrty, loadModRes)
->getResult(0)};
builder.create<LLVM::CallOp>(loc, printfunc, printargs1);
}
if (cuResultHandlerPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int)
->getResult(0);
mlir::Value args[2] = {addr_glob, loadModRes};
builder.create<LLVM::CallOp>(loc, curesult_handler_ty, loadModRes);
}

auto mod = builder.create<LLVM::LoadOp>(loc, ptrty, modptr);
Expand All @@ -637,15 +697,36 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
SmallVector<mlir::Value> funcargs = {funcptr->getResult(0),
mod->getResult(0),
addr_kernstr->getResult(0)};
mlir::Value loadFuncRes = nullptr;
if (cuModuleGetFunctionPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleGetFunctionPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
funcargs.insert(funcargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs);
loadFuncRes =
builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, funcload, funcargs);
loadFuncRes = builder.create<LLVM::CallOp>(loc, funcload, funcargs)
->getResult(0);
}

if (debug) {
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, loadFuncStr)->getResult(0),
builder.create<LLVM::IntToPtrOp>(loc, ptrty, loadFuncRes)
->getResult(0)};
builder.create<LLVM::CallOp>(loc, printfunc, printargs1);
}
if (cuResultHandlerPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int)
->getResult(0);
mlir::Value args[2] = {addr_glob, loadFuncRes};
builder.create<LLVM::CallOp>(loc, curesult_handler_ty, args);
}

auto func = builder.create<LLVM::LoadOp>(loc, ptrty, funcptr);
Expand Down Expand Up @@ -689,14 +770,28 @@ CallInfo CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
kernRes =
builder.create<LLVM::CallOp>(loc, launch, args)->getResult(0);
}
{
if (debug) {
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrFunc)
builder.create<LLVM::AddressOfOp>(loc, launchKernelStr)
->getResult(0),
builder.create<LLVM::IntToPtrOp>(loc, ptrty, kernRes)
->getResult(0)};
builder.create<LLVM::CallOp>(loc, printfunc, printargs1);
}
if (debug) {
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, modOpStr)->getResult(0)};
builder.create<LLVM::CallOp>(loc, putfunc, printargs1);
}
if (cuResultHandlerPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuResultHandlerPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int)
->getResult(0);
mlir::Value args[2] = {addr_glob, kernRes};
builder.create<LLVM::CallOp>(loc, curesult_handler_ty, args);
}

op.erase();
ldop.erase();
Expand Down Expand Up @@ -776,6 +871,12 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {

auto *symbolOp = symbolTable.lookupNearestSymbolFrom(op, op.getFnAttr());
auto fn = cast<FunctionOpInterface>(symbolOp);
if (fn.getArguments().size() != op.getInputs().size()) {
op->emitError() << "Kernel_call had " << op.getInputs().size()
<< " whereas called kernel requires "
<< fn.getArguments().size() << "\n";
return;
}

Value vals[] = {op.getGridx(), op.getGridy(), op.getGridz(),
op.getBlockx(), op.getBlocky(), op.getBlockz(),
Expand All @@ -802,7 +903,8 @@ struct LowerKernelPass : public LowerKernelPassBase<LowerKernelPass> {
data[5], data[6], data[7], toolkitPath.getValue(), linkFilesArray,
indexBitWidth.getValue(), cubinChip.getValue(),
cubinFeatures.getValue(), cuLaunchKernelPtr, cuModuleLoadDataPtr,
cuModuleGetFunctionPtr, compileLaunch, run_init, op);
cuModuleGetFunctionPtr, compileLaunch, run_init, op, debug,
cuResultHandlerPtr);

std::string backendinfo((char *)&cdata, sizeof(CallInfo));

Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,20 @@ def LowerKernelPass : Pass<"lower-kernel"> {
/*default=*/"false",
/*description=*/"Run initialization of cuda module"
>,
Option<
/*C++ variable name=*/"debug",
/*CLI argument=*/"debug",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Compile in debug prints"
>,
Option<
/*C++ variable name=*/"cuResultHandlerPtr",
/*CLI argument=*/"cuResultHandlerPtr",
/*type=*/"size_t",
/*default=*/"0",
/*description=*/"Function handler to call with result of curesult"
>,
];
}

Expand Down

0 comments on commit 84f97d3

Please sign in to comment.