Skip to content
This repository has been archived by the owner on Sep 6, 2024. It is now read-only.

Commit

Permalink
Prevent interaction with views by others
Browse files Browse the repository at this point in the history
  • Loading branch information
ItsDrike committed Jul 28, 2024
1 parent 068974e commit 95de1b6
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 10 deletions.
7 changes: 4 additions & 3 deletions src/exts/tvdb_info/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, cast
from typing import Literal

from discord import ApplicationContext, Cog, Member, User, option, slash_command

Expand Down Expand Up @@ -31,7 +31,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None =
await ctx.defer()

if user is None:
user = cast(User | Member, ctx.user) # for some reason, pyright thinks user can be None here
user = ctx.author

# Convert Member to User (Member isn't a subclass of User...)
if isinstance(user, Member):
Expand All @@ -46,6 +46,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None =
bot=self.bot,
tvdb_client=self.tvdb_client,
user=user,
invoker_user_id=ctx.author.id,
watched_list=await user_get_list_safe(self.bot.db_session, db_user, "watched"),
favorite_list=await user_get_list_safe(self.bot.db_session, db_user, "favorite"),
)
Expand Down Expand Up @@ -114,7 +115,7 @@ async def search(
await ctx.respond("No results found.")
return

view = await search_view(self.bot, ctx.user.id, response)
view = await search_view(self.bot, ctx.user.id, ctx.user.id, response)
await view.send(ctx.interaction)


Expand Down
27 changes: 26 additions & 1 deletion src/exts/tvdb_info/ui/_media_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
class MediaView(ErrorHandledView, ABC):
"""Base class for views that display info about some media (movie/series/episode)."""

def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_list: UserList) -> None:
def __init__(
self,
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
) -> None:
"""Initialize MediaView.
:param bot: The bot instance.
Expand All @@ -30,6 +38,7 @@ def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_l

self.bot = bot
self.user_id = user_id
self.invoker_user_id = invoker_user_id
self.watched_list = watched_list
self.favorite_list = favorite_list

Expand Down Expand Up @@ -108,6 +117,16 @@ async def _refresh(self) -> None:

await self.message.edit(embed=self._get_embed(), view=self)

async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool:
"""Ensure that the interaction was invoked by the author of this view."""
if interaction.user is None:
raise ValueError("Interaction user is None")

if interaction.user.id != self.invoker_user_id:
await interaction.response.send_message("You can't interact with this view.", ephemeral=True)
return False
return True

@abstractmethod
async def is_favorite(self) -> bool:
"""Check if the current media is marked as favorite by the user.
Expand Down Expand Up @@ -147,6 +166,9 @@ async def send(self, interaction: discord.Interaction) -> None:

async def _watched_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks on the mark as watched button."""
if not await self._ensure_correct_invoker(interaction):
return

await interaction.response.defer()
cur_state = self.watched_button.state
await self.set_watched(not cur_state)
Expand All @@ -156,6 +178,9 @@ async def _watched_button_callback(self, interaction: discord.Interaction) -> No

async def _favorite_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks on the mark as favorite button."""
if not await self._ensure_correct_invoker(interaction):
return

await interaction.response.defer()
cur_state = self.favorite_button.state
await self.set_favorite(not cur_state)
Expand Down
15 changes: 14 additions & 1 deletion src/exts/tvdb_info/ui/episode_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
series: Series,
season_idx: int = 1,
episode_idx: int = 1,
) -> None:
super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list)
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
)

self.series = series

Expand Down Expand Up @@ -156,6 +163,9 @@ async def set_watched(self, state: bool) -> None:

async def _episode_dropdown_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user selects an episode from the drop-down."""
if not await self._ensure_correct_invoker(interaction):
return

if not self.episode_dropdown.values or not isinstance(self.episode_dropdown.values[0], str):
raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.")

Expand All @@ -171,6 +181,9 @@ async def _episode_dropdown_callback(self, interaction: discord.Interaction) ->

async def _season_dropdown_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user selects a season from the drop-down."""
if not await self._ensure_correct_invoker(interaction):
return

