diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fa7b7a3..b74ccbc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.0a1] - 2023-02-17 :heart: + +- Improves how custom binders can be defined, reducing code verbosity for + custom types. This is an important feature to implement common validation of + common parameters across multiple endpoints. +- Adds support for binder types defining OpenAPI Specification for their + parameters. +- Fixes bug #305 (`ClientSession ssl=False` not working as intended). + ## [2.0a0] - 2023-01-08 :hourglass_flowing_sand: -- Renames the `plugins` namespace to `settings` +- Renames the `plugins` namespace to `settings`. - Upgrades `rodi` to v2, which includes improvements. - Adds support for alternative implementation of containers for dependency injection, using the new `ContainerProtocol` in `rodi`. diff --git a/README.md b/README.md index 120dbf38..48766155 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ import asyncio from blacksheep.client import ClientSession -async def client_example(loop): +async def client_example(): async with ClientSession() as client: response = await client.get("https://docs.python.org/3/") @@ -249,9 +249,7 @@ async def client_example(loop): print(text) -loop = asyncio.get_event_loop() -loop.run_until_complete(client_example(loop)) - +asyncio.run(client_example()) ``` ## Supported platforms and runtimes diff --git a/blacksheep/client/connection.py b/blacksheep/client/connection.py index 81e6d977..0456fb69 100644 --- a/blacksheep/client/connection.py +++ b/blacksheep/client/connection.py @@ -24,6 +24,7 @@ INSECURE_SSLCONTEXT = ssl.SSLContext(protocol=ssl.PROTOCOL_TLS_CLIENT) INSECURE_SSLCONTEXT.check_hostname = False +INSECURE_SSLCONTEXT.verify_mode = ssl.CERT_NONE class IncomingContent(Content): @@ -80,7 +81,6 @@ def __init__(self, response, transport): class ClientConnection(asyncio.Protocol): - __slots__ = ( "loop", "pool", diff --git a/blacksheep/client/cookies.py b/blacksheep/client/cookies.py index bec5ef58..9fb04dbc 100644 --- a/blacksheep/client/cookies.py +++ b/blacksheep/client/cookies.py @@ -32,7 +32,6 @@ def not_ip_address(value: str): class StoredCookie: - __slots__ = ("cookie", "persistent", "creation_time", "expiry_time") def __init__(self, cookie: Cookie): @@ -203,7 +202,6 @@ def _get_cookies_checking_exp( schema: str, cookies: Dict[str, StoredCookie] ) -> Iterable[Cookie]: for cookie_name, stored_cookie in cookies.copy().items(): - if stored_cookie.is_expired(): del cookies[cookie_name] continue diff --git a/blacksheep/client/pool.py b/blacksheep/client/pool.py index 994ecedd..4d5147c8 100644 --- a/blacksheep/client/pool.py +++ b/blacksheep/client/pool.py @@ -25,8 +25,10 @@ def get_ssl_context( "Invalid ssl argument, expected one of: " "None, False, True, instance of ssl.SSLContext." ) + if ssl: raise InvalidArgument("SSL argument specified for non-https scheme.") + return None diff --git a/blacksheep/client/session.py b/blacksheep/client/session.py index 2b49c79f..efe4696a 100644 --- a/blacksheep/client/session.py +++ b/blacksheep/client/session.py @@ -46,7 +46,6 @@ def __contains__(self, item: Any) -> bool: class ClientRequestContext: - __slots__ = ("path", "cookies") def __init__(self, request, cookies: Optional[CookieJar] = None): @@ -345,7 +344,7 @@ async def _send_core(self, request: Request) -> Response: return response - async def _send_using_connection(self, request) -> Response: + async def _send_using_connection(self, request, attempt: int = 1) -> Response: connection = await self.get_connection(request.url) try: @@ -353,9 +352,9 @@ async def _send_using_connection(self, request) -> Response: connection.send(request), self.request_timeout ) except ConnectionClosedError as connection_closed_error: - if connection_closed_error.can_retry: + if connection_closed_error.can_retry and attempt < 4: await asyncio.sleep(self.delay_before_retry) - return await self._send_using_connection(request) + return await self._send_using_connection(request, attempt + 1) raise except TimeoutError: raise RequestTimeout(request.url, self.request_timeout) diff --git a/blacksheep/common/files/info.py b/blacksheep/common/files/info.py index fc9b58dc..16caae78 100644 --- a/blacksheep/common/files/info.py +++ b/blacksheep/common/files/info.py @@ -5,7 +5,6 @@ class FileInfo: - __slots__ = ("etag", "size", "mime", "modified_time") def __init__(self, size: int, etag: str, mime: str, modified_time: str): diff --git a/blacksheep/multipart.py b/blacksheep/multipart.py index be13449c..8629dbbb 100644 --- a/blacksheep/multipart.py +++ b/blacksheep/multipart.py @@ -107,7 +107,6 @@ def parse_multipart(value: bytes) -> Generator[FormPart, None, None]: default_charset = None for part_bytes in split_multipart(value): - try: yield parse_part(part_bytes, default_charset) except CharsetPart as charset: diff --git a/blacksheep/ranges.py b/blacksheep/ranges.py index c1218f51..967426e8 100644 --- a/blacksheep/ranges.py +++ b/blacksheep/ranges.py @@ -123,7 +123,6 @@ def _parse_range_value(range_value: str): class Range: - __slots__ = ("_unit", "_parts") def __init__(self, unit: str, parts: Sequence[RangePart]): diff --git a/blacksheep/server/bindings.py b/blacksheep/server/bindings.py index eff3a9d3..50e8333f 100644 --- a/blacksheep/server/bindings.py +++ b/blacksheep/server/bindings.py @@ -239,8 +239,15 @@ class RequestMethod(BoundValue[str]): """ +def _implicit_default(obj: "Binder"): + try: + return issubclass(obj.handle, BoundValue) + except (AttributeError, TypeError): + return False + + class Binder(metaclass=BinderMeta): # type: ignore - handle: ClassVar[Type[BoundValue]] + handle: ClassVar[Type[Any]] name_alias: ClassVar[str] = "" type_alias: ClassVar[Any] = None @@ -252,7 +259,7 @@ def __init__( required: bool = True, converter: Optional[Callable] = None, ): - self._implicit = implicit + self._implicit = implicit or not _implicit_default(self) self.parameter_name = name self.expected_type = expected_type self.required = required @@ -316,7 +323,10 @@ def example(id: str): # applied implicitly ... """ - value = await self.get_value(request) + try: + value = await self.get_value(request) + except ValueError as value_error: + raise BadRequest("Invalid parameter.") from value_error if value is None and self.default is not empty: return self.default @@ -334,6 +344,7 @@ def example(id: str): @abstractmethod async def get_value(self, request: Request) -> Any: """Gets a value from the given request object.""" + raise NotImplementedError() def get_binder_by_type(bound_value_type: Type[BoundValue]) -> Type[Binder]: @@ -405,7 +416,7 @@ class BodyBinder(Binder): def __init__( self, - expected_type: T, + expected_type, name: str = "body", implicit: bool = False, required: bool = False, @@ -592,7 +603,7 @@ class SyncBinder(Binder): def __init__( self, - expected_type: T = List[str], + expected_type: Any = List[str], name: str = "", implicit: bool = False, required: bool = False, diff --git a/blacksheep/server/normalization.py b/blacksheep/server/normalization.py index ebe0ad71..df5b324a 100644 --- a/blacksheep/server/normalization.py +++ b/blacksheep/server/normalization.py @@ -206,7 +206,7 @@ def __init__(self, parameter_name, route): def _check_union( - parameter: inspect.Parameter, annotation: Any, method: Callable[..., Any] + parameter: ParamInfo, annotation: Any, method: Callable[..., Any] ) -> Tuple[bool, Any]: """ Checks if the given annotation is Optional[] - in such case unwraps it @@ -292,7 +292,7 @@ def _get_bound_value_type(bound_type: Type[BoundValue]) -> Type[Any]: def _get_parameter_binder( - parameter: inspect.Parameter, + parameter: ParamInfo, services: ContainerProtocol, route: Optional[Route], method: Callable[..., Any], @@ -316,6 +316,13 @@ def _get_parameter_binder( if annotation in Binder.aliases: return Binder.aliases[annotation](services) + if ( + annotation in Binder.handlers + and annotation not in services + and not issubclass(annotation, BoundValue) + ): + return Binder.handlers[annotation](annotation, parameter.name) + # 1. is the type annotation of BoundValue[T] type? if _is_bound_value_annotation(annotation): binder_type = get_binder_by_type(annotation) @@ -377,7 +384,7 @@ def _get_parameter_binder( def get_parameter_binder( - parameter: inspect.Parameter, + parameter: ParamInfo, services: ContainerProtocol, route: Optional[Route], method: Callable[..., Any], diff --git a/blacksheep/server/openapi/docstrings.py b/blacksheep/server/openapi/docstrings.py index 4410b8e6..3afd4461 100644 --- a/blacksheep/server/openapi/docstrings.py +++ b/blacksheep/server/openapi/docstrings.py @@ -89,7 +89,6 @@ def parse_docstring(self, docstring: str) -> DocstringInfo: def type_repr_to_type(type_repr: str) -> Optional[Type]: - array_match = _array_rx.match(type_repr) if array_match: diff --git a/blacksheep/server/openapi/v3.py b/blacksheep/server/openapi/v3.py index 6bb134cb..3faf85eb 100644 --- a/blacksheep/server/openapi/v3.py +++ b/blacksheep/server/openapi/v3.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, fields, is_dataclass from datetime import date, datetime from enum import Enum, IntEnum -from typing import Any, Dict, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union from typing import _GenericAlias as GenericAlias from typing import get_type_hints from uuid import UUID @@ -295,6 +295,9 @@ def __init__( DataClassTypeHandler(), PydanticModelTypeHandler(), ] + self._binder_docs: Dict[ + Type[Binder], Iterable[Union[Parameter, Reference]] + ] = {} @property def object_types_handlers(self) -> List[ObjectTypeHandler]: @@ -304,6 +307,7 @@ def get_ui_page_title(self) -> str: return self.info.title def generate_documentation(self, app: Application) -> OpenAPI: + self._optimize_binders_docs() return OpenAPI( info=self.info, paths=self.get_paths(app), components=self.components ) @@ -697,12 +701,16 @@ def get_parameters( if not hasattr(handler, "binders"): return None binders: List[Binder] = handler.binders - parameters: Mapping[str, Union[Parameter, Reference]] = {} + parameters: Dict[str, Union[Parameter, Reference]] = {} docs = self.get_handler_docs(handler) parameters_info = (docs.parameters if docs else None) or dict() for binder in binders: + if binder.__class__ in self._binder_docs: + self._handle_binder_docs(binder, parameters) + continue + location = self.get_parameter_location_for_binder(binder) if not location: @@ -971,3 +979,51 @@ def get_routes_docs( self.events.on_paths_created.fire_sync(paths_doc) return paths_doc + + def set_binder_docs( + self, + binder_type: Type[Binder], + params_docs: Iterable[Union[Parameter, Reference]], + ): + """ + Configures parameters documentation for a given binder type. A binder can + read values from one or more input parameters, this is why this method supports + an iterable of Parameter or Reference objects. In most use cases, it is + desirable to use a Parameter here. Reference objects are configured + automatically when the documentation is built. + """ + self._binder_docs[binder_type] = params_docs + + def _handle_binder_docs( + self, binder: Binder, parameters: Dict[str, Union[Parameter, Reference]] + ): + params_docs = self._binder_docs[binder.__class__] + + for i, param_doc in enumerate(params_docs): + parameters[f"{binder.__class__.__qualname__}_{i}"] = param_doc + + def _optimize_binders_docs(self): + """ + Optimizes the documentation for custom binders to use references and + components.parameters, instead of duplicating parameters documentation in each + operation where they are used. + """ + new_dict = {} + params_docs: Iterable[Union[Parameter, Reference]] + + for key, params_docs in self._binder_docs.items(): + new_docs: List[Reference] = [] + + for param in params_docs: + if isinstance(param, Reference): + new_docs.append(param) + else: + if self.components.parameters is None: + self.components.parameters = {} + + self.components.parameters[param.name] = param + new_docs.append(Reference(f"#/components/parameters/{param.name}")) + + new_dict[key] = new_docs + + self._binder_docs = new_dict diff --git a/blacksheep/server/routing.py b/blacksheep/server/routing.py index 49060167..b3cd8e8f 100644 --- a/blacksheep/server/routing.py +++ b/blacksheep/server/routing.py @@ -78,7 +78,6 @@ def __init__(self, parameter_pattern_name: str, matched_parameter: str) -> None: class RouteMatch: - __slots__ = ("values", "pattern", "handler") def __init__(self, route: "Route", values: Optional[Dict[str, bytes]]): @@ -98,7 +97,6 @@ def _get_parameter_pattern_fragment( class Route: - __slots__ = ( "handler", "pattern", @@ -378,7 +376,6 @@ def ws(self, pattern) -> Callable[..., Any]: class Router(RouterBase): - __slots__ = ("routes", "_map", "_fallback") def __init__(self): @@ -478,7 +475,6 @@ def get_matching_route(self, method: AnyStr, value: AnyStr) -> Optional[Route]: class RegisteredRoute: - __slots__ = ("method", "pattern", "handler") def __init__(self, method: str, pattern: str, handler: Callable): diff --git a/itests/test_client.py b/itests/test_client.py index 82380138..05a72764 100644 --- a/itests/test_client.py +++ b/itests/test_client.py @@ -155,7 +155,6 @@ async def test_post_form(session, data): @pytest.mark.asyncio async def test_post_multipart_form_with_files(session): - if os.path.exists("out"): shutil.rmtree("out") @@ -192,7 +191,6 @@ async def test_post_multipart_form_with_files(session): @pytest.mark.asyncio async def test_post_multipart_form_with_images(session): - if os.path.exists("out"): shutil.rmtree("out") diff --git a/itests/test_server.py b/itests/test_server.py index 167ccb0b..6faeb10e 100644 --- a/itests/test_server.py +++ b/itests/test_server.py @@ -164,7 +164,6 @@ def test_post_form_urlencoded(session_1, data, echoed): def test_post_multipart_form_with_files(session_1): - if os.path.exists("out"): shutil.rmtree("out") diff --git a/setup.py b/setup.py index 29038245..6dad9f9f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def readme(): setup( name="blacksheep", - version="2.0a0", + version="2.0a1", description="Fast web framework for Python asyncio", long_description=readme(), long_description_content_type="text/markdown", diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 6fe16e67..61e0c2ec 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -72,7 +72,6 @@ def get_scenarios(fn): async def test_non_url_redirect( responses, expected_status, expected_location, pools_factory ): - async with ClientSession( base_url=b"http://localhost:8080", pools=pools_factory(responses) ) as client: @@ -109,7 +108,6 @@ async def test_non_url_redirect( ), ) async def test_good_redirect(responses, expected_response_body, pools_factory): - async with ClientSession( base_url=b"http://localhost:8080", pools=pools_factory(responses) ) as client: @@ -137,7 +135,6 @@ async def test_good_redirect(responses, expected_response_body, pools_factory): ], ) async def test_not_follow_redirect(responses, expected_location, pools_factory): - async with ClientSession( base_url=b"http://localhost:8080", pools=pools_factory(responses), @@ -186,7 +183,6 @@ async def test_not_follow_redirect(responses, expected_location, pools_factory): async def test_maximum_number_of_redirects_detection( responses, maximum_redirects, pools_factory ): - async with ClientSession( base_url=b"http://localhost:8080", pools=pools_factory(responses) ) as client: @@ -245,11 +241,9 @@ async def test_maximum_number_of_redirects_detection( async def test_circular_redirect_detection( responses, expected_error_message, pools_factory ): - async with ClientSession( base_url=b"http://localhost:8080", pools=pools_factory(responses) ) as client: - with pytest.raises(CircularRedirectError) as error: await client.get(b"/") diff --git a/tests/test_application.py b/tests/test_application.py index 2a60cc7d..107930a9 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -2645,7 +2645,6 @@ async def home(foo: FromRoute[parameter_type]): async def test_valid_header_parameter_parse( parameter_type, parameter, expected_value, app ): - T = TypeVar("T") class XFooHeader(FromHeader[T]): @@ -3880,7 +3879,7 @@ def decorator(next_handler): async def wrapped(*args, **kwargs) -> Response: response = ensure_response(await next_handler(*args, **kwargs)) - for (name, value) in additional_headers: + for name, value in additional_headers: response.add_header(name.encode(), value.encode()) return response diff --git a/tests/test_bindings.py b/tests/test_bindings.py index 19b36c3e..8cf9400f 100644 --- a/tests/test_bindings.py +++ b/tests/test_bindings.py @@ -57,7 +57,6 @@ def __init__(self, a: str, b: List[str]): @pytest.mark.asyncio async def test_from_body_json_binding(): - request = Request("POST", b"/", [JSONContentType]).with_content( JSONContent({"a": "world", "b": 9000}) ) @@ -73,7 +72,6 @@ async def test_from_body_json_binding(): @pytest.mark.asyncio async def test_from_body_json_binding_extra_parameters_strategy(): - request = Request("POST", b"/", [JSONContentType]).with_content( JSONContent( { @@ -95,7 +93,6 @@ async def test_from_body_json_binding_extra_parameters_strategy(): @pytest.mark.asyncio async def test_from_body_json_with_converter(): - request = Request("POST", b"/", [JSONContentType]).with_content( JSONContent( { @@ -459,7 +456,6 @@ async def test_identity_binder(): @pytest.mark.asyncio async def test_from_body_form_binding_urlencoded(): - request = Request("POST", b"/", []).with_content( FormContent({"a": "world", "b": 9000}) ) @@ -475,7 +471,6 @@ async def test_from_body_form_binding_urlencoded(): @pytest.mark.asyncio async def test_from_body_form_binding_urlencoded_keys_duplicates(): - request = Request("POST", b"/", []).with_content( FormContent([("a", "world"), ("b", "one"), ("b", "two"), ("b", "three")]) ) @@ -491,7 +486,6 @@ async def test_from_body_form_binding_urlencoded_keys_duplicates(): @pytest.mark.asyncio async def test_from_body_form_binding_multipart(): - request = Request("POST", b"/", []).with_content( MultiPartFormData([FormPart(b"a", b"world"), FormPart(b"b", b"9000")]) ) @@ -506,7 +500,6 @@ async def test_from_body_form_binding_multipart(): @pytest.mark.asyncio async def test_from_body_form_binding_multipart_keys_duplicates(): - request = Request("POST", b"/", []).with_content( MultiPartFormData( [ diff --git a/tests/test_controllers.py b/tests/test_controllers.py index 3658ed63..b2d7e5e8 100644 --- a/tests/test_controllers.py +++ b/tests/test_controllers.py @@ -467,7 +467,6 @@ async def test_controller_with_base_route_as_string_attribute(app): get = app.controllers_router.get class Home(Controller): - route = "/home" def greet(self): @@ -499,7 +498,6 @@ async def test_application_raises_for_invalid_route_class_attribute(app): get = app.controllers_router.get class Home(Controller): - route = False def greet(self): @@ -602,7 +600,6 @@ async def test_controllers_with_duplicate_routes_with_base_route_throw( # and another handler class A(Controller): - route = "home" @get(first_pattern) @@ -632,7 +629,6 @@ async def test_controller_with_duplicate_route_with_base_route_throw( # and another handler class A(Controller): - route = "home" @get(first_pattern) diff --git a/tests/test_cookies.py b/tests/test_cookies.py index 1f3cb0a2..055c6d06 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -262,7 +262,6 @@ def test_parse_cookie_separators(value, expected_name, expected_value, expected_ def test_raise_for_value_exceeding_length(): - with pytest.raises(CookieValueExceedsMaximumLength): Cookie("crash", "A" * 4967) diff --git a/tests/test_cors.py b/tests/test_cors.py index 369ab83f..2b37e086 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -580,7 +580,6 @@ async def test_non_cors_options_request(app): @pytest.mark.asyncio async def test_use_cors_raises_for_started_app(app): - await app.start() with pytest.raises(ApplicationAlreadyStartedCORSError): diff --git a/tests/test_files_serving.py b/tests/test_files_serving.py index d2a56cbf..3fe8ad7e 100644 --- a/tests/test_files_serving.py +++ b/tests/test_files_serving.py @@ -663,7 +663,6 @@ async def test_serve_files_multiple_folders(files2_index_contents, app): def test_validate_source_path_raises_for_invalid_path(): - with pytest.raises(InvalidArgument): validate_source_path("./not-existing") diff --git a/tests/test_openapi_v3.py b/tests/test_openapi_v3.py index 805bcdf9..1e8be225 100644 --- a/tests/test_openapi_v3.py +++ b/tests/test_openapi_v3.py @@ -173,7 +173,6 @@ def check_consistency(cls, v, values): class PydConstrained(BaseModel): - a: PositiveInt b: NegativeFloat big_int: conint(gt=1000, lt=1024) diff --git a/tests/test_ranges.py b/tests/test_ranges.py index 18962b5d..6fe938b0 100644 --- a/tests/test_ranges.py +++ b/tests/test_ranges.py @@ -141,7 +141,6 @@ def test_range_eq_not_implemented(item): def test_range_part_raises_if_start_gt_end(): - with pytest.raises(ValueError): RangePart(400, 300) @@ -155,7 +154,6 @@ def test_range_part_raises_if_start_gt_end(): def test_range_part_raises_if_any_part_is_negative(): - with pytest.raises(ValueError): RangePart(-100, 0) diff --git a/tests/test_url.py b/tests/test_url.py index 53446543..706cacfc 100644 --- a/tests/test_url.py +++ b/tests/test_url.py @@ -108,6 +108,5 @@ def test_base_url(value, expected_base_url): def test_raises_for_invalid_scheme(): - with pytest.raises(InvalidURL): URL(b"file://D:/a/b/c")