diff --git a/channels/generic/http.py b/channels/generic/http.py index d067717a2..2deee2518 100644 --- a/channels/generic/http.py +++ b/channels/generic/http.py @@ -12,6 +12,8 @@ class AsyncHttpConsumer(AsyncConsumer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.body = [] + self.headers_sent = False + self.more_body = True async def send_headers(self, *, status=200, headers=None): """ @@ -27,6 +29,8 @@ async def send_headers(self, *, status=200, headers=None): elif isinstance(headers, dict): headers = list(headers.items()) + self.headers_sent = True + await self.send( {"type": "http.response.start", "status": status, "headers": headers} ) @@ -40,9 +44,12 @@ async def send_body(self, body, *, more_body=False): the channel will be ignored. """ assert isinstance(body, bytes), "Body is not bytes" - await self.send( - {"type": "http.response.body", "body": body, "more_body": more_body} - ) + + if self.more_body: + self.more_body = more_body + await self.send( + {"type": "http.response.body", "body": body, "more_body": more_body} + ) async def send_response(self, status, body, **kwargs): """ @@ -70,6 +77,19 @@ async def disconnect(self): """ pass + async def close(self, body=b"", status=500, headers=None): + """ + Closes the HTTP response from the server end. + """ + if not self.more_body: + # HTTP Response is already closed, nothing to do. + return + + if not self.headers_sent: + await self.send_headers(status=status, headers=headers) + + await self.send_body(body) + async def http_request(self, message): """ Async entrypoint - concatenates body fragments and hands off control @@ -80,10 +100,15 @@ async def http_request(self, message): if not message.get("more_body"): try: await self.handle(b"".join(self.body)) - finally: - await self.disconnect() - raise StopConsumer() + except StopConsumer: + await self.close(status=200) + except: + # TODO This is just a patch, after bubbling up the exception no body is calling http_disconnect. + await self.close(body=b"Internal Server Error", status=500) + raise + # TODO This should be called by daphne whenever any exception bubbles up (As it does with websockets) + # IMHO this should be parallel to websocket_disconnect, removing the consumer from potential channel_layer groups. async def http_disconnect(self, message): """ Let the user do their cleanup and close the consumer. diff --git a/channels/testing/http.py b/channels/testing/http.py index 6b1514ca7..760d522ca 100644 --- a/channels/testing/http.py +++ b/channels/testing/http.py @@ -31,7 +31,7 @@ def __init__(self, application, method, path, body=b"", headers=None): async def get_response(self, timeout=1): """ - Get the application's response. Returns a dict with keys of + Get the application's full response. Returns a dict with keys of "body", "headers" and "status". """ # If we've not sent the request yet, do so @@ -54,3 +54,36 @@ async def get_response(self, timeout=1): del response_start["type"] response_start.setdefault("headers", []) return response_start + + async def send_request(self): + """ + Sends the request to the application without then waiting for + headers or any response. + """ + if not self.sent_request: + self.sent_request = True + await self.send_input({"type": "http.request", "body": self.body}) + + async def get_response_start(self, timeout=1): + """ + Gets the start of the response (its headers and status code) + """ + response_start = await self.receive_output(timeout) + assert response_start["type"] == "http.response.start" + + # Return structured info + del response_start["type"] + response_start.setdefault("headers", []) + return response_start + + async def get_body_chunk(self, timeout=1): + """ + Gets one chunk of body. + """ + chunk = await self.receive_output(timeout) + assert chunk["type"] == "http.response.body" + assert isinstance(chunk["body"], bytes) + if not chunk.get("more_body", False): + await self.wait(timeout) + + return chunk["body"] diff --git a/tests/test_generic_http.py b/tests/test_generic_http.py index 9150eed4a..430460bee 100644 --- a/tests/test_generic_http.py +++ b/tests/test_generic_http.py @@ -1,8 +1,10 @@ import json import pytest +from django.test import override_settings from channels.generic.http import AsyncHttpConsumer +from channels.layers import get_channel_layer from channels.testing import HttpCommunicator @@ -32,3 +34,64 @@ async def handle(self, body): assert response["body"] == b'{"value": 42}' assert response["status"] == 200 assert response["headers"] == [(b"Content-Type", b"application/json")] + + +@pytest.mark.asyncio +async def test_async_http_consumer_with_channel_layer(): + """ + Tests that AsyncHttpConsumer is implemented correctly. + """ + + class TestConsumer(AsyncHttpConsumer): + """ + Abstract consumer that provides a method that handles running a command and getting a response on a + device. + """ + + channel_layer_alias = "testlayer" + + async def handle(self, body): + # Add consumer to a known test group that we will use to send events to. + await self.channel_layer.group_add("test_group", self.channel_name) + await self.send_headers( + status=200, headers=[(b"Content-Type", b"application/json")] + ) + + async def send_to_long_poll(self, event): + received_data = str(event["data"]).encode("utf8") + # We just echo what we receive, and close the response. + await self.send_body(received_data, more_body=False) + + channel_layers_setting = { + "testlayer": {"BACKEND": "channels.layers.InMemoryChannelLayer"} + } + + with override_settings(CHANNEL_LAYERS=channel_layers_setting): + # Open a connection + communicator = HttpCommunicator( + TestConsumer, + method="POST", + path="/test/", + body=json.dumps({"value": 42, "anything": False}).encode("utf-8"), + ) + + # We issue the HTTP request + await communicator.send_request() + + # Gets the response start (status and headers) + response_start = await communicator.get_response_start(timeout=1) + + # Make sure that the start of the response looks good so far. + assert response_start["status"] == 200 + assert response_start["headers"] == [(b"Content-Type", b"application/json")] + + # Send now a message to the consumer through the channel layer. Using the known test_group. + channel_layer = get_channel_layer("testlayer") + await channel_layer.group_send( + "test_group", + {"type": "send.to.long.poll", "data": "hello from channel layer"}, + ) + + # Now we should be able to get the message back on the remaining chunk of body. + body = await communicator.get_body_chunk(timeout=1) + assert body == b"hello from channel layer" diff --git a/tests/test_http_stream.py b/tests/test_http_stream.py new file mode 100644 index 000000000..c492dc278 --- /dev/null +++ b/tests/test_http_stream.py @@ -0,0 +1,43 @@ +import asyncio + +import pytest + +from channels.generic.http import AsyncHttpConsumer +from channels.testing import HttpCommunicator + + +@pytest.mark.asyncio +async def test_async_http_consumer(): + """ + Tests that AsyncHttpConsumer is implemented correctly. + """ + + class TestConsumer(AsyncHttpConsumer): + async def handle(self, body): + self.is_streaming = True + await self.send_headers( + headers=[ + (b"Cache-Control", b"no-cache"), + (b"Content-Type", b"text/event-stream"), + (b"Transfer-Encoding", b"chunked"), + ] + ) + asyncio.get_event_loop().create_task(self.stream()) + + async def stream(self): + for n in range(0, 3): + if not self.is_streaming: + break + payload = "data: %d\n\n" % (n + 1) + await self.send_body(payload.encode("utf-8"), more_body=True) + await asyncio.sleep(0.2) + await self.send_body(b"") + + async def disconnect(self): + self.is_streaming = False + + # Open a connection + communicator = HttpCommunicator(TestConsumer, method="GET", path="/test/", body=b"") + response = await communicator.get_response() + assert response["body"] == b"data: 1\n\ndata: 2\n\ndata: 3\n\n" + assert response["status"] == 200