Skip to content

Commit

Permalink
Merge branch 'main' into chat_with_video_db
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Jan 23, 2024
2 parents f42b731 + d555e9a commit 3805829
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 176 deletions.
16 changes: 1 addition & 15 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import traceback
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.exceptions import BaseException
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
from ray.exceptions import RayTaskError

from aana.api.app import custom_exception_handler
from aana.api.responses import AanaJSONResponse
Expand Down Expand Up @@ -389,19 +386,8 @@ async def generator_wrapper() -> AsyncGenerator[bytes, None]:
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
except RayTaskError as e:
yield custom_exception_handler(None, e).body
except BaseException as e:
yield custom_exception_handler(None, e).body
except Exception as e:
error = e.__class__.__name__
stacktrace = traceback.format_exc()
yield AanaJSONResponse(
status_code=400,
content=ExceptionResponseModel(
error=error, message=str(e), stacktrace=stacktrace
).dict(),
).body
yield custom_exception_handler(None, e).body

return StreamingResponse(
generator_wrapper(), media_type="application/json"
Expand Down
6 changes: 2 additions & 4 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
)


def custom_exception_handler(
request: Request | None, exc_raw: BaseException | RayTaskError
):
def custom_exception_handler(request: Request | None, exc_raw: Exception):
"""This handler is used to handle custom exceptions raised in the application.
BaseException is the base exception for all the exceptions
Expand All @@ -43,7 +41,7 @@ def custom_exception_handler(
Args:
request (Request): The request object
exc_raw (Union[BaseException, RayTaskError]): The exception raised
exc_raw (Exception): The exception raised
Returns:
JSONResponse: JSON response with the error details. The response contains the following fields:
Expand Down
4 changes: 2 additions & 2 deletions aana/api/responses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pathlib import PosixPath
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -28,7 +28,7 @@ def json_serializer_default(obj: Any) -> Any:
"""
if isinstance(obj, BaseModel):
return obj.dict()
if isinstance(obj, PosixPath):
if isinstance(obj, Path):
return str(obj)
raise TypeError

Expand Down
1 change: 0 additions & 1 deletion aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ class Settings(BaseSettings):
"""A pydantic model for SDK settings."""

tmp_data_dir: Path = Path("/tmp/aana_data") # noqa: S108
youtube_video_dir = tmp_data_dir / "youtube_videos"
image_dir = tmp_data_dir / "images"
video_dir = tmp_data_dir / "videos"

Expand Down
1 change: 1 addition & 0 deletions aana/exceptions/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, table_name: str, id: int | media_id_type): # noqa: A002
super().__init__(table=table_name, id=id)
self.table_name = table_name
self.id = id
self.http_status_code = 404

def __reduce__(self):
"""Used for pickling."""
Expand Down
6 changes: 4 additions & 2 deletions aana/exceptions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,20 @@ class DownloadException(BaseException):
url (str): the URL of the file that caused the exception
"""

def __init__(self, url: str):
def __init__(self, url: str, msg: str = ""):
"""Initialize the exception.
Args:
url (str): the URL of the file that caused the exception
msg (str): the error message
"""
super().__init__(url=url)
self.url = url
self.msg = msg

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.url,))
return (self.__class__, (self.url, self.msg))


class VideoException(BaseException):
Expand Down
2 changes: 1 addition & 1 deletion aana/models/core/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def save(self):
raise ValueError( # noqa: TRY003
"At least one of 'path', 'url', or 'content' must be provided."
)
self.path = file_path
self.is_saved = True

def save_from_bytes(self, file_path: Path, content: bytes):
Expand All @@ -89,6 +88,7 @@ def save_from_bytes(self, file_path: Path, content: bytes):
content (bytes): the content of the media
"""
file_path.write_bytes(content)
self.path = file_path

def save_from_content(self, file_path: Path):
"""Save the media from the content.
Expand Down
45 changes: 43 additions & 2 deletions aana/models/core/video.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import hashlib
import hashlib # noqa: I001
from dataclasses import dataclass
from pathlib import Path
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263
from decord import DECORDError

from aana.configs.settings import settings
from aana.exceptions.general import VideoReadingException
from aana.models.core.media import Media


Expand All @@ -28,7 +31,12 @@ class Video(Media):
media_dir: Path | None = settings.video_dir

def validate(self):
"""Validate the video."""
"""Validate the video.
Raises:
ValueError: if none of 'path', 'url', or 'content' is provided
VideoReadingException: if the video is not valid
"""
# validate the parent class
super().validate()

Expand All @@ -44,6 +52,39 @@ def validate(self):
"At least one of 'path', 'url' or 'content' must be provided."
)

# check that the video is valid
if self.path and not self.is_video():
raise VideoReadingException(video=self)

def is_video(self) -> bool:
"""Checks if it's a valid video."""
if not self.path:
return False

try:
decord.VideoReader(str(self.path))
except DECORDError:
try:
decord.AudioReader(str(self.path))
except DECORDError:
return False
return True

