Skip to content

Commit

Permalink
fix: Fix map_elements ignoring skip_nulls=True for struct dtype (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Jan 12, 2025
1 parent 87baf86 commit a433272
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 27 deletions.
8 changes: 8 additions & 0 deletions crates/polars-python/src/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod series;
use std::collections::BTreeMap;

use polars::chunked_array::builder::get_list_builder;
use polars::export::arrow::bitmap::MutableBitmap;
use polars::prelude::*;
use polars_core::export::rayon::prelude::*;
use polars_core::utils::CustomIterTools;
Expand Down Expand Up @@ -72,14 +73,20 @@ fn iterator_to_struct<'a>(
struct_fields.insert(fld.name().clone(), buf);
}

let mut validity = MutableBitmap::with_capacity(capacity);
validity.extend_constant(init_null_count, false);
validity.push(true);

for dict in it {
match dict? {
None => {
validity.push(false);
for field_items in struct_fields.values_mut() {
field_items.push(AnyValue::Null);
}
},
Some(dict) => {
validity.push(true);
let dict = dict.downcast::<PyDict>()?;
let current_len = struct_fields
.values()
Expand Down Expand Up @@ -129,6 +136,7 @@ fn iterator_to_struct<'a>(
Ok(
StructChunked::from_series(name, fields[0].len(), fields.iter())
.unwrap()
.with_outer_validity(Some(validity.freeze()))
.into_series()
.into(),
)
Expand Down
50 changes: 37 additions & 13 deletions crates/polars-python/src/map/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2357,16 +2357,22 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let mut null_count = 0;

for val in iter_struct(self) {
let out = lambda.call1((Wrap(val),))?;
if out.is_none() {
null_count += 1;
continue;
match val {
AnyValue::Null => null_count += 1,
_ => {
let out = lambda.call1((Wrap(val),))?;
if out.is_none() {
null_count += 1;
continue;
}
return infer_and_finish(self, py, lambda, &out, null_count);
},
}
return infer_and_finish(self, py, lambda, &out, null_count);
}

// todo! full null
Ok(self.clone().into_series().into())
Ok(Self::full_null(self.name().clone(), self.len())
.into_series()
.into())
}

fn apply_into_struct(
Expand All @@ -2379,7 +2385,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = 1;
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| lambda.call1((Wrap(val),)).map(Some));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => lambda.call1((Wrap(val),)).map(Some),
});
iterator_to_struct(
py,
it,
Expand All @@ -2404,7 +2413,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_primitive(
it,
Expand All @@ -2425,7 +2437,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_bool(
it,
Expand All @@ -2446,7 +2461,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_string(
it,
Expand All @@ -2468,7 +2486,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let lambda = lambda.bind(py);
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_series_out(py, lambda, Wrap(val)).map(Some));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_series_out(py, lambda, Wrap(val)).map(Some),
});
iterator_to_list(
dt,
it,
Expand Down Expand Up @@ -2513,7 +2534,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_object(
it,
Expand Down
36 changes: 22 additions & 14 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from datetime import date, datetime, timedelta
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -49,20 +50,28 @@ def test_map_elements_arithmetic_consistency() -> None:

def test_map_elements_struct() -> None:
df = pl.DataFrame(
{"A": ["a", "a"], "B": [2, 3], "C": [True, False], "D": [12.0, None]}
{
"A": ["a", "a", None],
"B": [2, 3, None],
"C": [True, False, None],
"D": [12.0, None, None],
"E": [None, [1], [2, 3]],
}
)
out = df.with_columns(pl.struct(df.columns).alias("struct")).select(
pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"),
pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"),
pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"),
pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"),
pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"),
)
expected = pl.DataFrame(
{
"A_field": ["a", "a"],
"B_field": [2, 3],
"C_field": [True, False],
"D_field": [12.0, None],
"A_field": ["a", "a", None],
"B_field": [2, 3, None],
"C_field": [True, False, None],
"D_field": [12.0, None, None],
"E_field": [None, [1], [2, 3]],
}
)

Expand Down Expand Up @@ -171,17 +180,16 @@ def test_empty_list_in_map_elements() -> None:
).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]}


def test_map_elements_skip_nulls() -> None:
some_map = {None: "a", 1: "b"}
s = pl.Series([None, 1])
@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}])
@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}])
def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None:
s = pl.Series([value, None])

assert s.map_elements(
lambda x: some_map[x], return_dtype=pl.String, skip_nulls=True
).to_list() == [None, "b"]
result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list()
assert result == [return_value, None]

assert s.map_elements(
lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False
).to_list() == ["a", "b"]
result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list()
assert result == [return_value, return_value]


def test_map_elements_object_dtypes() -> None:
Expand Down

0 comments on commit a433272

Please sign in to comment.