Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More prints of lowered kernels #201

Merged
merged 12 commits into from
Dec 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 180 additions & 14 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,21 @@ void *CompileHostModule(std::string &key, mlir::ModuleOp modOp, bool run_init) {
}

auto ptr = (void *)EntrySym->getValue();
llvm::errs() << " entry ptr: " << ptr << "\n";

kernels[key] = ptr;

auto NVSym = JIT->lookup(LibA.get(), "nv_func_init");
if (!NVSym) {
llvm::errs() << " lookupError " << NVSym.takeError() << "\n";
return nullptr;
}
if (run_init) {
auto NVSym = JIT->lookup(LibA.get(), "nv_func_init");
if (!NVSym) {
llvm::errs() << " lookupError " << NVSym.takeError() << "\n";
return nullptr;
}

auto nvptr = (void *)NVSym->getValue();
auto nvptr = (void *)NVSym->getValue();

((void (*)())(nvptr))();
((void (*)())(nvptr))();
}

return ptr;
}
Expand All @@ -272,6 +275,8 @@ extern "C" void EnzymeGPUCustomCall(void *__restrict__ stream,
XlaCustomCallStatus *__restrict__ status) {
auto ptr = (void (*)(void *, void **))(opaqueptr[0]);
printf("ptr=%p\n", ptr);
printf("stream=%p\n", stream);
printf("bufferptr=%p\n", buffers);
printf("buffer[0]=%p\n", buffers[0]);
// auto ptr = (void(*)(void*, void**, size_t, size_t, size_t, size_t, size_t,
// size_t)) (opaqueptr[0][0]);
Expand Down Expand Up @@ -422,6 +427,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,

builder.setInsertionPointToEnd(&submod.getBodyRegion().front());

auto printfunc = builder.create<func::FuncOp>(loc, "printf", calleeType);
printfunc.setVisibility(SymbolTable::Visibility::Private);

LLVM::GlobalOp printStrStream;
{
std::string value = "found pointer [stream] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrStream = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strstream",
builder.getStringAttr(value + '\0'));
}

auto func = builder.create<func::FuncOp>(loc, "entry", calleeType);

auto &entryBlock = *func.addEntryBlock();
Expand Down Expand Up @@ -453,10 +471,19 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
arguments.push_back(ld);
}
auto dynshmem = builder.create<arith::ConstantIntOp>(loc, shmem, i32);

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrStream)->getResult(0),
stream};
builder.create<func::CallOp>(loc, printfunc, printargs1);
}

stream = builder
.create<UnrealizedConversionCastOp>(
loc, gpu::AsyncTokenType::get(stream.getContext()), stream)
->getResult(0);

builder.create<gpu::LaunchFuncOp>(loc, gpufunc, gridSize, blockSize, dynshmem,
arguments, stream.getType(),
ValueRange(stream));
Expand Down Expand Up @@ -498,6 +525,10 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
pm.run(submod);

OpBuilder builder(submod);

SymbolTable st2(submod);
auto print2 = st2.lookup<LLVM::LLVMFuncOp>("printf");

builder.setInsertionPointToStart(&submod.getBodyRegion().front());
auto ptrty = LLVM::LLVMPointerType::get(builder.getContext());
auto i64 = builder.getIntegerType(64);
Expand All @@ -517,7 +548,7 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
mlir::Type cutys[] = {ptrty, idx, idx, idx, idx, idx,
idx, i32, ptrty, ptrty, ptrty};

auto launch_ty = LLVM::LLVMFunctionType::get(voidty, cutys);
auto launch_ty = LLVM::LLVMFunctionType::get(i32, cutys);
LLVM::LLVMFuncOp launch =
builder.create<LLVM::LLVMFuncOp>(loc, "cuLaunchKernel", launch_ty);

