Skip to content

Commit

Permalink
Allow additional parameters for youtube
Browse files Browse the repository at this point in the history
and improved errors for invalid video URL
  • Loading branch information
Aleksandr Movchan committed Dec 15, 2023
1 parent a597569 commit c33c956
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 38 deletions.
16 changes: 1 addition & 15 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import traceback
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.exceptions import BaseException
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
from ray.exceptions import RayTaskError

from aana.api.app import custom_exception_handler
from aana.api.responses import AanaJSONResponse
Expand Down Expand Up @@ -389,19 +386,8 @@ async def generator_wrapper() -> AsyncGenerator[bytes, None]:
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
except RayTaskError as e:
yield custom_exception_handler(None, e).body
except BaseException as e:
yield custom_exception_handler(None, e)
except Exception as e:
error = e.__class__.__name__
stacktrace = traceback.format_exc()
yield AanaJSONResponse(
status_code=400,
content=ExceptionResponseModel(
error=error, message=str(e), stacktrace=stacktrace
).dict(),
).body
yield custom_exception_handler(None, e).body

return StreamingResponse(
generator_wrapper(), media_type="application/json"
Expand Down
10 changes: 3 additions & 7 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
)


def custom_exception_handler(
request: Request | None, exc_raw: BaseException | RayTaskError
):
def custom_exception_handler(request: Request | None, exc_raw: Exception):
"""This handler is used to handle custom exceptions raised in the application.
BaseException is the base exception for all the exceptions
Expand All @@ -43,7 +41,7 @@ def custom_exception_handler(
Args:
request (Request): The request object
exc_raw (Union[BaseException, RayTaskError]): The exception raised
exc_raw (Exception): The exception raised
Returns:
JSONResponse: JSON response with the error details. The response contains the following fields:
Expand All @@ -60,8 +58,6 @@ def custom_exception_handler(
stacktrace = str(exc_raw)
# get the original exception
exc: BaseException = exc_raw.cause
if not isinstance(exc, BaseException):
raise TypeError(exc)
else:
# if it is not a RayTaskError
# then we need to get the stack trace
Expand All @@ -70,7 +66,7 @@ def custom_exception_handler(
# get the data from the exception
# can be used to return additional info
# like image path, url, model name etc.
data = exc.get_data()
data = exc.get_data() if isinstance(exc, BaseException) else {}
# get the name of the class of the exception
# can be used to identify the type of the error
error = exc.__class__.__name__
Expand Down
3 changes: 3 additions & 0 deletions aana/api/responses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Any

import orjson
Expand Down Expand Up @@ -27,6 +28,8 @@ def json_serializer_default(obj: Any) -> Any:
"""
if isinstance(obj, BaseModel):
return obj.dict()
if isinstance(obj, Path):
return str(obj)
raise TypeError


Expand Down
2 changes: 1 addition & 1 deletion aana/models/core/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def save(self):
raise ValueError( # noqa: TRY003
"At least one of 'path', 'url', or 'content' must be provided."
)
self.path = file_path
self.is_saved = True

def save_from_bytes(self, file_path: Path, content: bytes):
Expand All @@ -89,6 +88,7 @@ def save_from_bytes(self, file_path: Path, content: bytes):
content (bytes): the content of the media
"""
file_path.write_bytes(content)
self.path = file_path

def save_from_content(self, file_path: Path):
"""Save the media from the content.
Expand Down
28 changes: 27 additions & 1 deletion aana/models/core/video.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import hashlib
import hashlib # noqa: I001
from dataclasses import dataclass
from pathlib import Path
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263

from aana.configs.settings import settings
from aana.exceptions.general import VideoReadingException
from aana.models.core.media import Media


Expand Down Expand Up @@ -44,6 +46,30 @@ def validate(self):
"At least one of 'path', 'url' or 'content' must be provided."
)

def is_video(self) -> bool:
"""Checks if it's a valid video."""
if not self.path:
return False
try:
decord.VideoReader(str(self.path))
except Exception:
return False
return True

def save_from_url(self, file_path):
"""Save the media from the URL.
Args:
file_path (Path): the path to save the media to
Raises:
DownloadError: if the media can't be downloaded
"""
super().save_from_url(file_path)
# check that the file is a video
if not self.is_video():
raise VideoReadingException(video=self)

