Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat with Video #23

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def custom_exception_handler(
error = exc.__class__.__name__
# get the message of the exception
message = str(exc)
status_code = exc.http_status_code if hasattr(exc, "http_status_code") else 400
return AanaJSONResponse(
status_code=400,
status_code=status_code,
content=ExceptionResponseModel(
error=error, message=message, data=data, stacktrace=stacktrace
).dict(),
Expand Down
8 changes: 8 additions & 0 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,13 @@
],
streaming=True,
),
Endpoint(
name="video_metadata",
path="/video/metadata",
summary="Load video metadata",
outputs=[
EndpointOutput(name="metadata", output="video_metadata"),
],
),
],
}
2 changes: 2 additions & 0 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aana.models.pydantic.prompt import Prompt
from aana.models.pydantic.sampling_params import SamplingParams
from aana.models.pydantic.video_input import VideoInput, VideoInputList
from aana.models.pydantic.video_metadata import VideoMetadata
from aana.models.pydantic.video_params import VideoParams
from aana.models.pydantic.whisper_params import WhisperParams

Expand Down Expand Up @@ -667,6 +668,7 @@
"name": "video_metadata",
"key": "output",
"path": "video_metadata",
"data_model": VideoMetadata,
},
],
},
Expand Down
22 changes: 22 additions & 0 deletions aana/exceptions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,25 @@ def __init__(self, prompt_len: int, max_len: int):
def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.prompt_len, self.max_len))


class MediaIdFoundException(BaseException):
"""Exception raised when a media ID is found in the prompt.

Attributes:
media_id (str): the media ID
"""

def __init__(self, media_id: str):
"""Initialize the exception.

Args:
media_id (str): the media ID
"""
super().__init__(media_id=media_id)
self.media_id = media_id
self.http_status_code = 404

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.media_id,))
22 changes: 22 additions & 0 deletions aana/models/pydantic/video_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from types import MappingProxyType

from pydantic import BaseModel, Field


class VideoMetadata(BaseModel):
"""Metadata of a video.

Attributes:
title (str): the title of the video
description (str): the description of the video
"""

title: str = Field(None, description="The title of the video.")
description: str = Field(None, description="The description of the video.")

class Config:
schema_extra = MappingProxyType(
{
"description": "Metadata of a video.",
}
)
8 changes: 7 additions & 1 deletion aana/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from yt_dlp.utils import DownloadError

from aana.configs.settings import settings
from aana.exceptions.general import DownloadException, VideoReadingException
from aana.exceptions.general import (
DownloadException,
MediaIdFoundException,
VideoReadingException,
)
from aana.models.core.image import Image
from aana.models.core.video import Video
from aana.models.core.video_source import VideoSource
Expand Down Expand Up @@ -348,6 +352,8 @@ def load_video_metadata(media_id: str):
output_dir.mkdir(parents=True, exist_ok=True)

output_path = Path(output_dir) / f"{media_id}.pkl"
if not output_path.exists():
raise MediaIdFoundException(media_id)
with output_path.open("rb") as f:
metadata = pickle.load(f) # noqa: S301
return metadata
Expand Down