Skip to content

Commit

Permalink
WebSocket support follow-up (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
Klavionik authored Feb 1, 2022
1 parent 4554d97 commit 8e53014
Show file tree
Hide file tree
Showing 13 changed files with 501 additions and 129 deletions.
1 change: 1 addition & 0 deletions blacksheep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
from .headers import Header, Headers
from .messages import Request, Response
from .server import Application, Route, Router
from .server.websocket import WebSocket, WebSocketDisconnectError
from .url import URL, InvalidURL
2 changes: 2 additions & 0 deletions blacksheep/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ from .messages import Response as Response
from .server import Application as Application
from .server import Route as Route
from .server import Router as Router
from .server.websocket import WebSocket as WebSocket
from .server.websocket import WebSocketDisconnectError as WebSocketDisconnectError
from .url import URL as URL
from .url import InvalidURL as InvalidURL
10 changes: 6 additions & 4 deletions blacksheep/server/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,16 @@ async def _handle_http(self, scope, receive, send):
request.content.dispose()

async def __call__(self, scope, receive, send):
if scope["type"] == "lifespan":
return await self._handle_lifespan(receive, send)
if scope["type"] == "http":
return await self._handle_http(scope, receive, send)

if scope["type"] == "websocket":
return await self._handle_websocket(scope, receive, send)

if scope["type"] == "http":
return await self._handle_http(scope, receive, send)
if scope["type"] == "lifespan":
return await self._handle_lifespan(receive, send)

raise TypeError(f"Unsupported scope type: {scope['type']}")


class MountMixin:
Expand Down
7 changes: 7 additions & 0 deletions blacksheep/server/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from blacksheep.normalization import copy_special_attributes
from blacksheep.server import responses
from blacksheep.server.routing import Route
from blacksheep.server.websocket import WebSocket

from .bindings import (
Binder,
Expand Down Expand Up @@ -301,6 +302,12 @@ def _get_parameter_binder(
if isinstance(annotation, (str, ForwardRef)): # pragma: no cover
raise UnsupportedForwardRefInSignatureError(original_annotation)

if annotation is Request:
return RequestBinder()

if annotation is WebSocket:
return WebSocketBinder()

# 1. is the type annotation of BoundValue[T] type?
if _is_bound_value_annotation(annotation):
binder_type = get_binder_by_type(annotation)
Expand Down
13 changes: 8 additions & 5 deletions blacksheep/server/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, route: "Route", values: Optional[Dict[str, bytes]]):


def _get_parameter_pattern_fragment(
parameter_name: bytes, value_pattern: bytes = br"[^\/]+"
parameter_name: bytes, value_pattern: bytes = rb"[^\/]+"
) -> bytes:
return b"/(?P<" + parameter_name + b">" + value_pattern + b")"

Expand Down Expand Up @@ -147,9 +147,9 @@ def _get_regex_for_pattern(self, pattern: bytes):
)

if b"/*" in pattern:
pattern = _route_all_rx.sub(br"?(?P<tail>.*)", pattern)
pattern = _route_all_rx.sub(rb"?(?P<tail>.*)", pattern)
else:
pattern = _route_all_rx.sub(br"(?P<tail>.*)", pattern)
pattern = _route_all_rx.sub(rb"(?P<tail>.*)", pattern)

# support for < > patterns, e.g. /api/cats/<cat_id>
# but also: /api/cats/<int:cat_id> or /api/cats/<uuid:cat_id> for more
Expand All @@ -167,7 +167,7 @@ def _get_regex_for_pattern(self, pattern: bytes):

# route parameters defined using /:name syntax
if b"/:" in pattern:
pattern = _route_param_rx.sub(br"/(?P<\1>[^\/]+)", pattern)
pattern = _route_param_rx.sub(rb"/(?P<\1>[^\/]+)", pattern)

# NB: following code is just to throw user friendly errors;
# regex would fail anyway, but with a more complex message
Expand Down Expand Up @@ -243,7 +243,7 @@ def __repr__(self) -> str:

@property
def mustache_pattern(self) -> str:
return _route_param_rx.sub(br"/{\1}", self.pattern).decode("utf8")
return _route_param_rx.sub(rb"/{\1}", self.pattern).decode("utf8")

@property
def full_pattern(self) -> bytes:
Expand Down Expand Up @@ -321,6 +321,9 @@ def add_connect(self, pattern: str, handler: Callable[..., Any]) -> None:
def add_patch(self, pattern: str, handler: Callable[..., Any]) -> None:
self.add(HTTPMethod.PATCH, pattern, handler)

def add_ws(self, pattern: str, handler: Callable[..., Any]) -> None:
self.add(HTTPMethod.GET, pattern, handler)

def head(self, pattern: Optional[str] = "/") -> Callable[..., Any]:
return self.get_decorator(HTTPMethod.HEAD, pattern)

Expand Down
78 changes: 61 additions & 17 deletions blacksheep/server/websocket.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import json
from enum import Enum
from functools import wraps
from typing import Any, AnyStr, Callable, List, MutableMapping, Optional


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000):
self.code = code
from blacksheep.plugins import json


