Skip to content

Commit

Permalink
fix: propagate validity when cast primitive to list (#11846)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Oct 19, 2023
1 parent a42185f commit a05b298
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
7 changes: 6 additions & 1 deletion crates/polars-arrow/src/compute/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,12 @@ pub fn cast(
// Safety: offsets _are_ monotonically increasing
let offsets = unsafe { Offsets::new_unchecked(offsets) };

let list_array = ListArray::<i64>::new(to_type.clone(), offsets.into(), values, None);
let list_array = ListArray::<i64>::new(
to_type.clone(),
offsets.into(),
values,
array.validity().cloned(),
);

Ok(Box::new(list_array))
},
Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,10 @@ def test_list_recursive_categorical_cast() -> None:
@pytest.mark.parametrize(
("data", "expected_data", "dtype"),
[
([1, 2], [[1], [2]], pl.Int64),
([1.0, 2.0], [[1.0], [2.0]], pl.Float64),
(["x", "y"], [["x"], ["y"]], pl.Utf8),
([True, False], [[True], [False]], pl.Boolean),
([None, 1, 2], [None, [1], [2]], pl.Int64),
([None, 1.0, 2.0], [None, [1.0], [2.0]], pl.Float64),
([None, "x", "y"], [None, ["x"], ["y"]], pl.Utf8),
([None, True, False], [None, [True], [False]], pl.Boolean),
],
)
def test_non_nested_cast_to_list(
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def test_equal() -> None:
assert s3.dt.convert_time_zone("Asia/Tokyo").series_equal(s4) is True


@pytest.mark.parametrize(
"dtype",
[pl.Int64, pl.Float64, pl.Utf8, pl.Boolean],
)
def test_eq_missing_list_and_primitive(dtype: PolarsDataType) -> None:
s1 = pl.Series([None, None], dtype=dtype)
s2 = pl.Series([None, None], dtype=pl.List(dtype))

expected = pl.Series([True, True])
assert_series_equal(s1.eq_missing(s2), expected)


def test_to_frame() -> None:
s1 = pl.Series([1, 2])
s2 = pl.Series("s", [1, 2])
Expand Down

0 comments on commit a05b298

Please sign in to comment.