Skip to content

Commit

Permalink
refactor(rust): Make all functions in binary namespace non-anonymous (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Oct 31, 2023
1 parent 5689ad5 commit d4c56a2
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 48 deletions.
26 changes: 26 additions & 0 deletions crates/polars-plan/src/dsl/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,30 @@ impl BinaryNameSpace {
true,
)
}

#[cfg(feature = "binary_encoding")]
pub fn hex_decode(self, strict: bool) -> Expr {
self.0
.map_private(FunctionExpr::BinaryExpr(BinaryFunction::HexDecode(strict)))
}

#[cfg(feature = "binary_encoding")]
pub fn hex_encode(self) -> Expr {
self.0
.map_private(FunctionExpr::BinaryExpr(BinaryFunction::HexEncode))
}

#[cfg(feature = "binary_encoding")]
pub fn base64_decode(self, strict: bool) -> Expr {
self.0
.map_private(FunctionExpr::BinaryExpr(BinaryFunction::Base64Decode(
strict,
)))
}

#[cfg(feature = "binary_encoding")]
pub fn base64_encode(self) -> Expr {
self.0
.map_private(FunctionExpr::BinaryExpr(BinaryFunction::Base64Encode))
}
}
55 changes: 55 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@ pub enum BinaryFunction {
Contains,
StartsWith,
EndsWith,
#[cfg(feature = "binary_encoding")]
HexDecode(bool),
#[cfg(feature = "binary_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Base64Decode(bool),
#[cfg(feature = "binary_encoding")]
Base64Encode,
}

impl BinaryFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use BinaryFunction::*;
match self {
Contains { .. } => mapper.with_dtype(DataType::Boolean),
EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "binary_encoding")]
HexDecode(_) | Base64Decode(_) => mapper.with_same_dtype(),
#[cfg(feature = "binary_encoding")]
HexEncode | Base64Encode => mapper.with_dtype(DataType::Utf8),
}
}
}

impl Display for BinaryFunction {
Expand All @@ -18,6 +40,14 @@ impl Display for BinaryFunction {
Contains { .. } => "contains",
StartsWith => "starts_with",
EndsWith => "ends_with",
#[cfg(feature = "binary_encoding")]
HexDecode(_) => "hex_decode",
#[cfg(feature = "binary_encoding")]
HexEncode => "hex_encode",
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => "base64_decode",
#[cfg(feature = "binary_encoding")]
Base64Encode => "base64_encode",
};
write!(f, "bin.{s}")
}
Expand All @@ -38,6 +68,7 @@ pub(super) fn ends_with(s: &[Series]) -> PolarsResult<Series> {
.with_name(ca.name())
.into_series())
}

pub(super) fn starts_with(s: &[Series]) -> PolarsResult<Series> {
let ca = s[0].binary()?;
let prefix = s[1].binary()?;
Expand All @@ -48,6 +79,30 @@ pub(super) fn starts_with(s: &[Series]) -> PolarsResult<Series> {
.into_series())
}

#[cfg(feature = "binary_encoding")]
pub(super) fn hex_decode(s: &Series, strict: bool) -> PolarsResult<Series> {
let ca = s.binary()?;
ca.hex_decode(strict).map(|ok| ok.into_series())
}

#[cfg(feature = "binary_encoding")]
pub(super) fn hex_encode(s: &Series) -> PolarsResult<Series> {
let ca = s.binary()?;
Ok(ca.hex_encode())
}

#[cfg(feature = "binary_encoding")]
pub(super) fn base64_decode(s: &Series, strict: bool) -> PolarsResult<Series> {
let ca = s.binary()?;
ca.base64_decode(strict).map(|ok| ok.into_series())
}

#[cfg(feature = "binary_encoding")]
pub(super) fn base64_encode(s: &Series) -> PolarsResult<Series> {
let ca = s.binary()?;
Ok(ca.base64_encode())
}

impl From<BinaryFunction> for FunctionExpr {
fn from(b: BinaryFunction) -> Self {
FunctionExpr::BinaryExpr(b)
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,14 @@ impl From<BinaryFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
StartsWith => {
map_as_slice!(binary::starts_with)
},
#[cfg(feature = "binary_encoding")]
HexDecode(strict) => map!(binary::hex_decode, strict),
#[cfg(feature = "binary_encoding")]
HexEncode => map!(binary::hex_encode),
#[cfg(feature = "binary_encoding")]
Base64Decode(strict) => map!(binary::base64_decode, strict),
#[cfg(feature = "binary_encoding")]
Base64Encode => map!(binary::base64_encode),
}
}
}
Expand Down
7 changes: 1 addition & 6 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@ impl FunctionExpr {
SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "strings")]
StringExpr(s) => s.get_field(mapper),
BinaryExpr(s) => {
use BinaryFunction::*;
match s {
Contains { .. } | EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
}
},
BinaryExpr(s) => s.get_field(mapper),
#[cfg(feature = "temporal")]
TemporalExpr(fun) => fun.get_field(mapper),
#[cfg(feature = "range")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ extract_jsonpath = [
"polars-lazy?/extract_jsonpath",
]
string_encoding = ["polars-ops/string_encoding", "polars-core/strings"]
binary_encoding = ["polars-ops/binary_encoding"]
binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding"]
group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"]
lazy_regex = ["polars-lazy?/regex"]
cum_agg = ["polars-ops/cum_agg", "polars-lazy?/cum_agg"]
Expand Down
45 changes: 4 additions & 41 deletions py-polars/src/expr/binary.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use polars::prelude::*;
use pyo3::prelude::*;

use crate::PyExpr;
Expand All @@ -23,57 +22,21 @@ impl PyExpr {

#[cfg(feature = "binary_encoding")]
fn bin_hex_decode(&self, strict: bool) -> Self {
self.inner
.clone()
.map(
move |s| {
s.binary()?
.hex_decode(strict)
.map(|s| Some(s.into_series()))
},
GetOutput::same_type(),
)
.with_fmt("bin.hex_decode")
.into()
self.inner.clone().binary().hex_decode(strict).into()
}

#[cfg(feature = "binary_encoding")]
fn bin_base64_decode(&self, strict: bool) -> Self {
self.inner
.clone()
.map(
move |s| {
s.binary()?
.base64_decode(strict)
.map(|s| Some(s.into_series()))
},
GetOutput::same_type(),
)
.with_fmt("bin.base64_decode")
.into()
self.inner.clone().binary().base64_decode(strict).into()
}

#[cfg(feature = "binary_encoding")]
fn bin_hex_encode(&self) -> Self {
self.inner
.clone()
.map(
move |s| s.binary().map(|s| Some(s.hex_encode().into_series())),
GetOutput::from_type(DataType::Utf8),
)
.with_fmt("bin.hex_encode")
.into()
self.inner.clone().binary().hex_encode().into()
}

#[cfg(feature = "binary_encoding")]
fn bin_base64_encode(&self) -> Self {
self.inner
.clone()
.map(
move |s| s.binary().map(|s| Some(s.base64_encode().into_series())),
GetOutput::from_type(DataType::Utf8),
)
.with_fmt("bin.base64_encode")
.into()
self.inner.clone().binary().base64_encode().into()
}
}

0 comments on commit d4c56a2

Please sign in to comment.