Expand All @@ -536,6 +567,66 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrSet;
{
std::string value = "found pointer [set] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrSet = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strset",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrGlob;
{
std::string value = "found pointer [glob] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrGlob = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strglob",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrCu;
{
std::string value = "found pointer [cu] = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrCu = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strcu",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrMod;
{
std::string value = "found pointer mod = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrMod = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strmod",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrLdFunc;
{
std::string value = "found pointer ld func = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrLdFunc = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strldfunc", builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp printStrLaunch;
{
std::string value = "found pointer launch = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrLaunch = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal,
"strlaunch", builder.getStringAttr(value + '\0'));
}

builder.setInsertionPointToStart(&submod.getBodyRegion().front());

LLVM::LLVMFuncOp initfn = builder.create<LLVM::LLVMFuncOp>(
Expand All @@ -547,6 +638,16 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
builder.create<LLVM::GlobalCtorsOp>(loc, builder.getArrayAttr(funcs),
builder.getArrayAttr(idxs));

LLVM::GlobalOp printStrFunc;
{
std::string value = "found pointer func = %p\n";
auto type = LLVM::LLVMArrayType::get(
mlir::IntegerType::get(builder.getContext(), 8), value.size() + 1);
printStrFunc = builder.create<LLVM::GlobalOp>(
loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, "strfunc",
builder.getStringAttr(value + '\0'));
}

LLVM::GlobalOp binary;
submod.walk([&](gpu::BinaryOp op) {
gpu::ObjectAttr object = getSelectedObject(op);
Expand Down Expand Up @@ -583,16 +684,28 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
auto addr_modbin = builder.create<LLVM::AddressOfOp>(loc, binary);
SmallVector<mlir::Value> modargs = {modptr->getResult(0),
addr_modbin->getResult(0)};

mlir::Value loadRes;
if (cuModuleLoadDataPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleLoadDataPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
modargs.insert(modargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, modload_ty, modargs);
loadRes = builder.create<LLVM::CallOp>(loc, modload_ty, modargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, modload, modargs);
loadRes =
builder.create<LLVM::CallOp>(loc, modload, modargs)->getResult(0);
}
loadRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, loadRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrMod)->getResult(0),
loadRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto mod = builder.create<LLVM::LoadOp>(loc, ptrty, modptr);

auto addr_kernstr =
Expand All @@ -601,19 +714,45 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
SmallVector<mlir::Value> funcargs = {funcptr->getResult(0),
mod->getResult(0),
addr_kernstr->getResult(0)};
mlir::Value getRes;
if (cuModuleGetFunctionPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuModuleGetFunctionPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
funcargs.insert(funcargs.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs);
getRes = builder.create<LLVM::CallOp>(loc, funcload_ty, funcargs)
->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, funcload, funcargs);
getRes = builder.create<LLVM::CallOp>(loc, funcload, funcargs)
->getResult(0);
}

getRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, getRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrFunc)
->getResult(0),
getRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto func = builder.create<LLVM::LoadOp>(loc, ptrty, funcptr);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrLdFunc)
->getResult(0),
func};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

auto addr_glob = builder.create<LLVM::AddressOfOp>(loc, glob);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrSet)->getResult(0),
addr_glob};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}
builder.create<LLVM::StoreOp>(loc, func, addr_glob);
builder.create<LLVM::ReturnOp>(loc, ValueRange());
}
Expand All @@ -639,15 +778,42 @@ void *CompileKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
params,
builder.create<LLVM::ZeroOp>(loc, ptrty)};

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrGlob)
->getResult(0),
addr_glob};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrCu)->getResult(0),
cufunc};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

mlir::Value callRes;
if (cuLaunchKernelPtr) {
auto addr_glob_int = builder.create<LLVM::ConstantOp>(
loc, i64, builder.getI64IntegerAttr(cuLaunchKernelPtr));
auto addr_glob =
builder.create<LLVM::IntToPtrOp>(loc, ptrty, addr_glob_int);
args.insert(args.begin(), addr_glob);
builder.create<LLVM::CallOp>(loc, launch_ty, args);
callRes =
builder.create<LLVM::CallOp>(loc, launch_ty, args)->getResult(0);
} else {
builder.create<LLVM::CallOp>(loc, launch, args);
callRes =
builder.create<LLVM::CallOp>(loc, launch, args)->getResult(0);
}

callRes = builder.create<LLVM::IntToPtrOp>(loc, ptrty, callRes);
{
Value printargs1[] = {
builder.create<LLVM::AddressOfOp>(loc, printStrLaunch)
->getResult(0),
callRes};
builder.create<LLVM::CallOp>(loc, print2, printargs1);
}

op.erase();
Expand Down
Loading