Skip to content

Commit

Permalink
Merge pull request #98 from Doist/brandon/response
Browse files Browse the repository at this point in the history
fix: improved handling of api call failure
  • Loading branch information
brandon-doist authored Jul 31, 2024
2 parents f1b78a1 + f8965d4 commit f204fc9
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 59 deletions.
57 changes: 39 additions & 18 deletions sqs_workers/memory_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class MemoryQueue:
name: str = attr.ib()
attributes: Dict[str, Dict[str, str]] = attr.ib()
messages: List["MemoryMessage"] = attr.ib(factory=list)
in_flight: Dict[str, "MemoryMessage"] = attr.ib(factory=dict)

def __attrs_post_init__(self):
self.attributes["QueueArn"] = self.name
Expand Down Expand Up @@ -146,6 +147,8 @@ def receive_messages(self, WaitTimeSeconds="0", MaxNumberOfMessages="10", **kwar
else:
ready_messages.append(message)
self.messages[:] = push_back_messages
for m in ready_messages:
self.in_flight[m.message_id] = m
return ready_messages

def delete_messages(self, Entries):
Expand All @@ -155,22 +158,47 @@ def delete_messages(self, Entries):
See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/
services/sqs.html#SQS.Queue.delete_messages
"""
message_ids = {entry["Id"] for entry in Entries}
found_entries = []
not_found_entries = []

successfully_deleted = set()
push_back_messages = []
for e in Entries:
if e["Id"] in self.in_flight:
found_entries.append(e)
self.in_flight.pop(e["Id"])
else:
not_found_entries.append(e)

for message in self.messages:
if message.message_id in message_ids:
successfully_deleted.add(message.message_id)
return {
"Successful": [{"Id": e["Id"]} for e in found_entries],
"Failed": [{"Id": e["Id"]} for e in not_found_entries],
}

def change_message_visibility_batch(self, Entries):
"""
Changes message visibility by looking at in-flight messages, setting
a new execute_at, and returning it to the pool of processable messages
"""
found_entries = []
not_found_entries = []

now = datetime.datetime.utcnow()

for e in Entries:
if e["Id"] in self.in_flight:
found_entries.append(e)
in_flight_message = self.in_flight[e["Id"]]
sec = int(e["VisibilityTimeout"])
execute_at = now + datetime.timedelta(seconds=sec)
updated_message = attr.evolve(in_flight_message, execute_at=execute_at)
updated_message.attributes["ApproximateReceiveCount"] += 1
self.messages.append(updated_message)
self.in_flight.pop(e["Id"])
else:
push_back_messages.append(message)
self.messages[:] = push_back_messages
not_found_entries.append(e)

didnt_deleted = message_ids.difference(successfully_deleted)
return {
"Successful": [{"Id": _id} for _id in successfully_deleted],
"Failed": [{"Id": _id} for _id in didnt_deleted],
"Successful": [{"Id": e["Id"]} for e in found_entries],
"Failed": [{"Id": e["Id"]} for e in not_found_entries],
}

def delete(self):
Expand Down Expand Up @@ -241,10 +269,3 @@ def from_kwargs(cls, queue_impl, kwargs):
return MemoryMessage(
queue_impl, body, message_atttributes, attributes, execute_at
)

def change_visibility(self, VisibilityTimeout="0", **kwargs):
timeout = int(VisibilityTimeout)
execute_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=timeout)
message = attr.evolve(self, execute_at=execute_at)
message.attributes["ApproximateReceiveCount"] += 1
self.queue_impl.messages.append(message)
124 changes: 86 additions & 38 deletions sqs_workers/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
Callable,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Tuple,
TypeVar,
)

Expand All @@ -27,6 +29,7 @@
from sqs_workers.exceptions import SQSError
from sqs_workers.processors import DEFAULT_CONTEXT_VAR, Processor
from sqs_workers.shutdown_policies import NEVER_SHUTDOWN
from sqs_workers.utils import batcher

DEFAULT_MESSAGE_GROUP_ID = "default"
SEND_BATCH_SIZE = 10
Expand Down Expand Up @@ -85,53 +88,84 @@ def process_queue(self, shutdown_policy=NEVER_SHUTDOWN, wait_second=10):
)
break

def process_batch(self, wait_seconds=0) -> BatchProcessingResult:
def process_batch(self, wait_seconds: int = 0) -> BatchProcessingResult:
"""
Process a batch of messages from the queue (10 messages at most), return
the number of successfully processed messages, and exit
"""
queue = self.get_queue()

if self.batching_policy.batching_enabled:
return self._process_messages_in_batch(queue, wait_seconds)
messages = self.get_raw_messages(
wait_seconds, self.batching_policy.batch_size
)
success = self.process_messages(messages)
messages_with_success = ((m, success) for m in messages)
else:
messages = self.get_raw_messages(wait_seconds)
success = [self.process_message(message) for message in messages]
messages_with_success = zip(messages, success)

return self._process_messages_individually(queue, wait_seconds)
return self._handle_processed(messages_with_success)

def _process_messages_in_batch(self, queue, wait_seconds):
messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size)
result = BatchProcessingResult(self.name)
def _handle_processed(self, messages_with_success: Iterable[Tuple[Any, bool]]):
"""
Handles the results of processing messages.
success = self.process_messages(messages)
For successful messages, we delete the message ID from the queue, which is
equivalent to acknowledging it.
for message in messages:
result.update_with_message(message, success)
if success:
entry = {
"Id": message.message_id,
"ReceiptHandle": message.receipt_handle,
}
queue.delete_messages(Entries=[entry])
else:
timeout = self.backoff_policy.get_visibility_timeout(message)
message.change_visibility(VisibilityTimeout=timeout)
return result
For failed messages, we change the visibility of the message, in order to
keep it un-consumeable for a little while (a form of backoff).
In each case (delete or change-viz), we batch the API calls to AWS in order
to try to avoid getting throttled, with batches of size 10 (the limit). The
config (see sqs_env.py) should also retry in the event of exceptions.
"""
queue = self.get_queue()

def _process_messages_individually(self, queue, wait_seconds):
messages = self.get_raw_messages(wait_seconds)
result = BatchProcessingResult(self.name)

for message in messages:
success = self.process_message(message)
result.update_with_message(message, success)
if success:
entry = {
"Id": message.message_id,
"ReceiptHandle": message.receipt_handle,
}
queue.delete_messages(Entries=[entry])
else:
timeout = self.backoff_policy.get_visibility_timeout(message)
message.change_visibility(VisibilityTimeout=timeout)
for subgroup in batcher(messages_with_success, batch_size=10):
entries_to_ack = []
entries_to_change_viz = []

for m, success in subgroup:
result.update_with_message(m, success)
if success:
entries_to_ack.append(
{
"Id": m.message_id,
"ReceiptHandle": m.receipt_handle,
}
)
else:
entries_to_change_viz.append(
{
"Id": m.message_id,
"ReceiptHandle": m.receipt_handle,
"VisibilityTimeout": self.backoff_policy.get_visibility_timeout(
m
),
}
)

ack_response = queue.delete_messages(Entries=entries_to_ack)

if ack_response.get("Failed"):
logger.warning(
"Failed to delete processed messages from queue",
extra={"queue": self.name, "failures": ack_response["Failed"]},
)

viz_response = queue.change_message_visibility_batch(
Entries=entries_to_change_viz,
)

if viz_response.get("Failed"):
logger.warning(
"Failed to change visibility of messages which failed to process",
extra={"queue": self.name, "failures": viz_response["Failed"]},
)

return result

def process_message(self, message: Any) -> bool:
Expand All @@ -151,15 +185,16 @@ def process_messages(self, messages: List[Any]) -> bool:
"""
raise NotImplementedError()

def get_raw_messages(self, wait_seconds, max_messages=10):
def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> List[Any]:
"""Return raw messages from the queue, addressed by its name"""
queue = self.get_queue()

kwargs = {
"WaitTimeSeconds": wait_seconds,
"MaxNumberOfMessages": max_messages if max_messages <= 10 else 10,
"MessageAttributeNames": ["All"],
"AttributeNames": ["All"],
}
queue = self.get_queue()

if max_messages <= 10:
return queue.receive_messages(**kwargs)
Expand All @@ -180,16 +215,29 @@ def get_raw_messages(self, wait_seconds, max_messages=10):
def drain_queue(self, wait_seconds=0):
"""Delete all messages from the queue without calling purge()."""
queue = self.get_queue()

deleted_count = 0
while True:
messages = self.get_raw_messages(wait_seconds)
if not messages:
break

entries = [
{"Id": msg.message_id, "ReceiptHandle": msg.receipt_handle}
for msg in messages
]
queue.delete_messages(Entries=entries)

ack_response = queue.delete_messages(Entries=entries)

if ack_response.get("Failed"):
logger.warning(
"Failed to delete processed messages from queue",
extra={
"queue": self.name,
"failures": ack_response["Failed"],
},
)

deleted_count += len(messages)
return deleted_count

Expand Down
12 changes: 10 additions & 2 deletions sqs_workers/sqs_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import attr
import boto3
from botocore.config import Config
from typing_extensions import ParamSpec

from sqs_workers import DEFAULT_BACKOFF, RawQueue, codecs, context, processors
Expand Down Expand Up @@ -45,6 +46,10 @@ class SQSEnv:
queue_prefix = attr.ib(default="")
codec: str = attr.ib(default=codecs.DEFAULT_CONTENT_TYPE)

# retry settings for internal boto
retry_max_attempts: int = attr.ib(default=3)
retry_mode: str = attr.ib(default="standard")

# queue-specific settings
backoff_policy = attr.ib(default=DEFAULT_BACKOFF)

Expand All @@ -60,10 +65,13 @@ class SQSEnv:

def __attrs_post_init__(self):
self.context = self.context_maker()
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
retry_dict = {"max_attempts": self.retry_max_attempts, "mode": self.retry_mode}
retry_config = Config(retries=retry_dict)
if not self.sqs_client:
self.sqs_client = self.session.client("sqs")
self.sqs_client = self.session.client("sqs", config=retry_config)
if not self.sqs_resource:
self.sqs_resource = self.session.resource("sqs")
self.sqs_resource = self.session.resource("sqs", config=retry_config)

@overload
def queue(
Expand Down
10 changes: 9 additions & 1 deletion sqs_workers/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
import logging
from inspect import Signature
from typing import Any
from itertools import islice
from typing import Any, Iterable

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,3 +122,10 @@ def ensure_string(obj: Any, encoding="utf-8", errors="strict") -> str:
return obj.decode(encoding, errors)
else:
return str(obj)


def batcher(iterable, batch_size) -> Iterable[Iterable[Any]]:
"""Cuts an iterable up into sub-iterables of size batch_size."""
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

0 comments on commit f204fc9

Please sign in to comment.