From 041cc497a60d2913a004c708d516f10a8e3dea20 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 28 Feb 2021 11:25:13 +0800 Subject: [PATCH 1/5] add support to cuda --- onnxruntime-sys/build.rs | 12 +++++++++++- onnxruntime-sys/wrapper.h | 4 ++++ onnxruntime/Cargo.toml | 1 + onnxruntime/src/session.rs | 23 +++++++++++++++++++++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 1589555a..29c0c84b 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -66,7 +66,17 @@ fn generate_bindings(_include_dir: &Path) { #[cfg(feature = "generate-bindings")] fn generate_bindings(include_dir: &Path) { - let clang_arg = format!("-I{}", include_dir.display()); + let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); + let clang_arg = match env::var(ORT_ENV_GPU) { + Ok(cuda_env) => match cuda_env.to_lowercase().as_str() { + "1" | "yes" | "true" | "on" => match os.as_str() { + "linux" | "windows" => "-gpu", + _ => format!("-I{}", include_dir.display()) + }, + _ => format!("-I{}", include_dir.display()), + }, + Err(_) => format!("-I{}", include_dir.display()), + }; // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=wrapper.h"); diff --git a/onnxruntime-sys/wrapper.h b/onnxruntime-sys/wrapper.h index e63d3523..8b171332 100644 --- a/onnxruntime-sys/wrapper.h +++ b/onnxruntime-sys/wrapper.h @@ -1 +1,5 @@ #include "onnxruntime_c_api.h" +#include "cpu_provider_factory.h" +#ifdef ORT_USE_CUDA +#include "cuda_provider_factory.h" +#endif \ No newline at end of file diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 9ceec820..13d03a2b 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -36,6 +36,7 @@ tracing-subscriber = "0.2" ureq = "1.5.1" [features] +cuda = [] # Fetch model from ONNX Model Zoo (https://github.com/onnx/models) model-fetching = ["ureq"] # Disable build script; used for https://docs.rs diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..4340015b 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -124,6 +124,29 @@ impl<'a> SessionBuilder<'a> { Ok(self) } + /// Set the session use cpu provider + pub fn with_cpu(mut self, use_arena: bool) -> Result> { + unsafe { + sys::OrtSessionOptionsAppendExecutionProvider_CPU( + self.session_options_ptr, + use_arena.into(), + ); + } + Ok(self) + } + + /// Set the session use cuda provider + #[cfg(feature = "cuda")] + pub fn with_cuda(mut self, device_id: i32) -> Result> { + unsafe { + sys::OrtSessionOptionsAppendExecutionProvider_CUDA( + self.session_options_ptr, + device_id.into(), + ); + } + Ok(self) + } + /// Set the session's allocator /// /// Defaults to [`AllocatorType::Arena`](../enum.AllocatorType.html#variant.Arena) From a0bdfbd9962e308966754604a45a863af38f379a Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 28 Feb 2021 11:49:35 +0800 Subject: [PATCH 2/5] fix format --- onnxruntime-sys/build.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index 29c0c84b..fea29001 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -70,8 +70,8 @@ fn generate_bindings(include_dir: &Path) { let clang_arg = match env::var(ORT_ENV_GPU) { Ok(cuda_env) => match cuda_env.to_lowercase().as_str() { "1" | "yes" | "true" | "on" => match os.as_str() { - "linux" | "windows" => "-gpu", - _ => format!("-I{}", include_dir.display()) + "linux" | "windows" => format!("-I{} -DORT_USE_CUDA", include_dir.display()), + _ => format!("-I{}", include_dir.display()), }, _ => format!("-I{}", include_dir.display()), }, From 641024de3cf59518f63f618cfcec6586345570ca Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 28 Feb 2021 13:40:41 +0800 Subject: [PATCH 3/5] fix cuda DEFINE --- onnxruntime-sys/build.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index fea29001..e0c46f2a 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -67,15 +67,16 @@ fn generate_bindings(_include_dir: &Path) { #[cfg(feature = "generate-bindings")] fn generate_bindings(include_dir: &Path) { let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); - let clang_arg = match env::var(ORT_ENV_GPU) { + let clang_arg = format!("-I{}", include_dir.display()); + let clang_cuda_arg = match env::var(ORT_ENV_GPU) { Ok(cuda_env) => match cuda_env.to_lowercase().as_str() { "1" | "yes" | "true" | "on" => match os.as_str() { - "linux" | "windows" => format!("-I{} -DORT_USE_CUDA", include_dir.display()), - _ => format!("-I{}", include_dir.display()), + "linux" | "windows" => "-DORT_USE_CUDA", + _ => "", }, - _ => format!("-I{}", include_dir.display()), + _ => "", }, - Err(_) => format!("-I{}", include_dir.display()), + Err(_) => "", }; // Tell cargo to invalidate the built crate whenever the wrapper changes @@ -91,6 +92,8 @@ fn generate_bindings(include_dir: &Path) { .header("wrapper.h") // The current working directory is 'onnxruntime-sys' .clang_arg(clang_arg) + // Add define ORT_USE_CUDA + .clang_arg(clang_cuda_arg) // Tell cargo to invalidate the built crate whenever any of the // included header files changed. .parse_callbacks(Box::new(bindgen::CargoCallbacks)) From 9f63b2f7bca0b73b104e047a4df28435658bcfc8 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 28 Feb 2021 15:33:07 +0800 Subject: [PATCH 4/5] update sample.rs --- onnxruntime-sys/wrapper.h | 2 +- onnxruntime/examples/sample.rs | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/onnxruntime-sys/wrapper.h b/onnxruntime-sys/wrapper.h index 8b171332..b11f5275 100644 --- a/onnxruntime-sys/wrapper.h +++ b/onnxruntime-sys/wrapper.h @@ -2,4 +2,4 @@ #include "cpu_provider_factory.h" #ifdef ORT_USE_CUDA #include "cuda_provider_factory.h" -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index d16d08da..26543142 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -25,6 +25,15 @@ fn run() -> Result<(), Error> { tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + #[cfg(feature = "cuda")] + let environment = Environment::builder() + .with_name("test") + .with_gpu(0) + // The ONNX Runtime's log level can be different than the one of the wrapper crate or the application. + .with_log_level(LoggingLevel::Info) + .build()?; + + #[cfg(not(feature = "cuda"))] let environment = Environment::builder() .with_name("test") // The ONNX Runtime's log level can be different than the one of the wrapper crate or the application. From d094718818850b0ecd0a5949e381fa37af4379b1 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Mon, 1 Mar 2021 22:11:44 +0800 Subject: [PATCH 5/5] remove unused mut --- onnxruntime/src/session.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 4340015b..e22b7ee0 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -125,7 +125,7 @@ impl<'a> SessionBuilder<'a> { } /// Set the session use cpu provider - pub fn with_cpu(mut self, use_arena: bool) -> Result> { + pub fn with_cpu(self, use_arena: bool) -> Result> { unsafe { sys::OrtSessionOptionsAppendExecutionProvider_CPU( self.session_options_ptr, @@ -137,7 +137,7 @@ impl<'a> SessionBuilder<'a> { /// Set the session use cuda provider #[cfg(feature = "cuda")] - pub fn with_cuda(mut self, device_id: i32) -> Result> { + pub fn with_cuda(self, device_id: i32) -> Result> { unsafe { sys::OrtSessionOptionsAppendExecutionProvider_CUDA( self.session_options_ptr,