Skip to content

Commit

Permalink
remove HashableT in frame.pyi where possible (#1104)
Browse files Browse the repository at this point in the history
* remove HashableT in frame.pyi where possible

* fix to_records, update pyright version

* fix up Hashable refs and add tests
  • Loading branch information
Dr-Irv authored Feb 4, 2025
1 parent 583d198 commit 54b15c3
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 24 deletions.
73 changes: 50 additions & 23 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ from pandas._typing import (
HashableT,
HashableT1,
HashableT2,
HashableT3,
IgnoreRaise,
IndexingInt,
IndexLabel,
Expand Down Expand Up @@ -175,13 +174,13 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
@overload
def __getitem__(self, idx: Scalar) -> Series | _T: ...
@overload
def __getitem__(
def __getitem__( # type: ignore[overload-overlap]
self,
idx: (
IndexType
| MaskType
| Callable[[DataFrame], IndexType | MaskType | list[HashableT]]
| list[HashableT]
| Callable[[DataFrame], IndexType | MaskType | Sequence[Hashable]]
| list[Hashable]
| tuple[
IndexType
| MaskType
Expand Down Expand Up @@ -236,7 +235,7 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
@overload
def __setitem__(
self,
idx: tuple[_IndexSliceTuple, HashableT],
idx: tuple[_IndexSliceTuple, Hashable],
value: Scalar | NAType | NaTType | ArrayLike | Series | list | None,
) -> None: ...

Expand Down Expand Up @@ -438,6 +437,24 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
_str | npt.DTypeLike | Mapping[HashableT2, npt.DTypeLike] | None
) = ...,
) -> np.recarray: ...
@overload
def to_stata(
self,
path: FilePath | WriteBuffer[bytes],
*,
convert_dates: dict[HashableT1, StataDateFormat] | None = ...,
write_index: _bool = ...,
byteorder: Literal["<", ">", "little", "big"] | None = ...,
time_stamp: dt.datetime | None = ...,
data_label: _str | None = ...,
variable_labels: dict[HashableT2, str] | None = ...,
version: Literal[117, 118, 119],
convert_strl: SequenceNotStr[Hashable] | None = ...,
compression: CompressionOptions = ...,
storage_options: StorageOptions = ...,
value_labels: dict[Hashable, dict[float, str]] | None = ...,
) -> None: ...
@overload
def to_stata(
self,
path: FilePath | WriteBuffer[bytes],
Expand All @@ -449,7 +466,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
data_label: _str | None = ...,
variable_labels: dict[HashableT2, str] | None = ...,
version: Literal[114, 117, 118, 119] | None = ...,
convert_strl: list[HashableT3] | None = ...,
convert_strl: None = ...,
compression: CompressionOptions = ...,
storage_options: StorageOptions = ...,
value_labels: dict[Hashable, dict[float, str]] | None = ...,
Expand All @@ -462,7 +479,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
engine: ParquetEngine = ...,
compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ...,
index: bool | None = ...,
partition_cols: list[HashableT] | None = ...,
partition_cols: Sequence[Hashable] | None = ...,
storage_options: StorageOptions = ...,
**kwargs: Any,
) -> None: ...
Expand All @@ -473,7 +490,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
engine: ParquetEngine = ...,
compression: Literal["snappy", "gzip", "brotli", "lz4", "zstd"] | None = ...,
index: bool | None = ...,
partition_cols: list[HashableT] | None = ...,
partition_cols: Sequence[Hashable] | None = ...,
storage_options: StorageOptions = ...,
**kwargs: Any,
) -> bytes: ...
Expand All @@ -499,7 +516,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def to_html(
self,
buf: FilePath | WriteBuffer[str],
columns: list[HashableT] | Index | Series | None = ...,
columns: SequenceNotStr[Hashable] | Index | Series | None = ...,
col_space: ColspaceArgType | None = ...,
header: _bool = ...,
index: _bool = ...,
Expand Down Expand Up @@ -546,7 +563,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def to_html(
self,
buf: None = ...,
columns: Sequence[HashableT] | None = ...,
columns: Sequence[Hashable] | None = ...,
col_space: ColspaceArgType | None = ...,
header: _bool = ...,
index: _bool = ...,
Expand Down Expand Up @@ -597,8 +614,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
root_name: str = ...,
row_name: str = ...,
na_rep: str | None = ...,
attr_cols: list[HashableT1] | None = ...,
elem_cols: list[HashableT2] | None = ...,
attr_cols: SequenceNotStr[Hashable] | None = ...,
elem_cols: SequenceNotStr[Hashable] | None = ...,
namespaces: dict[str | None, str] | None = ...,
prefix: str | None = ...,
encoding: str = ...,
Expand All @@ -617,8 +634,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
root_name: str | None = ...,
row_name: str | None = ...,
na_rep: str | None = ...,
attr_cols: list[HashableT1] | None = ...,
elem_cols: list[HashableT2] | None = ...,
attr_cols: list[Hashable] | None = ...,
elem_cols: list[Hashable] | None = ...,
namespaces: dict[str | None, str] | None = ...,
prefix: str | None = ...,
encoding: str = ...,
Expand Down Expand Up @@ -846,7 +863,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def set_index(
self,
keys: (
Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT]
Label
| Series
| Index
| np.ndarray
| Iterator[Hashable]
| Sequence[Hashable]
),
*,
drop: _bool = ...,
Expand All @@ -858,7 +880,12 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def set_index(
self,
keys: (
Label | Series | Index | np.ndarray | Iterator[HashableT] | list[HashableT]
Label
| Series
| Index
| np.ndarray
| Iterator[Hashable]
| Sequence[Hashable]
),
*,
drop: _bool = ...,
Expand All @@ -876,7 +903,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
col_fill: Hashable = ...,
inplace: Literal[True],
allow_duplicates: _bool = ...,
names: Hashable | list[HashableT] = ...,
names: Hashable | Sequence[Hashable] = ...,
) -> None: ...
@overload
def reset_index(
Expand All @@ -888,7 +915,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
drop: _bool = ...,
inplace: Literal[False] = ...,
allow_duplicates: _bool = ...,
names: Hashable | list[HashableT] = ...,
names: Hashable | Sequence[Hashable] = ...,
) -> Self: ...
@overload
def reset_index(
Expand All @@ -900,7 +927,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
col_level: int | _str = ...,
col_fill: Hashable = ...,
allow_duplicates: _bool = ...,
names: Hashable | list[HashableT] = ...,
names: Hashable | Sequence[Hashable] = ...,
) -> Self | None: ...
def isna(self) -> Self: ...
def isnull(self) -> Self: ...
Expand Down Expand Up @@ -1681,7 +1708,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def columns(self) -> Index[str]: ...
@columns.setter # setter needs to be right next to getter; otherwise mypy complains
def columns(
self, cols: AnyArrayLike | list[HashableT] | tuple[HashableT, ...]
self, cols: AnyArrayLike | SequenceNotStr[Hashable] | tuple[Hashable, ...]
) -> None: ...
@property
def dtypes(self) -> Series: ...
Expand Down Expand Up @@ -2359,8 +2386,8 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def to_string(
self,
buf: FilePath | WriteBuffer[str],
columns: list[HashableT1] | Index | Series | None = ...,
col_space: int | list[int] | dict[HashableT2, int] | None = ...,
columns: SequenceNotStr[Hashable] | Index | Series | None = ...,
col_space: int | list[int] | dict[HashableT, int] | None = ...,
header: _bool | list[_str] | tuple[str, ...] = ...,
index: _bool = ...,
na_rep: _str = ...,
Expand All @@ -2382,7 +2409,7 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
def to_string(
self,
buf: None = ...,
columns: list[HashableT] | Index | Series | None = ...,
columns: Sequence[Hashable] | Index | Series | None = ...,
col_space: int | list[int] | dict[Hashable, int] | None = ...,
header: _bool | Sequence[_str] = ...,
index: _bool = ...,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mypy = "1.14.1"
pandas = "2.2.3"
pyarrow = ">=10.0.1"
pytest = ">=7.1.2"
pyright = ">= 1.1.391"
pyright = ">= 1.1.393"
poethepoet = ">=0.16.5"
loguru = ">=0.6.0"
typing-extensions = ">=4.4.0"
Expand Down
36 changes: 36 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,6 +3091,11 @@ def test_to_records() -> None:
),
np.recarray,
)
dtypes = {"col1": np.int8, "col2": np.int16}
check(
assert_type(DF.to_records(False, dtypes), np.recarray),
np.recarray,
)


def test_to_dict() -> None:
Expand Down Expand Up @@ -3815,6 +3820,37 @@ def _constructor(self) -> type[MyClass]:
check(assert_type(df[["a", "b"]], MyClass), MyClass)


def test_hashable_args() -> None:
# GH 1104
df = pd.DataFrame([["abc"]], columns=["test"], index=["ind"])
test = ["test"]

with ensure_clean() as path:

df.to_stata(path, version=117, convert_strl=test)
df.to_stata(path, version=117, convert_strl=["test"])

df.to_html(path, columns=test)
df.to_html(path, columns=["test"])

df.to_xml(path, attr_cols=test)
df.to_xml(path, attr_cols=["test"])

df.to_xml(path, elem_cols=test)
df.to_xml(path, elem_cols=["test"])

# Next lines should work, but it is a mypy bug
# https://github.com/python/mypy/issues/3004
# pyright accepts this, so we only type check for pyright,
# and also test the code with pytest
df.columns = test # type: ignore[assignment]
df.columns = ["test"] # type: ignore[assignment]

testDict = {"test": 1}
df.to_string("test", col_space=testDict)
df.to_string("test", col_space={"test": 1})


# GH 906
@pd.api.extensions.register_dataframe_accessor("geo")
class GeoAccessor: ...

0 comments on commit 54b15c3

Please sign in to comment.