def save_from_url(self, file_path):
"""Save the media from the URL.
Args:
file_path (Path): the path to save the media to
Raises:
DownloadError: if the media can't be downloaded
VideoReadingException: if the media is not a valid video
"""
super().save_from_url(file_path)
# check that the file is a video
if not self.is_video():
raise VideoReadingException(video=self)

def __repr__(self) -> str:
"""Get the representation of the video.
Expand Down
35 changes: 0 additions & 35 deletions aana/models/core/video_source.py

This file was deleted.

4 changes: 2 additions & 2 deletions aana/tests/test_frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_extract_frames_failure():
# image file instead of video file will create Video object
# but will fail in extract_frames_decord
path = resources.path("aana.tests.files.images", "Starry_Night.jpeg")
invalid_video = Video(path=path)
params = VideoParams(extract_fps=1.0, fast_mode_enabled=False)
with pytest.raises(VideoReadingException):
invalid_video = Video(path=path)
params = VideoParams(extract_fps=1.0, fast_mode_enabled=False)
extract_frames_decord(video=invalid_video, params=params)
46 changes: 23 additions & 23 deletions aana/tests/test_video.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
# ruff: noqa: S101
import hashlib
from importlib import resources
from pathlib import Path

import pytest

from aana.configs.settings import settings
from aana.exceptions.general import DownloadException
from aana.exceptions.general import DownloadException, VideoReadingException
from aana.models.core.video import Video
from aana.models.pydantic.video_input import VideoInput
from aana.utils.video import download_video


@pytest.fixture
def mock_download_file(mocker):
"""Mock download_file."""
mock = mocker.patch("aana.models.core.media.download_file", autospec=True)
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
content = path.read_bytes()
mock.return_value = content
return mock


def test_video(mock_download_file):
def test_video():
"""Test that the video can be created from path, url, or content."""
# Test creation from a path
try:
Expand All @@ -38,7 +29,7 @@ def test_video(mock_download_file):

# Test creation from a URL
try:
url = "http://example.com/squirrel.mp4"
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"
video = Video(url=url, save_on_disk=False)
assert video.path is None
assert video.content is None
Expand Down Expand Up @@ -68,7 +59,7 @@ def test_media_dir():
# Test saving from URL to disk
video_dir = settings.video_dir
try:
url = "http://example.com/squirrel.mp4"
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"
video = Video(url=url, save_on_disk=True)
assert video.media_dir == video_dir
assert video.content is None
Expand All @@ -86,7 +77,7 @@ def test_video_path_not_exist():
Video(path=path)


def test_save_video(mock_download_file):
def test_save_video():
"""Test that save_on_disk works for video."""
# Test that the video is saved to disk when save_on_disk is True
try:
Expand All @@ -102,7 +93,7 @@ def test_save_video(mock_download_file):

# Test saving from URL to disk
try:
url = "http://example.com/squirrel.mp4"
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"
video = Video(url=url, save_on_disk=True)
assert video.content is None
assert video.url == url
Expand All @@ -122,10 +113,10 @@ def test_save_video(mock_download_file):
video.cleanup()


def test_cleanup(mock_download_file):
def test_cleanup():
"""Test that cleanup works for video."""
try:
url = "http://example.com/squirrel.mp4"
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"
video = Video(url=url, save_on_disk=True)
assert video.path.exists()
finally:
Expand All @@ -152,7 +143,7 @@ def test_at_least_one_input():
Video(save_on_disk=True)


def test_download_video(mock_download_file):
def test_download_video():
"""Test download_video."""
# Test VideoInput with path
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
Expand All @@ -165,7 +156,7 @@ def test_download_video(mock_download_file):

# Test VideoInput with url
try:
url = "http://example.com/squirrel.mp4"
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/squirrel.mp4"
video_input = VideoInput(url=url)
video = download_video(video_input)
assert isinstance(video, Video)
Expand All @@ -179,8 +170,11 @@ def test_download_video(mock_download_file):

# Test Youtube URL
youtube_url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
youtube_video_dir = settings.youtube_video_dir
expected_path = youtube_video_dir / "dQw4w9WgXcQ.mp4"
youtube_url_hash = hashlib.md5(
youtube_url.encode(), usedforsecurity=False
).hexdigest()
video_dir = settings.video_dir
expected_path = video_dir / f"{youtube_url_hash}.webm"
# remove the file if it exists
expected_path.unlink(missing_ok=True)

Expand All @@ -192,7 +186,7 @@ def test_download_video(mock_download_file):
assert video.path is not None
assert video.path.exists()
assert video.content is None
assert video.url is None
assert video.url == youtube_url
assert video.media_id == "dQw4w9WgXcQ"
assert (
video.title
Expand All @@ -211,6 +205,12 @@ def test_download_video(mock_download_file):
with pytest.raises(DownloadException):
download_video(youtube_video_input)

# Test url that doesn't contain a video
url = "https://mobius-public.s3.eu-west-1.amazonaws.com/Starry_Night.jpeg"
video_input = VideoInput(url=url)
with pytest.raises(VideoReadingException):
download_video(video_input)

# Test Video object as input
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
video = Video(path=path)
Expand Down
Loading

0 comments on commit 3805829

Please sign in to comment.