Skip to content
This repository has been archived by the owner on Sep 6, 2024. It is now read-only.

Improve tvdb more #40

Merged
merged 9 commits into from
Jul 21, 2024
81 changes: 54 additions & 27 deletions src/exts/tvdb_info/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Literal

import aiohttp
Expand All @@ -6,7 +7,8 @@

from src.bot import Bot
from src.settings import THETVDB_COPYRIGHT_FOOTER, THETVDB_LOGO
from src.tvdb import Movie, Series, TvdbClient
from src.tvdb import FetchMeta, Movie, Series, TvdbClient
from src.tvdb.errors import InvalidIdError
from src.utils.log import get_logger

log = get_logger(__name__)
Expand All @@ -18,22 +20,23 @@
class InfoView(discord.ui.View):
"""View for displaying information about a movie or series."""

def __init__(self, results: list[Movie | Series]):
def __init__(self, results: Sequence[Movie | Series]) -> None:
super().__init__(disable_on_timeout=True)
self.results = results
self.dropdown = discord.ui.Select(
placeholder="Not what you're looking for? Select a different result.",
options=[
discord.SelectOption(
label=result.bilingual_name or "",
value=str(i),
description=result.overview[:100] if result.overview else None,
)
for i, result in enumerate(self.results)
],
)
self.dropdown.callback = self._dropdown_callback
self.add_item(self.dropdown)
if len(self.results) > 1:
self.dropdown = discord.ui.Select(
placeholder="Not what you're looking for? Select a different result.",
options=[
discord.SelectOption(
label=result.bilingual_name or "",
value=str(i),
description=result.overview[:100] if result.overview else None,
)
for i, result in enumerate(self.results)
],
)
self.dropdown.callback = self._dropdown_callback
self.add_item(self.dropdown)
self.index = 0

def _get_embed(self) -> discord.Embed:
Expand Down Expand Up @@ -87,24 +90,48 @@ def __init__(self, bot: Bot) -> None:
choices=["movie", "series"],
required=False,
)
@option("by_id", input_type=bool, description="Search by tvdb ID.", required=False)
async def search(
self, ctx: ApplicationContext, query: str, entity_type: Literal["movie", "series"] | None = None
self,
ctx: ApplicationContext,
*,
query: str,
entity_type: Literal["movie", "series"] | None = None,
by_id: bool = False,
) -> None:
"""Search for a movie or series."""
await ctx.defer()
async with aiohttp.ClientSession() as session:
client = TvdbClient(session)
match entity_type:
case "movie":
response = await client.search(query, limit=5, entity_type="movie")
case "series":
response = await client.search(query, limit=5, entity_type="series")
case None:
response = await client.search(query, limit=5)

if not response:
await ctx.respond("No results found.")
return
if by_id:
if query.startswith("movie-"):
entity_type = "movie"
query = query[6:]
elif query.startswith("series-"):
entity_type = "series"
query = query[7:]
try:
match entity_type:
case "movie":
response = [await Movie.fetch(query, client, extended=True, meta=FetchMeta.TRANSLATIONS)]
case "series":
response = [await Series.fetch(query, client, extended=True, meta=FetchMeta.TRANSLATIONS)]
case None:
await ctx.respond(
"You must specify a type (movie or series) when searching by ID.", ephemeral=True
)
return
except InvalidIdError:
await ctx.respond(
'Invalid ID. Id must be an integer, or "movie-" / "series-" followed by an integer.',
ephemeral=True,
)
return
else:
response = await client.search(query, limit=5, entity_type=entity_type)
if not response:
await ctx.respond("No results found.")
return
view = InfoView(response)
await view.send(ctx)

Expand Down
10 changes: 3 additions & 7 deletions src/tvdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from .client import InvalidApiKeyError, Movie, Series, TvdbClient
from .client import FetchMeta, Movie, Series, TvdbClient
from .errors import InvalidApiKeyError

__all__ = [
"TvdbClient",
"InvalidApiKeyError",
"Movie",
"Series",
]
__all__ = ["TvdbClient", "InvalidApiKeyError", "Movie", "Series", "FetchMeta"]
162 changes: 121 additions & 41 deletions src/tvdb/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, final, overload, override
from enum import Enum
from typing import ClassVar, Literal, Self, final, overload, override

