Skip to content

Commit

Permalink
feat: Support API token for scanning hf:// (#17682)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jul 19, 2024
1 parent a212ce9 commit c4738d2
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 77 deletions.
12 changes: 11 additions & 1 deletion crates/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,17 @@ async = [
"polars-error/regex",
"polars-parquet?/async",
]
cloud = ["object_store", "async", "polars-error/object_store", "url", "serde_json", "serde", "file_cache", "reqwest"]
cloud = [
"object_store",
"async",
"polars-error/object_store",
"url",
"serde_json",
"serde",
"file_cache",
"reqwest",
"http",
]
file_cache = ["async", "dep:blake3", "dep:fs4"]
aws = ["object_store/aws", "cloud", "reqwest"]
azure = ["object_store/azure", "cloud"]
Expand Down
9 changes: 3 additions & 6 deletions crates/polars-io/src/cloud/object_store_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ pub async fn build_object_store(
}
}

#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
let options = options.map(std::borrow::Cow::Borrowed).unwrap_or_default();

let cloud_type = CloudType::from_url(&parsed)?;
Expand Down Expand Up @@ -111,16 +110,14 @@ pub async fn build_object_store(
allow_cache = false;
#[cfg(feature = "http")]
{
let store = object_store::http::HttpBuilder::new()
.with_url(url)
.with_client_options(super::get_client_options())
.build()?;
Ok::<_, PolarsError>(Arc::new(store) as Arc<dyn ObjectStore>)
let store = options.build_http(url)?;
PolarsResult::Ok(Arc::new(store) as Arc<dyn ObjectStore>)
}
}
#[cfg(not(feature = "http"))]
return err_missing_feature("http", &cloud_location.scheme);
},
CloudType::Hf => panic!("impl error: unresolved hf:// path"),
}?;
if allow_cache {
let mut cache = OBJECT_STORE_CACHE.write().await;
Expand Down
141 changes: 105 additions & 36 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use polars_error::*;
use polars_utils::cache::FastFixedCache;
#[cfg(feature = "aws")]
use regex::Regex;
#[cfg(feature = "http")]
use reqwest::header::HeaderMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "aws")]
Expand Down Expand Up @@ -54,19 +56,27 @@ static BUCKET_REGION: Lazy<std::sync::Mutex<FastFixedCache<SmartString, SmartStr
#[allow(dead_code)]
type Configs<T> = Vec<(T, String)>;

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(crate) enum CloudConfig {
#[cfg(feature = "aws")]
Aws(Configs<AmazonS3ConfigKey>),
#[cfg(feature = "azure")]
Azure(Configs<AzureConfigKey>),
#[cfg(feature = "gcp")]
Gcp(Configs<GoogleConfigKey>),
#[cfg(feature = "http")]
Http { headers: Vec<(String, String)> },
}

#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// Options to connect to various cloud providers.
pub struct CloudOptions {
pub max_retries: usize,
#[cfg(feature = "file_cache")]
pub file_cache_ttl: u64,
#[cfg(feature = "aws")]
aws: Option<Configs<AmazonS3ConfigKey>>,
#[cfg(feature = "azure")]
azure: Option<Configs<AzureConfigKey>>,
#[cfg(feature = "gcp")]
gcp: Option<Configs<GoogleConfigKey>>,
pub(crate) config: Option<CloudConfig>,
}

impl Default for CloudOptions {
Expand All @@ -75,16 +85,29 @@ impl Default for CloudOptions {
max_retries: 2,
#[cfg(feature = "file_cache")]
file_cache_ttl: get_env_file_cache_ttl(),
#[cfg(feature = "aws")]
aws: Default::default(),
#[cfg(feature = "azure")]
azure: Default::default(),
#[cfg(feature = "gcp")]
gcp: Default::default(),
config: None,
}
}
}

#[cfg(feature = "http")]
pub(crate) fn try_build_http_header_map_from_items_slice<S: AsRef<str>>(
headers: &[(S, S)],
) -> PolarsResult<HeaderMap> {
use reqwest::header::{HeaderName, HeaderValue};

let mut map = HeaderMap::with_capacity(headers.len());
for (k, v) in headers {
let (k, v) = (k.as_ref(), v.as_ref());
map.insert(
HeaderName::from_str(k).map_err(to_compute_err)?,
HeaderValue::from_str(v).map_err(to_compute_err)?,
);
}

Ok(map)
}

#[allow(dead_code)]
/// Parse an untype configuration hashmap to a typed configuration for the given configuration key type.
fn parsed_untyped_config<T, I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
Expand Down Expand Up @@ -112,6 +135,7 @@ pub enum CloudType {
File,
Gcp,
Http,
Hf,
}

