Skip to content

Commit

Permalink
Detect blocking calls in coroutines using BlockBuster
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Feb 1, 2025
1 parent cbddcaf commit 34fc433
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ venv*/
.python-version
build/
dist/
.idea/
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ types-PyYAML==6.0.12.20241230
types-dataclasses==0.6.6
pytest==8.3.4
trio==0.28.0
blockbuster==1.5.13

# Documentation
black==24.10.0
Expand Down
3 changes: 2 additions & 1 deletion starlette/datastructures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
Expand Down Expand Up @@ -438,7 +439,7 @@ async def write(self, data: bytes) -> None:
if self.size is not None:
self.size += len(data)

if self._in_memory:
if self._in_memory and self.file.tell() + len(data) <= getattr(self.file, "_max_size", sys.maxsize):
self.file.write(data)
else:
await run_in_threadpool(self.file.write, data)
Expand Down
4 changes: 3 additions & 1 deletion starlette/middleware/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import traceback
import typing

import anyio

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
Expand Down Expand Up @@ -167,7 +169,7 @@ async def _send(message: Message) -> None:
request = Request(scope)
if self.debug:
# In debug mode, return traceback responses.
response = self.debug_response(request, exc)
response = await anyio.to_thread.run_sync(self.debug_response, request, exc)
elif self.handler is None:
# Use our default 500 error handler.
response = self.error_response(request, exc)
Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ async def receive() -> Message:
await response_complete.wait()
return {"type": "http.disconnect"}

body = request.read()
body = await anyio.to_thread.run_sync(request.read)
if isinstance(body, str):
body_bytes: bytes = body.encode("utf-8") # pragma: no cover
elif body is None:
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from __future__ import annotations

import functools
from collections.abc import Iterator
from typing import Any, Literal

import pytest
from blockbuster import blockbuster_ctx

from starlette.testclient import TestClient
from tests.types import TestClientFactory


@pytest.fixture(autouse=True)
def blockbuster() -> Iterator[None]:
with blockbuster_ctx("starlette") as bb:
bb.functions["os.stat"].can_block_in("/mimetypes.py", "init")
yield


@pytest.fixture
def test_client_factory(
anyio_backend_name: Literal["asyncio", "trio"],
Expand Down
4 changes: 2 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ async def cancel_on_disconnect(
# before we start returning the body
await task_group.start(cancel_on_disconnect)

# A timeout is set for 0.1 second in order to ensure that
# A timeout is set for 0.2 second in order to ensure that
# we never deadlock the test run in an infinite loop
with anyio.move_on_after(0.1):
with anyio.move_on_after(0.2):
while True:
await send(
{
Expand Down
12 changes: 6 additions & 6 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request: Request) -> Response:
def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")

app = Starlette(debug=True, routes=[Route("/", endpoint=homepage)])
Expand All @@ -40,7 +40,7 @@ def test_calls_context_processors(tmp_path: Path, test_client_factory: TestClien
path = tmp_path / "index.html"
path.write_text("<html>Hello {{ username }}</html>")

async def homepage(request: Request) -> Response:
def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")

def hello_world_processor(request: Request) -> dict[str, str]:
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_template_with_middleware(tmpdir: Path, test_client_factory: TestClientF
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request: Request) -> Response:
def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")

class CustomMiddleware(BaseHTTPMiddleware):
Expand All @@ -96,15 +96,15 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory: TestCli
template_a = dir_a / "template_a.html"
template_a.write_text("<html><a href='{{ url_for('page_a') }}'></a> a</html>")

async def page_a(request: Request) -> Response:
def page_a(request: Request) -> Response:
return templates.TemplateResponse(request, "template_a.html")

dir_b = tmp_path.resolve() / "b"
dir_b.mkdir()
template_b = dir_b / "template_b.html"
template_b.write_text("<html><a href='{{ url_for('page_b') }}'></a> b</html>")

async def page_b(request: Request) -> Response:
def page_b(request: Request) -> Response:
return templates.TemplateResponse(request, "template_b.html")

app = Starlette(
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_templates_with_environment(tmpdir: Path, test_client_factory: TestClien
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request: Request) -> Response:
def homepage(request: Request) -> Response:
return templates.TemplateResponse(request, "index.html")

env = jinja2.Environment(loader=jinja2.FileSystemLoader(str(tmpdir)))
Expand Down

0 comments on commit 34fc433

Please sign in to comment.