if not self.season_dropdown.values or not isinstance(self.season_dropdown.values[0], str):
raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.")

Expand Down
17 changes: 16 additions & 1 deletion src/exts/tvdb_info/ui/movie_series_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Movie | Series,
) -> None:
super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list)
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
)

self.media_data = media_data

Expand Down Expand Up @@ -140,13 +147,15 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Series,
) -> None:
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=media_data,
Expand All @@ -172,9 +181,13 @@ def _add_items(self) -> None:

async def _episodes_button_callback(self, interaction: discord.Interaction) -> None:
"""Callback for when the user clicks the "View Episodes" button."""
if not await self._ensure_correct_invoker(interaction):
return

view = EpisodeView(
bot=self.bot,
user_id=self.user_id,
invoker_user_id=self.invoker_user_id,
watched_list=self.watched_list,
favorite_list=self.favorite_list,
series=self.media_data,
Expand Down Expand Up @@ -250,13 +263,15 @@ def __init__(
*,
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
media_data: Movie,
) -> None:
super().__init__(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=media_data,
Expand Down
12 changes: 12 additions & 0 deletions src/exts/tvdb_info/ui/profile_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def __init__(
bot: Bot,
tvdb_client: TvdbClient,
user: discord.User,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
) -> None:
super().__init__()
self.bot = bot
self.tvdb_client = tvdb_client
self.discord_user = user
self.invoker_user_id = invoker_user_id
self.watched_list = watched_list
self.favorite_list = favorite_list

Expand Down Expand Up @@ -151,6 +153,16 @@ async def _initialize(self) -> None:
# as that's the only thing we need here and while it is a bit inconsistent, it's a LOT more efficient.
self.episodes_total = len(watched_episodes)

async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool:
"""Ensure that the interaction was invoked by the author of this view."""
if interaction.user is None:
raise ValueError("Interaction user is None")

if interaction.user.id != self.invoker_user_id:
await interaction.response.send_message("You can't interact with this view.", ephemeral=True)
return False
return True

def _get_embed(self) -> discord.Embed:
embed = discord.Embed(
title="Profile",
Expand Down
19 changes: 15 additions & 4 deletions src/exts/tvdb_info/ui/search_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def _search_view(
bot: Bot,
user_id: int,
invoker_user_id: int,
watched_list: UserList,
favorite_list: UserList,
results: Sequence[Movie | Series],
Expand All @@ -25,6 +26,7 @@ def _search_view(
view = MovieView(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=result,
Expand All @@ -33,6 +35,7 @@ def _search_view(
view = SeriesView(
bot=bot,
user_id=user_id,
invoker_user_id=invoker_user_id,
watched_list=watched_list,
favorite_list=favorite_list,
media_data=result,
Expand All @@ -56,20 +59,28 @@ def _search_view(
)

async def _search_dropdown_callback(interaction: discord.Interaction) -> None:
if not await view._ensure_correct_invoker(interaction): # pyright: ignore[reportPrivateUsage]
return

if not search_result_dropdown.values or not isinstance(search_result_dropdown.values[0], str):
raise ValueError("Dropdown values are empty or not a string but callback was triggered.")

index = int(search_result_dropdown.values[0])
view = _search_view(bot, user_id, watched_list, favorite_list, results, index)
await view.send(interaction)
new_view = _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, index)
await new_view.send(interaction)

search_result_dropdown.callback = _search_dropdown_callback

view.add_item(search_result_dropdown)
return view


async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series]) -> MovieView | SeriesView:
async def search_view(
bot: Bot,
user_id: int,
invoker_user_id: int,
results: Sequence[Movie | Series],
) -> MovieView | SeriesView:
"""Construct a view showing the search results.
This uses specific views to render a single result. This view is then modified to
Expand All @@ -81,4 +92,4 @@ async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series])
await refresh_list_items(bot.db_session, watched_list)
await refresh_list_items(bot.db_session, favorite_list)

return _search_view(bot, user_id, watched_list, favorite_list, results, 0)
return _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, 0)

0 comments on commit 95de1b6

Please sign in to comment.