import aiohttp
from yarl import URL
Expand All @@ -8,19 +9,19 @@
from src.tvdb.generated_models import (
MovieBaseRecord,
MovieExtendedRecord,
MoviesIdExtendedGetResponse,
MoviesIdGetResponse,
SearchGetResponse,
SearchResult,
SeriesBaseRecord,
SeriesExtendedRecord,
)
from src.tvdb.models import (
MovieExtendedResponse,
MovieResponse,
SearchResponse,
SeriesExtendedResponse,
SeriesResponse,
SeriesIdExtendedGetResponse,
SeriesIdGetResponse,
)
from src.utils.log import get_logger

from .errors import BadCallError, InvalidApiKeyError, InvalidIdError

log = get_logger(__name__)

type JSON_DATA = dict[str, JSON_DATA] | list[JSON_DATA] | str | int | float | bool | None # noice
Expand All @@ -30,13 +31,32 @@
type AnyRecord = SeriesRecord | MovieRecord


class FetchMeta(Enum):
"""When calling fetch with extended=True, this is used if we want to fetch translations or episodes as well."""

TRANSLATIONS = "translations"
EPISODES = "episodes"


def parse_media_id(media_id: int | str) -> int:
"""Parse the media ID from a string."""
return int(str(media_id).removeprefix("movie-").removeprefix("series-"))
try:
media_id = int(str(media_id).removeprefix("movie-").removeprefix("series-"))
except ValueError:
raise InvalidIdError("Invalid media ID.")
else:
return media_id


class _Media(ABC):
def __init__(self, client: "TvdbClient", data: SeriesRecord | MovieRecord | SearchResult):
ENDPOINT: ClassVar[str]

ResponseType: ClassVar[type[MoviesIdGetResponse | SeriesIdGetResponse]]
ExtendedResponseType: ClassVar[type[MoviesIdExtendedGetResponse | SeriesIdExtendedGetResponse]]

def __init__(self, client: "TvdbClient", data: AnyRecord | SearchResult | None):
if data is None:
raise ValueError("Data can't be None but is allowed to because of the broken pydantic generated models.")
self.data = data
self.client = client
self.name: str | None = self.data.name
Expand Down Expand Up @@ -97,56 +117,105 @@ def id(self, value: int | str) -> None: # pyright: ignore[reportPropertyTypeMis

@classmethod
@abstractmethod
async def fetch(cls, media_id: int | str, *, client: "TvdbClient", extended: bool = False) -> "_Media": ...
def supports_meta(cls, meta: FetchMeta) -> bool:
"""Check if the class supports a specific meta."""
...

@classmethod
@overload
async def fetch(
cls,
media_id: int | str,
client: "TvdbClient",
*,
extended: Literal[False],
short: Literal[False] | None = None,
meta: None = None,
) -> Self: ...

@final
class Movie(_Media):
"""Class to interact with the TVDB API for movies."""
@classmethod
@overload
async def fetch(
cls,
media_id: int | str,
client: "TvdbClient",
*,
extended: Literal[True],
short: bool | None = None,
meta: FetchMeta | None = None,
) -> Self: ...

@override
@classmethod
async def fetch(cls, media_id: int | str, client: "TvdbClient", *, extended: bool = False) -> "Movie":
async def fetch(
cls,
media_id: int | str,
client: "TvdbClient",
*,
extended: bool = False,
short: bool | None = None,
meta: FetchMeta | None = None,
) -> Self:
"""Fetch a movie by its ID.

:param media_id: The ID of the movie.
:param client: The TVDB client to use.
:param extended: Whether to fetch extended information.
:param short: Whether to omit characters and artworks from the response. Requires extended=True to work.
:param meta: The meta to fetch. Requires extended=True to work.
:return:
"""
media_id = parse_media_id(media_id)
response = await client.request("GET", f"movies/{media_id}" + ("/extended" if extended else ""))
response = MovieResponse(**response) if not extended else MovieExtendedResponse(**response) # pyright: ignore[reportCallIssue]
query: dict[str, str] = {}
if extended:
if meta:
query["meta"] = meta.value
if short:
query["short"] = "true"
else:
query["short"] = "false"
elif meta:
raise BadCallError("Meta can only be used with extended=True.")
elif short:
raise BadCallError("Short can only be enabled with extended=True.")
response = await client.request(
"GET",
f"{cls.ENDPOINT}/{media_id}" + ("/extended" if extended else ""),
query=query if query else None,
)
response = cls.ResponseType(**response) if not extended else cls.ExtendedResponseType(**response) # pyright: ignore[reportCallIssue]
return cls(client, response.data)


