diff --git a/msgspec/__init__.pyi b/msgspec/__init__.pyi index 86a1051a..c44f2b32 100644 --- a/msgspec/__init__.pyi +++ b/msgspec/__init__.pyi @@ -16,7 +16,7 @@ from typing import ( overload, ) -from typing_extensions import dataclass_transform +from typing_extensions import dataclass_transform, Buffer from . import inspect, json, msgpack, structs, toml, yaml @@ -108,7 +108,7 @@ class Raw(bytes): @overload def __new__(cls) -> "Raw": ... @overload - def __new__(cls, msg: Union[bytes, str]) -> "Raw": ... + def __new__(cls, msg: Union[Buffer, str]) -> "Raw": ... def copy(self) -> "Raw": ... class Meta: diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 3b8b7e5e..450f31e6 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -14,6 +14,8 @@ from typing import ( overload, ) +from typing_extensions import Buffer + T = TypeVar("T") enc_hook_sig = Optional[Callable[[Any], Any]] @@ -73,19 +75,19 @@ class Decoder(Generic[T]): dec_hook: dec_hook_sig = None, float_hook: float_hook_sig = None, ) -> None: ... - def decode(self, data: Union[bytes, str]) -> T: ... - def decode_lines(self, data: Union[bytes, str]) -> list[T]: ... + def decode(self, data: Union[Buffer, str]) -> T: ... + def decode_lines(self, data: Union[Buffer, str]) -> list[T]: ... @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, strict: bool = True, dec_hook: dec_hook_sig = None, ) -> Any: ... @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, type: Type[T] = ..., strict: bool = True, @@ -93,7 +95,7 @@ def decode( ) -> T: ... @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, type: Any = ..., strict: bool = True, @@ -110,4 +112,4 @@ def schema_components( @overload def format(buf: str, *, indent: int = 2) -> str: ... @overload -def format(buf: bytes, *, indent: int = 2) -> bytes: ... +def format(buf: Buffer, *, indent: int = 2) -> bytes: ... diff --git a/msgspec/toml.py b/msgspec/toml.py index f37228d8..c4306440 100644 --- a/msgspec/toml.py +++ b/msgspec/toml.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import datetime as _datetime -from typing import Any, Callable, Optional, Type, TypeVar, Union, overload, Literal +from typing import TYPE_CHECKING, overload, TypeVar, Any from . import ( DecodeError as _DecodeError, @@ -7,6 +9,11 @@ to_builtins as _to_builtins, ) +if TYPE_CHECKING: + from typing import Callable, Optional, Type, Union, Literal + from typing_extensions import Buffer + + __all__ = ("encode", "decode") @@ -103,7 +110,7 @@ def encode( @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, strict: bool = True, dec_hook: Optional[Callable[[type, Any], Any]] = None, @@ -113,7 +120,7 @@ def decode( @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, type: Type[T] = ..., strict: bool = True, @@ -124,7 +131,7 @@ def decode( @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, type: Any = ..., strict: bool = True, diff --git a/msgspec/yaml.py b/msgspec/yaml.py index bb57ef91..e8f6abc5 100644 --- a/msgspec/yaml.py +++ b/msgspec/yaml.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import datetime as _datetime -from typing import Any, Callable, Optional, Type, TypeVar, Union, overload, Literal +from typing import TYPE_CHECKING, overload, TypeVar, Any from . import ( DecodeError as _DecodeError, @@ -7,6 +9,11 @@ to_builtins as _to_builtins, ) +if TYPE_CHECKING: + from typing import Callable, Optional, Type, Union, Literal + from typing_extensions import Buffer + + __all__ = ("encode", "decode") @@ -96,7 +103,7 @@ def encode( @overload def decode( - buf: Union[bytes, str], + buf: Union[Buffer, str], *, strict: bool = True, dec_hook: Optional[Callable[[type, Any], Any]] = None, diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index fd49bd14..33dcb009 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -654,6 +654,11 @@ def check_msgpack_decode_typed() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_msgpack_decode_from_buffer() -> None: + msg = msgspec.msgpack.encode([1, 2, 3]) + msgspec.toml.decode(memoryview(msg)) + + def check_msgpack_decode_typed_union() -> None: o: Union[int, str] = msgspec.msgpack.decode(b"", type=Union[int, str]) reveal_type(o) # assert "int" in typ and "str" in typ @@ -835,6 +840,10 @@ def check_json_decode_from_str() -> None: reveal_type(o) # assert ("List" in typ or "list" in typ) and "int" in typ +def check_json_decode_from_buffer() -> None: + msgspec.json.decode(memoryview(b"[1, 2, 3]")) + + def check_json_encode_enc_hook() -> None: msgspec.json.encode(object(), enc_hook=lambda x: None) @@ -929,6 +938,10 @@ def check_yaml_decode_from_str() -> None: reveal_type(o) # assert "list" in typ.lower() and "int" in typ +def check_yaml_decode_from_buffer() -> None: + msgspec.yaml.decode(memoryview(b"[1, 2, 3]")) + + def check_yaml_encode_enc_hook() -> None: msgspec.yaml.encode(object(), enc_hook=lambda x: None) @@ -977,6 +990,10 @@ def check_toml_decode_from_str() -> None: reveal_type(o) # assert "dict" in typ.lower() and "int" in typ +def check_toml_decode_from_buffer() -> None: + msgspec.toml.decode(memoryview(b"a = 1")) + + def check_toml_encode_enc_hook() -> None: msgspec.toml.encode(object(), enc_hook=lambda x: None)