diff --git a/README.md b/README.md index 275e3a1..24c8054 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,12 @@ async with ( HttpRequestRecorder('any_recorder_name', 8080) as recorder, ClientSession() as http_session ): - expectation = recorder.expect_path(path='/any-path', responses=b'Hello back from recorder') + expectation = recorder.expect_path( + path='/any-path', # path to respond to + responses=b'Hello back from recorder', # responses to return, can be a list + name = 'any name', # optional name for the expected interaction + timeout = 1, # optional timeout in seconds, defaults to 3 + ) await http_session.get('http://localhost:8080/any-path', data=b'Hello') diff --git a/http_request_recorder/http_request_recorder.py b/http_request_recorder/http_request_recorder.py index c69f963..744f3bc 100644 --- a/http_request_recorder/http_request_recorder.py +++ b/http_request_recorder/http_request_recorder.py @@ -11,6 +11,7 @@ ResponsesType = Union[str, bytes, web.Response, Iterable[str], Iterable[bytes], Iterable[web.Response]] + class RecordedRequest: def __init__(self, ): self.body = None @@ -37,8 +38,9 @@ def __init__(self, response): self.was_triggered = Event() self.response = response - def __init__(self, matcher, responses: ResponsesType, name: str = None): + def __init__(self, matcher, responses: ResponsesType, name: str, timeout: int): self.name: str = name + self._timeout: int = timeout self.expected_count = None # None: use infinitely if isinstance(responses, (str, bytes, web.Response)): @@ -84,7 +86,7 @@ async def wait(self) -> str: # suppress (not very helpful) stack of asyncio errors that get raised on timeout with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(to_return.was_triggered.wait(), timeout=30) + await asyncio.wait_for(to_return.was_triggered.wait(), self._timeout) if not to_return.was_triggered.is_set(): # the above wait_for() timed out, raise a useful Exception: raise TimeoutError(f"{self} timed out waiting for a request") @@ -153,21 +155,22 @@ async def handle_request(self, request: BaseRequest): return web.Response(status=200, body=response) - def expect(self, matcher, responses: ResponsesType = "", name: str = None) -> ExpectedInteraction: - expectation = ExpectedInteraction(matcher, responses, name) + def expect(self, matcher, responses: ResponsesType = "", name: str = None, timeout: int = 3) -> ExpectedInteraction: + expectation = ExpectedInteraction(matcher, responses, name, timeout) self._expectations.append(expectation) return expectation - def expect_path(self, path: str, responses: ResponsesType = "") -> ExpectedInteraction: - return self.expect(lambda request: path == request.path, responses, name=path) + def expect_path(self, path: str, responses: ResponsesType = "", timeout: int = 3) -> ExpectedInteraction: + return self.expect(lambda request: path == request.path, responses, name=path, timeout=timeout) - def expect_xml_rpc(self, in_body: bytes, responses: ResponsesType = ""): + def expect_xml_rpc(self, in_body: bytes, responses: ResponsesType = "", timeout: int = 3): # TODO: test def matcher(request): return "/RPC2" == request.path and in_body in request.body return self.expect(matcher, responses=responses, - name=f"XmlRpc: {in_body.decode('UTF-8')}") + name=f"XmlRpc: {in_body.decode('UTF-8')}", + timeout=timeout) @staticmethod async def _request_string_for_log(request):