Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update EasyAuditMiddleware to support async context #292

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions easyaudit/middleware/easyaudit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# makes easy-audit thread-safe
import contextlib
from threading import local
from typing import Callable

from asgiref.local import Local
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.http.request import HttpRequest
from django.http.response import HttpResponse


class MockRequest:
Expand All @@ -10,7 +15,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


_thread_locals = local()
_thread_locals = Local()


def get_current_request():
Expand Down Expand Up @@ -38,30 +43,32 @@ def clear_request():


class EasyAuditMiddleware:
"""Makes request available to this app signals."""
async_capable = True
sync_capable = True

def __init__(self, get_response=None):
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
if iscoroutinefunction(self.get_response):
markcoroutinefunction(self)

def __call__(self, request):
_thread_locals.request = (
request # seems redundant w/process_request, but keeping in for now.
)
if hasattr(self, "process_request"):
response = self.process_request(request)
response = response or self.get_response(request)
if hasattr(self, "process_response"):
response = self.process_response(request, response)
return response
def __call__(self, request: HttpRequest) -> HttpResponse:
if iscoroutinefunction(self):
return self.__acall__(request)

def process_request(self, request):
_thread_locals.request = request
response = self.get_response(request)

def process_response(self, request, response):
with contextlib.suppress(AttributeError):
del _thread_locals.request

return response

def process_exception(self, request, exception):
async def __acall__(self, request: HttpRequest) -> HttpResponse:
_thread_locals.request = request

response = await self.get_response(request)

with contextlib.suppress(AttributeError):
del _thread_locals.request

return response
32 changes: 25 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging

import pytest
from asgiref.sync import sync_to_async
Expand Down Expand Up @@ -297,16 +298,16 @@ def test_middleware_logged_in_user_in_request(self, user, client):
class TestASGIRequestEvent:
async def test_login(self, async_user, async_client, username, password):
await sync_to_async(async_client.login)(username=username, password=password)
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.get(reverse("test_app:index"))
assert resp.status_code == 200

qs = await sync_to_async(RequestEvent.objects.filter)(user=async_user)
assert await sync_to_async(qs.exists)()
qs = RequestEvent.objects.filter(user=async_user)
assert await qs.aexists()

async def test_remote_addr_default(self, async_client):
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.request(
method="GET",
Expand All @@ -318,11 +319,11 @@ async def test_remote_addr_default(self, async_client):
)
assert resp.status_code == 200

event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index"))
event = await RequestEvent.objects.aget(url=reverse("test_app:index"))
assert event.remote_ip == "127.0.0.1"

async def test_remote_addr_another(self, async_client):
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.request(
method="GET",
Expand All @@ -335,9 +336,26 @@ async def test_remote_addr_another(self, async_client):
)
assert resp.status_code == 200

event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index"))
event = await RequestEvent.objects.aget(url=reverse("test_app:index"))
assert event.remote_ip == "10.0.0.1"

async def test_middleware_is_async_capable(self, async_client, caplog, settings):
"""Test for async capability of EasyAuditMiddleware.

If the EasyAuditMiddleware is async capable Django `django.request` logger
will not emit debug message 'Asynchronous handler adapted for middleware …'

See: https://docs.djangoproject.com/en/5.0/topics/async/#async-views
"""
unwanted_log_message = (
"Asynchronous handler adapted for middleware "
"easyaudit.middleware.easyaudit.EasyAuditMiddleware"
)
settings.DEBUG = True
with caplog.at_level(logging.DEBUG, "django.request"):
await async_client.get(reverse("test_app:index"))
assert unwanted_log_message not in caplog.text


@pytest.mark.django_db
class TestWSGIRequestEvent:
Expand Down
Loading