From b4823f5571158563325fa328b60a96374bdf39e7 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Tue, 4 Feb 2025 17:18:02 +0000 Subject: [PATCH] Add retry endpoint for failed tasks --- aana/routers/task.py | 26 ++++++++++++++++++++++++++ aana/storage/repository/task.py | 19 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/aana/routers/task.py b/aana/routers/task.py index 620e8fe5..8884ff22 100644 --- a/aana/routers/task.py +++ b/aana/routers/task.py @@ -135,6 +135,32 @@ async def delete_task( return TaskInfo.from_entity(task) +@router.post( + "/tasks/{task_id}/retry", + summary="Retry Failed Task", + description="Retry a failed task by resetting its status to CREATED.", +) +async def retry_task( + task_id: str, db: GetDbDependency, user_id: UserIdDependency +) -> TaskInfo: + """Retry a failed task by resetting its status.""" + task_repo = TaskRepository(db) + task = task_repo.read(task_id, check=False) + if not task or task.user_id != user_id: + raise HTTPException( + status_code=404, + detail="Task not found", + ) + if task.status != TaskStatus.FAILED: + raise HTTPException( + status_code=400, + detail="Only failed tasks can be retried", + ) + + updated_task = task_repo.retry_task(task.id) + return TaskInfo.from_entity(updated_task) + + # Legacy endpoints (to be removed in the future) diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 5e9c21d8..980e2824 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -274,6 +274,25 @@ def update_status( self.session.commit() return task + def retry_task(self, task_id: str) -> TaskEntity: + """Retry a task. The task will reset to CREATED status. + + Args: + task_id (str): The ID of the task. + + Returns: + TaskEntity: The updated task. + """ + task = self.read(task_id) + task.status = TaskStatus.CREATED + task.progress = 0 + task.result = None + task.num_retries = 0 + task.assigned_at = None + task.completed_at = None + self.session.commit() + return task + def get_active_tasks(self) -> list[TaskEntity]: """Fetches all active tasks.