From 354b046b002d2ba29137d5b6ada0d15bfe25b93e Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 23 Oct 2023 12:37:55 +0800 Subject: [PATCH] fix: construct list series from any values subject to dtype --- py-polars/polars/utils/_construction.py | 22 ++++++++++++++++++++- py-polars/src/series/construction.rs | 13 ++++++++++++ py-polars/tests/unit/datatypes/test_list.py | 20 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 8baf75cd19fb..d4952be71ea2 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -267,6 +267,26 @@ def sequence_from_anyvalue_or_object(name: str, values: Sequence[Any]) -> PySeri raise +def sequence_from_anyvalue_and_dtype_or_object( + name: str, values: Sequence[Any], dtype: PolarsDataType +) -> PySeries: + """ + Last resort conversion. + + AnyValues are most flexible and if they fail we go for object types + + """ + try: + return PySeries.new_from_anyvalues_and_dtype(name, values, dtype, strict=True) + # raised if we cannot convert to Wrap + except RuntimeError: + return PySeries.new_object(name, values, _strict=False) + except ComputeError as exc: + if "mixed dtypes" in str(exc): + return PySeries.new_object(name, values, _strict=False) + raise + + def iterable_to_pyseries( name: str, values: Iterable[Any], @@ -518,7 +538,7 @@ def sequence_to_pyseries( if isinstance(dtype, Object): return PySeries.new_object(name, values, strict) if dtype: - srs = sequence_from_anyvalue_or_object(name, values) + srs = sequence_from_anyvalue_and_dtype_or_object(name, values, dtype) if dtype.is_not(srs.dtype()): srs = srs.cast(dtype, strict=False) return srs diff --git a/py-polars/src/series/construction.rs b/py-polars/src/series/construction.rs index 7ea8027fd1cc..acb624e2d3ed 100644 --- a/py-polars/src/series/construction.rs +++ b/py-polars/src/series/construction.rs @@ -180,6 +180,19 @@ impl PySeries { Ok(s.into()) } + #[staticmethod] + fn new_from_anyvalues_and_dtype( + name: &str, + val: Vec>>, + dtype: Wrap, + strict: bool, + ) -> PyResult { + let avs = slice_extract_wrapped(&val); + let s = Series::from_any_values_and_dtype(name, avs, &dtype.0, strict) + .map_err(PyPolarsErr::from)?; + Ok(s.into()) + } + #[staticmethod] fn new_str(name: &str, val: Wrap, _strict: bool) -> Self { val.0.into_series().with_name(name).into() diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 23e1a7be42b4..b4ff7cd3a393 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -571,3 +571,23 @@ def test_list_inner_cast_physical_11513() -> None: def test_list_is_nested_deprecated(dtype: PolarsDataType, expected: bool) -> None: with pytest.deprecated_call(): assert dtype.is_nested is expected + + +def test_list_series_construction_with_dtype_11849_11878() -> None: + s = pl.Series([[1, 2], [3.3, 4.9]], dtype=pl.List(pl.Float64)) + assert s.to_list() == [[1, 2], [3.3, 4.9]] + + s1 = pl.Series([[1, 2], [3.0, 4.0]], dtype=pl.List(pl.Float64)) + s2 = pl.Series([[1, 2], [3.0, 4.9]], dtype=pl.List(pl.Float64)) + assert_series_equal(s1 == s2, pl.Series([True, False])) + + s = pl.Series( + "groups", + [[{"1": "A", "2": None}], [{"1": "B", "2": "C"}, {"1": "D", "2": "E"}]], + dtype=pl.List(pl.Struct([pl.Field("1", pl.Utf8), pl.Field("2", pl.Utf8)])), + ) + + assert s.to_list() == [ + [{"1": "A", "2": None}], + [{"1": "B", "2": "C"}, {"1": "D", "2": "E"}], + ]