Skip to content

Commit

Permalink
fix: allow null dtypes in UDFs if they match the schema (pola-rs#15699)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Apr 16, 2024
1 parent ee660ee commit 2ac0da2
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 41 deletions.
58 changes: 30 additions & 28 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,36 @@ impl DataType {
_ => false,
}
}

// Answers if this type matches the given type of a schema.
//
// Allows (nested) Null types in this type to match any type in the schema,
// but not vice versa. In such a case Ok(true) is returned, because a cast
// is necessary. If no cast is necessary Ok(false) is returned, and an
// error is returned if the types are incompatible.
pub fn matches_schema_type(&self, schema_type: &DataType) -> PolarsResult<bool> {
match (self, schema_type) {
(DataType::List(l), DataType::List(r)) => l.matches_schema_type(r),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(l), DataType::Struct(r)) => {
let mut must_cast = false;
for (l, r) in l.iter().zip(r.iter()) {
must_cast |= l.dtype.matches_schema_type(&r.dtype)?;
}
Ok(must_cast)
},
(DataType::Null, DataType::Null) => Ok(false),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2),
// We don't allow the other way around, only if our current type is
// null and the schema isn't we allow it.
(DataType::Null, _) => Ok(true),
(l, r) if l == r => Ok(false),
(l, r) => {
polars_bail!(SchemaMismatch: "type {:?} is incompatible with expected type {:?}", l, r)
},
}
}
}

impl PartialEq<ArrowDataType> for DataType {
Expand Down Expand Up @@ -610,34 +640,6 @@ pub fn merge_dtypes(left: &DataType, right: &DataType) -> PolarsResult<DataType>
})
}

// if returns
// `Ok(true)`: can extend, but must cast
// `Ok(false)`: can extend as is
// Error: cannot extend.
pub(crate) fn can_extend_dtype(left: &DataType, right: &DataType) -> PolarsResult<bool> {
match (left, right) {
(DataType::List(l), DataType::List(r)) => can_extend_dtype(l, r),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(l), DataType::Struct(r)) => {
let mut must_cast = false;
for (l, r) in l.iter().zip(r.iter()) {
must_cast |= can_extend_dtype(&l.dtype, &r.dtype)?;
}
Ok(must_cast)
},
(DataType::Null, DataType::Null) => Ok(false),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(_, s1), DataType::Decimal(_, s2)) => Ok(s1 != s2),
// Other way around we don't allow because we keep left dtype as is.
// We don't go to supertype, and we certainly don't want to cast self to null type.
(_, DataType::Null) => Ok(true),
(l, r) => {
polars_ensure!(l == r, SchemaMismatch: "cannot extend/append {:?} with {:?}", left, right);
Ok(false)
},
}
}

#[cfg(feature = "dtype-categorical")]
pub fn create_enum_data_type(categories: Utf8ViewArray) -> DataType {
let rev_map = RevMapping::build_local(categories);
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ impl Series {
///
/// See [`ChunkedArray::append`] and [`ChunkedArray::extend`].
pub fn append(&mut self, other: &Series) -> PolarsResult<&mut Self> {
let must_cast = can_extend_dtype(self.dtype(), other.dtype())?;

let must_cast = other.dtype().matches_schema_type(self.dtype())?;
if must_cast {
let other = other.cast(self.dtype())?;
self._get_inner_mut().append(&other)?;
Expand All @@ -274,8 +273,7 @@ impl Series {
///
/// See [`ChunkedArray::extend`] and [`ChunkedArray::append`].
pub fn extend(&mut self, other: &Series) -> PolarsResult<&mut Self> {
let must_cast = can_extend_dtype(self.dtype(), other.dtype())?;

let must_cast = other.dtype().matches_schema_type(self.dtype())?;
if must_cast {
let other = other.cast(self.dtype())?;
self._get_inner_mut().extend(&other)?;
Expand Down
21 changes: 13 additions & 8 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,19 @@ impl SeriesUdf for PythonUdfExpression {
let func = unsafe { CALL_SERIES_UDF_PYTHON.unwrap() };

let output_type = self.output_type.clone().unwrap_or(DataType::Unknown);
let out = func(s[0].clone(), &self.python_function)?;

polars_ensure!(
matches!(output_type, DataType::Unknown) || out.dtype() == &output_type,
SchemaMismatch:
"expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
output_type, out.dtype(),
);
let mut out = func(s[0].clone(), &self.python_function)?;
if output_type != DataType::Unknown {
let must_cast = out.dtype().matches_schema_type(&output_type).map_err(|_| {
polars_err!(
SchemaMismatch: "expected output type '{:?}', got '{:?}'; set `return_dtype` to the proper datatype",
output_type, out.dtype(),
)
})?;
if must_cast {
out = out.cast(&output_type)?;
}
}

Ok(Some(out))
}

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_append_to_an_enum() -> None:
def test_append_to_an_enum_with_new_category() -> None:
with pytest.raises(
pl.SchemaError,
match=("cannot extend/append Enum"),
match=("type Enum.*is incompatible with expected type Enum.*"),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])).append(
pl.Series(["d", "a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"]))
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def test_map_elements_infer_list() -> None:
assert df.select([pl.all().map_elements(lambda x: [x])]).dtypes == [pl.List] * 3


def test_map_elements_upcast_null_dtype_empty_list() -> None:
df = pl.DataFrame({"a": [1, 2]})
out = df.select(
pl.col("a").map_elements(lambda _: [], return_dtype=pl.List(pl.Int64))
)
assert_frame_equal(
out, pl.DataFrame({"a": [[], []]}, schema={"a": pl.List(pl.Int64)})
)


def test_map_elements_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
Expand Down

0 comments on commit 2ac0da2

Please sign in to comment.