From a495dd1c5d8af5c6e0be40df528a5709840dcfbd Mon Sep 17 00:00:00 2001 From: Kamil Niski Date: Sat, 20 Apr 2024 18:22:23 +0200 Subject: [PATCH] Update `EasyAuditMiddleware` to support async context Replaced standard threading with 'asgiref.local' in EasyAuditMiddleware. Also, made EasyAuditMiddleware extend Django's MiddlewareMixin to automatically handle sync and async execution modes. Github issue: #291 Update EasyAuditMiddleware for async compatibility The EasyAuditMiddleware class has been updated to be compatible with both synchronous and asynchronous processes. The class initialization now checks if the 'get_response' function is a coroutine and sets the class' async capability accordingly. An '__acall__' method has been introduced to handle asynchronous calls. Refactor sync_to_async calls in test_main.py Replaced sync_to_async keyword usage throughout the test_main.py file with direct asyncio calls, specifically in the ASGIRequestEvent tests. This change both simplifies the code and reduces reliance on the sync_to_async function. Add test for async capability in middleware A new test has been added to verify the async capability of the EasyAuditMiddleware. This ensures that the Django logger does not emit a debug message for asynchronous handler adaptation for the middleware. This is aligned with the Django documentation recommendations on async views. --- easyaudit/middleware/easyaudit.py | 41 ++++++++++++++++++------------- tests/test_main.py | 32 ++++++++++++++++++------ 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/easyaudit/middleware/easyaudit.py b/easyaudit/middleware/easyaudit.py index c23a9964..62d93f4f 100644 --- a/easyaudit/middleware/easyaudit.py +++ b/easyaudit/middleware/easyaudit.py @@ -1,6 +1,11 @@ # makes easy-audit thread-safe import contextlib -from threading import local +from typing import Callable + +from asgiref.local import Local +from asgiref.sync import iscoroutinefunction, markcoroutinefunction +from django.http.request import HttpRequest +from django.http.response import HttpResponse class MockRequest: @@ -10,7 +15,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -_thread_locals = local() +_thread_locals = Local() def get_current_request(): @@ -38,30 +43,32 @@ def clear_request(): class EasyAuditMiddleware: - """Makes request available to this app signals.""" + async_capable = True + sync_capable = True - def __init__(self, get_response=None): + def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.get_response = get_response + if iscoroutinefunction(self.get_response): + markcoroutinefunction(self) - def __call__(self, request): - _thread_locals.request = ( - request # seems redundant w/process_request, but keeping in for now. - ) - if hasattr(self, "process_request"): - response = self.process_request(request) - response = response or self.get_response(request) - if hasattr(self, "process_response"): - response = self.process_response(request, response) - return response + def __call__(self, request: HttpRequest) -> HttpResponse: + if iscoroutinefunction(self): + return self.__acall__(request) - def process_request(self, request): _thread_locals.request = request + response = self.get_response(request) - def process_response(self, request, response): with contextlib.suppress(AttributeError): del _thread_locals.request + return response - def process_exception(self, request, exception): + async def __acall__(self, request: HttpRequest) -> HttpResponse: + _thread_locals.request = request + + response = await self.get_response(request) + with contextlib.suppress(AttributeError): del _thread_locals.request + + return response diff --git a/tests/test_main.py b/tests/test_main.py index b938fd2d..aa744a3e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ import json +import logging import pytest from asgiref.sync import sync_to_async @@ -297,16 +298,16 @@ def test_middleware_logged_in_user_in_request(self, user, client): class TestASGIRequestEvent: async def test_login(self, async_user, async_client, username, password): await sync_to_async(async_client.login)(username=username, password=password) - assert await sync_to_async(RequestEvent.objects.count)() == 0 + assert await RequestEvent.objects.acount() == 0 resp = await async_client.get(reverse("test_app:index")) assert resp.status_code == 200 - qs = await sync_to_async(RequestEvent.objects.filter)(user=async_user) - assert await sync_to_async(qs.exists)() + qs = RequestEvent.objects.filter(user=async_user) + assert await qs.aexists() async def test_remote_addr_default(self, async_client): - assert await sync_to_async(RequestEvent.objects.count)() == 0 + assert await RequestEvent.objects.acount() == 0 resp = await async_client.request( method="GET", @@ -318,11 +319,11 @@ async def test_remote_addr_default(self, async_client): ) assert resp.status_code == 200 - event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index")) + event = await RequestEvent.objects.aget(url=reverse("test_app:index")) assert event.remote_ip == "127.0.0.1" async def test_remote_addr_another(self, async_client): - assert await sync_to_async(RequestEvent.objects.count)() == 0 + assert await RequestEvent.objects.acount() == 0 resp = await async_client.request( method="GET", @@ -335,9 +336,26 @@ async def test_remote_addr_another(self, async_client): ) assert resp.status_code == 200 - event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index")) + event = await RequestEvent.objects.aget(url=reverse("test_app:index")) assert event.remote_ip == "10.0.0.1" + async def test_middleware_is_async_capable(self, async_client, caplog, settings): + """Test for async capability of EasyAuditMiddleware. + + If the EasyAuditMiddleware is async capable Django `django.request` logger + will not emit debug message 'Asynchronous handler adapted for middleware …' + + See: https://docs.djangoproject.com/en/5.0/topics/async/#async-views + """ + unwanted_log_message = ( + "Asynchronous handler adapted for middleware " + "easyaudit.middleware.easyaudit.EasyAuditMiddleware" + ) + settings.DEBUG = True + with caplog.at_level(logging.DEBUG, "django.request"): + await async_client.get(reverse("test_app:index")) + assert unwanted_log_message not in caplog.text + @pytest.mark.django_db class TestWSGIRequestEvent: