Skip to content

Commit

Permalink
GH1089 Partial typehinting
Browse files Browse the repository at this point in the history
  • Loading branch information
loicdiridollou committed Jan 15, 2025
1 parent 01130e1 commit 4c60287
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
36 changes: 35 additions & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
@overload
def dot(self, other: Series[S1]) -> Scalar: ...
@overload
def dot(self, other: DataFrame) -> Series[S1]: ...
def dot(self, other: DataFrame) -> Series: ...
@overload
def dot(
self, other: ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | Index[S1]
Expand Down Expand Up @@ -1628,6 +1628,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
self, other: int | np_ndarray_anyint | Series[int]
) -> Series[int]: ...
# def __array__(self, dtype: Optional[_bool] = ...) -> _np_ndarray
@overload
def __div__(self: Series[int], other: Series[int]) -> Series[float]: ...
@overload
def __div__(self: Series[int], other: int) -> Series[float]: ...
@overload
def __div__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
def __eq__(self, other: object) -> Series[_bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __floordiv__(self, other: num | _ListLike | Series[S1]) -> Series[int]: ...
Expand All @@ -1648,6 +1653,14 @@ class Series(IndexOpsMixin[S1], NDFrame):
self, other: timedelta | Timedelta | TimedeltaSeries | np.timedelta64
) -> TimedeltaSeries: ...
@overload
def __mul__(self: Series[int], other: int) -> Series[int]: ...
@overload
def __mul__(self: Series[int], other: Series[int]) -> Series[int]: ...
@overload
def __mul__(self: Series[int], other: Series[float]) -> Series[float]: ...
@overload
def __mul__(self: Series[Any], other: Series[Any]) -> Series: ...
@overload
def __mul__(self, other: num | _ListLike | Series) -> Series: ...
def __mod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
def __ne__(self, other: object) -> Series[_bool]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
Expand All @@ -1674,6 +1687,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
def __rand__( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: int | np_ndarray_anyint | Series[int]
) -> Series[int]: ...
@overload
def __rdiv__(self: Series[int], other: int) -> Series[float]: ...
@overload
def __rdiv__(self: Series[int], other: Series[int]) -> Series[float]: ...
@overload
def __rdiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
def __rdivmod__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __rfloordiv__(self, other: num | _ListLike | Series[S1]) -> Series[S1]: ...
Expand Down Expand Up @@ -1936,6 +1954,22 @@ class Series(IndexOpsMixin[S1], NDFrame):
axis: AxisIndex | None = ...,
) -> Series[S1]: ...
@overload
def mul(
self: Series[int],
other: Series[int],
level: Level | None = ...,
fill_value: float | None = ...,
axis: AxisIndex | None = ...,
) -> Series[int]: ...
@overload
def mul(
self: Series[int],
other: Series[float],
level: Level | None = ...,
fill_value: float | None = ...,
axis: AxisIndex | None = ...,
) -> Series[float]: ...
@overload
def mul(
self,
other: timedelta | Timedelta | TimedeltaSeries | np.timedelta64,
Expand Down
21 changes: 13 additions & 8 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,10 +656,12 @@ def test_types_element_wise_arithmetic() -> None:
check(assert_type(s - s2, pd.Series), pd.Series, np.integer)
check(assert_type(s.sub(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)

check(assert_type(s * s2, pd.Series), pd.Series, np.integer)
check(assert_type(s.mul(s2, fill_value=0), pd.Series), pd.Series, np.integer)
check(assert_type(s * s2, "pd.Series[int]"), pd.Series, np.integer)
check(assert_type(s.mul(s2, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)

check(assert_type(s / s2, pd.Series), pd.Series, np.float64)
# GH1089 should be the following
# check(assert_type(s / s2, "pd.Series[float]"), pd.Series, np.float64)
check(assert_type(s / s2, "pd.Series"), pd.Series, np.float64)
check(
assert_type(s.div(s2, fill_value=0), "pd.Series[float]"), pd.Series, np.float64
)
Expand Down Expand Up @@ -693,9 +695,11 @@ def test_types_scalar_arithmetic() -> None:
check(assert_type(s - 1, "pd.Series[int]"), pd.Series, np.integer)
check(assert_type(s.sub(1, fill_value=0), "pd.Series[int]"), pd.Series, np.integer)

check(assert_type(s * 2, pd.Series), pd.Series, np.integer)
check(assert_type(s * 2, "pd.Series[int]"), pd.Series, np.integer)
check(assert_type(s.mul(2, fill_value=0), pd.Series), pd.Series, np.integer)

# GH1089 should be
# check(assert_type(s / 2, "pd.Series[float]"), pd.Series, np.float64)
check(assert_type(s / 2, pd.Series), pd.Series, np.float64)
check(
assert_type(s.div(2, fill_value=0), "pd.Series[float]"), pd.Series, np.float64
Expand Down Expand Up @@ -1311,7 +1315,7 @@ def test_types_dot() -> None:
n1 = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]])
check(assert_type(s1.dot(s2), Scalar), np.integer)
check(assert_type(s1 @ s2, Scalar), np.integer)
check(assert_type(s1.dot(df1), "pd.Series[int]"), pd.Series, np.integer)
check(assert_type(s1.dot(df1), pd.Series), pd.Series, np.integer)
check(assert_type(s1 @ df1, pd.Series), pd.Series)
check(assert_type(s1.dot(n1), np.ndarray), np.ndarray)
check(assert_type(s1 @ n1, np.ndarray), np.ndarray)
Expand All @@ -1333,7 +1337,8 @@ def test_series_min_max_sub_axis() -> None:
sd = s1 / s2
check(assert_type(sa, pd.Series), pd.Series)
check(assert_type(ss, pd.Series), pd.Series)
check(assert_type(sm, pd.Series), pd.Series)
# TODO GH1089 This should not match to Series[int]
check(assert_type(sm, pd.Series), pd.Series) # pyright: ignore
check(assert_type(sd, pd.Series), pd.Series)


Expand Down Expand Up @@ -1368,11 +1373,11 @@ def test_series_multiindex_getitem() -> None:
def test_series_mul() -> None:
s = pd.Series([1, 2, 3])
sm = s * 4
check(assert_type(sm, pd.Series), pd.Series)
check(assert_type(sm, "pd.Series[int]"), pd.Series, np.integer)
ss = s - 4
check(assert_type(ss, "pd.Series[int]"), pd.Series, np.integer)
sm2 = s * s
check(assert_type(sm2, pd.Series), pd.Series)
check(assert_type(sm2, "pd.Series[int]"), pd.Series, np.integer)
sp = s + 4
check(assert_type(sp, "pd.Series[int]"), pd.Series, np.integer)

Expand Down

0 comments on commit 4c60287

Please sign in to comment.