impl CloudType {
Expand All @@ -123,6 +147,7 @@ impl CloudType {
"gs" | "gcp" | "gcs" => Self::Gcp,
"file" => Self::File,
"http" | "https" => Self::Http,
"hf" => Self::Hf,
_ => polars_bail!(ComputeError: "unknown url scheme"),
})
}
Expand Down Expand Up @@ -225,21 +250,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.aws = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AmazonS3ConfigKey>>(),
);
self.config = Some(CloudConfig::Aws(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for AWS.
#[cfg(feature = "aws")]
pub async fn build_aws(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.aws.as_ref();
let mut builder = AmazonS3Builder::from_env().with_url(url);
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Aws(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand Down Expand Up @@ -328,21 +352,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.azure = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<AzureConfigKey>>(),
);
self.config = Some(CloudConfig::Azure(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for Azure.
#[cfg(feature = "azure")]
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.azure.as_ref();
let mut builder = MicrosoftAzureBuilder::from_env();
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Azure(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand All @@ -362,21 +385,20 @@ impl CloudOptions {
mut self,
configs: I,
) -> Self {
self.gcp = Some(
configs
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect::<Configs<GoogleConfigKey>>(),
);
self.config = Some(CloudConfig::Gcp(
configs.into_iter().map(|(k, v)| (k, v.into())).collect(),
));
self
}

/// Build the [`object_store::ObjectStore`] implementation for GCP.
#[cfg(feature = "gcp")]
pub fn build_gcp(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
let options = self.gcp.as_ref();
let mut builder = GoogleCloudStorageBuilder::from_env();
if let Some(options) = options {
if let Some(options) = &self.config {
let CloudConfig::Gcp(options) = options else {
panic!("impl error: cloud type mismatch")
};
for (key, value) in options.iter() {
builder = builder.with_config(*key, value);
}
Expand All @@ -390,6 +412,23 @@ impl CloudOptions {
.map_err(to_compute_err)
}

#[cfg(feature = "http")]
pub fn build_http(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
object_store::http::HttpBuilder::new()
.with_url(url)
.with_client_options({
let mut opts = super::get_client_options();
if let Some(CloudConfig::Http { headers }) = &self.config {
opts = opts.with_default_headers(try_build_http_header_map_from_items_slice(
headers.as_slice(),
)?);
}
opts
})
.build()
.map_err(to_compute_err)
}

/// Parse a configuration from a Hashmap. This is the interface from Python.
#[allow(unused_variables)]
pub fn from_untyped_config<I: IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>>(
Expand Down Expand Up @@ -432,6 +471,36 @@ impl CloudOptions {
polars_bail!(ComputeError: "'gcp' feature is not enabled");
}
},
CloudType::Hf => {
#[cfg(feature = "http")]
{
let mut this = Self::default();

if let Ok(v) = std::env::var("HF_TOKEN") {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
}

for (i, (k, v)) in config.into_iter().enumerate() {
let (k, v) = (k.as_ref(), v.into());

if i == 0 && k == "token" {
this.config = Some(CloudConfig::Http {
headers: vec![("Authorization".into(), format!("Bearer {}", v))],
})
} else {
polars_bail!(ComputeError: "unknown configuration key: {}", k)
}
}

Ok(this)
}
#[cfg(not(feature = "http"))]
{
polars_bail!(ComputeError: "'http' feature is not enabled");
}
},
}
}
}
Expand Down
48 changes: 31 additions & 17 deletions crates/polars-io/src/path_utils/hugging_face.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
use std::collections::VecDeque;
use std::path::PathBuf;

use polars_error::{polars_bail, to_compute_err, PolarsResult};
use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult};

use crate::cloud::{extract_prefix_expansion, Matcher};
use crate::cloud::{
extract_prefix_expansion, try_build_http_header_map_from_items_slice, CloudConfig,
CloudOptions, Matcher,
};
use crate::path_utils::HiveIdxTracker;
use crate::pl_async::with_concurrency_budget;

Expand Down Expand Up @@ -198,14 +201,25 @@ impl<'a> GetPages<'a> {
pub(super) async fn expand_paths_hf(
paths: &[PathBuf],
check_directory_level: bool,
cloud_options: Option<&CloudOptions>,
) -> PolarsResult<(usize, Vec<PathBuf>)> {
assert!(!paths.is_empty());

let client = &reqwest::ClientBuilder::new()
.http1_only()
.https_only(true)
.build()
.unwrap();
let client = reqwest::ClientBuilder::new().http1_only().https_only(true);

let client = if let Some(CloudOptions {
config: Some(CloudConfig::Http { headers }),
..
}) = cloud_options
{
client.default_headers(try_build_http_header_map_from_items_slice(
headers.as_slice(),
)?)
} else {
client
};

let client = &client.build().unwrap();

let mut out_paths = vec![];
let mut stack = VecDeque::new();
Expand Down Expand Up @@ -263,26 +277,26 @@ pub(super) async fn expand_paths_hf(
client,
};

fn try_parse_api_response(bytes: &[u8]) -> PolarsResult<Vec<HFAPIResponse>> {
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes).map_err(
|e| polars_err!(ComputeError: "failed to parse API response as JSON: error: {}, value: {}", e, std::str::from_utf8(bytes).unwrap()),
)
}

if let Some(matcher) = expansion_matcher {
while let Some(bytes) = gp.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
entries.extend(
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes)
.map_err(to_compute_err)?
.into_iter()
.filter(|x| {
matcher.is_matching(x.path.as_str()) && (!x.is_file() || x.size > 0)
}),
);
entries.extend(try_parse_api_response(bytes)?.into_iter().filter(|x| {
matcher.is_matching(x.path.as_str()) && (!x.is_file() || x.size > 0)
}));
}
} else {
while let Some(bytes) = gp.next().await {
let bytes = bytes?;
let bytes = bytes.as_ref();
entries.extend(
serde_json::from_slice::<Vec<HFAPIResponse>>(bytes)
.map_err(to_compute_err)?
try_parse_api_response(bytes)?
.into_iter()
.filter(|x| !x.is_file() || x.size > 0),
);
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/path_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ pub fn expand_paths_hive(
if first_path.starts_with("hf://") {
let (expand_start_idx, paths) =
crate::pl_async::get_runtime().block_on_potential_spawn(
hugging_face::expand_paths_hf(paths, check_directory_level),
hugging_face::expand_paths_hf(paths, check_directory_level, cloud_options),
)?;

return Ok((Arc::from(paths), expand_start_idx));
Expand Down
Loading

0 comments on commit c4738d2

Please sign in to comment.