class WebSocketState(Enum):
Expand All @@ -20,6 +16,40 @@ class MessageMode(str, Enum):
BYTES = "bytes"


class WebSocketError(Exception):
"""A base class for all web sockets errors."""


class InvalidWebSocketStateError(WebSocketError):
def __init__(
self,
*,
party: str = "client",
current_state: WebSocketState,
expected_state: WebSocketState,
):
super().__init__(party, current_state, expected_state)
self.party = party
self.current_state = current_state
self.expected_state = expected_state

def __str__(self):
return (
f"Invalid {self.party} state of the WebSocket connection. "
f"Expected state: {self.expected_state}. "
f"Current state: {self.current_state}."
)


class WebSocketDisconnectError(WebSocketError):
def __init__(self, code: int = 1000):
super().__init__(code)
self.code = code

def __str__(self):
return f"The client closed the connection. WebSocket close code: {self.code}."


class WebSocket:
def __init__(
self, scope: MutableMapping[str, Any], receive: Callable, send: Callable
Expand All @@ -34,7 +64,13 @@ def __init__(
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING

async def connect(self) -> None:
async def _connect(self) -> None:
if self.client_state != WebSocketState.CONNECTING:
raise InvalidWebSocketStateError(
current_state=self.client_state,
expected_state=WebSocketState.CONNECTING,
)

message = await self._receive()
assert message["type"] == "websocket.connect"

Expand All @@ -43,10 +79,9 @@ async def connect(self) -> None:
async def accept(
self, headers: Optional[List] = None, subprotocol: str = None
) -> None:
assert self.client_state == WebSocketState.CONNECTING
await self.connect()

headers = headers or []

await self._connect()
self.application_state = WebSocketState.CONNECTED

message = {
Expand All @@ -58,7 +93,12 @@ async def accept(
await self._send(message)

async def receive(self) -> MutableMapping[str, AnyStr]:
assert self.application_state == WebSocketState.CONNECTED
if self.application_state != WebSocketState.CONNECTED:
raise InvalidWebSocketStateError(
party="application",
current_state=self.application_state,
expected_state=WebSocketState.CONNECTED,
)

message = await self._receive()
assert message["type"] == "websocket.receive"
Expand All @@ -74,7 +114,7 @@ async def receive_bytes(self) -> bytes:
return message["bytes"]

async def receive_json(
self, mode: str = MessageMode.TEXT
self, mode: MessageMode = MessageMode.TEXT
) -> MutableMapping[str, Any]:
assert mode in list(MessageMode)
message = await self.receive()
Expand All @@ -85,18 +125,22 @@ async def receive_json(
if mode == MessageMode.BYTES:
return json.loads(message["bytes"].decode())

async def send(self, message: MutableMapping[str, AnyStr]) -> None:
assert self.client_state == WebSocketState.CONNECTED
async def _send_message(self, message: MutableMapping[str, AnyStr]) -> None:
if self.client_state != WebSocketState.CONNECTED:
raise InvalidWebSocketStateError(
current_state=self.client_state,
expected_state=WebSocketState.CONNECTED,
)
await self._send(message)

async def send_text(self, data: str) -> None:
await self.send({"type": "websocket.send", "text": data})
await self._send_message({"type": "websocket.send", "text": data})

async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})
await self._send_message({"type": "websocket.send", "bytes": data})

async def send_json(
self, data: MutableMapping[Any, Any], mode: str = MessageMode.TEXT
self, data: MutableMapping[Any, Any], mode: MessageMode = MessageMode.TEXT
):
assert mode in list(MessageMode)
text = json.dumps(data)
Expand All @@ -114,7 +158,7 @@ async def disconnect():

if message["type"] == "websocket.disconnect":
self.application_state = self.client_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(message["code"])
raise WebSocketDisconnectError(message["code"])

return message

Expand Down
3 changes: 3 additions & 0 deletions blacksheep/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from blacksheep.contents import FormContent, JSONContent, TextContent
from blacksheep.testing.client import TestClient
from blacksheep.testing.messages import MockReceive, MockSend
from blacksheep.testing.simulator import AbstractTestSimulator

__all__ = [
Expand All @@ -8,4 +9,6 @@
"JSONContent",
"TextContent",
"FormContent",
"MockReceive",
"MockSend",
]
5 changes: 4 additions & 1 deletion blacksheep/testing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@ async def __call__(self):
message = self.messages[self.index]
except IndexError:
message = b""
else:
self.index += 1

if isinstance(message, dict):
return message
self.index += 1

await asyncio.sleep(0)
return {
"body": message,
Expand Down
32 changes: 32 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3267,6 +3267,26 @@ async def example_3(user: Identity):
assert content == "User name: Charlie Brown"


@pytest.mark.asyncio
async def test_request_binding(app):
@app.router.get("/")
async def example(req: Request):
assert isinstance(req, Request)
return "Foo"

await app.start()

await app(
get_example_scope("GET", "/", []),
MockReceive(),
MockSend(),
)

content = await app.response.text()
assert app.response.status == 200
assert content == "Foo"


@pytest.mark.asyncio
async def test_use_auth_raises_if_app_is_already_started(app):
class MockAuthHandler(AuthenticationHandler):
Expand Down Expand Up @@ -3690,3 +3710,15 @@ async def test_async_event_raises_for_fire_method():

with pytest.raises(TypeError):
await event.fire()


@pytest.mark.asyncio
async def test_application_raises_for_unhandled_scope_type(app):
with pytest.raises(TypeError) as app_type_error:
await app(
{"type": "foo"},
MockReceive(),
MockSend(),
)

assert str(app_type_error.value) == "Unsupported scope type: foo"
8 changes: 8 additions & 0 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,9 @@ def patch_foo():
def delete_foo():
...

def ws():
...

router.add_trace("/", home_verbose)
router.add_options("/", home_options)
router.add_connect("/", home_connect)
Expand All @@ -430,6 +433,7 @@ def delete_foo():
router.add_patch("/foo", patch_foo)
router.add_post("/foo", create_foo)
router.add_delete("/foo", delete_foo)
router.add_ws("/ws", ws)

m = router.get_match(HTTPMethod.GET, b"/")
assert m is not None
Expand Down Expand Up @@ -466,6 +470,10 @@ def delete_foo():
assert m is not None
assert m.handler is delete_foo

m = router.get_match(HTTPMethod.GET, b"/ws")
assert m is not None
assert m.handler is ws


def test_router_match_among_many_decorators():
router = Router()
Expand Down
Loading

0 comments on commit 8e53014

Please sign in to comment.