diff --git a/blacksheep/__init__.py b/blacksheep/__init__.py index 16500ba8..49c27b09 100644 --- a/blacksheep/__init__.py +++ b/blacksheep/__init__.py @@ -24,4 +24,5 @@ from .headers import Header, Headers from .messages import Request, Response from .server import Application, Route, Router +from .server.websocket import WebSocket, WebSocketDisconnectError from .url import URL, InvalidURL diff --git a/blacksheep/__init__.pyi b/blacksheep/__init__.pyi index 700364ba..18259bbb 100644 --- a/blacksheep/__init__.pyi +++ b/blacksheep/__init__.pyi @@ -20,5 +20,7 @@ from .messages import Response as Response from .server import Application as Application from .server import Route as Route from .server import Router as Router +from .server.websocket import WebSocket as WebSocket +from .server.websocket import WebSocketDisconnectError as WebSocketDisconnectError from .url import URL as URL from .url import InvalidURL as InvalidURL diff --git a/blacksheep/server/application.py b/blacksheep/server/application.py index 8559d7e6..6cfd03f8 100644 --- a/blacksheep/server/application.py +++ b/blacksheep/server/application.py @@ -681,14 +681,16 @@ async def _handle_http(self, scope, receive, send): request.content.dispose() async def __call__(self, scope, receive, send): - if scope["type"] == "lifespan": - return await self._handle_lifespan(receive, send) + if scope["type"] == "http": + return await self._handle_http(scope, receive, send) if scope["type"] == "websocket": return await self._handle_websocket(scope, receive, send) - if scope["type"] == "http": - return await self._handle_http(scope, receive, send) + if scope["type"] == "lifespan": + return await self._handle_lifespan(receive, send) + + raise TypeError(f"Unsupported scope type: {scope['type']}") class MountMixin: diff --git a/blacksheep/server/normalization.py b/blacksheep/server/normalization.py index 54e45bb3..2bdb53ea 100644 --- a/blacksheep/server/normalization.py +++ b/blacksheep/server/normalization.py @@ -26,6 +26,7 @@ from blacksheep.normalization import copy_special_attributes from blacksheep.server import responses from blacksheep.server.routing import Route +from blacksheep.server.websocket import WebSocket from .bindings import ( Binder, @@ -301,6 +302,12 @@ def _get_parameter_binder( if isinstance(annotation, (str, ForwardRef)): # pragma: no cover raise UnsupportedForwardRefInSignatureError(original_annotation) + if annotation is Request: + return RequestBinder() + + if annotation is WebSocket: + return WebSocketBinder() + # 1. is the type annotation of BoundValue[T] type? if _is_bound_value_annotation(annotation): binder_type = get_binder_by_type(annotation) diff --git a/blacksheep/server/routing.py b/blacksheep/server/routing.py index 426708f4..a11b2705 100644 --- a/blacksheep/server/routing.py +++ b/blacksheep/server/routing.py @@ -86,7 +86,7 @@ def __init__(self, route: "Route", values: Optional[Dict[str, bytes]]): def _get_parameter_pattern_fragment( - parameter_name: bytes, value_pattern: bytes = br"[^\/]+" + parameter_name: bytes, value_pattern: bytes = rb"[^\/]+" ) -> bytes: return b"/(?P<" + parameter_name + b">" + value_pattern + b")" @@ -147,9 +147,9 @@ def _get_regex_for_pattern(self, pattern: bytes): ) if b"/*" in pattern: - pattern = _route_all_rx.sub(br"?(?P.*)", pattern) + pattern = _route_all_rx.sub(rb"?(?P.*)", pattern) else: - pattern = _route_all_rx.sub(br"(?P.*)", pattern) + pattern = _route_all_rx.sub(rb"(?P.*)", pattern) # support for < > patterns, e.g. /api/cats/ # but also: /api/cats/ or /api/cats/ for more @@ -167,7 +167,7 @@ def _get_regex_for_pattern(self, pattern: bytes): # route parameters defined using /:name syntax if b"/:" in pattern: - pattern = _route_param_rx.sub(br"/(?P<\1>[^\/]+)", pattern) + pattern = _route_param_rx.sub(rb"/(?P<\1>[^\/]+)", pattern) # NB: following code is just to throw user friendly errors; # regex would fail anyway, but with a more complex message @@ -243,7 +243,7 @@ def __repr__(self) -> str: @property def mustache_pattern(self) -> str: - return _route_param_rx.sub(br"/{\1}", self.pattern).decode("utf8") + return _route_param_rx.sub(rb"/{\1}", self.pattern).decode("utf8") @property def full_pattern(self) -> bytes: @@ -321,6 +321,9 @@ def add_connect(self, pattern: str, handler: Callable[..., Any]) -> None: def add_patch(self, pattern: str, handler: Callable[..., Any]) -> None: self.add(HTTPMethod.PATCH, pattern, handler) + def add_ws(self, pattern: str, handler: Callable[..., Any]) -> None: + self.add(HTTPMethod.GET, pattern, handler) + def head(self, pattern: Optional[str] = "/") -> Callable[..., Any]: return self.get_decorator(HTTPMethod.HEAD, pattern) diff --git a/blacksheep/server/websocket.py b/blacksheep/server/websocket.py index 8ffc6826..ded21f6f 100644 --- a/blacksheep/server/websocket.py +++ b/blacksheep/server/websocket.py @@ -1,12 +1,8 @@ -import json from enum import Enum from functools import wraps from typing import Any, AnyStr, Callable, List, MutableMapping, Optional - -class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000): - self.code = code +from blacksheep.plugins import json class WebSocketState(Enum): @@ -20,6 +16,40 @@ class MessageMode(str, Enum): BYTES = "bytes" +class WebSocketError(Exception): + """A base class for all web sockets errors.""" + + +class InvalidWebSocketStateError(WebSocketError): + def __init__( + self, + *, + party: str = "client", + current_state: WebSocketState, + expected_state: WebSocketState, + ): + super().__init__(party, current_state, expected_state) + self.party = party + self.current_state = current_state + self.expected_state = expected_state + + def __str__(self): + return ( + f"Invalid {self.party} state of the WebSocket connection. " + f"Expected state: {self.expected_state}. " + f"Current state: {self.current_state}." + ) + + +class WebSocketDisconnectError(WebSocketError): + def __init__(self, code: int = 1000): + super().__init__(code) + self.code = code + + def __str__(self): + return f"The client closed the connection. WebSocket close code: {self.code}." + + class WebSocket: def __init__( self, scope: MutableMapping[str, Any], receive: Callable, send: Callable @@ -34,7 +64,13 @@ def __init__( self.client_state = WebSocketState.CONNECTING self.application_state = WebSocketState.CONNECTING - async def connect(self) -> None: + async def _connect(self) -> None: + if self.client_state != WebSocketState.CONNECTING: + raise InvalidWebSocketStateError( + current_state=self.client_state, + expected_state=WebSocketState.CONNECTING, + ) + message = await self._receive() assert message["type"] == "websocket.connect" @@ -43,10 +79,9 @@ async def connect(self) -> None: async def accept( self, headers: Optional[List] = None, subprotocol: str = None ) -> None: - assert self.client_state == WebSocketState.CONNECTING - await self.connect() - headers = headers or [] + + await self._connect() self.application_state = WebSocketState.CONNECTED message = { @@ -58,7 +93,12 @@ async def accept( await self._send(message) async def receive(self) -> MutableMapping[str, AnyStr]: - assert self.application_state == WebSocketState.CONNECTED + if self.application_state != WebSocketState.CONNECTED: + raise InvalidWebSocketStateError( + party="application", + current_state=self.application_state, + expected_state=WebSocketState.CONNECTED, + ) message = await self._receive() assert message["type"] == "websocket.receive" @@ -74,7 +114,7 @@ async def receive_bytes(self) -> bytes: return message["bytes"] async def receive_json( - self, mode: str = MessageMode.TEXT + self, mode: MessageMode = MessageMode.TEXT ) -> MutableMapping[str, Any]: assert mode in list(MessageMode) message = await self.receive() @@ -85,18 +125,22 @@ async def receive_json( if mode == MessageMode.BYTES: return json.loads(message["bytes"].decode()) - async def send(self, message: MutableMapping[str, AnyStr]) -> None: - assert self.client_state == WebSocketState.CONNECTED + async def _send_message(self, message: MutableMapping[str, AnyStr]) -> None: + if self.client_state != WebSocketState.CONNECTED: + raise InvalidWebSocketStateError( + current_state=self.client_state, + expected_state=WebSocketState.CONNECTED, + ) await self._send(message) async def send_text(self, data: str) -> None: - await self.send({"type": "websocket.send", "text": data}) + await self._send_message({"type": "websocket.send", "text": data}) async def send_bytes(self, data: bytes) -> None: - await self.send({"type": "websocket.send", "bytes": data}) + await self._send_message({"type": "websocket.send", "bytes": data}) async def send_json( - self, data: MutableMapping[Any, Any], mode: str = MessageMode.TEXT + self, data: MutableMapping[Any, Any], mode: MessageMode = MessageMode.TEXT ): assert mode in list(MessageMode) text = json.dumps(data) @@ -114,7 +158,7 @@ async def disconnect(): if message["type"] == "websocket.disconnect": self.application_state = self.client_state = WebSocketState.DISCONNECTED - raise WebSocketDisconnect(message["code"]) + raise WebSocketDisconnectError(message["code"]) return message diff --git a/blacksheep/testing/__init__.py b/blacksheep/testing/__init__.py index e9f8cd08..f3896da3 100644 --- a/blacksheep/testing/__init__.py +++ b/blacksheep/testing/__init__.py @@ -1,5 +1,6 @@ from blacksheep.contents import FormContent, JSONContent, TextContent from blacksheep.testing.client import TestClient +from blacksheep.testing.messages import MockReceive, MockSend from blacksheep.testing.simulator import AbstractTestSimulator __all__ = [ @@ -8,4 +9,6 @@ "JSONContent", "TextContent", "FormContent", + "MockReceive", + "MockSend", ] diff --git a/blacksheep/testing/messages.py b/blacksheep/testing/messages.py index 24833297..bf65b01d 100644 --- a/blacksheep/testing/messages.py +++ b/blacksheep/testing/messages.py @@ -31,9 +31,12 @@ async def __call__(self): message = self.messages[self.index] except IndexError: message = b"" + else: + self.index += 1 + if isinstance(message, dict): return message - self.index += 1 + await asyncio.sleep(0) return { "body": message, diff --git a/tests/test_application.py b/tests/test_application.py index 5a5900c3..162e07cc 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -3267,6 +3267,26 @@ async def example_3(user: Identity): assert content == "User name: Charlie Brown" +@pytest.mark.asyncio +async def test_request_binding(app): + @app.router.get("/") + async def example(req: Request): + assert isinstance(req, Request) + return "Foo" + + await app.start() + + await app( + get_example_scope("GET", "/", []), + MockReceive(), + MockSend(), + ) + + content = await app.response.text() + assert app.response.status == 200 + assert content == "Foo" + + @pytest.mark.asyncio async def test_use_auth_raises_if_app_is_already_started(app): class MockAuthHandler(AuthenticationHandler): @@ -3690,3 +3710,15 @@ async def test_async_event_raises_for_fire_method(): with pytest.raises(TypeError): await event.fire() + + +@pytest.mark.asyncio +async def test_application_raises_for_unhandled_scope_type(app): + with pytest.raises(TypeError) as app_type_error: + await app( + {"type": "foo"}, + MockReceive(), + MockSend(), + ) + + assert str(app_type_error.value) == "Unsupported scope type: foo" diff --git a/tests/test_router.py b/tests/test_router.py index f1a372ab..e4f6ce90 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -422,6 +422,9 @@ def patch_foo(): def delete_foo(): ... + def ws(): + ... + router.add_trace("/", home_verbose) router.add_options("/", home_options) router.add_connect("/", home_connect) @@ -430,6 +433,7 @@ def delete_foo(): router.add_patch("/foo", patch_foo) router.add_post("/foo", create_foo) router.add_delete("/foo", delete_foo) + router.add_ws("/ws", ws) m = router.get_match(HTTPMethod.GET, b"/") assert m is not None @@ -466,6 +470,10 @@ def delete_foo(): assert m is not None assert m.handler is delete_foo + m = router.get_match(HTTPMethod.GET, b"/ws") + assert m is not None + assert m.handler is ws + def test_router_match_among_many_decorators(): router = Router() diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 00000000..73dc28ca --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,369 @@ +import pytest + +from blacksheep.server.websocket import ( + InvalidWebSocketStateError, + MessageMode, + WebSocket, + WebSocketDisconnectError, + WebSocketState, +) +from blacksheep.testing.messages import MockReceive, MockSend +from tests.utils.application import FakeApplication + + +@pytest.fixture +def example_scope(): + return {"type": "websocket"} + + +@pytest.mark.asyncio +async def test_connect_raises_if_not_connecting(example_scope): + ws = WebSocket( + example_scope, MockReceive([{"type": "websocket.connect"}]), MockSend() + ) + + ws.client_state = WebSocketState.CONNECTED + + with pytest.raises(InvalidWebSocketStateError) as error: + await ws.accept() + + assert error.value.current_state == WebSocketState.CONNECTED + assert error.value.expected_state == WebSocketState.CONNECTING + + assert str(error.value) == ( + f"Invalid {error.value.party} state of the WebSocket connection. " + f"Expected state: {error.value.expected_state}. " + f"Current state: {error.value.current_state}." + ) + + +@pytest.mark.asyncio +async def test_websocket_accept(example_scope): + """ + A websocket gets fully connected when the ASGI server sends a message of type + 'websocket.connect' and the server accepts the connection. + """ + ws = WebSocket( + example_scope, MockReceive([{"type": "websocket.connect"}]), MockSend() + ) + + await ws.accept() + + assert ws.client_state == WebSocketState.CONNECTED + assert ws.application_state == WebSocketState.CONNECTED + + +@pytest.mark.asyncio +async def test_websocket_receive_text(example_scope): + """ + A first message is received when the underlying ASGI server first sends a + 'websocket.connect' message, then a content message. + """ + ws = WebSocket( + example_scope, + MockReceive( + [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": "Lorem ipsum dolor sit amet"}, + ] + ), + MockSend(), + ) + + await ws.accept() + + message = await ws.receive_text() + assert message == "Lorem ipsum dolor sit amet" + + +@pytest.mark.asyncio +async def test_websocket_receive_bytes(example_scope): + """ + A first message is received when the underlying ASGI server first sends a + 'websocket.connect' message, then a content message. + """ + ws = WebSocket( + example_scope, + MockReceive( + [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "bytes": b"Lorem ipsum dolor sit amet"}, + ] + ), + MockSend(), + ) + + await ws.accept() + + message = await ws.receive_bytes() + assert message == b"Lorem ipsum dolor sit amet" + + +@pytest.mark.asyncio +async def test_websocket_receive_json(example_scope): + """ + A first message is received when the underlying ASGI server first sends a + 'websocket.connect' message, then a content message. + """ + ws = WebSocket( + example_scope, + MockReceive( + [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "text": '{"message": "Lorem ipsum"}'}, + ] + ), + MockSend(), + ) + + await ws.accept() + + message = await ws.receive_json() + assert message == {"message": "Lorem ipsum"} + + +@pytest.mark.asyncio +async def test_websocket_receive_json_from_bytes(example_scope): + """ + A first message is received when the underlying ASGI server first sends a + 'websocket.connect' message, then a content message. + """ + ws = WebSocket( + example_scope, + MockReceive( + [ + {"type": "websocket.connect"}, + {"type": "websocket.receive", "bytes": b'{"message": "Lorem ipsum"}'}, + ] + ), + MockSend(), + ) + + await ws.accept() + + message = await ws.receive_json(mode=MessageMode.BYTES) + assert message == {"message": "Lorem ipsum"} + + +@pytest.mark.asyncio +async def test_websocket_send_text(example_scope): + """ + A message is sent by the server to clients, by sending a message to the underlying + ASGI server with type "websocket.send" and a "text" or "bytes" property. + """ + mocked_send = MockSend() + ws = WebSocket( + example_scope, + MockReceive([{"type": "websocket.connect"}]), + mocked_send, + ) + + await ws.accept() + + await ws.send_text("Lorem ipsum dolor sit amet") + + assert len(mocked_send.messages) > 0 + message = mocked_send.messages[-1] + + assert message.get("text") == "Lorem ipsum dolor sit amet" + assert message.get("type") == "websocket.send" + + +@pytest.mark.asyncio +async def test_websocket_send_bytes(example_scope): + """ + A message is sent by the server to clients, by sending a message to the underlying + ASGI server with type "websocket.send" and a "text" or "bytes" property. + """ + mocked_send = MockSend() + ws = WebSocket( + example_scope, + MockReceive([{"type": "websocket.connect"}]), + mocked_send, + ) + + await ws.accept() + + await ws.send_bytes(b"Lorem ipsum dolor sit amet") + + assert len(mocked_send.messages) > 0 + message = mocked_send.messages[-1] + + assert message.get("bytes") == b"Lorem ipsum dolor sit amet" + assert message.get("type") == "websocket.send" + + +@pytest.mark.asyncio +async def test_websocket_send_json(example_scope): + """ + A message is sent by the server to clients, by sending a message to the underlying + ASGI server with type "websocket.send" and a "text" or "bytes" property. + """ + mocked_send = MockSend() + ws = WebSocket( + example_scope, + MockReceive([{"type": "websocket.connect"}]), + mocked_send, + ) + + await ws.accept() + + await ws.send_json({"message": "Lorem ipsum dolor sit amet"}) + + assert len(mocked_send.messages) > 0 + message = mocked_send.messages[-1] + + assert message.get("text") == '{"message":"Lorem ipsum dolor sit amet"}' + assert message.get("type") == "websocket.send" + + +@pytest.mark.asyncio +async def test_websocket_send_json_as_bytes(example_scope): + """ + A message is sent by the server to clients, by sending a message to the underlying + ASGI server with type "websocket.send" and a "text" or "bytes" property. + """ + mocked_send = MockSend() + ws = WebSocket( + example_scope, + MockReceive([{"type": "websocket.connect"}]), + mocked_send, + ) + + await ws.accept() + + await ws.send_json({"message": "Lorem ipsum dolor sit amet"}, MessageMode.BYTES) + + assert len(mocked_send.messages) > 0 + message = mocked_send.messages[-1] + + assert message.get("bytes") == b'{"message":"Lorem ipsum dolor sit amet"}' + assert message.get("type") == "websocket.send" + + +@pytest.mark.asyncio +async def test_connecting_websocket_raises_for_receive(example_scope): + ws = WebSocket(example_scope, MockReceive(), MockSend()) + + assert ws.client_state == WebSocketState.CONNECTING + + with pytest.raises(InvalidWebSocketStateError) as error: + await ws.receive() + + assert error.value.current_state == WebSocketState.CONNECTING + assert error.value.expected_state == WebSocketState.CONNECTED + + assert str(error.value) == ( + f"Invalid {error.value.party} state of the WebSocket connection. " + f"Expected state: {error.value.expected_state}. " + f"Current state: {error.value.current_state}." + ) + + +@pytest.mark.asyncio +async def test_connecting_websocket_raises_for_send(example_scope): + ws = WebSocket(example_scope, MockReceive(), MockSend()) + + assert ws.client_state == WebSocketState.CONNECTING + + with pytest.raises(InvalidWebSocketStateError) as error: + await ws.send_text("Error") + + assert error.value.current_state == WebSocketState.CONNECTING + assert error.value.expected_state == WebSocketState.CONNECTED + + assert str(error.value) == ( + f"Invalid {error.value.party} state of the WebSocket connection. " + f"Expected state: {error.value.expected_state}. " + f"Current state: {error.value.current_state}." + ) + + +@pytest.mark.asyncio +async def test_websocket_raises_for_receive_when_closed_by_client(example_scope): + """ + If the underlying ASGI server sends a message of type "websocket.disconnect", + it means that the client disconnected. + """ + ws = WebSocket( + example_scope, + MockReceive( + [ + {"type": "websocket.connect"}, + {"type": "websocket.disconnect", "code": 500}, + ] + ), + MockSend(), + ) + + await ws.accept() + + with pytest.raises(WebSocketDisconnectError) as error: + await ws.receive() + + assert error.value.code == 500 + + assert str(error.value) == ( + f"The client closed the connection. WebSocket close code: {error.value.code}." + ) + + +@pytest.mark.asyncio +async def test_application_handling_websocket_request_not_found(): + """ + If a client tries to open a WebSocket connection on an endpoint that is not handled, + the application returns an ASGI message to close the connection. + """ + app = FakeApplication() + mock_send = MockSend() + mock_receive = MockReceive() + + await app({"type": "websocket", "path": "/ws"}, mock_receive, mock_send) + + close_message = mock_send.messages[0] + assert close_message == {"type": "websocket.close", "code": 1000} + + +@pytest.mark.asyncio +async def test_application_handling_proper_websocket_request(): + """ + If a client tries to open a WebSocket connection on an endpoint that is handled, + the application websocket handler is called. + """ + app = FakeApplication() + mock_send = MockSend() + mock_receive = MockReceive([{"type": "websocket.connect"}]) + + @app.router.ws("/ws/{foo}") + async def websocket_handler(websocket, foo): + assert isinstance(websocket, WebSocket) + assert websocket.application_state == WebSocketState.CONNECTING + assert websocket.client_state == WebSocketState.CONNECTING + + assert foo == "001" + + await websocket.accept() + + await app.start() + await app({"type": "websocket", "path": "/ws/001"}, mock_receive, mock_send) + + +@pytest.mark.asyncio +async def test_application_websocket_binding_by_type_annotation(): + """ + This test verifies that the WebSocketBinder can bind a WebSocket by type annotation. + """ + app = FakeApplication() + mock_send = MockSend() + mock_receive = MockReceive([{"type": "websocket.connect"}]) + + @app.router.ws("/ws") + async def websocket_handler(my_ws: WebSocket): + assert isinstance(my_ws, WebSocket) + assert my_ws.application_state == WebSocketState.CONNECTING + assert my_ws.client_state == WebSocketState.CONNECTING + + await my_ws.accept() + + await app.start() + await app({"type": "websocket", "path": "/ws"}, mock_receive, mock_send) diff --git a/wsdemo/app.py b/wsdemo/app.py deleted file mode 100644 index 6936dc5f..00000000 --- a/wsdemo/app.py +++ /dev/null @@ -1,25 +0,0 @@ -import blacksheep -import pathlib - -from blacksheep.server.responses import redirect -from blacksheep.server.websocket import WebSocket - -STATIC_PATH = pathlib.Path(__file__).parent / 'static' - -app = blacksheep.Application() -app.serve_files(STATIC_PATH, root_path='/static') - - -@app.router.ws('/ws/{client_id}') -async def ws(websocket: WebSocket, client_id: str): - await websocket.accept() - print(f'client_id={client_id}') - - while True: - msg = await websocket.receive_text() - await websocket.send_text(msg) - - -@app.router.get('/') -def r(): - return redirect('/static/chat.html') diff --git a/wsdemo/static/chat.html b/wsdemo/static/chat.html deleted file mode 100644 index ba8c0f7f..00000000 --- a/wsdemo/static/chat.html +++ /dev/null @@ -1,77 +0,0 @@ - - - - - BlackSheep WebSocket Test - - - -
-

Your client ID is {{ CLIENT_ID }}

-

Status: {{ status }}

-
    -
  • {{ message }}
  • -
-
- - - - -
-
- - - \ No newline at end of file