@final
class Series(_Media):
"""Class to interact with the TVDB API for series."""
class Movie(_Media):
"""Class to interact with the TVDB API for movies."""

ENDPOINT: ClassVar[str] = "movies"

ResponseType = MoviesIdGetResponse
ExtendedResponseType = MoviesIdExtendedGetResponse

@override
@classmethod
async def fetch(cls, media_id: int | str, client: "TvdbClient", *, extended: bool = False) -> "Series":
"""Fetch a series by its ID.
async def supports_meta(cls, meta: FetchMeta) -> bool:
"""Check if the class supports a specific meta."""
return meta is FetchMeta.TRANSLATIONS

:param media_id: The ID of the series.
:param client: The TVDB client to use.
:param extended: Whether to fetch extended information.
:return:
"""
media_id = parse_media_id(media_id)
response = await client.request("GET", f"series/{media_id}" + ("/extended" if extended else ""))
response = SeriesResponse(**response) if not extended else SeriesExtendedResponse(**response) # pyright: ignore[reportCallIssue]
return cls(client, response.data)

@final
class Series(_Media):
"""Class to interact with the TVDB API for series."""

ENDPOINT: ClassVar[str] = "series"

class InvalidApiKeyError(Exception):
"""Exception raised when the TVDB API key used was invalid."""
ResponseType = SeriesIdGetResponse
ExtendedResponseType = SeriesIdExtendedGetResponse

def __init__(self, response: aiohttp.ClientResponse, response_txt: str):
self.response = response
self.response_txt = response_txt
super().__init__("Invalid TVDB API key.")
@override
@classmethod
async def supports_meta(cls, meta: FetchMeta) -> bool:
"""Check if the class supports a specific meta."""
return meta in {FetchMeta.TRANSLATIONS, FetchMeta.EPISODES}


class TvdbClient:
Expand Down Expand Up @@ -200,15 +269,26 @@ async def request(
return await response.json()

async def search(
self, search_query: str, entity_type: Literal["series", "movie", "all"] = "all", limit: int = 1
self, search_query: str, entity_type: Literal["series", "movie", None] = None, limit: int = 1
) -> list[Movie | Series]:
"""Search for a series or movie in the TVDB database."""
query: dict[str, str] = {"query": search_query, "limit": str(limit)}
if entity_type != "all":
if entity_type:
query["type"] = entity_type
data = await self.request("GET", "search", query=query)
response = SearchResponse(**data) # pyright: ignore[reportCallIssue]
return [Movie(self, result) if result.id[0] == "m" else Series(self, result) for result in response.data]
response = SearchGetResponse(**data) # pyright: ignore[reportCallIssue]
if not response.data:
raise ValueError("This should not happen.")
returnable: list[Movie | Series] = []
for result in response.data:
match result.type:
case "movie":
returnable.append(Movie(self, result))
case "series":
returnable.append(Series(self, result))
case _:
pass
return returnable

async def _login(self) -> None:
"""Obtain the auth token from the TVDB API.
Expand Down
22 changes: 22 additions & 0 deletions src/tvdb/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import aiohttp


class TVDBError(Exception):
"""The base exception for all TVDB errors."""


class BadCallError(TVDBError):
"""Exception raised when the meta value is incompatible with the class."""


class InvalidIdError(TVDBError):
"""Exception raised when the ID provided is invalid."""


class InvalidApiKeyError(TVDBError):
"""Exception raised when the TVDB API key used was invalid."""

def __init__(self, response: aiohttp.ClientResponse, response_txt: str):
self.response = response
self.response_txt = response_txt
super().__init__("Invalid TVDB API key.")
Loading
Loading