From 0f56ca0b0b6fb47ae78ea6b7a86878d63e162eaf Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Tue, 30 Jul 2024 15:39:55 -0400 Subject: [PATCH 1/6] fix: improved handling of api call failure --- sqs_workers/queue.py | 120 ++++++++++++++++++++++++++++------------- sqs_workers/sqs_env.py | 7 ++- sqs_workers/utils.py | 10 +++- 3 files changed, 96 insertions(+), 41 deletions(-) diff --git a/sqs_workers/queue.py b/sqs_workers/queue.py index 1ac5cac..b6275fc 100644 --- a/sqs_workers/queue.py +++ b/sqs_workers/queue.py @@ -10,9 +10,11 @@ Callable, Dict, Generator, + Iterable, List, Literal, Optional, + Tuple, TypeVar, ) @@ -27,6 +29,7 @@ from sqs_workers.exceptions import SQSError from sqs_workers.processors import DEFAULT_CONTEXT_VAR, Processor from sqs_workers.shutdown_policies import NEVER_SHUTDOWN +from sqs_workers.utils import batcher DEFAULT_MESSAGE_GROUP_ID = "default" SEND_BATCH_SIZE = 10 @@ -85,53 +88,80 @@ def process_queue(self, shutdown_policy=NEVER_SHUTDOWN, wait_second=10): ) break - def process_batch(self, wait_seconds=0) -> BatchProcessingResult: + def process_batch(self, wait_seconds: int = 0) -> BatchProcessingResult: """ Process a batch of messages from the queue (10 messages at most), return the number of successfully processed messages, and exit """ - queue = self.get_queue() - if self.batching_policy.batching_enabled: - return self._process_messages_in_batch(queue, wait_seconds) + messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size) + success = self.process_messages(messages) + messages_with_success = ((m, success) for m in messages) + else: + messages = self.get_raw_messages(wait_seconds) + success = [self.process_message(message) for message in messages] + messages_with_success = zip(messages, success) - return self._process_messages_individually(queue, wait_seconds) + return self._handle_processed(messages_with_success) - def _process_messages_in_batch(self, queue, wait_seconds): - messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size) - result = BatchProcessingResult(self.name) + def _handle_processed(self, messages_with_success: Iterable[Tuple[Any, bool]]): + """ + Handles the results of processing messages. - success = self.process_messages(messages) + For successful messages, we delete the message ID from the queue, which is + equivalent to acknowledging it. - for message in messages: - result.update_with_message(message, success) - if success: - entry = { - "Id": message.message_id, - "ReceiptHandle": message.receipt_handle, - } - queue.delete_messages(Entries=[entry]) - else: - timeout = self.backoff_policy.get_visibility_timeout(message) - message.change_visibility(VisibilityTimeout=timeout) - return result + For failed messages, we change the visibility of the message, in order to + keep it un-consumeable for a little while (a form of backoff). + + In each case (delete or change-viz), we batch the API calls to AWS in order + to try to avoid getting throttled, with batches of size 10 (the limit). The + config (see sqs_env.py) should also retry in the event of exceptions. + """ + queue = self.get_queue() - def _process_messages_individually(self, queue, wait_seconds): - messages = self.get_raw_messages(wait_seconds) result = BatchProcessingResult(self.name) - for message in messages: - success = self.process_message(message) - result.update_with_message(message, success) - if success: - entry = { - "Id": message.message_id, - "ReceiptHandle": message.receipt_handle, - } - queue.delete_messages(Entries=[entry]) - else: - timeout = self.backoff_policy.get_visibility_timeout(message) - message.change_visibility(VisibilityTimeout=timeout) + for subgroup in batcher(messages_with_success, batch_size=10): + entries_to_ack = [] + entries_to_change_viz = [] + + for m, success in subgroup: + result.update_with_message(m, success) + if success: + entries_to_ack.append( + { + "Id": m.message_id, + "ReceiptHandle": m.receipt_handle, + } + ) + else: + entries_to_change_viz.append( + { + "Id": m.message_id, + "ReceiptHandle": m.receipt_handle, + "VisibilityTimeout": self.backoff_policy.get_visibility_timeout(m), + } + ) + + ack_response = queue.delete_messages(Entries=entries_to_ack) + + if ack_response.get("Failed"): + logger.warning( + "Failed to delete processed messages from queue", + extra={"queue": self.name, "failures": ack_response["Failed"]}, + ) + + viz_response = queue.change_message_visibility_batch( + Entries=entries_to_change_viz, + ) + + if viz_response.get("Failed"): + logger.warning( + "Failed to change visibility of messages which failed to process", + extra={"queue": self.name, "failures": viz_response["Failed"]}, + ) + return result def process_message(self, message: Any) -> bool: @@ -151,15 +181,16 @@ def process_messages(self, messages: List[Any]) -> bool: """ raise NotImplementedError() - def get_raw_messages(self, wait_seconds, max_messages=10): + def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> list[Any]: """Return raw messages from the queue, addressed by its name""" + queue = self.get_queue() + kwargs = { "WaitTimeSeconds": wait_seconds, "MaxNumberOfMessages": max_messages if max_messages <= 10 else 10, "MessageAttributeNames": ["All"], "AttributeNames": ["All"], } - queue = self.get_queue() if max_messages <= 10: return queue.receive_messages(**kwargs) @@ -180,16 +211,29 @@ def get_raw_messages(self, wait_seconds, max_messages=10): def drain_queue(self, wait_seconds=0): """Delete all messages from the queue without calling purge().""" queue = self.get_queue() + deleted_count = 0 while True: messages = self.get_raw_messages(wait_seconds) if not messages: break + entries = [ {"Id": msg.message_id, "ReceiptHandle": msg.receipt_handle} for msg in messages ] - queue.delete_messages(Entries=entries) + + ack_response = queue.delete_messages(Entries=entries) + + if ack_response.get("Failed"): + logger.warning( + "Failed to delete processed messages from queue", + extra={ + "queue": self.name, + "failures": ack_response["Failed"], + }, + ) + deleted_count += len(messages) return deleted_count diff --git a/sqs_workers/sqs_env.py b/sqs_workers/sqs_env.py index ecb9fa9..356ed5c 100644 --- a/sqs_workers/sqs_env.py +++ b/sqs_workers/sqs_env.py @@ -15,6 +15,7 @@ import attr import boto3 +from botocore.config import Config from typing_extensions import ParamSpec from sqs_workers import DEFAULT_BACKOFF, RawQueue, codecs, context, processors @@ -60,10 +61,12 @@ class SQSEnv: def __attrs_post_init__(self): self.context = self.context_maker() + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html + retry_config = Config(retries={"max_attempts": 3, "mode": "standard"}) if not self.sqs_client: - self.sqs_client = self.session.client("sqs") + self.sqs_client = self.session.client("sqs", config=retry_config) if not self.sqs_resource: - self.sqs_resource = self.session.resource("sqs") + self.sqs_resource = self.session.resource("sqs", config=retry_config) @overload def queue( diff --git a/sqs_workers/utils.py b/sqs_workers/utils.py index 602f534..d29af08 100644 --- a/sqs_workers/utils.py +++ b/sqs_workers/utils.py @@ -1,7 +1,8 @@ import importlib import logging from inspect import Signature -from typing import Any +from itertools import islice +from typing import Any, Iterable logger = logging.getLogger(__name__) @@ -121,3 +122,10 @@ def ensure_string(obj: Any, encoding="utf-8", errors="strict") -> str: return obj.decode(encoding, errors) else: return str(obj) + + +def batcher(iterable, batch_size) -> Iterable[Iterable[Any]]: + """Cuts an iterable up into sub-iterables of size batch_size.""" + iterator = iter(iterable) + while batch := list(islice(iterator, batch_size)): + yield batch From 17748348df1f5c6f9fa6ac03709e54b1898af72d Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Tue, 30 Jul 2024 15:40:33 -0400 Subject: [PATCH 2/6] chore: ruff --- sqs_workers/queue.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sqs_workers/queue.py b/sqs_workers/queue.py index b6275fc..05bb94b 100644 --- a/sqs_workers/queue.py +++ b/sqs_workers/queue.py @@ -94,7 +94,9 @@ def process_batch(self, wait_seconds: int = 0) -> BatchProcessingResult: the number of successfully processed messages, and exit """ if self.batching_policy.batching_enabled: - messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size) + messages = self.get_raw_messages( + wait_seconds, self.batching_policy.batch_size + ) success = self.process_messages(messages) messages_with_success = ((m, success) for m in messages) else: @@ -140,7 +142,9 @@ def _handle_processed(self, messages_with_success: Iterable[Tuple[Any, bool]]): { "Id": m.message_id, "ReceiptHandle": m.receipt_handle, - "VisibilityTimeout": self.backoff_policy.get_visibility_timeout(m), + "VisibilityTimeout": self.backoff_policy.get_visibility_timeout( + m + ), } ) From 3d576a2d010502ef0553212c868fcb07f3db4f2b Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Tue, 30 Jul 2024 16:34:43 -0400 Subject: [PATCH 3/6] chore: retrofit memory sqs with actual in-flight tracking --- sqs_workers/memory_sqs.py | 44 ++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/sqs_workers/memory_sqs.py b/sqs_workers/memory_sqs.py index 6849e8c..e1a54ae 100644 --- a/sqs_workers/memory_sqs.py +++ b/sqs_workers/memory_sqs.py @@ -107,6 +107,7 @@ class MemoryQueue: name: str = attr.ib() attributes: Dict[str, Dict[str, str]] = attr.ib() messages: List["MemoryMessage"] = attr.ib(factory=list) + in_flight: List["MemoryMessage"] = attr.ib(factory=list) def __attrs_post_init__(self): self.attributes["QueueArn"] = self.name @@ -146,6 +147,7 @@ def receive_messages(self, WaitTimeSeconds="0", MaxNumberOfMessages="10", **kwar else: ready_messages.append(message) self.messages[:] = push_back_messages + self.in_flight.extend(ready_messages) return ready_messages def delete_messages(self, Entries): @@ -158,19 +160,40 @@ def delete_messages(self, Entries): message_ids = {entry["Id"] for entry in Entries} successfully_deleted = set() - push_back_messages = [] - for message in self.messages: + for i, message in enumerate(self.in_flight): if message.message_id in message_ids: successfully_deleted.add(message.message_id) - else: - push_back_messages.append(message) - self.messages[:] = push_back_messages + del self.in_flight[i] - didnt_deleted = message_ids.difference(successfully_deleted) return { "Successful": [{"Id": _id} for _id in successfully_deleted], - "Failed": [{"Id": _id} for _id in didnt_deleted], + } + + def change_message_visibility_batch(self, Entries): + """ + Changes message visibility by looking at in-flight messages, setting + a new execute_at, and returning it to the pool of processable messages + """ + edited = [] + return_to_pool = [] + entries_by_id = {e["Id"]: e for e in Entries} + + for i, m in enumerate(self.in_flight): + if m.message_id in entries_by_id.keys(): + sec = int(entries_by_id[m.message_id]["VisibilityTimeout"]) + now = datetime.datetime.utcnow() + execute_at = now + datetime.timedelta(seconds=sec) + changed = attr.evolve(m, execute_at=execute_at) + changed.attributes["ApproximateReceiveCount"] += 1 + edited.append(changed) + return_to_pool.append(changed) + del self.in_flight[i] + + self.messages.extend(return_to_pool) + + return { + "Successful": [{"Id": _id} for _id in edited], } def delete(self): @@ -241,10 +264,3 @@ def from_kwargs(cls, queue_impl, kwargs): return MemoryMessage( queue_impl, body, message_atttributes, attributes, execute_at ) - - def change_visibility(self, VisibilityTimeout="0", **kwargs): - timeout = int(VisibilityTimeout) - execute_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=timeout) - message = attr.evolve(self, execute_at=execute_at) - message.attributes["ApproximateReceiveCount"] += 1 - self.queue_impl.messages.append(message) From 3d35e7f12e54bfd3e58d681a778abc477456d94b Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Tue, 30 Jul 2024 16:38:02 -0400 Subject: [PATCH 4/6] fix: need old typing --- sqs_workers/queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqs_workers/queue.py b/sqs_workers/queue.py index 05bb94b..08d6599 100644 --- a/sqs_workers/queue.py +++ b/sqs_workers/queue.py @@ -185,7 +185,7 @@ def process_messages(self, messages: List[Any]) -> bool: """ raise NotImplementedError() - def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> list[Any]: + def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> List[Any]: """Return raw messages from the queue, addressed by its name""" queue = self.get_queue() From 36d0c3f916ff2e8cc622526f3c5039c951a11ad4 Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Wed, 31 Jul 2024 10:06:23 -0400 Subject: [PATCH 5/6] chore: clean up test memory impl per Goncalo comments --- sqs_workers/memory_sqs.py | 54 +++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/sqs_workers/memory_sqs.py b/sqs_workers/memory_sqs.py index e1a54ae..bb8bdb0 100644 --- a/sqs_workers/memory_sqs.py +++ b/sqs_workers/memory_sqs.py @@ -107,7 +107,7 @@ class MemoryQueue: name: str = attr.ib() attributes: Dict[str, Dict[str, str]] = attr.ib() messages: List["MemoryMessage"] = attr.ib(factory=list) - in_flight: List["MemoryMessage"] = attr.ib(factory=list) + in_flight: Dict[str, "MemoryMessage"] = attr.ib(factory=dict) def __attrs_post_init__(self): self.attributes["QueueArn"] = self.name @@ -147,7 +147,8 @@ def receive_messages(self, WaitTimeSeconds="0", MaxNumberOfMessages="10", **kwar else: ready_messages.append(message) self.messages[:] = push_back_messages - self.in_flight.extend(ready_messages) + for m in ready_messages: + self.in_flight[m.message_id] = m return ready_messages def delete_messages(self, Entries): @@ -157,17 +158,19 @@ def delete_messages(self, Entries): See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/ services/sqs.html#SQS.Queue.delete_messages """ - message_ids = {entry["Id"] for entry in Entries} + found_entries = [] + not_found_entries = [] - successfully_deleted = set() - - for i, message in enumerate(self.in_flight): - if message.message_id in message_ids: - successfully_deleted.add(message.message_id) - del self.in_flight[i] + for e in Entries: + if e["Id"] in self.in_flight: + found_entries.append(e) + self.in_flight.pop(e["Id"]) + else: + not_found_entries.append(e) return { - "Successful": [{"Id": _id} for _id in successfully_deleted], + "Successful": [{"Id": e["Id"]} for e in found_entries], + "Failed": [{"Id": e["Id"]} for e in not_found_entries], } def change_message_visibility_batch(self, Entries): @@ -175,25 +178,26 @@ def change_message_visibility_batch(self, Entries): Changes message visibility by looking at in-flight messages, setting a new execute_at, and returning it to the pool of processable messages """ - edited = [] - return_to_pool = [] - entries_by_id = {e["Id"]: e for e in Entries} - - for i, m in enumerate(self.in_flight): - if m.message_id in entries_by_id.keys(): - sec = int(entries_by_id[m.message_id]["VisibilityTimeout"]) + found_entries = [] + not_found_entries = [] + + for e in Entries: + if e["Id"] in self.in_flight: + found_entries.append(e) + in_flight_message = self.in_flight[e["Id"]] + sec = int(e["VisibilityTimeout"]) now = datetime.datetime.utcnow() execute_at = now + datetime.timedelta(seconds=sec) - changed = attr.evolve(m, execute_at=execute_at) - changed.attributes["ApproximateReceiveCount"] += 1 - edited.append(changed) - return_to_pool.append(changed) - del self.in_flight[i] - - self.messages.extend(return_to_pool) + updated_message = attr.evolve(in_flight_message, execute_at=execute_at) + updated_message.attributes["ApproximateReceiveCount"] += 1 + self.messages.append(updated_message) + self.in_flight.pop(e["Id"]) + else: + not_found_entries.append(e) return { - "Successful": [{"Id": _id} for _id in edited], + "Successful": [{"Id": e["Id"]} for e in found_entries], + "Failed": [{"Id": e["Id"]} for e in not_found_entries], } def delete(self): From f8965d423196a269214577a4426508bf719cf659 Mon Sep 17 00:00:00 2001 From: Brandon Willett Date: Wed, 31 Jul 2024 10:11:50 -0400 Subject: [PATCH 6/6] chore: other small refactoring from PR review --- sqs_workers/memory_sqs.py | 3 ++- sqs_workers/sqs_env.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sqs_workers/memory_sqs.py b/sqs_workers/memory_sqs.py index bb8bdb0..a1c26f7 100644 --- a/sqs_workers/memory_sqs.py +++ b/sqs_workers/memory_sqs.py @@ -181,12 +181,13 @@ def change_message_visibility_batch(self, Entries): found_entries = [] not_found_entries = [] + now = datetime.datetime.utcnow() + for e in Entries: if e["Id"] in self.in_flight: found_entries.append(e) in_flight_message = self.in_flight[e["Id"]] sec = int(e["VisibilityTimeout"]) - now = datetime.datetime.utcnow() execute_at = now + datetime.timedelta(seconds=sec) updated_message = attr.evolve(in_flight_message, execute_at=execute_at) updated_message.attributes["ApproximateReceiveCount"] += 1 diff --git a/sqs_workers/sqs_env.py b/sqs_workers/sqs_env.py index 356ed5c..05baf50 100644 --- a/sqs_workers/sqs_env.py +++ b/sqs_workers/sqs_env.py @@ -46,6 +46,10 @@ class SQSEnv: queue_prefix = attr.ib(default="") codec: str = attr.ib(default=codecs.DEFAULT_CONTENT_TYPE) + # retry settings for internal boto + retry_max_attempts: int = attr.ib(default=3) + retry_mode: str = attr.ib(default="standard") + # queue-specific settings backoff_policy = attr.ib(default=DEFAULT_BACKOFF) @@ -62,7 +66,8 @@ class SQSEnv: def __attrs_post_init__(self): self.context = self.context_maker() # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html - retry_config = Config(retries={"max_attempts": 3, "mode": "standard"}) + retry_dict = {"max_attempts": self.retry_max_attempts, "mode": self.retry_mode} + retry_config = Config(retries=retry_dict) if not self.sqs_client: self.sqs_client = self.session.client("sqs", config=retry_config) if not self.sqs_resource: