diff --git a/crates/polars-python/src/map/mod.rs b/crates/polars-python/src/map/mod.rs index 4e06af2f2225..75a5eba3b754 100644 --- a/crates/polars-python/src/map/mod.rs +++ b/crates/polars-python/src/map/mod.rs @@ -5,6 +5,7 @@ pub mod series; use std::collections::BTreeMap; use polars::chunked_array::builder::get_list_builder; +use polars::export::arrow::bitmap::MutableBitmap; use polars::prelude::*; use polars_core::export::rayon::prelude::*; use polars_core::utils::CustomIterTools; @@ -72,14 +73,20 @@ fn iterator_to_struct<'a>( struct_fields.insert(fld.name().clone(), buf); } + let mut validity = MutableBitmap::with_capacity(capacity); + validity.extend_constant(init_null_count, false); + validity.push(true); + for dict in it { match dict? { None => { + validity.push(false); for field_items in struct_fields.values_mut() { field_items.push(AnyValue::Null); } }, Some(dict) => { + validity.push(true); let dict = dict.downcast::()?; let current_len = struct_fields .values() @@ -129,6 +136,7 @@ fn iterator_to_struct<'a>( Ok( StructChunked::from_series(name, fields[0].len(), fields.iter()) .unwrap() + .with_outer_validity(Some(validity.freeze())) .into_series() .into(), ) diff --git a/crates/polars-python/src/map/series.rs b/crates/polars-python/src/map/series.rs index ab53f2d43d27..baec9cb43dcd 100644 --- a/crates/polars-python/src/map/series.rs +++ b/crates/polars-python/src/map/series.rs @@ -2357,16 +2357,22 @@ impl<'a> ApplyLambda<'a> for StructChunked { let mut null_count = 0; for val in iter_struct(self) { - let out = lambda.call1((Wrap(val),))?; - if out.is_none() { - null_count += 1; - continue; + match val { + AnyValue::Null => null_count += 1, + _ => { + let out = lambda.call1((Wrap(val),))?; + if out.is_none() { + null_count += 1; + continue; + } + return infer_and_finish(self, py, lambda, &out, null_count); + }, } - return infer_and_finish(self, py, lambda, &out, null_count); } - // todo! full null - Ok(self.clone().into_series().into()) + Ok(Self::full_null(self.name().clone(), self.len()) + .into_series() + .into()) } fn apply_into_struct( @@ -2379,7 +2385,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let skip = 1; let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| lambda.call1((Wrap(val),)).map(Some)); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => lambda.call1((Wrap(val),)).map(Some), + }); iterator_to_struct( py, it, @@ -2404,7 +2413,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let skip = usize::from(first_value.is_some()); let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| call_lambda_and_extract(py, lambda, Wrap(val))); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => call_lambda_and_extract(py, lambda, Wrap(val)), + }); iterator_to_primitive( it, @@ -2425,7 +2437,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let skip = usize::from(first_value.is_some()); let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| call_lambda_and_extract(py, lambda, Wrap(val))); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => call_lambda_and_extract(py, lambda, Wrap(val)), + }); iterator_to_bool( it, @@ -2446,7 +2461,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let skip = usize::from(first_value.is_some()); let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| call_lambda_and_extract(py, lambda, Wrap(val))); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => call_lambda_and_extract(py, lambda, Wrap(val)), + }); iterator_to_string( it, @@ -2468,7 +2486,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let lambda = lambda.bind(py); let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| call_lambda_series_out(py, lambda, Wrap(val)).map(Some)); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => call_lambda_series_out(py, lambda, Wrap(val)).map(Some), + }); iterator_to_list( dt, it, @@ -2513,7 +2534,10 @@ impl<'a> ApplyLambda<'a> for StructChunked { let skip = usize::from(first_value.is_some()); let it = iter_struct(self) .skip(init_null_count + skip) - .map(|val| call_lambda_and_extract(py, lambda, Wrap(val))); + .map(|val| match val { + AnyValue::Null => Ok(None), + _ => call_lambda_and_extract(py, lambda, Wrap(val)), + }); iterator_to_object( it, diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index 91cfe13d8bc9..a0634b1ddb4f 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -2,6 +2,7 @@ import json from datetime import date, datetime, timedelta +from typing import Any import numpy as np import pytest @@ -49,20 +50,28 @@ def test_map_elements_arithmetic_consistency() -> None: def test_map_elements_struct() -> None: df = pl.DataFrame( - {"A": ["a", "a"], "B": [2, 3], "C": [True, False], "D": [12.0, None]} + { + "A": ["a", "a", None], + "B": [2, 3, None], + "C": [True, False, None], + "D": [12.0, None, None], + "E": [None, [1], [2, 3]], + } ) out = df.with_columns(pl.struct(df.columns).alias("struct")).select( pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"), pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"), pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"), pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"), + pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"), ) expected = pl.DataFrame( { - "A_field": ["a", "a"], - "B_field": [2, 3], - "C_field": [True, False], - "D_field": [12.0, None], + "A_field": ["a", "a", None], + "B_field": [2, 3, None], + "C_field": [True, False, None], + "D_field": [12.0, None, None], + "E_field": [None, [1], [2, 3]], } ) @@ -171,17 +180,16 @@ def test_empty_list_in_map_elements() -> None: ).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]} -def test_map_elements_skip_nulls() -> None: - some_map = {None: "a", 1: "b"} - s = pl.Series([None, 1]) +@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}]) +@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}]) +def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None: + s = pl.Series([value, None]) - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String, skip_nulls=True - ).to_list() == [None, "b"] + result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list() + assert result == [return_value, None] - assert s.map_elements( - lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False - ).to_list() == ["a", "b"] + result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list() + assert result == [return_value, return_value] def test_map_elements_object_dtypes() -> None: