Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix map_elements ignoring skip_nulls=True for struct dtype #20668

Merged
merged 4 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions crates/polars-python/src/map/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod series;
use std::collections::BTreeMap;

use polars::chunked_array::builder::get_list_builder;
use polars::export::arrow::bitmap::MutableBitmap;
use polars::prelude::*;
use polars_core::export::rayon::prelude::*;
use polars_core::utils::CustomIterTools;
Expand Down Expand Up @@ -72,14 +73,20 @@ fn iterator_to_struct<'a>(
struct_fields.insert(fld.name().clone(), buf);
}

let mut validity = MutableBitmap::with_capacity(capacity);
validity.extend_constant(init_null_count, false);
validity.push(true);

for dict in it {
match dict? {
None => {
validity.push(false);
for field_items in struct_fields.values_mut() {
field_items.push(AnyValue::Null);
}
},
Some(dict) => {
validity.push(true);
let dict = dict.downcast::<PyDict>()?;
let current_len = struct_fields
.values()
Expand Down Expand Up @@ -129,6 +136,7 @@ fn iterator_to_struct<'a>(
Ok(
StructChunked::from_series(name, fields[0].len(), fields.iter())
.unwrap()
.with_outer_validity(Some(validity.freeze()))
.into_series()
.into(),
)
Expand Down
50 changes: 37 additions & 13 deletions crates/polars-python/src/map/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2357,16 +2357,22 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let mut null_count = 0;

for val in iter_struct(self) {
let out = lambda.call1((Wrap(val),))?;
if out.is_none() {
null_count += 1;
continue;
match val {
AnyValue::Null => null_count += 1,
_ => {
let out = lambda.call1((Wrap(val),))?;
if out.is_none() {
null_count += 1;
continue;
}
return infer_and_finish(self, py, lambda, &out, null_count);
},
}
return infer_and_finish(self, py, lambda, &out, null_count);
}

// todo! full null
Ok(self.clone().into_series().into())
Ok(Self::full_null(self.name().clone(), self.len())
.into_series()
.into())
}

fn apply_into_struct(
Expand All @@ -2379,7 +2385,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = 1;
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| lambda.call1((Wrap(val),)).map(Some));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => lambda.call1((Wrap(val),)).map(Some),
});
iterator_to_struct(
py,
it,
Expand All @@ -2404,7 +2413,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_primitive(
it,
Expand All @@ -2425,7 +2437,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_bool(
it,
Expand All @@ -2446,7 +2461,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_string(
it,
Expand All @@ -2468,7 +2486,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let lambda = lambda.bind(py);
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_series_out(py, lambda, Wrap(val)).map(Some));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_series_out(py, lambda, Wrap(val)).map(Some),
});
iterator_to_list(
dt,
it,
Expand Down Expand Up @@ -2513,7 +2534,10 @@ impl<'a> ApplyLambda<'a> for StructChunked {
let skip = usize::from(first_value.is_some());
let it = iter_struct(self)
.skip(init_null_count + skip)
.map(|val| call_lambda_and_extract(py, lambda, Wrap(val)));
.map(|val| match val {
AnyValue::Null => Ok(None),
_ => call_lambda_and_extract(py, lambda, Wrap(val)),
});

iterator_to_object(
it,
Expand Down
36 changes: 22 additions & 14 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
from datetime import date, datetime, timedelta
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -49,20 +50,28 @@ def test_map_elements_arithmetic_consistency() -> None:

def test_map_elements_struct() -> None:
df = pl.DataFrame(
{"A": ["a", "a"], "B": [2, 3], "C": [True, False], "D": [12.0, None]}
{
"A": ["a", "a", None],
"B": [2, 3, None],
"C": [True, False, None],
"D": [12.0, None, None],
"E": [None, [1], [2, 3]],
}
)
out = df.with_columns(pl.struct(df.columns).alias("struct")).select(
pl.col("struct").map_elements(lambda x: x["A"]).alias("A_field"),
pl.col("struct").map_elements(lambda x: x["B"]).alias("B_field"),
pl.col("struct").map_elements(lambda x: x["C"]).alias("C_field"),
pl.col("struct").map_elements(lambda x: x["D"]).alias("D_field"),
pl.col("struct").map_elements(lambda x: x["E"]).alias("E_field"),
)
expected = pl.DataFrame(
{
"A_field": ["a", "a"],
"B_field": [2, 3],
"C_field": [True, False],
"D_field": [12.0, None],
"A_field": ["a", "a", None],
"B_field": [2, 3, None],
"C_field": [True, False, None],
"D_field": [12.0, None, None],
"E_field": [None, [1], [2, 3]],
}
)

Expand Down Expand Up @@ -171,17 +180,16 @@ def test_empty_list_in_map_elements() -> None:
).to_dict(as_series=False) == {"a": [[], [1, 2], [], [5]]}


def test_map_elements_skip_nulls() -> None:
some_map = {None: "a", 1: "b"}
s = pl.Series([None, 1])
@pytest.mark.parametrize("value", [1, True, "abc", [1, 2], {"a": 1}])
@pytest.mark.parametrize("return_value", [1, True, "abc", [1, 2], {"a": 1}])
def test_map_elements_skip_nulls(value: Any, return_value: Any) -> None:
s = pl.Series([value, None])

assert s.map_elements(
lambda x: some_map[x], return_dtype=pl.String, skip_nulls=True
).to_list() == [None, "b"]
result = s.map_elements(lambda x: return_value, skip_nulls=True).to_list()
assert result == [return_value, None]

assert s.map_elements(
lambda x: some_map[x], return_dtype=pl.String, skip_nulls=False
).to_list() == ["a", "b"]
result = s.map_elements(lambda x: return_value, skip_nulls=False).to_list()
assert result == [return_value, return_value]


def test_map_elements_object_dtypes() -> None:
Expand Down
Loading