diff --git a/include/logger/logger_stl.h b/include/logger/logger_stl.h index 1e591c9..7e92b2e 100644 --- a/include/logger/logger_stl.h +++ b/include/logger/logger_stl.h @@ -1,7 +1,8 @@ #pragma once -#include #include +#include + #include "logger.h" namespace logger { diff --git a/lib/cuda_mock.cpp b/lib/cuda_mock.cpp index 95f3140..298a0c6 100644 --- a/lib/cuda_mock.cpp +++ b/lib/cuda_mock.cpp @@ -1,17 +1,17 @@ #include "cuda_mock.h" +#include #include #include #include -#include #include #include #include "GlobalVarMgr.h" -#include "env_mgr.h" #include "backtrace.h" #include "cuda_op_tracer.h" +#include "env_mgr.h" #include "hook.h" #include "logger/logger.h" diff --git a/lib/cuda_op_tracer.cpp b/lib/cuda_op_tracer.cpp index e360c9e..8b169d6 100644 --- a/lib/cuda_op_tracer.cpp +++ b/lib/cuda_op_tracer.cpp @@ -71,7 +71,8 @@ extern "C" CUresult cudaLaunchKernel_wrapper(const void* func, dim3 gridDim, hook::HookInstaller getHookInstaller(const HookerInfo& info) { static const char* symbolName = "cudaLaunchKernel"; - static void* newFuncAddr = reinterpret_cast(&cudaLaunchKernel_wrapper); + static void* newFuncAddr = + reinterpret_cast(&cudaLaunchKernel_wrapper); if (info.srcLib && info.targeLib && info.symbolName && info.newFuncPtr) { kCudaRTLibName = info.srcLib; kPytorchCudaLibName = info.targeLib; @@ -81,6 +82,17 @@ hook::HookInstaller getHookInstaller(const HookerInfo& info) { hook::HookInstaller installer; installer.isTargetLib = [](const char* libName) -> bool { CudaInfoCollection::instance().collectRtLib(libName); + // TODO 为啥这行打印不生效? + MLOG(HOOK, INFO) << "[installer.isTargetLib] libName:" + << kPytorchCudaLibName + << " targetlibName: " << libName; + + /* + 模糊匹配而不是精确匹配,因为: + 1. libname 可能包含版本号,例如 libtorch_cuda.so.1 + 2. libname包含了完整的路径,例如 + /usr/local/cuda-10.2/lib64/libcudart.so + */ if (std::string(libName).find(kPytorchCudaLibName) != std::string::npos) { return true; @@ -88,7 +100,8 @@ hook::HookInstaller getHookInstaller(const HookerInfo& info) { return false; }; installer.isTargetSymbol = [=](const char* symbol) -> bool { - // LOG(INFO) << "visit symbol:" << symbol; + MLOG(HOOK, INFO) << "[installer.isTargetSymbol] symbol:" << symbol + << " targetSymbolName:" << symbolName; if (std::string(symbol) == symbolName) { return true; } diff --git a/lib/cuda_op_tracer.h b/lib/cuda_op_tracer.h index f2fd1b9..3df6409 100644 --- a/lib/cuda_op_tracer.h +++ b/lib/cuda_op_tracer.h @@ -28,7 +28,7 @@ struct HookerInfo { const char* srcLib = nullptr; // the dynamic lib which the target symbol will be replace const char* targeLib = nullptr; - // the symbol which will be replace + // the symbol which will be replace const char* symbolName = nullptr; void* newFuncPtr = nullptr; }; diff --git a/lib/hook.cpp b/lib/hook.cpp index 7327ed3..4dfd96b 100644 --- a/lib/hook.cpp +++ b/lib/hook.cpp @@ -251,8 +251,14 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { CHECK(installer.newFuncPtr, "new_func_ptr can't be empty!"); if (!installer.isTargetLib(pltTable->lib_name.c_str()) || !isTargetLibFromEnv(pltTable->lib_name.c_str())) { + MLOG(HOOK, INFO) << "[install_hooker SKIP]" + << pltTable->lib_name.c_str() + << ", not target lib, skip"; return -1; } + MLOG(HOOK, INFO) + << "[install_hooker INSTALL] ====start install hook for lib " + << pltTable->lib_name.c_str() << "======"; size_t index = 0; while (index < pltTable->rela_plt_cnt) { auto plt = pltTable->rela_plt + index++; @@ -262,13 +268,16 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { size_t idx = ELF64_R_SYM(plt->r_info); idx = pltTable->dynsym[idx].st_name; - MLOG(HOOK, INFO) << pltTable->symbol_table + + void* addr = + reinterpret_cast(pltTable->base_header_addr + plt->r_offset); + + MLOG(HOOK, INFO) << "GOT[" << index << "]=" << addr << " symbol=" + << pltTable->symbol_table + idx; // got symbol name from STRTAB if (!installer.isTargetSymbol(pltTable->symbol_table + idx)) { continue; } - void* addr = - reinterpret_cast(pltTable->base_header_addr + plt->r_offset); + int prot = get_memory_permission(addr); if (prot == 0) { return -1; @@ -277,6 +286,9 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { // writable page if (mprotect(ALIGN_ADDR(addr), page_size, PROT_READ | PROT_WRITE) != 0) { + MLOG(HOOK, ERROR) << "GOT[" << index + << "] is readonly and cannot be converted to " + "writable. abort hooking."; return -1; } } @@ -289,19 +301,21 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { originalInfo.oldFuncPtr = reinterpret_cast(*reinterpret_cast(addr)); auto new_func_ptr = installer.newFuncPtr(originalInfo); + + MLOG(HOOK, INFO) << "start replace GOT[" << index << "]" + << pltTable->symbol_table << ", *" << addr << "=" + << new_func_ptr << ", original GOT[" << index + << "]=" << *reinterpret_cast(addr); + *reinterpret_cast(addr) = reinterpret_cast(new_func_ptr); - MLOG(HOOK, INFO) << "store " << new_func_ptr << " to " << addr - << " original value:" - << *reinterpret_cast(addr); // we will not recover the address protect // TODO: move this to uninstall function // if (!(prot & PROT_WRITE)) { // mprotect(ALIGN_ADDR(addr), page_size, prot); // } - MLOG(HOOK, INFO) << "replace:" << pltTable->symbol_table + idx - << " with " << pltTable->symbol_table + idx - << " success"; + MLOG(HOOK, INFO) << "replace " << pltTable->symbol_table + idx + << " success. "; if (installer.onSuccess) { installer.onSuccess(); } diff --git a/lib/hooks/print_hook.cpp b/lib/hooks/print_hook.cpp index f904943..dddf8a1 100644 --- a/lib/hooks/print_hook.cpp +++ b/lib/hooks/print_hook.cpp @@ -18,10 +18,15 @@ XpuRuntimePrintfHook::instance()->save_to_internel_buffer(buf); \ MLOG(PROFILE, WARN) << buf; +/* __chk 系列多了一个flag参数*/ static int builtin_printf_chk(int flag, const char* fmt, ...) { __internal_printf(); return 0; } +static int builtin_fprintf_chk(void* stdcout, int flag, const char* fmt, ...) { + __internal_printf(); + return 0; +} static int builtin_printf(const char* fmt, ...) { __internal_printf(); @@ -56,6 +61,8 @@ bool XpuRuntimePrintfHook::targetSym(const char* name) { void* XpuRuntimePrintfHook::newFuncPtr(const hook::OriginalInfo& info) { if (adt::StringRef("__printf_chk") == curSymName()) { return reinterpret_cast(&builtin_printf_chk); + } else if (adt::StringRef("__fprintf_chk") == curSymName()) { + return reinterpret_cast(&builtin_fprintf_chk); } else if (adt::StringRef("printf") == curSymName()) { return reinterpret_cast(&builtin_printf); } else if (adt::StringRef("fprintf") == curSymName() || @@ -63,6 +70,8 @@ void* XpuRuntimePrintfHook::newFuncPtr(const hook::OriginalInfo& info) { adt::StringRef("vfprintf") == curSymName()) { return reinterpret_cast(&builtin_fprintf); } + MLOG(HOOK, ERROR) << "cannot find function pointer for " << curSymName() + << ", return NULL instead"; return nullptr; } diff --git a/lib/statistic.h b/lib/statistic.h index e2de944..3db54af 100644 --- a/lib/statistic.h +++ b/lib/statistic.h @@ -4,8 +4,8 @@ #include #include -#include #include +#include namespace hook { @@ -45,7 +45,8 @@ class MemoryStatisticCollection { return lib == other.lib && devId == other.devId && kind == other.kind; } - friend std::ostream& operator<<(std::ostream& os, const PtrIdentity& id); + friend std::ostream& operator<<(std::ostream& os, + const PtrIdentity& id); }; struct PtrIdentityHash { size_t operator()(const PtrIdentity& id) const { diff --git a/src/cuda_mock/cuda_mock_impl.py b/src/cuda_mock/cuda_mock_impl.py index a4541e0..235bf8c 100644 --- a/src/cuda_mock/cuda_mock_impl.py +++ b/src/cuda_mock/cuda_mock_impl.py @@ -133,7 +133,12 @@ def dump_to_cache(self): gProfileDataCollection = ProfileDataCollection("gpu" if is_nvidia_gpu else "xpu") gDefaultTargetLib = ["libxpucuda.so", "libcuda.so"] -gDefaultTargetSymbols = ["__printf_chk", "printf", "fprintf", "__fprintf", "vfprintf",] +gDefaultTargetSymbols = [ + # 不带chk后缀的符号 + "printf", "fprintf", + #带chk后缀的符号 + "__printf_chk", "__fprintf_chk", + ] class __XpuRuntimeProfiler: def __init__(self, target_libs = gDefaultTargetLib, target_symbols = gDefaultTargetSymbols): print_hook_initialize(target_libs=target_libs, target_symbols=target_symbols)