Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing on Windows #2803

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/click/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,14 +538,14 @@ def auto_wrap_for_ansi(stream: t.TextIO, color: bool | None = None) -> t.TextIO:
rv = t.cast(t.TextIO, ansi_wrapper.stream)
_write = rv.write

def _safe_write(s):
def _safe_write(s: str) -> int:
try:
return _write(s)
except BaseException:
ansi_wrapper.reset_all()
raise

rv.write = _safe_write
rv.write = _safe_write # type: ignore[method-assign]

try:
_ansi_stream_wrappers[stream] = rv
Expand Down
7 changes: 3 additions & 4 deletions src/click/_termui_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _translate_ch_to_exc(ch: str) -> None:
return None


if WIN:
if sys.platform == "win32":
import msvcrt

@contextlib.contextmanager
Expand Down Expand Up @@ -703,12 +703,11 @@ def getchar(echo: bool) -> str:
#
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`.
func: t.Callable[[], str]

if echo:
func = msvcrt.getwche # type: ignore
func = t.cast(t.Callable[[], str], msvcrt.getwche)
else:
func = msvcrt.getwch # type: ignore
func = t.cast(t.Callable[[], str], msvcrt.getwch)

rv = func()

Expand Down
55 changes: 34 additions & 21 deletions src/click/_winconsole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import time
import typing as t
from ctypes import Array
from ctypes import byref
from ctypes import c_char
from ctypes import c_char_p
Expand Down Expand Up @@ -67,6 +68,14 @@
EOF = b"\x1a"
MAX_BYTES_WRITTEN = 32767

if t.TYPE_CHECKING:
try:
# Using `typing_extensions.Buffer` instead of `collections.abc`
# on Windows for some reason does not have `Sized` implemented.
from collections.abc import Buffer # type: ignore
except ImportError:
from typing_extensions import Buffer

try:
from ctypes import pythonapi
except ImportError:
Expand All @@ -93,32 +102,32 @@ class Py_buffer(Structure):
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release

def get_buffer(obj, writable=False):
def get_buffer(obj: Buffer, writable: bool = False) -> Array[c_char]:
buf = Py_buffer()
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
flags: int = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags)

try:
buffer_type = c_char * buf.len
buffer_type: Array[c_char] = c_char * buf.len
return buffer_type.from_address(buf.buf)
finally:
PyBuffer_Release(byref(buf))


class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle):
def __init__(self, handle: int | None) -> None:
self.handle = handle

def isatty(self):
def isatty(self) -> t.Literal[True]:
super().isatty()
return True


class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self):
def readable(self) -> t.Literal[True]:
return True

def readinto(self, b):
def readinto(self, b: Buffer) -> int:
bytes_to_be_read = len(b)
if not bytes_to_be_read:
return 0
Expand Down Expand Up @@ -150,18 +159,18 @@ def readinto(self, b):


class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self):
def writable(self) -> t.Literal[True]:
return True

@staticmethod
def _get_error_message(errno):
def _get_error_message(errno: int) -> str:
if errno == ERROR_SUCCESS:
return "ERROR_SUCCESS"
elif errno == ERROR_NOT_ENOUGH_MEMORY:
return "ERROR_NOT_ENOUGH_MEMORY"
return f"Windows error {errno}"

def write(self, b):
def write(self, b: Buffer) -> int:
bytes_to_be_written = len(b)
buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
Expand Down Expand Up @@ -209,7 +218,7 @@ def __getattr__(self, name: str) -> t.Any:
def isatty(self) -> bool:
return self.buffer.isatty()

def __repr__(self):
def __repr__(self) -> str:
return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"


Expand Down Expand Up @@ -267,16 +276,20 @@ def _get_windows_console_stream(
f: t.TextIO, encoding: str | None, errors: str | None
) -> t.TextIO | None:
if (
get_buffer is not None
and encoding in {"utf-16-le", None}
and errors in {"strict", None}
and _is_console(f)
get_buffer is None
or encoding not in {"utf-16-le", None}
or errors not in {"strict", None}
or not _is_console(f)
):
func = _stream_factories.get(f.fileno())
if func is not None:
b = getattr(f, "buffer", None)
return None

func = _stream_factories.get(f.fileno())
if func is None:
return None

b = getattr(f, "buffer", None)

if b is None:
return None
if b is None:
return None

return func(b)
return func(b)
12 changes: 9 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ commands = pre-commit run --all-files
[testenv:typing]
deps = -r requirements/typing.txt
commands =
mypy
pyright tests/typing
pyright --verifytypes click --ignoreexternal
mypy --platform linux
mypy --platform darwin
mypy --platform win32
pyright tests/typing --pythonplatform Linux
pyright tests/typing --pythonplatform Darwin
pyright tests/typing --pythonplatform Windows
pyright --verifytypes click --ignoreexternal --pythonplatform Linux
pyright --verifytypes click --ignoreexternal --pythonplatform Darwin
pyright --verifytypes click --ignoreexternal --pythonplatform Windows

[testenv:docs]
deps = -r requirements/docs.txt
Expand Down