From 17e997534cc1e6fc0de6164527ada7b7b8951c70 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 18 Sep 2023 01:24:44 +0800 Subject: [PATCH] fix: Make pl.struct serializable --- .../src/dsl/function_expr/coerce.rs | 6 +++++ .../polars-plan/src/dsl/function_expr/mod.rs | 9 ++++++++ .../src/dsl/function_expr/schema.rs | 5 ++++ .../polars-plan/src/dsl/functions/coerce.rs | 23 +++++++++---------- py-polars/src/functions/lazy.rs | 2 +- py-polars/tests/unit/test_serde.py | 6 +++++ 6 files changed, 38 insertions(+), 13 deletions(-) create mode 100644 crates/polars-plan/src/dsl/function_expr/coerce.rs diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs new file mode 100644 index 000000000000..00c180d0ba4a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -0,0 +1,6 @@ +use polars_core::prelude::*; + +#[cfg(feature = "dtype-struct")] +pub fn as_struct(s: &[Series]) -> PolarsResult { + Ok(StructChunked::new(s[0].name(), s)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 9ffb7efd2a45..ec8ef7927d4d 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -11,6 +11,7 @@ mod bounds; mod cat; #[cfg(feature = "round_series")] mod clip; +mod coerce; mod concat; mod correlation; mod cum; @@ -142,6 +143,8 @@ pub enum FunctionExpr { ArrayExpr(ArrayFunction), #[cfg(feature = "dtype-struct")] StructExpr(StructFunction), + #[cfg(feature = "dtype-struct")] + AsStruct, #[cfg(feature = "top_k")] TopK { k: usize, @@ -318,6 +321,8 @@ impl Display for FunctionExpr { ListExpr(func) => return write!(f, "{func}"), #[cfg(feature = "dtype-struct")] StructExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "dtype-struct")] + AsStruct => "as_struct", #[cfg(feature = "top_k")] TopK { .. } => "top_k", Shift(_) => "shift", @@ -571,6 +576,10 @@ impl From for SpecialEq> { FieldByName(name) => map!(struct_::get_by_name, name.clone()), } }, + #[cfg(feature = "dtype-struct")] + AsStruct => { + map_as_slice!(coerce::as_struct) + }, #[cfg(feature = "top_k")] TopK { k, descending } => { map!(top_k, k, descending) diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 6080ad3a406a..47576431ff20 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -131,6 +131,11 @@ impl FunctionExpr { } }, #[cfg(feature = "dtype-struct")] + AsStruct => Ok(Field::new( + fields[0].name(), + DataType::Struct(fields.to_vec()), + )), + #[cfg(feature = "dtype-struct")] StructExpr(s) => { use polars_core::utils::slice_offsets; use StructFunction::*; diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs index e28a1697eefd..e009e9e61918 100644 --- a/crates/polars-plan/src/dsl/functions/coerce.rs +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -3,16 +3,15 @@ use super::*; /// Take several expressions and collect them into a [`StructChunked`]. #[cfg(feature = "dtype-struct")] -pub fn as_struct(exprs: &[Expr]) -> Expr { - map_multiple( - |s| StructChunked::new(s[0].name(), s).map(|ca| Some(ca.into_series())), - exprs, - GetOutput::map_fields(|fld| Field::new(fld[0].name(), DataType::Struct(fld.to_vec()))), - ) - .with_function_options(|mut options| { - options.input_wildcard_expansion = true; - options.fmt_str = "as_struct"; - options.pass_name_to_apply = true; - options - }) +pub fn as_struct(exprs: Vec) -> Expr { + Expr::Function { + input: exprs, + function: FunctionExpr::AsStruct, + options: FunctionOptions { + input_wildcard_expansion: true, + pass_name_to_apply: true, + collect_groups: ApplyOptions::ApplyFlat, + ..Default::default() + }, + } } diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 0fb9243e0896..a5ea89da0c80 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -71,7 +71,7 @@ pub fn arg_where(condition: PyExpr) -> PyExpr { #[pyfunction] pub fn as_struct(exprs: Vec) -> PyExpr { let exprs = exprs.to_exprs(); - dsl::as_struct(&exprs).into() + dsl::as_struct(exprs).into() } #[pyfunction] diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 820fafaf1abd..22293206fb91 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -17,6 +17,12 @@ def test_pickling_simple_expression() -> None: assert str(pickle.loads(buf)) == str(e) +def test_pickling_as_struct_11100() -> None: + e = pl.struct("a") + buf = pickle.dumps(e) + assert str(pickle.loads(buf)) == str(e) + + def test_lazyframe_serde() -> None: lf = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).lazy().select(pl.col("a"))