From a2cfc24ff695340d345dd5c75894a8e323b95146 Mon Sep 17 00:00:00 2001 From: Franklin Delehelle Date: Wed, 11 Dec 2024 14:51:35 +0200 Subject: [PATCH 1/2] fix: better CUDA detection --- crates/cudart-sys/Cargo.toml | 4 ++-- crates/cudart-sys/src/utils.rs | 33 ++++++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/crates/cudart-sys/Cargo.toml b/crates/cudart-sys/Cargo.toml index 77f7233..282e123 100644 --- a/crates/cudart-sys/Cargo.toml +++ b/crates/cudart-sys/Cargo.toml @@ -11,7 +11,7 @@ name = "era_cudart_sys" description = "Raw CUDA bindings for ZKsync" [dependencies] -serde_json = "1.0" +regex-lite = "0.1" [build-dependencies] -serde_json = "1.0" +regex-lite = "0.1" diff --git a/crates/cudart-sys/src/utils.rs b/crates/cudart-sys/src/utils.rs index 025caf9..a514cb6 100644 --- a/crates/cudart-sys/src/utils.rs +++ b/crates/cudart-sys/src/utils.rs @@ -4,12 +4,15 @@ use std::path::{Path, PathBuf}; pub fn get_cuda_path() -> Option<&'static Path> { #[cfg(target_os = "linux")] { - let path = Path::new("/usr/local/cuda"); - if path.exists() { - Some(path) - } else { - None + for path_name in [option_env!("CUDA_PATH"), Some("/usr/local/cuda")].iter().flatten() { + println!("trying {path_name}..."); + let path = Path::new(path_name); + if path.exists() { + println!("CUDA installation found at `{}`", path.display()); + return Some(path) + } } + None } #[cfg(target_os = "windows")] { @@ -42,12 +45,24 @@ pub fn get_cuda_lib_path() -> Option { pub fn get_cuda_version() -> Option { if let Some(version) = option_env!("CUDA_VERSION") { + println!("CUDA version defined in CUDA_VERSION as `{}`", version); Some(version.to_string()) } else if let Some(path) = get_cuda_path() { - let file = File::open(path.join("version.json")).expect("CUDA Toolkit should be installed"); - let reader = std::io::BufReader::new(file); - let value: serde_json::Value = serde_json::from_reader(reader).unwrap(); - Some(value["cuda"]["version"].as_str().unwrap().to_string()) + println!("inferring CUDA version from nvcc output..."); + let re = regex_lite::Regex::new(r"V(?\d{2}\.\d+\.\d+)").unwrap(); + let nvcc_out = std::process::Command::new("nvcc") + .arg("--version") + .output() + .expect("failed to start `nvcc`"); + let nvcc_str = std::str::from_utf8(&nvcc_out.stdout).expect("`nvcc` output is not UTF8"); + let captures = re.captures(&nvcc_str).unwrap(); + let version = captures + .get(0) + .expect("unable to find nvcc version in the form VMM.mm.pp in the output of `nvcc --version`:\n{nvcc_str}") + .as_str() + .to_string(); + println!("CUDA version inferred to be `{version}`."); + Some(version) } else { None } From aa59fa9600fe194650cffb2868207feb7a3389c5 Mon Sep 17 00:00:00 2001 From: Franklin Delehelle Date: Tue, 14 Jan 2025 13:28:46 +0100 Subject: [PATCH 2/2] remove prints --- crates/cudart-sys/src/utils.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/crates/cudart-sys/src/utils.rs b/crates/cudart-sys/src/utils.rs index a514cb6..2337371 100644 --- a/crates/cudart-sys/src/utils.rs +++ b/crates/cudart-sys/src/utils.rs @@ -4,12 +4,13 @@ use std::path::{Path, PathBuf}; pub fn get_cuda_path() -> Option<&'static Path> { #[cfg(target_os = "linux")] { - for path_name in [option_env!("CUDA_PATH"), Some("/usr/local/cuda")].iter().flatten() { - println!("trying {path_name}..."); + for path_name in [option_env!("CUDA_PATH"), Some("/usr/local/cuda")] + .iter() + .flatten() + { let path = Path::new(path_name); if path.exists() { - println!("CUDA installation found at `{}`", path.display()); - return Some(path) + return Some(path); } } None @@ -45,10 +46,8 @@ pub fn get_cuda_lib_path() -> Option { pub fn get_cuda_version() -> Option { if let Some(version) = option_env!("CUDA_VERSION") { - println!("CUDA version defined in CUDA_VERSION as `{}`", version); Some(version.to_string()) } else if let Some(path) = get_cuda_path() { - println!("inferring CUDA version from nvcc output..."); let re = regex_lite::Regex::new(r"V(?\d{2}\.\d+\.\d+)").unwrap(); let nvcc_out = std::process::Command::new("nvcc") .arg("--version") @@ -61,7 +60,6 @@ pub fn get_cuda_version() -> Option { .expect("unable to find nvcc version in the form VMM.mm.pp in the output of `nvcc --version`:\n{nvcc_str}") .as_str() .to_string(); - println!("CUDA version inferred to be `{version}`."); Some(version) } else { None