Skip to content

Commit

Permalink
feat: Support authentication with HuggingFace login (#17881)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jul 26, 2024
1 parent 2a6ebec commit e0b4012
Showing 1 changed file with 50 additions and 12 deletions.
62 changes: 50 additions & 12 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ use url::Url;
#[cfg(feature = "file_cache")]
use crate::file_cache::get_env_file_cache_ttl;
#[cfg(feature = "aws")]
use crate::path_utils::resolve_homedir;
#[cfg(feature = "aws")]
use crate::pl_async::with_concurrency_budget;

#[cfg(feature = "aws")]
Expand Down Expand Up @@ -211,6 +209,8 @@ fn read_config(
builder: &mut AmazonS3Builder,
items: &[(&Path, &[(&str, AmazonS3ConfigKey)])],
) -> Option<()> {
use crate::path_utils::resolve_homedir;

for (path, keys) in items {
if keys
.iter()
Expand Down Expand Up @@ -474,26 +474,64 @@ impl CloudOptions {
CloudType::Hf => {
#[cfg(feature = "http")]
{
let mut this = Self::default();
use polars_core::config;

if let Ok(v) = std::env::var("HF_TOKEN") {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
}
use crate::path_utils::resolve_homedir;

let mut this = Self::default();
let mut token = None;
let verbose = config::verbose();

for (i, (k, v)) in config.into_iter().enumerate() {
let (k, v) = (k.as_ref(), v.into());

if i == 0 && k == "token" {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
if verbose {
eprintln!("HF token sourced from storage_options");
}
token = Some(v);
} else {
polars_bail!(ComputeError: "unknown configuration key: {}", k)
polars_bail!(ComputeError: "unknown configuration key for HF: {}", k)
}
}

token = token
.or_else(|| {
let v = std::env::var("HF_TOKEN").ok();
if v.is_some() && verbose {
eprintln!("HF token sourced from HF_TOKEN env var");
}
v
})
.or_else(|| {
let hf_home = std::env::var("HF_HOME");
let hf_home = hf_home.as_deref();
let hf_home = hf_home.unwrap_or("~/.cache/huggingface");
let hf_home = resolve_homedir(std::path::Path::new(&hf_home));
let cached_token_path = hf_home.join("token");

let v = std::string::String::from_utf8(
std::fs::read(&cached_token_path).ok()?,
)
.ok()
.filter(|x| !x.is_empty());

if v.is_some() && verbose {
eprintln!(
"HF token sourced from {}",
cached_token_path.to_str().unwrap()
);
}

v
});

if let Some(v) = token {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
}

Ok(this)
}
#[cfg(not(feature = "http"))]
Expand Down

0 comments on commit e0b4012

Please sign in to comment.