Skip to content

Commit

Permalink
fix: negatively acknowledge error only after it has been saved
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Dec 21, 2023
1 parent 5373056 commit bec29f2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
60 changes: 25 additions & 35 deletions neo4j-app/neo4j_app/icij_worker/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,33 @@ async def consume(self) -> Tuple[Task, str]:
@final
@asynccontextmanager
async def acknowledgment_cm(self, task: Task, project: str):
async with self._persist_error(task, project):
try:
self._current = task, project
self.debug('Task(id="%s") locked', task.id)
try:
event = TaskEvent(
task_id=task.id, progress=0, status=TaskStatus.RUNNING
)
await self.publish_event(event, project)
yield
await self.acknowledge(task, project)
except asyncio.CancelledError as e:
self.error('Task(id="%s") worker cancelled, exiting', task.id)
raise e
except RecoverableError:
self.error('Task(id="%s") encountered error', task.id)
await self.negatively_acknowledge(task, project, requeue=True)
except Exception as fatal_error:
await self.negatively_acknowledge(task, project, requeue=False)
raise fatal_error
self._current = None
self.info('Task(id="%s") successful !', task.id)
event = TaskEvent(task_id=task.id, progress=0, status=TaskStatus.RUNNING)
await self.publish_event(event, project)
yield
await self.acknowledge(task, project)
except asyncio.CancelledError as e:
self.error(
'Task(id="%s") worker cancelled, exiting without persisting error',
task.id,
)
raise e
except RecoverableError:
self.error('Task(id="%s") encountered error', task.id)
await self.negatively_acknowledge(task, project, requeue=True)
except Exception as fatal_error:
if isinstance(fatal_error, MaxRetriesExceeded):
self.error('Task(id="%s") exceeded max retries, exiting !', task.id)
else:
self.error('Task(id="%s") fatal error, exiting !', task.id)
task_error = TaskError.from_exception(fatal_error)
await self.save_error(error=task_error, task=task, project=project)
await self.negatively_acknowledge(task, project, requeue=False)
raise fatal_error
self._current = None
self.info('Task(id="%s") successful !', task.id)

@final
async def acknowledge(self, task: Task, project: str):
Expand Down Expand Up @@ -252,22 +258,6 @@ async def _publish_progress(self, progress: float, task: Task, project: str):
event = TaskEvent(progress=progress, task_id=task.id)
await self.publish_event(event, project)

@final
@asynccontextmanager
async def _persist_error(self, task: Task, project: str):
try:
yield
except asyncio.CancelledError as e: # pylint: disable=broad-except
self.debug("worker cancelled, no need to persist error")
raise e
except Exception as e: # pylint: disable=broad-except
if isinstance(e, MaxRetriesExceeded):
self.error('Task(id="%s") exceeded max retries, exiting !', task.id)
else:
self.error('Task(id="%s") fatal error, exiting !', task.id)
error = TaskError.from_exception(e)
await self.save_error(error=error, task=task, project=project)

@final
def parse_task(
self, task: Task, project: str
Expand Down
12 changes: 9 additions & 3 deletions neo4j-app/neo4j_app/tests/icij_worker/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TaskResult,
TaskStatus,
)
from neo4j_app.icij_worker.exceptions import TaskCancelled
from neo4j_app.icij_worker.exceptions import TaskCancelled, UnregisteredTask
from neo4j_app.icij_worker.worker.worker import add_missing_args, task_wrapper
from neo4j_app.tests.conftest import TEST_PROJECT, async_true_after
from neo4j_app.tests.icij_worker.conftest import MockManager, MockWorker
Expand Down Expand Up @@ -250,7 +250,10 @@ async def test_task_wrapper_should_handle_non_recoverable_error(

# When
await task_manager.enqueue(task, project)
await task_wrapper(worker)
try:
await task_wrapper(worker)
except ValueError:
pass
saved_errors = await task_manager.get_task_errors(
task_id="some-id", project=project
)
Expand Down Expand Up @@ -312,7 +315,10 @@ async def test_task_wrapper_should_handle_unregistered_task(mock_worker: MockWor

# When
await task_manager.enqueue(task, project)
await task_wrapper(worker)
try:
await task_wrapper(worker)
except UnregisteredTask:
pass
saved_task = await task_manager.get_task(task_id="some-id", project=project)
saved_errors = await task_manager.get_task_errors(
task_id="some-id", project=project
Expand Down

0 comments on commit bec29f2

Please sign in to comment.