diff --git a/s3transfer/futures.py b/s3transfer/futures.py index 68775d04..0be265a9 100644 --- a/s3transfer/futures.py +++ b/s3transfer/futures.py @@ -17,6 +17,8 @@ from collections import namedtuple from concurrent import futures +from botocore.context import get_context + from s3transfer.compat import MAXINT from s3transfer.exceptions import CancelledError, TransferNotDoneError from s3transfer.utils import FunctionContainer, TaskSemaphore @@ -467,7 +469,9 @@ def submit(self, task, tag=None, block=True): semaphore.release, task.transfer_id, acquire_token ) # Submit the task to the underlying executor. - future = ExecutorFuture(self._executor.submit(task)) + # Pass the current context to ensure child threads persist the + # parent thread's context. + future = ExecutorFuture(self._executor.submit(task, get_context())) # Add the Semaphore.release() callback to the future such that # it is invoked once the future completes. future.add_done_callback(release_callback) diff --git a/s3transfer/tasks.py b/s3transfer/tasks.py index 4183715a..211e08fc 100644 --- a/s3transfer/tasks.py +++ b/s3transfer/tasks.py @@ -13,6 +13,8 @@ import copy import logging +from botocore.context import start_as_current_context + from s3transfer.utils import get_callbacks logger = logging.getLogger(__name__) @@ -118,32 +120,33 @@ def _get_kwargs_with_params_to_exclude(self, kwargs, exclude): filtered_kwargs[param] = value return filtered_kwargs - def __call__(self): + def __call__(self, ctx=None): """The callable to use when submitting a Task to an executor""" - try: - # Wait for all of futures this task depends on. - self._wait_on_dependent_futures() - # Gather up all of the main keyword arguments for main(). - # This includes the immediately provided main_kwargs and - # the values for pending_main_kwargs that source from the return - # values from the task's dependent futures. - kwargs = self._get_all_main_kwargs() - # If the task is not done (really only if some other related - # task to the TransferFuture had failed) then execute the task's - # main() method. - if not self._transfer_coordinator.done(): - return self._execute_main(kwargs) - except Exception as e: - self._log_and_set_exception(e) - finally: - # Run any done callbacks associated to the task no matter what. - for done_callback in self._done_callbacks: - done_callback() - - if self._is_final: - # If this is the final task announce that it is done if results - # are waiting on its completion. - self._transfer_coordinator.announce_done() + with start_as_current_context(ctx): + try: + # Wait for all of futures this task depends on. + self._wait_on_dependent_futures() + # Gather up all of the main keyword arguments for main(). + # This includes the immediately provided main_kwargs and + # the values for pending_main_kwargs that source from the return + # values from the task's dependent futures. + kwargs = self._get_all_main_kwargs() + # If the task is not done (really only if some other related + # task to the TransferFuture had failed) then execute the task's + # main() method. + if not self._transfer_coordinator.done(): + return self._execute_main(kwargs) + except Exception as e: + self._log_and_set_exception(e) + finally: + # Run any done callbacks associated to the task no matter what. + for done_callback in self._done_callbacks: + done_callback() + + if self._is_final: + # If this is the final task announce that it is done if results + # are waiting on its completion. + self._transfer_coordinator.announce_done() def _execute_main(self, kwargs): # Do not display keyword args that should not be printed, especially diff --git a/tests/unit/test_tasks.py b/tests/unit/test_tasks.py index 9759e8fb..b3501505 100644 --- a/tests/unit/test_tasks.py +++ b/tests/unit/test_tasks.py @@ -14,6 +14,8 @@ from functools import partial from threading import Event +from botocore.context import ClientContext, get_context + from s3transfer.futures import BoundedExecutor, TransferCoordinator from s3transfer.subscribers import BaseSubscriber from s3transfer.tasks import ( @@ -69,6 +71,11 @@ def _submit(self, transfer_future, **kwargs): pass +class ReturnContextTask(Task): + def _main(self): + return get_context() + + class ExceptionSubmissionTask(SubmissionTask): def _submit( self, @@ -723,6 +730,15 @@ def test_single_failed_pending_future_in_list(self): with self.assertRaises(TaskFailureException): self.transfer_coordinator.result() + def test_passing_context_to_task_call(self): + ctx = ClientContext() + ctx.features.add('FOO') + task = ReturnContextTask(self.transfer_coordinator) + self.assertEqual(task(ctx).features, {'FOO'}) + # `task(ctx)` returned, so the current context should be reset to None. + current_ctx = get_context() + self.assertEqual(current_ctx, None) + class BaseMultipartTaskTest(BaseTaskTest): def setUp(self):