Skip to content

Commit

Permalink
Fix printhook on xre ubuntu, add __fprintf_chk hook; clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
LamForest committed Dec 7, 2024
1 parent 2b8a52c commit 72dacb5
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 18 deletions.
3 changes: 2 additions & 1 deletion include/logger/logger_stl.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include <vector>
#include <ostream>
#include <vector>

#include "logger.h"

namespace logger {
Expand Down
4 changes: 2 additions & 2 deletions lib/cuda_mock.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#include "cuda_mock.h"

#include <Python.h>
#include <dlfcn.h>
#include <regex.h>
#include <string.h>
#include <Python.h>

#include <csetjmp>
#include <unordered_set>

#include "GlobalVarMgr.h"
#include "env_mgr.h"
#include "backtrace.h"
#include "cuda_op_tracer.h"
#include "env_mgr.h"
#include "hook.h"
#include "logger/logger.h"

Expand Down
17 changes: 15 additions & 2 deletions lib/cuda_op_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ extern "C" CUresult cudaLaunchKernel_wrapper(const void* func, dim3 gridDim,

hook::HookInstaller getHookInstaller(const HookerInfo& info) {
static const char* symbolName = "cudaLaunchKernel";
static void* newFuncAddr = reinterpret_cast<void*>(&cudaLaunchKernel_wrapper);
static void* newFuncAddr =
reinterpret_cast<void*>(&cudaLaunchKernel_wrapper);
if (info.srcLib && info.targeLib && info.symbolName && info.newFuncPtr) {
kCudaRTLibName = info.srcLib;
kPytorchCudaLibName = info.targeLib;
Expand All @@ -81,14 +82,26 @@ hook::HookInstaller getHookInstaller(const HookerInfo& info) {
hook::HookInstaller installer;
installer.isTargetLib = [](const char* libName) -> bool {
CudaInfoCollection::instance().collectRtLib(libName);
// TODO 为啥这行打印不生效?
MLOG(HOOK, INFO) << "[installer.isTargetLib] libName:"
<< kPytorchCudaLibName
<< " targetlibName: " << libName;

/*
模糊匹配而不是精确匹配,因为:
1. libname 可能包含版本号,例如 libtorch_cuda.so.1
2. libname包含了完整的路径,例如
/usr/local/cuda-10.2/lib64/libcudart.so
*/
if (std::string(libName).find(kPytorchCudaLibName) !=
std::string::npos) {
return true;
}
return false;
};
installer.isTargetSymbol = [=](const char* symbol) -> bool {
// LOG(INFO) << "visit symbol:" << symbol;
MLOG(HOOK, INFO) << "[installer.isTargetSymbol] symbol:" << symbol
<< " targetSymbolName:" << symbolName;
if (std::string(symbol) == symbolName) {
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/cuda_op_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct HookerInfo {
const char* srcLib = nullptr;
// the dynamic lib which the target symbol will be replace
const char* targeLib = nullptr;
// the symbol which will be replace
// the symbol which will be replace
const char* symbolName = nullptr;
void* newFuncPtr = nullptr;
};
Expand Down
32 changes: 23 additions & 9 deletions lib/hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,14 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) {
CHECK(installer.newFuncPtr, "new_func_ptr can't be empty!");
if (!installer.isTargetLib(pltTable->lib_name.c_str()) ||
!isTargetLibFromEnv(pltTable->lib_name.c_str())) {
MLOG(HOOK, INFO) << "[install_hooker SKIP]"
<< pltTable->lib_name.c_str()
<< ", not target lib, skip";
return -1;
}
MLOG(HOOK, INFO)
<< "[install_hooker INSTALL] ====start install hook for lib "
<< pltTable->lib_name.c_str() << "======";
size_t index = 0;
while (index < pltTable->rela_plt_cnt) {
auto plt = pltTable->rela_plt + index++;
Expand All @@ -262,13 +268,16 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) {

size_t idx = ELF64_R_SYM(plt->r_info);
idx = pltTable->dynsym[idx].st_name;
MLOG(HOOK, INFO) << pltTable->symbol_table +
void* addr =
reinterpret_cast<void*>(pltTable->base_header_addr + plt->r_offset);

MLOG(HOOK, INFO) << "GOT[" << index << "]=" << addr << " symbol="
<< pltTable->symbol_table +
idx; // got symbol name from STRTAB
if (!installer.isTargetSymbol(pltTable->symbol_table + idx)) {
continue;
}
void* addr =
reinterpret_cast<void*>(pltTable->base_header_addr + plt->r_offset);

int prot = get_memory_permission(addr);
if (prot == 0) {
return -1;
Expand All @@ -277,6 +286,9 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) {
// writable page
if (mprotect(ALIGN_ADDR(addr), page_size, PROT_READ | PROT_WRITE) !=
0) {
MLOG(HOOK, ERROR) << "GOT[" << index
<< "] is readonly and cannot be converted to "
"writable. abort hooking.";
return -1;
}
}
Expand All @@ -289,19 +301,21 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) {
originalInfo.oldFuncPtr =
reinterpret_cast<void*>(*reinterpret_cast<size_t*>(addr));
auto new_func_ptr = installer.newFuncPtr(originalInfo);

MLOG(HOOK, INFO) << "start replace GOT[" << index << "]"
<< pltTable->symbol_table << ", *" << addr << "="
<< new_func_ptr << ", original GOT[" << index
<< "]=" << *reinterpret_cast<void**>(addr);

*reinterpret_cast<size_t*>(addr) =
reinterpret_cast<size_t>(new_func_ptr);
MLOG(HOOK, INFO) << "store " << new_func_ptr << " to " << addr
<< " original value:"
<< *reinterpret_cast<void**>(addr);
// we will not recover the address protect
// TODO: move this to uninstall function
// if (!(prot & PROT_WRITE)) {
// mprotect(ALIGN_ADDR(addr), page_size, prot);
// }
MLOG(HOOK, INFO) << "replace:" << pltTable->symbol_table + idx
<< " with " << pltTable->symbol_table + idx
<< " success";
MLOG(HOOK, INFO) << "replace " << pltTable->symbol_table + idx
<< " success. ";
if (installer.onSuccess) {
installer.onSuccess();
}
Expand Down
9 changes: 9 additions & 0 deletions lib/hooks/print_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
XpuRuntimePrintfHook::instance()->save_to_internel_buffer(buf); \
MLOG(PROFILE, WARN) << buf;

/* __chk 系列多了一个flag参数*/
static int builtin_printf_chk(int flag, const char* fmt, ...) {
__internal_printf();
return 0;
}
static int builtin_fprintf_chk(void* stdcout, int flag, const char* fmt, ...) {
__internal_printf();
return 0;
}

static int builtin_printf(const char* fmt, ...) {
__internal_printf();
Expand Down Expand Up @@ -56,13 +61,17 @@ bool XpuRuntimePrintfHook::targetSym(const char* name) {
void* XpuRuntimePrintfHook::newFuncPtr(const hook::OriginalInfo& info) {
if (adt::StringRef("__printf_chk") == curSymName()) {
return reinterpret_cast<void*>(&builtin_printf_chk);
} else if (adt::StringRef("__fprintf_chk") == curSymName()) {
return reinterpret_cast<void*>(&builtin_fprintf_chk);
} else if (adt::StringRef("printf") == curSymName()) {
return reinterpret_cast<void*>(&builtin_printf);
} else if (adt::StringRef("fprintf") == curSymName() ||
adt::StringRef("__fprintf") == curSymName() ||
adt::StringRef("vfprintf") == curSymName()) {
return reinterpret_cast<void*>(&builtin_fprintf);
}
MLOG(HOOK, ERROR) << "cannot find function pointer for " << curSymName()
<< ", return NULL instead";
return nullptr;
}

Expand Down
5 changes: 3 additions & 2 deletions lib/statistic.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#include <iosfwd>
#include <set>
#include <unordered_map>
#include <string>
#include <unordered_map>

namespace hook {

Expand Down Expand Up @@ -45,7 +45,8 @@ class MemoryStatisticCollection {
return lib == other.lib && devId == other.devId &&
kind == other.kind;
}
friend std::ostream& operator<<(std::ostream& os, const PtrIdentity& id);
friend std::ostream& operator<<(std::ostream& os,
const PtrIdentity& id);
};
struct PtrIdentityHash {
size_t operator()(const PtrIdentity& id) const {
Expand Down
7 changes: 6 additions & 1 deletion src/cuda_mock/cuda_mock_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ def dump_to_cache(self):

gProfileDataCollection = ProfileDataCollection("gpu" if is_nvidia_gpu else "xpu")
gDefaultTargetLib = ["libxpucuda.so", "libcuda.so"]
gDefaultTargetSymbols = ["__printf_chk", "printf", "fprintf", "__fprintf", "vfprintf",]
gDefaultTargetSymbols = [
# 不带chk后缀的符号
"printf", "fprintf",
#带chk后缀的符号
"__printf_chk", "__fprintf_chk",
]
class __XpuRuntimeProfiler:
def __init__(self, target_libs = gDefaultTargetLib, target_symbols = gDefaultTargetSymbols):
print_hook_initialize(target_libs=target_libs, target_symbols=target_symbols)
Expand Down

0 comments on commit 72dacb5

Please sign in to comment.