Skip to content

Commit

Permalink
Add a workaround for CloudWatch GetLogEvents empty results (#1652)
Browse files Browse the repository at this point in the history
  • Loading branch information
un-def authored Sep 3, 2024
1 parent a740a0a commit 1537163
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 31 deletions.
53 changes: 37 additions & 16 deletions src/dstack/_internal/server/services/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,10 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
stream = self._get_stream_name(
project.name, request.run_name, request.job_submission_id, log_producer
)
parameters = {
"logGroupName": self._group,
"logStreamName": stream,
"limit": request.limit,
"startFromHead": (not request.descending),
}
if request.start_time:
# XXX: Since callers use start_time/end_time for pagination, one millisecond is added
# to avoid an infinite loop because startTime boundary is inclusive.
parameters["startTime"] = _datetime_to_unix_time_ms(request.start_time) + 1
if request.end_time:
# No need to substract one millisecond in this case, though, seems that endTime is
# exclusive, that is, time interval boundaries are [startTime, entTime)
parameters["endTime"] = _datetime_to_unix_time_ms(request.end_time)
cw_events: List[_CloudWatchLogEvent]
with self._wrap_boto_errors():
try:
response = self._client.get_log_events(**parameters)
cw_events = response["events"]
cw_events = self._get_log_events(stream, request)
except botocore.exceptions.ClientError as e:
if not self._is_resource_not_found_exception(e):
raise
Expand All @@ -122,6 +107,42 @@ def poll_logs(self, project: ProjectModel, request: PollLogsRequest) -> JobSubmi
]
return JobSubmissionLogs(logs=logs)

def _get_log_events(self, stream: str, request: PollLogsRequest) -> List[_CloudWatchLogEvent]:
parameters = {
"logGroupName": self._group,
"logStreamName": stream,
"limit": request.limit,
}
start_from_head = not request.descending
parameters["startFromHead"] = start_from_head
if request.start_time:
# XXX: Since callers use start_time/end_time for pagination, one millisecond is added
# to avoid an infinite loop because startTime boundary is inclusive.
parameters["startTime"] = _datetime_to_unix_time_ms(request.start_time) + 1
if request.end_time:
# No need to substract one millisecond in this case, though, seems that endTime is
# exclusive, that is, time interval boundaries are [startTime, entTime)
parameters["endTime"] = _datetime_to_unix_time_ms(request.end_time)
response = self._client.get_log_events(**parameters)
events: List[_CloudWatchLogEvent] = response["events"]
if start_from_head or events:
return events
# Workaround for https://github.com/boto/boto3/issues/3718
# Required only when startFromHead = false (the default value).
next_token: str = response["nextBackwardToken"]
# Limit max tries to avoid a possible infinite loop if the API is misbehaving
tries_left = 10
while tries_left:
parameters["nextToken"] = next_token
response = self._client.get_log_events(**parameters)
events = response["events"]
if events or response["nextBackwardToken"] == next_token:
return events
next_token = response["nextBackwardToken"]
tries_left -= 1
logger.warning("too many empty responses from stream %s, returning dummy response", stream)
return []

