From 95de1b68cc8184e7419b051c2f3d6f835c512614 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Sun, 28 Jul 2024 18:41:03 +0200 Subject: [PATCH] Prevent interaction with views by others --- src/exts/tvdb_info/main.py | 7 +++--- src/exts/tvdb_info/ui/_media_view.py | 27 +++++++++++++++++++++- src/exts/tvdb_info/ui/episode_view.py | 15 +++++++++++- src/exts/tvdb_info/ui/movie_series_view.py | 17 +++++++++++++- src/exts/tvdb_info/ui/profile_view.py | 12 ++++++++++ src/exts/tvdb_info/ui/search_view.py | 19 +++++++++++---- 6 files changed, 87 insertions(+), 10 deletions(-) diff --git a/src/exts/tvdb_info/main.py b/src/exts/tvdb_info/main.py index 322e4ce..f6567d1 100644 --- a/src/exts/tvdb_info/main.py +++ b/src/exts/tvdb_info/main.py @@ -1,4 +1,4 @@ -from typing import Literal, cast +from typing import Literal from discord import ApplicationContext, Cog, Member, User, option, slash_command @@ -31,7 +31,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None = await ctx.defer() if user is None: - user = cast(User | Member, ctx.user) # for some reason, pyright thinks user can be None here + user = ctx.author # Convert Member to User (Member isn't a subclass of User...) if isinstance(user, Member): @@ -46,6 +46,7 @@ async def profile(self, ctx: ApplicationContext, *, user: User | Member | None = bot=self.bot, tvdb_client=self.tvdb_client, user=user, + invoker_user_id=ctx.author.id, watched_list=await user_get_list_safe(self.bot.db_session, db_user, "watched"), favorite_list=await user_get_list_safe(self.bot.db_session, db_user, "favorite"), ) @@ -114,7 +115,7 @@ async def search( await ctx.respond("No results found.") return - view = await search_view(self.bot, ctx.user.id, response) + view = await search_view(self.bot, ctx.user.id, ctx.user.id, response) await view.send(ctx.interaction) diff --git a/src/exts/tvdb_info/ui/_media_view.py b/src/exts/tvdb_info/ui/_media_view.py index b9a6a60..fbdb650 100644 --- a/src/exts/tvdb_info/ui/_media_view.py +++ b/src/exts/tvdb_info/ui/_media_view.py @@ -13,7 +13,15 @@ class MediaView(ErrorHandledView, ABC): """Base class for views that display info about some media (movie/series/episode).""" - def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_list: UserList) -> None: + def __init__( + self, + *, + bot: Bot, + user_id: int, + invoker_user_id: int, + watched_list: UserList, + favorite_list: UserList, + ) -> None: """Initialize MediaView. :param bot: The bot instance. @@ -30,6 +38,7 @@ def __init__(self, *, bot: Bot, user_id: int, watched_list: UserList, favorite_l self.bot = bot self.user_id = user_id + self.invoker_user_id = invoker_user_id self.watched_list = watched_list self.favorite_list = favorite_list @@ -108,6 +117,16 @@ async def _refresh(self) -> None: await self.message.edit(embed=self._get_embed(), view=self) + async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool: + """Ensure that the interaction was invoked by the author of this view.""" + if interaction.user is None: + raise ValueError("Interaction user is None") + + if interaction.user.id != self.invoker_user_id: + await interaction.response.send_message("You can't interact with this view.", ephemeral=True) + return False + return True + @abstractmethod async def is_favorite(self) -> bool: """Check if the current media is marked as favorite by the user. @@ -147,6 +166,9 @@ async def send(self, interaction: discord.Interaction) -> None: async def _watched_button_callback(self, interaction: discord.Interaction) -> None: """Callback for when the user clicks on the mark as watched button.""" + if not await self._ensure_correct_invoker(interaction): + return + await interaction.response.defer() cur_state = self.watched_button.state await self.set_watched(not cur_state) @@ -156,6 +178,9 @@ async def _watched_button_callback(self, interaction: discord.Interaction) -> No async def _favorite_button_callback(self, interaction: discord.Interaction) -> None: """Callback for when the user clicks on the mark as favorite button.""" + if not await self._ensure_correct_invoker(interaction): + return + await interaction.response.defer() cur_state = self.favorite_button.state await self.set_favorite(not cur_state) diff --git a/src/exts/tvdb_info/ui/episode_view.py b/src/exts/tvdb_info/ui/episode_view.py index e62c074..c021ce9 100644 --- a/src/exts/tvdb_info/ui/episode_view.py +++ b/src/exts/tvdb_info/ui/episode_view.py @@ -22,13 +22,20 @@ def __init__( *, bot: Bot, user_id: int, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, series: Series, season_idx: int = 1, episode_idx: int = 1, ) -> None: - super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list) + super().__init__( + bot=bot, + user_id=user_id, + invoker_user_id=invoker_user_id, + watched_list=watched_list, + favorite_list=favorite_list, + ) self.series = series @@ -156,6 +163,9 @@ async def set_watched(self, state: bool) -> None: async def _episode_dropdown_callback(self, interaction: discord.Interaction) -> None: """Callback for when the user selects an episode from the drop-down.""" + if not await self._ensure_correct_invoker(interaction): + return + if not self.episode_dropdown.values or not isinstance(self.episode_dropdown.values[0], str): raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.") @@ -171,6 +181,9 @@ async def _episode_dropdown_callback(self, interaction: discord.Interaction) -> async def _season_dropdown_callback(self, interaction: discord.Interaction) -> None: """Callback for when the user selects a season from the drop-down.""" + if not await self._ensure_correct_invoker(interaction): + return + if not self.season_dropdown.values or not isinstance(self.season_dropdown.values[0], str): raise ValueError("Episode dropdown values are empty or non-string, but callback was triggered.") diff --git a/src/exts/tvdb_info/ui/movie_series_view.py b/src/exts/tvdb_info/ui/movie_series_view.py index 2c6e35f..95d7833 100644 --- a/src/exts/tvdb_info/ui/movie_series_view.py +++ b/src/exts/tvdb_info/ui/movie_series_view.py @@ -29,11 +29,18 @@ def __init__( *, bot: Bot, user_id: int, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, media_data: Movie | Series, ) -> None: - super().__init__(bot=bot, user_id=user_id, watched_list=watched_list, favorite_list=favorite_list) + super().__init__( + bot=bot, + user_id=user_id, + invoker_user_id=invoker_user_id, + watched_list=watched_list, + favorite_list=favorite_list, + ) self.media_data = media_data @@ -140,6 +147,7 @@ def __init__( *, bot: Bot, user_id: int, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, media_data: Series, @@ -147,6 +155,7 @@ def __init__( super().__init__( bot=bot, user_id=user_id, + invoker_user_id=invoker_user_id, watched_list=watched_list, favorite_list=favorite_list, media_data=media_data, @@ -172,9 +181,13 @@ def _add_items(self) -> None: async def _episodes_button_callback(self, interaction: discord.Interaction) -> None: """Callback for when the user clicks the "View Episodes" button.""" + if not await self._ensure_correct_invoker(interaction): + return + view = EpisodeView( bot=self.bot, user_id=self.user_id, + invoker_user_id=self.invoker_user_id, watched_list=self.watched_list, favorite_list=self.favorite_list, series=self.media_data, @@ -250,6 +263,7 @@ def __init__( *, bot: Bot, user_id: int, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, media_data: Movie, @@ -257,6 +271,7 @@ def __init__( super().__init__( bot=bot, user_id=user_id, + invoker_user_id=invoker_user_id, watched_list=watched_list, favorite_list=favorite_list, media_data=media_data, diff --git a/src/exts/tvdb_info/ui/profile_view.py b/src/exts/tvdb_info/ui/profile_view.py index 26ea079..d7fbda2 100644 --- a/src/exts/tvdb_info/ui/profile_view.py +++ b/src/exts/tvdb_info/ui/profile_view.py @@ -35,6 +35,7 @@ def __init__( bot: Bot, tvdb_client: TvdbClient, user: discord.User, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, ) -> None: @@ -42,6 +43,7 @@ def __init__( self.bot = bot self.tvdb_client = tvdb_client self.discord_user = user + self.invoker_user_id = invoker_user_id self.watched_list = watched_list self.favorite_list = favorite_list @@ -151,6 +153,16 @@ async def _initialize(self) -> None: # as that's the only thing we need here and while it is a bit inconsistent, it's a LOT more efficient. self.episodes_total = len(watched_episodes) + async def _ensure_correct_invoker(self, interaction: discord.Interaction) -> bool: + """Ensure that the interaction was invoked by the author of this view.""" + if interaction.user is None: + raise ValueError("Interaction user is None") + + if interaction.user.id != self.invoker_user_id: + await interaction.response.send_message("You can't interact with this view.", ephemeral=True) + return False + return True + def _get_embed(self) -> discord.Embed: embed = discord.Embed( title="Profile", diff --git a/src/exts/tvdb_info/ui/search_view.py b/src/exts/tvdb_info/ui/search_view.py index 9434a4d..b26eb9e 100644 --- a/src/exts/tvdb_info/ui/search_view.py +++ b/src/exts/tvdb_info/ui/search_view.py @@ -14,6 +14,7 @@ def _search_view( bot: Bot, user_id: int, + invoker_user_id: int, watched_list: UserList, favorite_list: UserList, results: Sequence[Movie | Series], @@ -25,6 +26,7 @@ def _search_view( view = MovieView( bot=bot, user_id=user_id, + invoker_user_id=invoker_user_id, watched_list=watched_list, favorite_list=favorite_list, media_data=result, @@ -33,6 +35,7 @@ def _search_view( view = SeriesView( bot=bot, user_id=user_id, + invoker_user_id=invoker_user_id, watched_list=watched_list, favorite_list=favorite_list, media_data=result, @@ -56,12 +59,15 @@ def _search_view( ) async def _search_dropdown_callback(interaction: discord.Interaction) -> None: + if not await view._ensure_correct_invoker(interaction): # pyright: ignore[reportPrivateUsage] + return + if not search_result_dropdown.values or not isinstance(search_result_dropdown.values[0], str): raise ValueError("Dropdown values are empty or not a string but callback was triggered.") index = int(search_result_dropdown.values[0]) - view = _search_view(bot, user_id, watched_list, favorite_list, results, index) - await view.send(interaction) + new_view = _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, index) + await new_view.send(interaction) search_result_dropdown.callback = _search_dropdown_callback @@ -69,7 +75,12 @@ async def _search_dropdown_callback(interaction: discord.Interaction) -> None: return view -async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series]) -> MovieView | SeriesView: +async def search_view( + bot: Bot, + user_id: int, + invoker_user_id: int, + results: Sequence[Movie | Series], +) -> MovieView | SeriesView: """Construct a view showing the search results. This uses specific views to render a single result. This view is then modified to @@ -81,4 +92,4 @@ async def search_view(bot: Bot, user_id: int, results: Sequence[Movie | Series]) await refresh_list_items(bot.db_session, watched_list) await refresh_list_items(bot.db_session, favorite_list) - return _search_view(bot, user_id, watched_list, favorite_list, results, 0) + return _search_view(bot, user_id, invoker_user_id, watched_list, favorite_list, results, 0)