Skip to content

Commit

Permalink
Add fetch_replies/topics/messages
Browse files Browse the repository at this point in the history
  • Loading branch information
KurimuzonAkuma committed Feb 28, 2025
1 parent 690f6be commit 691bf57
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
16 changes: 16 additions & 0 deletions pyrogram/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ class Client(Methods):
The platform where this client is running.
Defaults to 'other'
fetch_replies (``bool``, *optional*):
Pass True to automatically fetch replies for messages.
Defaults to True.
fetch_topics (``bool``, *optional*):
Pass True to automatically fetch forum topics.
fetch_stories (``bool``, *optional*):
Pass True to automatically fetch stories.
init_connection_params (:obj:`~pyrogram.raw.base.JSONValue`, *optional*):
Additional initConnection parameters.
For now, only the tz_offset field is supported, for specifying timezone offset in seconds.
Expand Down Expand Up @@ -268,6 +278,9 @@ def __init__(
max_message_cache_size: int = MAX_MESSAGE_CACHE_SIZE,
storage_engine: Optional[Storage] = None,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
fetch_replies: Optional[bool] = True,
fetch_topics: Optional[bool] = True,
fetch_stories: Optional[bool] = True,
init_connection_params: Optional["raw.base.JSONValue"] = None,
connection_factory: Type[Connection] = Connection,
protocol_factory: Type[TCP] = TCPAbridged
Expand Down Expand Up @@ -304,6 +317,9 @@ def __init__(
self.max_concurrent_transmissions = max_concurrent_transmissions
self.max_message_cache_size = max_message_cache_size
self.client_platform = client_platform
self.fetch_replies = fetch_replies
self.fetch_topics = fetch_topics
self.fetch_stories = fetch_stories
self.init_connection_params = init_connection_params
self.connection_factory = connection_factory
self.protocol_factory = protocol_factory
Expand Down
44 changes: 23 additions & 21 deletions pyrogram/types/messages_and_media/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,29 +981,30 @@ async def _parse(
)

if isinstance(action, raw.types.MessageActionPinMessage):
try:
parsed_message.pinned_message = await client.get_messages(
chat_id=parsed_message.chat.id,
pinned=True,
replies=0
)
parsed_message.service = enums.MessageServiceType.PINNED_MESSAGE

parsed_message.service = enums.MessageServiceType.PINNED_MESSAGE
except (MessageIdsEmpty, ChannelPrivate):
pass
if client.fetch_replies:
try:
parsed_message.pinned_message = await client.get_messages(
chat_id=parsed_message.chat.id,
pinned=True,
replies=0
)

except (MessageIdsEmpty, ChannelPrivate):
pass
elif isinstance(action, raw.types.MessageActionGameScore):
parsed_message.game_high_score = types.GameHighScore._parse_action(client, message, users)
parsed_message.service = enums.MessageServiceType.GAME_HIGH_SCORE

if message.reply_to and replies:
if client.fetch_replies and message.reply_to and replies:
try:
parsed_message.reply_to_message = await client.get_messages(
chat_id=parsed_message.chat.id,
message_ids=message.id,
reply=True,
replies=0
)

parsed_message.service = enums.MessageServiceType.GAME_HIGH_SCORE
except (MessageIdsEmpty, ChannelPrivate):
pass

Expand Down Expand Up @@ -1360,7 +1361,7 @@ async def _parse(

reply_to_message = client.message_cache[key]

if not reply_to_message:
if client.fetch_replies and not reply_to_message:
try:
reply_to_message = await client.get_messages(
replies=replies - 1,
Expand All @@ -1371,20 +1372,21 @@ async def _parse(

parsed_message.reply_to_message = reply_to_message
elif isinstance(message.reply_to, raw.types.MessageReplyStoryHeader):
if client.me and not client.me.is_bot:
if client.fetch_stories and client.me and not client.me.is_bot:
parsed_message.reply_to_story = await client.get_stories(
utils.get_peer_id(message.reply_to.peer),
message.reply_to.story_id
)

if not parsed_message.topic and parsed_message.chat.is_forum and client.me and not client.me.is_bot:
try:
parsed_message.topic = await client.get_forum_topics_by_id(
chat_id=parsed_message.chat.id,
topic_ids=parsed_message.message_thread_id or 1
)
except (ChannelPrivate, ChannelForumMissing):
pass
if client.fetch_topics:
try:
parsed_message.topic = await client.get_forum_topics_by_id(
chat_id=parsed_message.chat.id,
topic_ids=parsed_message.message_thread_id or 1
)
except (ChannelPrivate, ChannelForumMissing):
pass

if not parsed_message.poll: # Do not cache poll messages
client.message_cache[(parsed_message.chat.id, parsed_message.id)] = parsed_message
Expand Down
29 changes: 15 additions & 14 deletions pyrogram/types/messages_and_media/story.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ async def _parse(

if isinstance(story, raw.types.StoryItemDeleted):
return Story(client=client, id=story.id, deleted=True, from_user=from_user, sender_chat=sender_chat, chat=chat)
if isinstance(story, raw.types.StoryItemSkipped):
if client.fetch_stories and isinstance(story, raw.types.StoryItemSkipped):
try:
r = await client.invoke(
raw.functions.stories.GetStoriesByID(
Expand All @@ -279,25 +279,26 @@ async def _parse(
return Story(client=client, id=story.id, from_user=from_user, sender_chat=sender_chat, chat=chat)

if not getattr(story, "story", None):
try:
r = await client.invoke(
raw.functions.stories.GetStoriesByID(
peer=await client.resolve_peer(chat.id),
id=[story.id]
if client.fetch_stories:
try:
r = await client.invoke(
raw.functions.stories.GetStoriesByID(
peer=await client.resolve_peer(chat.id),
id=[story.id]
)
)
)

users.update({i.id: i for i in r.users})
chats.update({i.id: i for i in r.chats})
users.update({i.id: i for i in r.users})
chats.update({i.id: i for i in r.chats})

if r.stories:
story = r.stories[0]
except (ChannelPrivate, ChannelInvalid):
pass
if r.stories:
story = r.stories[0]
except (ChannelPrivate, ChannelInvalid):
pass
else:
story = story.story

if getattr(story, "min", None):
if client.fetch_stories and getattr(story, "min", None):
try:
r = await client.invoke(
raw.functions.stories.GetStoriesByID(
Expand Down

0 comments on commit 691bf57

Please sign in to comment.