diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index d29e4b10b6336..459de0f4242d7 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -117,9 +117,8 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; namespace onnxruntime { -//ProviderHost* g_host{}; - -ProviderInfo_CUDA* GetProviderInfo_CUDA(); +ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); +ProviderInfo_CUDA& GetProviderInfo_CUDA(); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -195,15 +194,15 @@ struct ProviderHostImpl : ProviderHost { void CPUAllocator__Free(CPUAllocator* p, void* allocation) override { return p->CPUAllocator::Free(allocation); } #ifdef USE_CUDA - std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA()->CreateCUDAAllocator(device_id, name); } - std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA()->CreateCUDAPinnedAllocator(device_id, name); } - std::unique_ptr CreateGPUDataTransfer(void* stream) override { return GetProviderInfo_CUDA()->CreateGPUDataTransfer(stream); } + std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA().CreateCUDAAllocator(device_id, name); } + std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_CUDA().CreateCUDAPinnedAllocator(device_id, name); } + std::unique_ptr CreateGPUDataTransfer(void* stream) override { return GetProviderInfo_CUDA().CreateGPUDataTransfer(stream); } - void cuda__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) override { return GetProviderInfo_CUDA()->cuda__Impl_Cast(stream, input_data, output_data, count); } - void cuda__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) override { return GetProviderInfo_CUDA()->cuda__Impl_Cast(stream, input_data, output_data, count); } + void cuda__Impl_Cast(void* stream, const int64_t* input_data, int32_t* output_data, size_t count) override { return GetProviderInfo_CUDA().cuda__Impl_Cast(stream, input_data, output_data, count); } + void cuda__Impl_Cast(void* stream, const int32_t* input_data, int64_t* output_data, size_t count) override { return GetProviderInfo_CUDA().cuda__Impl_Cast(stream, input_data, output_data, count); } - bool CudaCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) override { return GetProviderInfo_CUDA()->CudaCall_false(retCode, exprString, libName, successCode, msg); } - bool CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) override { return GetProviderInfo_CUDA()->CudaCall_true(retCode, exprString, libName, successCode, msg); } + bool CudaCall_false(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) override { return GetProviderInfo_CUDA().CudaCall_false(retCode, exprString, libName, successCode, msg); } + bool CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) override { return GetProviderInfo_CUDA().CudaCall_true(retCode, exprString, libName, successCode, msg); } #endif std::string GetEnvironmentVar(const std::string& var_name) override { return Env::Default().GetEnvironmentVar(var_name); } @@ -1048,7 +1047,7 @@ void UnloadSharedProviders() { // Used by test code std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name) { - if (auto* info = onnxruntime::GetProviderInfo_CUDA()) + if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->CreateCUDAPinnedAllocator(device_id, name); return nullptr; @@ -1095,10 +1094,17 @@ ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { return nullptr; } -ProviderInfo_CUDA* GetProviderInfo_CUDA() { +ProviderInfo_CUDA* TryGetProviderInfo_CUDA() { if (auto* provider = s_library_cuda.Get()) return reinterpret_cast(provider->GetInfo()); - LOGS_DEFAULT(WARNING) << "GetProviderInfo_CUDA called, returning nullptr"; + + return nullptr; +} + +ProviderInfo_CUDA& GetProviderInfo_CUDA() { + if(auto* info = TryGetProviderInfo_CUDA()) + return *info; + ORT_THROW("CUDA Provider not available, can't get interface for it"); } @@ -1108,13 +1114,13 @@ void CopyGpuToCpu( const size_t size, const OrtMemoryInfo& dst_location, const OrtMemoryInfo& src_location) { - if (auto* info = onnxruntime::GetProviderInfo_CUDA()) + if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->CopyGpuToCpu(dst_ptr, src_ptr, size, dst_location, src_location); ORT_THROW("GPU-to-CPU copy is not implemented."); } void cudaMemcpy_HostToDevice(void* dst, const void* src, size_t count) { - if (auto* info = onnxruntime::GetProviderInfo_CUDA()) + if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->cudaMemcpy_HostToDevice(dst, src, count); ORT_THROW("cudaMemcpy_HostToDevice is not implemented."); } @@ -1122,7 +1128,7 @@ void cudaMemcpy_HostToDevice(void* dst, const void* src, size_t count) { #if defined(USE_CUDA) && defined(ORT_USE_NCCL) && defined(USE_NCCL_P2P) namespace cuda { INcclService& INcclService::GetInstance() { - return GetProviderInfo_CUDA()->GetINcclService(); + return GetProviderInfo_CUDA().GetINcclService(); } } // namespace cuda #endif @@ -1192,17 +1198,17 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) { API_IMPL_BEGIN - if (auto* info = onnxruntime::GetProviderInfo_CUDA()) + if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->SetCurrentGpuDeviceId(device_id); - return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled."); + return CreateStatus(ORT_FAIL, "CUDA execution provider is either not enabled or not available."); API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) { API_IMPL_BEGIN - if (auto* info = onnxruntime::GetProviderInfo_CUDA()) + if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->GetCurrentGpuDeviceId(device_id); - return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled."); + return CreateStatus(ORT_FAIL, "CUDA execution provider is either not enabled or not available."); API_IMPL_END } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1d224f1d5e472..e1aff9a706839 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -40,6 +40,7 @@ EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { namespace { + class WindowsThread : public EnvThread { private: struct Param { diff --git a/onnxruntime/python/_pybind_state.py b/onnxruntime/python/_pybind_state.py index e7ea3efa88524..2b326410500fa 100644 --- a/onnxruntime/python/_pybind_state.py +++ b/onnxruntime/python/_pybind_state.py @@ -7,58 +7,12 @@ """ import os import platform -import sys from . import _ld_preload # noqa: F401 if platform.system() == "Windows": from . import version_info - if version_info.use_cuda: - cuda_version_major, cuda_version_minor = version_info.cuda_version.split(".") - if int(cuda_version_major) < 11: - # Prior to CUDA 11 both major and minor version at build time/runtime have to match. - cuda_env_variable = f"CUDA_PATH_V{cuda_version_major}_{cuda_version_minor}" - if cuda_env_variable not in os.environ: - raise ImportError(f"CUDA Toolkit {version_info.cuda_version} not installed on the machine.") - else: - # With CUDA 11 and newer only the major version at build time/runtime has to match. - # Use the most recent minor version available. - cuda_env_variable = None - for i in range(9, -1, -1): - if f"CUDA_PATH_V{cuda_version_major}_{i}" in os.environ: - cuda_env_variable = f"CUDA_PATH_V{cuda_version_major}_{i}" - break - if not cuda_env_variable: - raise ImportError(f"CUDA Toolkit {cuda_version_major}.x not installed on the machine.") - - cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin") - - # prefer CUDNN_HOME if set. fallback to the CUDA install directory (would have required user to manually - # copy the cudnn dll there - cudnn_path = os.environ["CUDNN_HOME"] if "CUDNN_HOME" in os.environ else os.environ[cuda_env_variable] - cudnn_bin_dir = os.path.join(cudnn_path, "bin") - - if not os.path.isfile(os.path.join(cudnn_bin_dir, f"cudnn64_{version_info.cudnn_version}.dll")): - raise ImportError(f"cuDNN {version_info.cudnn_version} not installed in {cudnn_bin_dir}. " - f"Set the CUDNN_HOME environment variable to the path of the 'cuda' directory " - f"in your CUDNN installation if necessary.") - - if sys.version_info >= (3, 8): - # Python 3.8 (and later) doesn't search system PATH when loading DLLs, so the CUDA location needs to be - # specified explicitly using the new API introduced in Python 3.8. - os.add_dll_directory(cuda_bin_dir) - cuda_root = os.path.join(cuda_bin_dir, "..", "..") - for root, _, files in os.walk(cuda_root): - for f in files: - if f == "cupti.lib": - os.add_dll_directory(root) - else: - # Python 3.7 (and earlier) searches directories listed in PATH variable. - # Make sure that the target CUDA version is at the beginning (important if multiple CUDA versions are - # installed on the machine.) - os.environ["PATH"] = cuda_bin_dir + os.pathsep + os.environ["PATH"] - if version_info.vs2019 and platform.architecture()[0] == "64bit": if not os.path.isfile("C:\\Windows\\System32\\vcruntime140_1.dll"): raise ImportError( diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 7b28fed2d1faf..9d5804514373f 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -67,11 +67,11 @@ void CpuToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { #ifdef USE_CUDA void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes) { - GetProviderInfo_CUDA()->cudaMemcpy_HostToDevice(dst, src, num_bytes); + GetProviderInfo_CUDA().cudaMemcpy_HostToDevice(dst, src, num_bytes); } void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { - GetProviderInfo_CUDA()->cudaMemcpy_DeviceToHost(dst, src, num_bytes); + GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } const std::unordered_map* GetCudaToHostMemCpyFunction() { @@ -82,7 +82,7 @@ const std::unordered_map* GetCudaToHostMemCpy } bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id) { - int num_devices = GetProviderInfo_CUDA()->cudaGetDeviceCount(); + int num_devices = GetProviderInfo_CUDA().cudaGetDeviceCount(); if (0 == num_devices) { LOGS(logger, WARNING) << "your system does not have a CUDA capable device."; @@ -105,7 +105,7 @@ AllocatorPtr GetCudaAllocator(OrtDevice::DeviceId id) { if (id_to_allocator_map->find(id) == id_to_allocator_map->end()) { // TODO: Expose knobs so that users can set fields associated with OrtArenaCfg so that we can pass it to the following method - id_to_allocator_map->insert({id, GetProviderInfo_CUDA()->CreateCudaAllocator(id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, nullptr)}); + id_to_allocator_map->insert({id, GetProviderInfo_CUDA().CreateCudaAllocator(id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, nullptr)}); } return (*id_to_allocator_map)[id]; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 1238f9e1e181b..62d009f30be0d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -470,24 +470,33 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector #endif } else if (type == kCudaExecutionProvider) { #ifdef USE_CUDA - const auto it = provider_options_map.find(type); - CUDAExecutionProviderInfo info{}; - if (it != provider_options_map.end()) - GetProviderInfo_CUDA()->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info); - else { - info.device_id = cuda_device_id; - info.gpu_mem_limit = gpu_mem_limit; - info.arena_extend_strategy = arena_extend_strategy; - info.cudnn_conv_algo_search = cudnn_conv_algo_search; - info.do_copy_in_default_stream = do_copy_in_default_stream; - info.external_allocator_info = external_allocator_info; - } + if(auto* cuda_provider_info = TryGetProviderInfo_CUDA()) + { + const auto it = provider_options_map.find(type); + CUDAExecutionProviderInfo info{}; + if (it != provider_options_map.end()) + cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info); + else { + info.device_id = cuda_device_id; + info.gpu_mem_limit = gpu_mem_limit; + info.arena_extend_strategy = arena_extend_strategy; + info.cudnn_conv_algo_search = cudnn_conv_algo_search; + info.do_copy_in_default_stream = do_copy_in_default_stream; + info.external_allocator_info = external_allocator_info; + } - // This variable is never initialized because the APIs by which is it should be initialized are deprecated, however they still - // exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can - // since FromProviderOptions might contain external CUDA allocator. - external_allocator_info = info.external_allocator_info; - RegisterExecutionProvider(sess, *GetProviderInfo_CUDA()->CreateExecutionProviderFactory(info)); + // This variable is never initialized because the APIs by which is it should be initialized are deprecated, however they still + // exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can + // since FromProviderOptions might contain external CUDA allocator. + external_allocator_info = info.external_allocator_info; + RegisterExecutionProvider(sess, *cuda_provider_info->CreateExecutionProviderFactory(info)); + } + else + { + if(!Env::Default().GetEnvironmentVar("CUDA_PATH").empty()) { + ORT_THROW("CUDA_PATH is set but CUDA wasn't able to be loaded. Please install the correct version of CUDA and cuDNN as mentioned in the GPU requirements page, make sure they're in the PATH, and that your GPU is supported."); + } + } #endif } else if (type == kRocmExecutionProvider) { #ifdef USE_ROCM diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 5193d38dd6607..f6fe658f8f2a3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -161,7 +161,8 @@ extern std::string nuphar_settings; #if defined(USE_CUDA) || defined(USE_ROCM) #ifdef USE_CUDA namespace onnxruntime { -ProviderInfo_CUDA* GetProviderInfo_CUDA(); +ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); +ProviderInfo_CUDA& GetProviderInfo_CUDA(); namespace python { // TODO remove deprecated global config extern OrtCudnnConvAlgoSearch cudnn_conv_algo_search; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 755f25c2cbee9..7c41528949861 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -68,7 +68,7 @@ struct KernelRegistryAndStatus { namespace onnxruntime { #ifdef USE_CUDA -ProviderInfo_CUDA* GetProviderInfo_CUDA(); +ProviderInfo_CUDA& GetProviderInfo_CUDA(); #endif class FuseAdd : public OpKernel { @@ -353,7 +353,7 @@ void RunModelWithBindingMatMul(InferenceSession& session_object, cpu_allocator); #ifdef USE_CUDA cudaStream_t stream = static_cast(gpu_provider->GetComputeStream()); - st = GetProviderInfo_CUDA()->CreateGPUDataTransfer(stream)->CopyTensor(rtensor, *cpu_tensor.get(), 0); + st = GetProviderInfo_CUDA().CreateGPUDataTransfer(stream)->CopyTensor(rtensor, *cpu_tensor.get(), 0); #elif USE_ROCM hipStream_t stream = static_cast(gpu_provider->GetComputeStream()); st = GPUDataTransfer(stream).CopyTensor(rtensor, *cpu_tensor.get(), 0);