diff --git a/neo4j-app/neo4j_app/icij_worker/worker/worker.py b/neo4j-app/neo4j_app/icij_worker/worker/worker.py index ef468f98..e0681ac7 100644 --- a/neo4j-app/neo4j_app/icij_worker/worker/worker.py +++ b/neo4j-app/neo4j_app/icij_worker/worker/worker.py @@ -137,27 +137,33 @@ async def consume(self) -> Tuple[Task, str]: @final @asynccontextmanager async def acknowledgment_cm(self, task: Task, project: str): - async with self._persist_error(task, project): + try: self._current = task, project self.debug('Task(id="%s") locked', task.id) - try: - event = TaskEvent( - task_id=task.id, progress=0, status=TaskStatus.RUNNING - ) - await self.publish_event(event, project) - yield - await self.acknowledge(task, project) - except asyncio.CancelledError as e: - self.error('Task(id="%s") worker cancelled, exiting', task.id) - raise e - except RecoverableError: - self.error('Task(id="%s") encountered error', task.id) - await self.negatively_acknowledge(task, project, requeue=True) - except Exception as fatal_error: - await self.negatively_acknowledge(task, project, requeue=False) - raise fatal_error - self._current = None - self.info('Task(id="%s") successful !', task.id) + event = TaskEvent(task_id=task.id, progress=0, status=TaskStatus.RUNNING) + await self.publish_event(event, project) + yield + await self.acknowledge(task, project) + except asyncio.CancelledError as e: + self.error( + 'Task(id="%s") worker cancelled, exiting without persisting error', + task.id, + ) + raise e + except RecoverableError: + self.error('Task(id="%s") encountered error', task.id) + await self.negatively_acknowledge(task, project, requeue=True) + except Exception as fatal_error: + if isinstance(fatal_error, MaxRetriesExceeded): + self.error('Task(id="%s") exceeded max retries, exiting !', task.id) + else: + self.error('Task(id="%s") fatal error, exiting !', task.id) + task_error = TaskError.from_exception(fatal_error) + await self.save_error(error=task_error, task=task, project=project) + await self.negatively_acknowledge(task, project, requeue=False) + raise fatal_error + self._current = None + self.info('Task(id="%s") successful !', task.id) @final async def acknowledge(self, task: Task, project: str): @@ -252,22 +258,6 @@ async def _publish_progress(self, progress: float, task: Task, project: str): event = TaskEvent(progress=progress, task_id=task.id) await self.publish_event(event, project) - @final - @asynccontextmanager - async def _persist_error(self, task: Task, project: str): - try: - yield - except asyncio.CancelledError as e: # pylint: disable=broad-except - self.debug("worker cancelled, no need to persist error") - raise e - except Exception as e: # pylint: disable=broad-except - if isinstance(e, MaxRetriesExceeded): - self.error('Task(id="%s") exceeded max retries, exiting !', task.id) - else: - self.error('Task(id="%s") fatal error, exiting !', task.id) - error = TaskError.from_exception(e) - await self.save_error(error=error, task=task, project=project) - @final def parse_task( self, task: Task, project: str diff --git a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py index 7bc6468f..d71de048 100644 --- a/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py +++ b/neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py @@ -19,7 +19,7 @@ TaskResult, TaskStatus, ) -from neo4j_app.icij_worker.exceptions import TaskCancelled +from neo4j_app.icij_worker.exceptions import TaskCancelled, UnregisteredTask from neo4j_app.icij_worker.worker.worker import add_missing_args, task_wrapper from neo4j_app.tests.conftest import TEST_PROJECT, async_true_after from neo4j_app.tests.icij_worker.conftest import MockManager, MockWorker @@ -250,7 +250,10 @@ async def test_task_wrapper_should_handle_non_recoverable_error( # When await task_manager.enqueue(task, project) - await task_wrapper(worker) + try: + await task_wrapper(worker) + except ValueError: + pass saved_errors = await task_manager.get_task_errors( task_id="some-id", project=project ) @@ -312,7 +315,10 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor # When await task_manager.enqueue(task, project) - await task_wrapper(worker) + try: + await task_wrapper(worker) + except UnregisteredTask: + pass saved_task = await task_manager.get_task(task_id="some-id", project=project) saved_errors = await task_manager.get_task_errors( task_id="some-id", project=project