def write_logs(
self,
project: ProjectModel,
Expand Down
164 changes: 149 additions & 15 deletions src/tests/_internal/server/services/test_logs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from datetime import datetime, timezone
from pathlib import Path
from unittest.mock import Mock, call
Expand Down Expand Up @@ -61,7 +62,11 @@ async def project(self, test_db, session: AsyncSession) -> ProjectModel:
def mock_client(self, monkeypatch: pytest.MonkeyPatch) -> Mock:
mock = Mock()
monkeypatch.setattr("boto3.Session.client", Mock(return_value=mock))
mock.get_log_events.return_value = {"events": []}
mock.get_log_events.return_value = {
"events": [],
"nextBackwardToken": "bwd",
"nextFormartToken": "fwd",
}
return mock

@pytest.fixture
Expand Down Expand Up @@ -160,19 +165,17 @@ def test_ensure_stream_exists_cached_forced(
)

@pytest.mark.asyncio
async def test_poll_logs_response(
async def test_poll_logs_non_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value = {
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
}
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == [
Expand All @@ -189,19 +192,33 @@ async def test_poll_logs_response(
]

@pytest.mark.asyncio
async def test_poll_logs_response_descending(
async def test_poll_logs_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value = {
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
}
# Check that we don't use the workaround when descending=False -> startFromHead=True
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.return_value["events"] = []
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
mock_client.get_log_events.assert_called_once()

@pytest.mark.asyncio
async def test_poll_logs_descending_non_empty_response_on_first_call(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

Expand All @@ -218,6 +235,118 @@ async def test_poll_logs_response_descending(
),
]

@pytest.mark.asyncio
async def test_poll_logs_descending_two_first_calls_return_empty_response(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# The first two calls return empty event lists, though the token is not the same, meaning
# there are more events.
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.side_effect = [
{
"events": [],
"nextBackwardToken": "bwd1",
"nextForwardToken": "fwd",
},
{
"events": [],
"nextBackwardToken": "bwd2",
"nextForwardToken": "fwd",
},
{
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
{"timestamp": 1696586513235, "message": "V29ybGQ="},
],
"nextBackwardToken": "bwd3",
"nextForwardToken": "fwd",
},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == [
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 235000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
message="V29ybGQ=",
),
LogEvent(
timestamp=datetime(2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc),
log_source=LogEventSource.STDOUT,
message="SGVsbG8=",
),
]
assert mock_client.get_log_events.call_count == 3

@pytest.mark.asyncio
async def test_poll_logs_descending_empty_response_with_same_token(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# The first two calls return empty event lists with the same token, meaning we reached
# the end.
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.side_effect = [
{
"events": [],
"nextBackwardToken": "bwd",
"nextForwardToken": "fwd",
},
{
"events": [],
"nextBackwardToken": "bwd",
"nextForwardToken": "fwd",
},
# We should not reach this response
{
"events": [
{"timestamp": 1696586513234, "message": "SGVsbG8="},
],
"nextBackwardToken": "bwd2",
"nextForwardToken": "fwd",
},
]
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
assert mock_client.get_log_events.call_count == 2

@pytest.mark.asyncio
async def test_poll_logs_descending_empty_response_max_tries(
self,
project: ProjectModel,
log_storage: CloudWatchLogStorage,
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# Test for a circuit breaker when the API returns empty results on each call, but the
# token is different on each call.
# https://github.com/dstackai/dstack/issues/1647
counter = itertools.count()

def _response_producer(*args, **kwargs):
return {
"events": [],
"nextBackwardToken": f"bwd{next(counter)}",
"nextForwardToken": "fwd",
}

mock_client.get_log_events.side_effect = _response_producer
poll_logs_request.descending = True
job_submission_logs = log_storage.poll_logs(project, poll_logs_request)

assert job_submission_logs.logs == []
assert mock_client.get_log_events.call_count == 11 # initial call + 10 tries

@pytest.mark.asyncio
async def test_poll_logs_request_params_asc_no_diag_no_dates(
self,
Expand Down Expand Up @@ -245,6 +374,11 @@ async def test_poll_logs_request_params_desc_diag_with_dates(
mock_client: Mock,
poll_logs_request: PollLogsRequest,
):
# Ensure the first response has events to avoid triggering a workaround for
# https://github.com/dstackai/dstack/issues/1647
mock_client.get_log_events.return_value["events"] = [
{"timestamp": 1696586513234, "message": "SGVsbG8="}
]
poll_logs_request.start_time = datetime(
2023, 10, 6, 10, 1, 53, 234000, tzinfo=timezone.utc
)
Expand Down

0 comments on commit 1537163

Please sign in to comment.