diff --git a/lib/xpu_mock.cpp b/lib/xpu_mock.cpp index be11830..9fd5d38 100644 --- a/lib/xpu_mock.cpp +++ b/lib/xpu_mock.cpp @@ -45,8 +45,6 @@ DEF_FUNCTION_INT(xpu_malloc, void** pdevptr, uint64_t size, int kind) { int devId = 0; CHECK(origin_xpu_current_device, "xpu_current_device not binded"); - CHECK(origin_xpu_malloc, "xpu_malloc not binded"); - r = origin_xpu_current_device(&devId); if (r != 0) { return r; @@ -71,8 +69,6 @@ DEF_FUNCTION_INT(xpu_free, void* devptr) { int devId = 0; CHECK(origin_xpu_current_device, "xpu_current_device not binded"); - CHECK(origin_xpu_free, "xpu_free not binded"); - r = origin_xpu_current_device(&devId); if (r != 0) { return r; @@ -112,23 +108,59 @@ DEF_FUNCTION_INT(xpu_stream_destroy, void* stream) { //-------------------------- cuda api --------------------------// -DEF_FUNCTION_INT(cudaMalloc, void** devPtr, size_t size) { - return origin_cudaMalloc(devPtr, size); +DEF_FUNCTION_INT(cudaSetDevice, int device) { + return origin_cudaSetDevice(device); } -DEF_FUNCTION_INT(cudaFree, void* devPtr) { return origin_cudaFree(devPtr); } +DEF_FUNCTION_INT(cudaGetDevice, int* device) { + return origin_cudaGetDevice(device); +} -DEF_FUNCTION_INT(cudaMemcpy, void* dst, const void* src, size_t count, - int kind) { - return origin_cudaMemcpy(dst, src, count, kind); +DEF_FUNCTION_INT(cudaMalloc, void** devPtr, size_t size) { + int r = 0; + int devId = 0; + + CHECK(origin_cudaGetDevice, "cudaGetDevice not binded"); + r = origin_cudaGetDevice(&devId); + if (r != 0) { + return r; + } + + r = origin_cudaMalloc(devPtr, size); + if (r != 0) { + LOG(WARN) << "xpu cudaMalloc device memory failed!\n" + << hook::MemoryStatisticCollection::instance(); + return r; + } + + hook::MemoryStatisticCollection::instance().record_alloc( + hook::HookRuntimeContext::instance().curLibName(), devId, *devPtr, size, + /*kind=GLOBAL_MEM*/ 0); + + return r; } -DEF_FUNCTION_INT(cudaSetDevice, int device) { - return origin_cudaSetDevice(device); +DEF_FUNCTION_INT(cudaFree, void* devPtr) { + int r = 0; + int devId = 0; + + CHECK(origin_cudaGetDevice, "cudaGetDevice not binded"); + r = origin_cudaGetDevice(&devId); + if (r != 0) { + return r; + } + + r = origin_cudaFree(devPtr); + + hook::MemoryStatisticCollection::instance().record_free( + hook::HookRuntimeContext::instance().curLibName(), devId, devPtr); + + return r; } -DEF_FUNCTION_INT(cudaGetDevice, int* device) { - return origin_cudaGetDevice(device); +DEF_FUNCTION_INT(cudaMemcpy, void* dst, const void* src, size_t count, + int kind) { + return origin_cudaMemcpy(dst, src, count, kind); } #define BUILD_FEATURE(name) \