From e0b4012f7d585ba801a48d5f972328344a6a82da Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 26 Jul 2024 16:52:02 +1000 Subject: [PATCH] feat: Support authentication with HuggingFace login (#17881) --- crates/polars-io/src/cloud/options.rs | 62 +++++++++++++++++++++------ 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index c9958f4b745e..f544765ddad9 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -39,8 +39,6 @@ use url::Url; #[cfg(feature = "file_cache")] use crate::file_cache::get_env_file_cache_ttl; #[cfg(feature = "aws")] -use crate::path_utils::resolve_homedir; -#[cfg(feature = "aws")] use crate::pl_async::with_concurrency_budget; #[cfg(feature = "aws")] @@ -211,6 +209,8 @@ fn read_config( builder: &mut AmazonS3Builder, items: &[(&Path, &[(&str, AmazonS3ConfigKey)])], ) -> Option<()> { + use crate::path_utils::resolve_homedir; + for (path, keys) in items { if keys .iter() @@ -474,26 +474,64 @@ impl CloudOptions { CloudType::Hf => { #[cfg(feature = "http")] { - let mut this = Self::default(); + use polars_core::config; - if let Ok(v) = std::env::var("HF_TOKEN") { - this.config = Some(CloudConfig::Http { - headers: vec![("Authorization".into(), format!("Bearer {}", v))], - }) - } + use crate::path_utils::resolve_homedir; + + let mut this = Self::default(); + let mut token = None; + let verbose = config::verbose(); for (i, (k, v)) in config.into_iter().enumerate() { let (k, v) = (k.as_ref(), v.into()); if i == 0 && k == "token" { - this.config = Some(CloudConfig::Http { - headers: vec![("Authorization".into(), format!("Bearer {}", v))], - }) + if verbose { + eprintln!("HF token sourced from storage_options"); + } + token = Some(v); } else { - polars_bail!(ComputeError: "unknown configuration key: {}", k) + polars_bail!(ComputeError: "unknown configuration key for HF: {}", k) } } + token = token + .or_else(|| { + let v = std::env::var("HF_TOKEN").ok(); + if v.is_some() && verbose { + eprintln!("HF token sourced from HF_TOKEN env var"); + } + v + }) + .or_else(|| { + let hf_home = std::env::var("HF_HOME"); + let hf_home = hf_home.as_deref(); + let hf_home = hf_home.unwrap_or("~/.cache/huggingface"); + let hf_home = resolve_homedir(std::path::Path::new(&hf_home)); + let cached_token_path = hf_home.join("token"); + + let v = std::string::String::from_utf8( + std::fs::read(&cached_token_path).ok()?, + ) + .ok() + .filter(|x| !x.is_empty()); + + if v.is_some() && verbose { + eprintln!( + "HF token sourced from {}", + cached_token_path.to_str().unwrap() + ); + } + + v + }); + + if let Some(v) = token { + this.config = Some(CloudConfig::Http { + headers: vec![("Authorization".into(), format!("Bearer {}", v))], + }) + } + Ok(this) } #[cfg(not(feature = "http"))]