Skip to content

Commit

Permalink
fix dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 5, 2023
1 parent bcfe20a commit 6d5aadf
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 50 deletions.
31 changes: 14 additions & 17 deletions crates/polars-ops/src/chunked_array/strings/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::*;
fn extract_groups_array(
arr: &Utf8Array<i64>,
reg: &Regex,
names: &[String],
names: &[&str],
data_type: ArrowDataType,
) -> PolarsResult<ArrayRef> {
let mut builders = (0..names.len())
Expand Down Expand Up @@ -42,7 +42,11 @@ fn extract_groups_array(
Ok(StructArray::new(data_type.clone(), values, None).boxed())
}

pub(super) fn extract_groups(ca: &Utf8Chunked, pat: &str) -> PolarsResult<Series> {
pub(super) fn extract_groups(
ca: &Utf8Chunked,
pat: &str,
dtype: &DataType,
) -> PolarsResult<Series> {
let reg = Regex::new(pat)?;
let n_fields = reg.captures_len();

Expand All @@ -51,22 +55,15 @@ pub(super) fn extract_groups(ca: &Utf8Chunked, pat: &str) -> PolarsResult<Series
.map(|ca| ca.into_series());
}

let names = reg
.capture_names()
.enumerate()
.skip(1)
.map(|(idx, opt_name)| {
opt_name
.map(|name| name.to_string())
.unwrap_or_else(|| format!("{idx}"))
})
let data_type = dtype.to_arrow();
// impl error if it isn't a struct
let DataType::Struct(fields) = dtype else {
unreachable!()
};
let names = fields
.iter()
.map(|fld| fld.name.as_str())
.collect::<Vec<_>>();
let data_type = ArrowDataType::Struct(
names
.iter()
.map(|name| ArrowField::new(name.as_str(), ArrowDataType::LargeUtf8, true))
.collect(),
);

let chunks = ca
.downcast_iter()
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ pub trait Utf8NameSpaceImpl: AsUtf8 {

#[cfg(feature = "extract_groups")]
/// Extract all capture groups from pattern and return as a struct
fn extract_groups(&self, pat: &str) -> PolarsResult<Series> {
fn extract_groups(&self, pat: &str, dtype: &DataType) -> PolarsResult<Series> {
let ca = self.as_utf8();
super::extract::extract_groups(ca, pat)
super::extract::extract_groups(ca, pat, dtype)
}

/// Count all successive non-overlapping regex matches.
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
map_as_slice!(strings::extract_all)
}
#[cfg(feature = "extract_groups")]
ExtractGroups { pat } => {
map!(strings::extract_groups, &pat)
ExtractGroups { pat, dtype } => {
map!(strings::extract_groups, &pat, &dtype)
}
NChars => map!(strings::n_chars),
Length => map!(strings::lengths),
Expand Down
29 changes: 4 additions & 25 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum StringFunction {
ExtractAll,
#[cfg(feature = "extract_groups")]
ExtractGroups {
dtype: DataType,
pat: String,
},
#[cfg(feature = "string_from_radix")]
Expand Down Expand Up @@ -95,27 +96,7 @@ impl StringFunction {
Extract { .. } => mapper.with_same_dtype(),
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
#[cfg(feature = "extract_groups")]
ExtractGroups { pat } => {
let reg = Regex::new(pat)?;
let names = reg
.capture_names()
.enumerate()
.skip(1)
.map(|(idx, opt_name)| {
opt_name
.map(|name| name.to_string())
.unwrap_or_else(|| format!("{idx}"))
})
.collect::<Vec<_>>();

let data_type = DataType::Struct(
names
.iter()
.map(|name| Field::new(name.as_str(), DataType::Utf8))
.collect(),
);
mapper.with_dtype(data_type)
}
ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
#[cfg(feature = "extract_jsonpath")]
Expand Down Expand Up @@ -322,11 +303,9 @@ pub(super) fn extract(s: &Series, pat: &str, group_index: usize) -> PolarsResult

#[cfg(feature = "extract_groups")]
/// Extract all capture groups from a regex pattern as a struct
pub(super) fn extract_groups(s: &Series, pat: &str) -> PolarsResult<Series> {
let pat = pat.to_string();

pub(super) fn extract_groups(s: &Series, pat: &str, dtype: &DataType) -> PolarsResult<Series> {
let ca = s.utf8()?;
ca.extract_groups(&pat)
ca.extract_groups(pat, dtype)
}

#[cfg(feature = "string_justify")]
Expand Down
33 changes: 29 additions & 4 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,35 @@ impl StringNameSpace {

#[cfg(feature = "extract_groups")]
// Extract all captures groups from a regex pattern as a struct
pub fn extract_groups(self, pat: &str) -> Expr {
let pat = pat.to_string();
self.0
.map_private(StringFunction::ExtractGroups { pat }.into())
pub fn extract_groups(self, pat: &str) -> PolarsResult<Expr> {
// regex will be compiled twice, because it doesn't support serde
// and we need to compile it here to determine the output datatype
let reg = regex::Regex::new(pat)?;
let names = reg
.capture_names()
.enumerate()
.skip(1)
.map(|(idx, opt_name)| {
opt_name
.map(|name| name.to_string())
.unwrap_or_else(|| format!("{idx}"))
})
.collect::<Vec<_>>();

let dtype = DataType::Struct(
names
.iter()
.map(|name| Field::new(name.as_str(), DataType::Utf8))
.collect(),
);

Ok(self.0.map_private(
StringFunction::ExtractGroups {
dtype,
pat: pat.to_string(),
}
.into(),
))
}

/// Return a copy of the string left filled with ASCII '0' digits to make a string of length width.
Expand Down

0 comments on commit 6d5aadf

Please sign in to comment.