Skip to content

Commit

Permalink
Fix Series#get and DataFrame#get (#867)
Browse files Browse the repository at this point in the history
* Remove NDFrame#get method

* Add get method to Series

* Add get method to DataFrame

* Add test cases for {DataFrame,Series}#get(..., default=None) would return None

* Remove `= ...` from {DataFrame,Series}#get for cases where default parameter is given

* Use _typing.T instead of locally defined type var
  • Loading branch information
skatsuta authored Feb 13, 2024
1 parent 40d9636 commit 7978238
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 3 deletions.
10 changes: 9 additions & 1 deletion pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ from pandas._typing import (
StorageOptions,
StrLike,
Suffixes,
T as _T,
TimestampConvention,
ValidationOptions,
WriteBuffer,
Expand Down Expand Up @@ -1696,7 +1697,14 @@ class DataFrame(NDFrame, OpsMixin):
# def from_dict
# def from_records
def ge(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
# def get
@overload
def get(self, key: Hashable, default: None = ...) -> Series | None: ...
@overload
def get(self, key: Hashable, default: _T) -> Series | _T: ...
@overload
def get(self, key: list[Hashable], default: None = ...) -> DataFrame | None: ...
@overload
def get(self, key: list[Hashable], default: _T) -> DataFrame | _T: ...
def gt(self, other, axis: Axis = ..., level: Level | None = ...) -> DataFrame: ...
def head(self, n: int = ...) -> DataFrame: ...
def infer_objects(self) -> DataFrame: ...
Expand Down
2 changes: 0 additions & 2 deletions pandas-stubs/core/generic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ from pandas._typing import (
AxisIndex,
CompressionOptions,
CSVQuoting,
Dtype,
DtypeArg,
DtypeBackend,
FilePath,
Expand Down Expand Up @@ -299,7 +298,6 @@ class NDFrame(indexing.IndexingMixin):
self, indices, axis=..., is_copy: _bool | None = ..., **kwargs
) -> Self: ...
def __delitem__(self, idx: Hashable) -> None: ...
def get(self, key: object, default: Dtype | None = ...) -> Dtype: ...
def reindex_like(
self,
other,
Expand Down
7 changes: 7 additions & 0 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ from pandas._typing import (
SortKind,
StrDtypeArg,
StrLike,
T,
TimedeltaDtypeArg,
TimestampConvention,
TimestampDtypeArg,
Expand Down Expand Up @@ -381,6 +382,12 @@ class Series(IndexOpsMixin[S1], NDFrame):
@overload
def __getitem__(self, idx: Scalar) -> S1: ...
def __setitem__(self, key, value) -> None: ...
@overload
def get(self, key: Hashable, default: None = ...) -> S1 | None: ...
@overload
def get(self, key: Hashable, default: S1) -> S1: ...
@overload
def get(self, key: Hashable, default: T) -> S1 | T: ...
def repeat(
self, repeats: int | list[int], axis: AxisIndex | None = ...
) -> Series[S1]: ...
Expand Down
35 changes: 35 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3107,3 +3107,38 @@ def test_itertuples() -> None:
for item in df.itertuples():
check(assert_type(item, _PandasNamedTuple), tuple)
assert_type(item.a, Scalar)


def test_get() -> None:
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]})

# Get single column
check(assert_type(df.get("a"), Union[pd.Series, None]), pd.Series, np.int64)
check(assert_type(df.get("z"), Union[pd.Series, None]), type(None))
check(
assert_type(df.get("a", default=None), Union[pd.Series, None]),
pd.Series,
np.int64,
)
check(assert_type(df.get("z", default=None), Union[pd.Series, None]), type(None))
check(
assert_type(df.get("a", default=1), Union[pd.Series, int]), pd.Series, np.int64
)
check(assert_type(df.get("z", default=1), Union[pd.Series, int]), int)

# Get multiple columns
check(assert_type(df.get(["a"]), Union[pd.DataFrame, None]), pd.DataFrame)
check(assert_type(df.get(["a", "b"]), Union[pd.DataFrame, None]), pd.DataFrame)
check(assert_type(df.get(["z"]), Union[pd.DataFrame, None]), type(None))
check(
assert_type(df.get(["a", "b"], default=None), Union[pd.DataFrame, None]),
pd.DataFrame,
)
check(
assert_type(df.get(["z"], default=None), Union[pd.DataFrame, None]), type(None)
)
check(
assert_type(df.get(["a", "b"], default=1), Union[pd.DataFrame, int]),
pd.DataFrame,
)
check(assert_type(df.get(["z"], default=1), Union[pd.DataFrame, int]), int)
21 changes: 21 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Any,
Generic,
TypeVar,
Union,
cast,
)

Expand Down Expand Up @@ -2875,6 +2876,26 @@ def test_round() -> None:
check(assert_type(round(pd.Series([1], dtype=int)), "pd.Series[int]"), pd.Series)


def test_get() -> None:
s_int = pd.Series([1, 2, 3], index=[1, 2, 3])

check(assert_type(s_int.get(1), Union[int, None]), np.int64)
check(assert_type(s_int.get(99), Union[int, None]), type(None))
check(assert_type(s_int.get(1, default=None), Union[int, None]), np.int64)
check(assert_type(s_int.get(99, default=None), Union[int, None]), type(None))
check(assert_type(s_int.get(1, default=2), int), np.int64)
check(assert_type(s_int.get(99, default="a"), Union[int, str]), str)

s_str = pd.Series(list("abc"), index=list("abc"))

check(assert_type(s_str.get("a"), Union[str, None]), str)
check(assert_type(s_str.get("z"), Union[str, None]), type(None))
check(assert_type(s_str.get("a", default=None), Union[str, None]), str)
check(assert_type(s_str.get("z", default=None), Union[str, None]), type(None))
check(assert_type(s_str.get("a", default="b"), str), str)
check(assert_type(s_str.get("z", default=True), Union[str, bool]), bool)


def test_series_new_empty() -> None:
# GH 826
check(assert_type(pd.Series(), "pd.Series[Any]"), pd.Series)
Expand Down

0 comments on commit 7978238

Please sign in to comment.