Skip to content

Commit

Permalink
fix(rust,python): Consistently propagate nulls for numpy ufuncs (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-sil authored Nov 5, 2023
1 parent 936e9c5 commit ed962f8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
12 changes: 8 additions & 4 deletions crates/polars-core/src/chunked_array/ops/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ where
Some(true) => value,
_ => opt_val,
})
.collect_trusted();
.collect_trusted::<Self>()
.with_name(self.name());
Ok(ca)
}
}
Expand Down Expand Up @@ -166,7 +167,8 @@ impl<'a> ChunkSet<'a, bool, bool> for BooleanChunked {
Some(true) => value,
_ => opt_val,
})
.collect_trusted();
.collect_trusted::<Self>()
.with_name(self.name());
Ok(ca)
}
}
Expand Down Expand Up @@ -229,7 +231,8 @@ impl<'a> ChunkSet<'a, &'a str, String> for Utf8Chunked {
Some(true) => value,
_ => opt_val,
})
.collect_trusted();
.collect_trusted::<Self>()
.with_name(self.name());
Ok(ca)
}
}
Expand Down Expand Up @@ -293,7 +296,8 @@ impl<'a> ChunkSet<'a, &'a [u8], Vec<u8>> for BinaryChunked {
Some(true) => value,
_ => opt_val,
})
.collect_trusted();
.collect_trusted::<Self>()
.with_name(self.name());
Ok(ca)
}
}
Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,12 @@ def __array_ufunc__(

args: list[int | float | np.ndarray[Any, Any]] = []

validity_mask = self.is_not_null()
for arg in inputs:
if isinstance(arg, (int, float, np.ndarray)):
args.append(arg)
elif isinstance(arg, Series):
validity_mask &= arg.is_not_null()
args.append(arg.view(ignore_nulls=True))
else:
raise TypeError(
Expand Down Expand Up @@ -1187,7 +1189,12 @@ def __array_ufunc__(
)

series = f(lambda out: ufunc(*args, out=out, dtype=dtype_char, **kwargs))
return self._from_pyseries(series)
return (
self._from_pyseries(series)
.to_frame()
.select(F.when(validity_mask).then(F.col(self.name)))
.to_series(0)
)
else:
raise NotImplementedError(
"only `__call__` is implemented for numpy ufuncs on a Series, got "
Expand Down Expand Up @@ -4344,7 +4351,7 @@ def to_init_repr(self, n: int = 1000) -> str:
f'pl.Series("{self.name}", {self.head(n).to_list()}, dtype=pl.{self.dtype})'
)

def set(self, filter: Series, value: int | float | str) -> Series:
def set(self, filter: Series, value: int | float | str | bool | None) -> Series:
"""
Set masked values.
Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,13 @@ def test_ufunc() -> None:
pl.Series("a", [3.0, None, 9.0, 12.0, 15.0, None]),
)

# Test if nulls propagate through ufuncs
a3 = pl.Series("a", [None, None, 3, 3])
b3 = pl.Series("b", [None, 3, None, 3])
assert_series_equal(
cast(pl.Series, np.maximum(a3, b3)), pl.Series("a", [None, None, None, 3])
)


def test_numpy_string_array() -> None:
s_utf8 = pl.Series("a", ["aa", "bb", "cc", "dd"], dtype=pl.Utf8)
Expand Down Expand Up @@ -894,7 +901,7 @@ def test_set() -> None:
a = pl.Series("a", [True, False, True])
mask = pl.Series("msk", [True, False, True])
a[mask] = False
assert_series_equal(a, pl.Series("", [False] * 3))
assert_series_equal(a, pl.Series("a", [False] * 3))


def test_set_value_as_list_fail() -> None:
Expand Down

0 comments on commit ed962f8

Please sign in to comment.