def __repr__(self) -> str:
"""Get the representation of the video.
Expand Down
2 changes: 1 addition & 1 deletion aana/models/core/video_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def from_url(cls, url: str) -> "VideoSource":
"""
# TODO: Check that the URL is valid

youtube_pattern = r"^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtube\.[a-zA-Z]{2,3}(\.[a-zA-Z]{2})?\/watch\?v=|youtu\.be\/)([a-zA-Z0-9_-]+)$"
youtube_pattern = r"^(?:https?:\/\/)?(?:www\.)?(?:youtube\.com\/watch\?v=|youtube\.[a-zA-Z]{2,3}(\.[a-zA-Z]{2})?\/watch\?v=|youtu\.be\/)([a-zA-Z0-9_-]+)(?:&[^\s]+)*$"

if re.match(youtube_pattern, url):
return cls.YOUTUBE
Expand Down
11 changes: 4 additions & 7 deletions aana/tests/test_chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,10 @@ def test_chat_template_custom():
prompt = apply_chat_template(
tokenizer, dialog, "llama2"
) # Apply custom chat template "llama2"
assert ( # noqa: S101
prompt
== (
"<s>[INST] <<SYS>>\\nYou are a friendly chatbot who always responds in the style "
"of a pirate\\n<</SYS>>\\n\\nHow many helicopters can a human eat in one sitting? "
"[/INST] I don't know, how many? </s><s>[INST] One, but only if they're really hungry! [/INST]"
)
assert prompt == (
"<s>[INST] <<SYS>>\\nYou are a friendly chatbot who always responds in the style "
"of a pirate\\n<</SYS>>\\n\\nHow many helicopters can a human eat in one sitting? "
"[/INST] I don't know, how many? </s><s>[INST] One, but only if they're really hungry! [/INST]"
)


Expand Down
27 changes: 21 additions & 6 deletions aana/tests/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@
import pytest

from aana.configs.settings import settings
from aana.exceptions.general import DownloadException
from aana.exceptions.general import DownloadException, VideoReadingException
from aana.models.core.video import Video
from aana.models.pydantic.video_input import VideoInput
from aana.utils.video import download_video


def mocked_download_file(url: str) -> bytes:
"""Mock download_file to return different content based on URL."""
if url == "http://example.com/squirrel.mp4":
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
elif url == "http://example.com/Starry_Night.jpeg":
path = resources.path("aana.tests.files.images", "Starry_Night.jpeg")
else:
raise DownloadException(url)
return path.read_bytes()


@pytest.fixture
def mock_download_file(mocker):
"""Mock download_file."""
"""Mock download_file to return different content based on URL."""
mock = mocker.patch("aana.models.core.media.download_file", autospec=True)
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
content = path.read_bytes()
mock.return_value = content
mock.side_effect = mocked_download_file
return mock


Expand Down Expand Up @@ -63,7 +72,7 @@ def test_video(mock_download_file):
video.cleanup()


def test_media_dir():
def test_media_dir(mock_download_file):
"""Test that the media_dir is set correctly."""
# Test saving from URL to disk
video_dir = settings.video_dir
Expand Down Expand Up @@ -211,6 +220,12 @@ def test_download_video(mock_download_file):
with pytest.raises(DownloadException):
download_video(youtube_video_input)

# Test url that doesn't contain a video
url = "http://example.com/Starry_Night.jpeg"
video_input = VideoInput(url=url)
with pytest.raises(VideoReadingException):
download_video(video_input)

# Test Video object as input
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
video = Video(path=path)
Expand Down
18 changes: 18 additions & 0 deletions aana/tests/test_video_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def test_video_source_from_url():
"https://www.youtube.co.uk/watch?v=yModCU1O",
"https://www.youtube.com/watch?v=18pCXD709TI",
"https://www.youtube.com/watch?v=18pCXD7",
"https://youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"http://youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"https://www.youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"http://www.youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"www.youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"youtube.com/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"https://youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"http://youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"https://www.youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"http://www.youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"www.youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"youtube.de/watch?v=yModCU1OVHY&t=4s&ab_channel=A24",
"https://youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
"http://youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
"https://www.youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
"http://www.youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
"www.youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
"youtu.be/yModCU1OVHY&t=4s&ab_channel=A24",
]

not_youtube_urls = [
Expand Down

0 comments on commit c33c956

Please sign in to comment.