diff --git a/src/operations.rs b/src/operations.rs index 9a78dbc5..4ab415ab 100644 --- a/src/operations.rs +++ b/src/operations.rs @@ -14,7 +14,7 @@ use crate::utils::get_bin_dir; use crate::utils::get_julianightlies_base_url; use crate::utils::get_juliaserver_base_url; use crate::utils::is_valid_julia_path; -use anyhow::{anyhow, bail, Context, Result}; +use anyhow::{anyhow, bail, Context, Error, Result}; use bstr::ByteSlice; use bstr::ByteVec; use console::style; @@ -1430,8 +1430,6 @@ pub fn update_version_db(paths: &GlobalPaths) -> Result<()> { let online_dbversion = download_juliaup_version(&dbversion_url.to_string()) .with_context(|| "Failed to download current version db version.")?; - let direct_download_etags = download_direct_download_etags(&old_config_file.data)?; - let bundled_dbversion = get_bundled_dbversion() .with_context(|| "Failed to determine the bundled version db version.")?; @@ -1463,6 +1461,8 @@ pub fn update_version_db(paths: &GlobalPaths) -> Result<()> { delete_old_version_db = true; } + let direct_download_etags = download_direct_download_etags(&old_config_file.data)?; + let mut new_config_file = load_mut_config_db(paths).with_context(|| { "`run_command_update_version_db` command failed to load configuration db." })?; @@ -1520,83 +1520,142 @@ pub fn update_version_db(paths: &GlobalPaths) -> Result<()> { Ok(()) } +// A generic function to run a function with a timeout and a message to inform the user why it is taking so long +fn run_with_slow_message(func: F, timeout_secs: u64, message: &str) -> Result +where + F: FnOnce() -> Result + Send + 'static, + R: Send + 'static, +{ + use std::sync::mpsc::channel; + use std::thread; + use std::time::Duration; + + let (tx, rx) = channel(); + + // Run the function in a separate thread + thread::spawn(move || { + let result = func(); + tx.send(result).unwrap(); + }); + + // Attempt to receive the result with a timeout + match rx.recv_timeout(Duration::from_secs(timeout_secs)) { + Ok(result) => result, + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + // Function has not completed within timeout_secs seconds, inform why + eprintln!("{}", message); + + // Now wait for the function to complete + let result = rx.recv().unwrap(); + result + } + Err(e) => panic!("Error receiving result: {:?}", e), + } +} + #[cfg(windows)] fn download_direct_download_etags(config_data: &JuliaupConfig) -> Result> { use windows::core::HSTRING; + use windows::Foundation::Uri; + use windows::Web::Http::HttpClient; use windows::Web::Http::HttpMethod; use windows::Web::Http::HttpRequestMessage; - let http_client = - windows::Web::Http::HttpClient::new().with_context(|| "Failed to create HttpClient.")?; + let http_client = HttpClient::new().with_context(|| "Failed to create HttpClient.")?; - let requests: Vec<_> = config_data - .installed_channels - .iter() - .filter_map(|(channel_name, channel)| { - if let JuliaupConfigChannel::DirectDownloadChannel { - path: _, - url, - local_etag: _, - server_etag: _, - version: _, - } = channel - { - let request_uri = - windows::Foundation::Uri::CreateUri(&windows::core::HSTRING::from(url)) - .with_context(|| "Failed to convert url string to Uri.") - .unwrap(); + let mut requests = Vec::new(); - let request = - HttpRequestMessage::Create(&HttpMethod::Head().unwrap(), &request_uri).unwrap(); + for (channel_name, channel) in &config_data.installed_channels { + if let JuliaupConfigChannel::DirectDownloadChannel { url, .. } = channel { + let http_client = http_client.clone(); + let url_clone = url.clone(); + let channel_name_clone = channel_name.clone(); + let message = format!( + "{} for new version on channel '{}' is taking a while... This can be slow due to server caching", + style("Checking").green().bold(), + channel_name + ); - let request = http_client.SendRequestAsync(&request).unwrap(); + let etag = run_with_slow_message( + move || { + let request_uri = Uri::CreateUri(&HSTRING::from(&url_clone)) + .with_context(|| format!("Failed to create URI from {}", &url_clone))?; - Some((channel_name, request)) - } else { - None - } - }) - .collect(); + let request = HttpRequestMessage::Create(&HttpMethod::Head()?, &request_uri) + .with_context(|| "Failed to create HttpRequestMessage.")?; - let requests: Vec<_> = requests - .into_iter() - .map(|(channel_name, request)| { - ( - channel_name.clone(), - request - .get() - .unwrap() - .Headers() - .unwrap() - .Lookup(&HSTRING::from("etag")) - .unwrap() - .to_string(), - ) - }) - .collect(); + let async_op = http_client + .SendRequestAsync(&request) + .map_err(|e| anyhow!("Failed to send request: {:?}", e))?; + + let response = async_op + .get() + .map_err(|e| anyhow!("Failed to get response: {:?}", e))?; + + let headers = response + .Headers() + .map_err(|e| anyhow!("Failed to get headers: {:?}", e))?; + + let etag = headers + .Lookup(&HSTRING::from("ETag")) + .map_err(|e| anyhow!("ETag header not found: {:?}", e))? + .to_string(); + + Ok::(etag) + }, + 3, // Timeout in seconds + &message, + )?; + + requests.push((channel_name_clone, etag)); + } + } Ok(requests) } #[cfg(not(windows))] fn download_direct_download_etags(config_data: &JuliaupConfig) -> Result> { - let client = reqwest::blocking::Client::new(); + use std::sync::Arc; + + let client = Arc::new(reqwest::blocking::Client::new()); let mut requests = Vec::new(); for (channel_name, channel) in &config_data.installed_channels { if let JuliaupConfigChannel::DirectDownloadChannel { url, .. } = channel { - let etag = client - .head(url) - .send()? - .headers() - .get("etag") - .ok_or_else(|| anyhow!("ETag header not found in response"))? - .to_str() - .map_err(|e| anyhow!("Failed to parse ETag header: {}", e))? - .to_string(); - - requests.push((channel_name.clone(), etag)); + let client = Arc::clone(&client); + let url_clone = url.clone(); + let channel_name_clone = channel_name.clone(); + let message = format!( + "{} for new version on channel '{}' is taking a while... This can be slow due to server caching", + style("Checking").green().bold(), + channel_name + ); + + let etag = run_with_slow_message( + move || { + let response = client.head(&url_clone).send().with_context(|| { + format!("Failed to send HEAD request to {}", &url_clone) + })?; + + let etag = response + .headers() + .get("etag") + .ok_or_else(|| { + anyhow!("ETag header not found in response from {}", &url_clone) + })? + .to_str() + .map_err(|e| anyhow!("Failed to parse ETag header: {}", e))? + .to_string(); + + Ok::(etag) + }, + 3, // Timeout in seconds + &message, + )?; + + requests.push((channel_name_clone, etag)); } }