Skip to content
This repository has been archived by the owner on Sep 6, 2024. It is now read-only.

Prevent interaction with views by others #74

Merged
merged 1 commit into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/exts/tvdb_info/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, cast
from typing import Literal

from discord import ApplicationContext, Cog, Member, User, option, slash_command

Expand Down Expand Up @@ -31,7 +31,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None =
await ctx.defer()

if user is None:
user = cast(User | Member, ctx.user) # for some reason, pyright thinks user can be None here
user = ctx.author

# Convert Member to User (Member isn't a subclass of User...)
if isinstance(user, Member):
Expand All @@ -46,6 +46,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None =
bot=self.bot,
tvdb_client=self.tvdb_client,
user=user,
invoker_user_id=ctx.author.id,
watched_list=await user_get_list_safe(self.bot.db_session, db_user, "watched"),
favorite_list=await user_get_list_safe(self.bot.db_session, db_user, "favorite"),
)
Expand Down Expand Up @@ -114,7 +115,7 @@ async def search(
await ctx.respond("No results found.")
return

view = await search_view(self.bot, ctx.user.id, response)
view = await search_view(self.bot, ctx.user.id, ctx.user.id, response)
await view.send(ctx.interaction)


Expand Down
27 changes: 26 additions & 1 deletion src/exts/tvdb_info/ui/_media_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
class MediaView(ErrorHandledView, ABC):
"""Base class for views that display info about some media (movie/series/episode)."""

def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_list: UserList) -> None:
def __init__(
self,
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
) -> None:
"""Initialize MediaView.

:param bot: The bot instance.
Expand All @@ -30,6 +38,7 @@ def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_l

self.bot = bot
self.user_id = user_id
self.invoker_user_id = invoker_user_id
self.watched_list = watched_list
self.favorite_list = favorite_list

Expand Down Expand Up @@ -108,6 +117,16 @@ async def _refresh(self) -> None:

await self.message.edit(embed=self._get_embed(), view=self)

async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool:
"""Ensure that the interaction was invoked by the author of this view."""
if interaction.user is None:
raise ValueError("Interaction user is None")

if interaction.user.id != self.invoker_user_id:
await interaction.response.send_message("You can't interact with this view.", ephemeral=True)
return False
return True

@abstractmethod
async def is_favorite(self) -> bool:
"""Check if the current media is marked as favorite by the user.
Expand Down Expand Up @@ -147,6 +166,9 @@ async def send(self, interaction: discord.Interaction) -> None:

async def _watched_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks on the mark as watched button."""
if not await self._ensure_correct_invoker(interaction):
return

await interaction.response.defer()
cur_state = self.watched_button.state
await self.set_watched(not cur_state)
Expand All @@ -156,6 +178,9 @@ async def _watched_button_callback(self, interaction: discord.Interaction) -> No

async def _favorite_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks on the mark as favorite button."""
if not await self._ensure_correct_invoker(interaction):
return

await interaction.response.defer()
cur_state = self.favorite_button.state
await self.set_favorite(not cur_state)
Expand Down
15 changes: 14 additions & 1 deletion src/exts/tvdb_info/ui/episode_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
series: Series,
season_idx: int = 1,
episode_idx: int = 1,
) -> None:
super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list)
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
)

self.series = series

Expand Down Expand Up @@ -156,6 +163,9 @@ async def set_watched(self, state: bool) -> None:

async def _episode_dropdown_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user selects an episode from the drop-down."""
if not await self._ensure_correct_invoker(interaction):
return

if not self.episode_dropdown.values or not isinstance(self.episode_dropdown.values[0], str):
raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.")

Expand All @@ -171,6 +181,9 @@ async def _episode_dropdown_callback(self, interaction: discord.Interaction) ->

async def _season_dropdown_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user selects a season from the drop-down."""
if not await self._ensure_correct_invoker(interaction):
return

if not self.season_dropdown.values or not isinstance(self.season_dropdown.values[0], str):
raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.")

Expand Down
17 changes: 16 additions & 1 deletion src/exts/tvdb_info/ui/movie_series_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Movie | Series,
) -> None:
super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list)
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
)

self.media_data = media_data

Expand Down Expand Up @@ -140,13 +147,15 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Series,
) -> None:
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=media_data,
Expand All @@ -172,9 +181,13 @@ def _add_items(self) -> None:

async def _episodes_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks the "View Episodes" button."""
if not await self._ensure_correct_invoker(interaction):
return

view = EpisodeView(
bot=self.bot,
user_id=self.user_id,
invoker_user_id=self.invoker_user_id,
watched_list=self.watched_list,
favorite_list=self.favorite_list,
series=self.media_data,
Expand Down Expand Up @@ -250,13 +263,15 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Movie,
) -> None:
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=media_data,
Expand Down
12 changes: 12 additions & 0 deletions src/exts/tvdb_info/ui/profile_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def __init__(
bot: Bot,
tvdb_client: TvdbClient,
user: discord.User,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
) -> None:
super().__init__()
self.bot = bot
self.tvdb_client = tvdb_client
self.discord_user = user
self.invoker_user_id = invoker_user_id
self.watched_list = watched_list
self.favorite_list = favorite_list

Expand Down Expand Up @@ -151,6 +153,16 @@ async def _initialize(self) -> None:
# as that's the only thing we need here and while it is a bit inconsistent, it's a LOT more efficient.
self.episodes_total = len(watched_episodes)

async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool:
"""Ensure that the interaction was invoked by the author of this view."""
if interaction.user is None:
raise ValueError("Interaction user is None")

if interaction.user.id != self.invoker_user_id:
await interaction.response.send_message("You can't interact with this view.", ephemeral=True)
return False
return True

def _get_embed(self) -> discord.Embed:
embed = discord.Embed(
title="Profile",
Expand Down
19 changes: 15 additions & 4 deletions src/exts/tvdb_info/ui/search_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def _search_view(
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
results: Sequence[Movie | Series],
Expand All @@ -25,6 +26,7 @@ def _search_view(
view = MovieView(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=result,
Expand All @@ -33,6 +35,7 @@ def _search_view(
view = SeriesView(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=result,
Expand All @@ -56,20 +59,28 @@ def _search_view(
)

async def _search_dropdown_callback(interaction: discord.Interaction) -> None:
if not await view._ensure_correct_invoker(interaction): # pyright: ignore[reportPrivateUsage]
return

if not search_result_dropdown.values or not isinstance(search_result_dropdown.values[0], str):
raise ValueError("Dropdown values are empty or not a string but callback was triggered.")

index = int(search_result_dropdown.values[0])
view = _search_view(bot, user_id, watched_list, favorite_list, results, index)
await view.send(interaction)
new_view = _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, index)
await new_view.send(interaction)

search_result_dropdown.callback = _search_dropdown_callback

view.add_item(search_result_dropdown)
return view


async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series]) -> MovieView | SeriesView:
async def search_view(
bot: Bot,
user_id: int,
invoker_user_id: int,
results: Sequence[Movie | Series],
) -> MovieView | SeriesView:
"""Construct a view showing the search results.

This uses specific views to render a single result. This view is then modified to
Expand All @@ -81,4 +92,4 @@ async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series])
await refresh_list_items(bot.db_session, watched_list)
await refresh_list_items(bot.db_session, favorite_list)

return _search_view(bot, user_id, watched_list, favorite_list, results, 0)
return _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, 0)
Loading