Skip to content

Commit

Permalink
Bug Fix: Defensively copy context entities (#340)
Browse files Browse the repository at this point in the history

Co-authored-by: Nathaniel Ruiz Nowell <[email protected]>
  • Loading branch information
tyler-dodge and NathanielRN authored Jun 27, 2022
1 parent 0f13101 commit 14a7ad9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
15 changes: 14 additions & 1 deletion aws_xray_sdk/core/async_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import sys
import copy

from .context import Context as _Context

Expand Down Expand Up @@ -108,6 +109,18 @@ def task_factory(loop, coro):
else:
current_task = asyncio.Task.current_task(loop=loop)
if current_task is not None and hasattr(current_task, 'context'):
setattr(task, 'context', current_task.context)
if current_task.context.get('entities'):
# NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides
# the parent by looking at its `_local.entities`, we must copy the entities
# for concurrent subsegments. Otherwise, the subsegments would be
# modifying the same `entities` list and sugsegments would take other
# subsegments as parents instead of the original `segment`.
#
# See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101
new_context = copy.copy(current_task.context)
new_context['entities'] = [item for item in current_task.context['entities']]
else:
new_context = current_task.context
setattr(task, 'context', new_context)

return task
23 changes: 23 additions & 0 deletions tests/test_async_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .util import get_new_stubbed_recorder
from aws_xray_sdk.version import VERSION
from aws_xray_sdk.core.async_context import AsyncContext
import asyncio


xray_recorder = get_new_stubbed_recorder()
Expand Down Expand Up @@ -43,6 +44,28 @@ async def test_capture(loop):
assert platform.python_implementation() == service.get('runtime')
assert platform.python_version() == service.get('runtime_version')

async def test_concurrent_calls(loop):
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
async with xray_recorder.in_segment_async('segment') as segment:
global counter
counter = 0
total_tasks = 10
flag = asyncio.Event()
async def assert_task():
async with xray_recorder.in_subsegment_async('segment') as subsegment:
global counter
counter += 1
# Begin all subsegments before closing any to ensure they overlap
if counter < total_tasks:
await flag.wait()
else:
flag.set()
return subsegment.parent_id
tasks = [assert_task() for task in range(total_tasks)]
subsegs_parent_ids = await asyncio.gather(*tasks)
for subseg_parent_id in subsegs_parent_ids:
assert subseg_parent_id == segment.id


async def test_async_context_managers(loop):
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
Expand Down

0 comments on commit 14a7ad9

Please sign in to comment.