diff --git a/aana/api/exception_handler.py b/aana/api/exception_handler.py index 2cd8fac8..1a6a05a9 100644 --- a/aana/api/exception_handler.py +++ b/aana/api/exception_handler.py @@ -1,30 +1,36 @@ import traceback from fastapi import Request +from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from ray.exceptions import RayTaskError from aana.api.responses import AanaJSONResponse +from aana.configs.settings import settings as aana_settings from aana.core.models.exception import ExceptionResponseModel from aana.exceptions.core import BaseException -async def validation_exception_handler(request: Request, exc: ValidationError): +async def validation_exception_handler( + request: Request, exc: ValidationError | RequestValidationError +): """This handler is used to handle pydantic validation errors. Args: request (Request): The request object - exc (ValidationError): The validation error + exc (ValidationError | RequestValidationError): The exception raised Returns: JSONResponse: JSON response with the error details """ + if isinstance(exc, ValidationError): + data = exc.errors(include_context=False) + elif isinstance(exc, RequestValidationError): + data = exc.errors() return AanaJSONResponse( status_code=422, content=ExceptionResponseModel( - error="ValidationError", - message="Validation error", - data=exc.errors(), + error="ValidationError", message="Validation error", data=data ).model_dump(), ) @@ -60,6 +66,9 @@ def custom_exception_handler(request: Request | None, exc_raw: Exception): # then we need to get the stack trace stacktrace = traceback.format_exc() exc = exc_raw + # Remove the stacktrace if it is disabled + if not aana_settings.include_stacktrace: + stacktrace = None # get the data from the exception # can be used to return additional info # like image path, url, model name etc. diff --git a/aana/configs/settings.py b/aana/configs/settings.py index d3e90d9c..08edaa5f 100644 --- a/aana/configs/settings.py +++ b/aana/configs/settings.py @@ -79,6 +79,7 @@ class Settings(BaseSettings): model_dir (Path): The temporary model directory. num_workers (int): The number of web workers. openai_endpoint_enabled (bool): Flag indicating if the OpenAI-compatible endpoint is enabled. Enabled by default. + include_stacktrace (bool): Flag indicating if stacktrace should be included in error messages. Enabled by default. task_queue (TaskQueueSettings): The task queue settings. db_config (DbSettings): The database configuration. test (TestSettings): The test settings. @@ -93,6 +94,7 @@ class Settings(BaseSettings): num_workers: int = 2 openai_endpoint_enabled: bool = True + include_stacktrace: bool = True task_queue: TaskQueueSettings = TaskQueueSettings()