Skip to content

Commit

Permalink
[ROCm][Windows] Disable roctracer-related code (pytorch#143329)
Browse files Browse the repository at this point in the history
Currently, the roctracer for Windows is not available. This PR disables any mentions of its usage for Windows, and creates dummy functions for Windows to keep compatibility with existing code, but which warn the user about the lack of Windows' availability.

Pull Request resolved: pytorch#143329
Approved by: https://github.com/sraikund16
  • Loading branch information
m-gallus authored and pytorchmergebot committed Jan 3, 2025
1 parent 891a86d commit 37e9da0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 4 deletions.
3 changes: 3 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,9 @@ if(USE_ROCM)
if(HIP_NEW_TYPE_ENUMS)
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
endif()
if(WIN32)
add_definitions(-DROCM_ON_WINDOWS)
endif()
add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
Expand Down
4 changes: 3 additions & 1 deletion torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ if(USE_ROCM)
USE_ROCM
__HIP_PLATFORM_AMD__
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB})
if(NOT WIN32)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB})
endif()
endif()

if(USE_XPU)
Expand Down
62 changes: 60 additions & 2 deletions torch/csrc/cuda/shared/nvtx.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#ifdef _WIN32
#include <wchar.h> // _wgetenv for nvtx
#endif

#ifndef ROCM_ON_WINDOWS
#ifdef TORCH_CUDA_USE_NVTX3
#include <nvtx3/nvtx3.hpp>
#else
#else // TORCH_CUDA_USE_NVTX3
#include <nvToolsExt.h>
#endif
#endif // TORCH_CUDA_USE_NVTX3
#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>

namespace torch::cuda::shared {

#ifndef ROCM_ON_WINDOWS
struct RangeHandle {
nvtxRangeId_t id;
const char* msg;
Expand Down Expand Up @@ -58,4 +64,56 @@ void initNvtxBindings(PyObject* module) {
nvtx.def("deviceRangeEnd", device_nvtxRangeEnd);
}

#else // ROCM_ON_WINDOWS

static void printUnavailableWarning() {
TORCH_WARN_ONCE("Warning: roctracer isn't available on Windows");
}

static int rangePushA(const std::string&) {
printUnavailableWarning();
return 0;
}

static int rangePop() {
printUnavailableWarning();
return 0;
}

static int rangeStartA(const std::string&) {
printUnavailableWarning();
return 0;
}

static void rangeEnd(int) {
printUnavailableWarning();
}

static void markA(const std::string&) {
printUnavailableWarning();
}

static py::object deviceRangeStart(const std::string&, std::intptr_t) {
printUnavailableWarning();
return py::none(); // Return an appropriate default object
}

static void deviceRangeEnd(py::object, std::intptr_t) {
printUnavailableWarning();
}

void initNvtxBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto nvtx = m.def_submodule("_nvtx", "unavailable");

nvtx.def("rangePushA", rangePushA);
nvtx.def("rangePop", rangePop);
nvtx.def("rangeStartA", rangeStartA);
nvtx.def("rangeEnd", rangeEnd);
nvtx.def("markA", markA);
nvtx.def("deviceRangeStart", deviceRangeStart);
nvtx.def("deviceRangeEnd", deviceRangeEnd);
}
#endif // ROCM_ON_WINDOWS

} // namespace torch::cuda::shared
20 changes: 19 additions & 1 deletion torch/csrc/profiler/stubs/cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#include <sstream>

#ifndef ROCM_ON_WINDOWS
#ifdef TORCH_CUDA_USE_NVTX3
#include <nvtx3/nvtx3.hpp>
#else
#include <nvToolsExt.h>
#endif

#else // ROCM_ON_WINDOWS
#include <c10/util/Exception.h>
#endif // ROCM_ON_WINDOWS
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/ApproximateClock.h>
#include <c10/util/irange.h>
Expand Down Expand Up @@ -71,6 +74,7 @@ struct CUDAMethods : public ProfilerStubs {
return ms * 1000.0;
}

#ifndef ROCM_ON_WINDOWS
void mark(const char* name) const override {
::nvtxMark(name);
}
Expand All @@ -82,6 +86,20 @@ struct CUDAMethods : public ProfilerStubs {
void rangePop() const override {
::nvtxRangePop();
}
#else // ROCM_ON_WINDOWS
static void printUnavailableWarning() {
TORCH_WARN_ONCE("Warning: roctracer isn't available on Windows");
}
void mark(const char* name) const override {
printUnavailableWarning();
}
void rangePush(const char* name) const override {
printUnavailableWarning();
}
void rangePop() const override {
printUnavailableWarning();
}
#endif

void onEachDevice(std::function<void(int)> op) const override {
at::cuda::OptionalCUDAGuard device_guard;
Expand Down

0 comments on commit 37e9da0

Please sign in to comment.