Skip to content

Commit

Permalink
Add support for paid_media and thumbnail in download_media (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
SpEcHiDe authored Jan 13, 2025
1 parent 0870553 commit 4ca8aff
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 71 deletions.
1 change: 1 addition & 0 deletions docs/source/releases/changes-in-this-fork.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ If you found any issue or have any suggestions, feel free to make `an issue <htt
Breaking Changes in this Fork
==============================

- In :meth:`~pyrogram.Client.download_media`, if the message is a :obj:`~pyrogram.types.PaidMediaInfo` with more than one ``paid_media`` **and** ``idx`` was not specified, then a list of paths or binary file-like objects is returned.
- Make :meth:`~pyrogram.Client.get_messages` accept only keyword-only arguments. `48d4230 <https://github.com/TelegramPlayGround/pyrogram/commit/48d42304f3ee51034d515919320634935e6b2c83>`_
- PR `#115 <https://github.com/TelegramPlayGround/pyrogram/pull/115>`_ This `change <https://github.com/pyrogram/pyrogram/pull/966#issuecomment-1108858881>`_ breaks some usages with offset-naive and offset-aware datetimes.
- PR from upstream: `#1411 <https://github.com/pyrogram/pyrogram/pull/1411>`_ without attribution.
Expand Down
225 changes: 159 additions & 66 deletions pyrogram/methods/messages/download_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import asyncio
import io
import os
import re
from datetime import datetime
from typing import Union, Optional, Callable
from typing import Callable, Optional, Union

import pyrogram
from pyrogram import types, enums
from pyrogram import enums, types, utils
from pyrogram.file_id import FileId, FileType, PHOTO_TYPES

DEFAULT_DOWNLOAD_DIR = "downloads/"
Expand All @@ -32,19 +33,39 @@
class DownloadMedia:
async def download_media(
self: "pyrogram.Client",
message: Union["types.Message", "types.Story", str],
message: Union[
"types.Message",
"types.Audio",
"types.Document",
"types.Photo",
"types.Sticker",
"types.Video",
"types.Animation",
"types.Voice",
"types.VideoNote",
# TODO
"types.Story",
"types.PaidMediaInfo",
"types.PaidMediaPhoto",
"types.PaidMediaVideo",
"types.Thumbnail",
"types.StrippedThumbnail",
"types.PaidMediaPreview",
str,
],
file_name: str = DEFAULT_DOWNLOAD_DIR,
in_memory: bool = False,
block: bool = True,
idx: int = None,
progress: Callable = None,
progress_args: tuple = ()
) -> Optional[Union[str, "io.BytesIO"]]:
) -> Optional[Union[str, "io.BytesIO", list[str], list["io.BytesIO"]]]:
"""Download the media from a message.
.. include:: /_includes/usable-by/users-bots.rst
Parameters:
message (:obj:`~pyrogram.types.Message` | :obj:`~pyrogram.types.Story` | ``str``):
message (:obj:`~pyrogram.types.Message` | :obj:`~pyrogram.types.Audio` | :obj:`~pyrogram.types.Document` | :obj:`~pyrogram.types.Photo` | :obj:`~pyrogram.types.Sticker` | :obj:`~pyrogram.types.Video` | :obj:`~pyrogram.types.Animation` | :obj:`~pyrogram.types.Voice` | :obj:`~pyrogram.types.VideoNote` | :obj:`~pyrogram.types.Story` | :obj:`~pyrogram.types.PaidMediaInfo` | :obj:`~pyrogram.types.PaidMediaPhoto` | :obj:`~pyrogram.types.PaidMediaVideo` | :obj:`~pyrogram.types.Thumbnail` | :obj:`~pyrogram.types.StrippedThumbnail` | :obj:`~pyrogram.types.PaidMediaPreview` | :obj:`~pyrogram.types.Story` | ``str``):
Pass a Message containing the media, the media itself (message.audio, message.video, ...) or a file id
as string.
Expand All @@ -63,6 +84,9 @@ async def download_media(
Blocks the code execution until the file has been downloaded.
Defaults to True.
idx (``int``, *optional*):
In case of a :obj:`~pyrogram.types.PaidMediaInfo` with more than one ``paid_media``, the zero based index of the :obj:`~pyrogram.types.PaidMedia` to download. Raises ``IndexError`` if the index specified does not exist in the original ``message``.
progress (``Callable``, *optional*):
Pass a callback function to view the file transmission progress.
The function must take *(current, total)* as positional arguments (look at Other Parameters below for a
Expand Down Expand Up @@ -90,9 +114,11 @@ async def download_media(
otherwise, in case the download failed or was deliberately stopped with
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
If the message is a :obj:`~pyrogram.types.PaidMediaInfo` with more than one ``paid_media`` containing ``minithumbnail`` and ``idx`` is not specified, then a list of paths or binary file-like objects is returned.
Raises:
RPCError: In case of a Telegram RPC error.
IndexError: In case of wrong value of ``idx``.
ValueError: If the message doesn't contain any downloadable media.
Example:
Expand Down Expand Up @@ -122,94 +148,161 @@ async def progress(current, total):
file_bytes = bytes(file.getbuffer())
"""

media = message
medium = [message]

if isinstance(message, types.Message):
if message.new_chat_photo:
media = message.new_chat_photo
medium = [message.new_chat_photo]

elif (
not (self.me and self.me.is_bot) and
message.story or message.reply_to_story
):
story_media = message.story or message.reply_to_story or None
if story_media and story_media.media:
media = getattr(story_media, story_media.media.value, None)
medium = [getattr(story_media, story_media.media.value, None)]
else:
medium = []

elif message.paid_media:
if any([isinstance(paid_media, (types.PaidMediaPhoto, types.PaidMediaVideo)) for paid_media in message.paid_media.paid_media]):
medium = [getattr(paid_media, "photo", (getattr(paid_media, "video", None))) for paid_media in message.paid_media.paid_media]
elif any([isinstance(paid_media, types.PaidMediaPreview) for paid_media in message.paid_media.paid_media]):
medium = [getattr(getattr(paid_media, "minithumbnail"), "data", None) for paid_media in message.paid_media.paid_media]
else:
media = None
medium = []

else:
if message.media:
media = getattr(message, message.media.value, None)
medium = [getattr(message, message.media.value, None)]
else:
media = None
medium = []

elif isinstance(message, str):
media = message

if isinstance(media, types.Story):
elif isinstance(message, types.Story):
if (self.me and self.me.is_bot):
raise ValueError("This method cannot be used by bots")
else:
if media.media:
media = getattr(message, message.media.value, None)
if medium.media:
medium = [getattr(message, message.media.value, None)]
else:
media = None
medium = []

elif isinstance(message, types.PaidMediaInfo):
if any([isinstance(paid_media, (types.PaidMediaPhoto, types.PaidMediaVideo)) for paid_media in message.paid_media]):
medium = [getattr(paid_media, "photo", (getattr(paid_media, "video", None))) for paid_media in message.paid_media]
elif any([isinstance(paid_media, types.PaidMediaPreview) for paid_media in message.paid_media]):
medium = [getattr(getattr(paid_media, "minithumbnail"), "data", None) for paid_media in message.paid_media]
else:
medium = []

elif isinstance(message, types.PaidMediaPhoto):
medium = [message.photo]

elif isinstance(message, types.PaidMediaVideo):
medium = [message.video]

elif isinstance(message, types.PaidMediaPreview):
medium = [getattr(getattr(message, "minithumbnail"), "data", None)]

elif isinstance(message, types.StrippedThumbnail):
medium = [message.data]

elif isinstance(message, types.Thumbnail):
medium = [message]

elif isinstance(message, str):
medium = [message]

if not media:
medium = types.List(filter(lambda x: x is not None, medium))

if len(medium) == 0:
raise ValueError(
f"The message {message if isinstance(message, str) else message.id} doesn't contain any downloadable media"
)

if isinstance(media, str):
file_id_str = media
else:
file_id_str = media.file_id

file_id_obj = FileId.decode(file_id_str)

file_type = file_id_obj.file_type
media_file_name = getattr(media, "file_name", "") # TODO
file_size = getattr(media, "file_size", 0)
mime_type = getattr(media, "mime_type", "")
date = getattr(media, "date", None)

directory, file_name = os.path.split(file_name)
file_name = file_name or media_file_name or ""

if not os.path.isabs(file_name):
directory = self.WORKDIR / (directory or DEFAULT_DOWNLOAD_DIR)

if not file_name:
guessed_extension = self.guess_extension(mime_type)

if file_type in PHOTO_TYPES:
extension = ".jpg"
elif file_type == FileType.VOICE:
extension = guessed_extension or ".ogg"
elif file_type in (FileType.VIDEO, FileType.ANIMATION, FileType.VIDEO_NOTE):
extension = guessed_extension or ".mp4"
elif file_type == FileType.DOCUMENT:
extension = guessed_extension or ".zip"
elif file_type == FileType.STICKER:
extension = guessed_extension or ".webp"
elif file_type == FileType.AUDIO:
extension = guessed_extension or ".mp3"
if idx is not None:
medium = [medium[idx]]

dledmedia = []

for media in medium:
if isinstance(media, bytes):
thumb = utils.from_inline_bytes(
utils.expand_inline_bytes(
media
)
)
if in_memory:
dledmedia.append(thumb)
continue

directory, file_name = os.path.split(file_name)
file_name = file_name or thumb.name

if not os.path.isabs(file_name):
directory = self.PARENT_DIR / (directory or DEFAULT_DOWNLOAD_DIR)

os.makedirs(directory, exist_ok=True) if not in_memory else None
temp_file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))

with open(temp_file_path, "wb") as file:
file.write(thumb.getbuffer())

dledmedia.append(temp_file_path)
continue

elif isinstance(media, str):
file_id_str = media
else:
extension = ".unknown"
file_id_str = media.file_id

file_id_obj = FileId.decode(file_id_str)

file_type = file_id_obj.file_type
media_file_name = getattr(media, "file_name", "") # TODO
file_size = getattr(media, "file_size", 0)
mime_type = getattr(media, "mime_type", "")
date = getattr(media, "date", None)

directory, file_name = os.path.split(file_name)
# TODO
file_name = file_name or media_file_name or ""

if not os.path.isabs(file_name):
directory = self.WORKDIR / (directory or DEFAULT_DOWNLOAD_DIR)

if not file_name:
guessed_extension = self.guess_extension(mime_type)

if file_type in PHOTO_TYPES:
extension = ".jpg"
elif file_type == FileType.VOICE:
extension = guessed_extension or ".ogg"
elif file_type in (FileType.VIDEO, FileType.ANIMATION, FileType.VIDEO_NOTE):
extension = guessed_extension or ".mp4"
elif file_type == FileType.DOCUMENT:
extension = guessed_extension or ".zip"
elif file_type == FileType.STICKER:
extension = guessed_extension or ".webp"
elif file_type == FileType.AUDIO:
extension = guessed_extension or ".mp3"
else:
extension = ".unknown"

file_name = "{}_{}_{}{}".format(
FileType(file_id_obj.file_type).name.lower(),
(date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
)

file_name = "{}_{}_{}{}".format(
FileType(file_id_obj.file_type).name.lower(),
(date or datetime.now()).strftime("%Y-%m-%d_%H-%M-%S"),
self.rnd_id(),
extension
downloader = self.handle_download(
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
)

downloader = self.handle_download(
(file_id_obj, directory, file_name, in_memory, file_size, progress, progress_args)
)
if block:
dledmedia.append(await downloader)
else:
asyncio.get_event_loop().create_task(downloader)

if block:
return await downloader
else:
asyncio.get_event_loop().create_task(downloader)
return types.List(dledmedia) if block and len(dledmedia) > 1 else dledmedia[0] if block and len(dledmedia) == 1 else None
3 changes: 2 additions & 1 deletion pyrogram/types/input_paid_media/paid_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def _parse(
duration=getattr(extended_media, "video_duration", None),
minithumbnail=types.StrippedThumbnail(
client=client,
data=extended_media.thumb
# TODO
data=getattr(getattr(extended_media, "thumb"), "bytes", None)
) if getattr(extended_media, "thumb", None) else None
)
if isinstance(extended_media, raw.types.MessageExtendedMedia):
Expand Down
9 changes: 8 additions & 1 deletion pyrogram/types/messages_and_media/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5298,9 +5298,10 @@ async def download(
file_name: str = "",
in_memory: bool = False,
block: bool = True,
idx: int = None,
progress: Callable = None,
progress_args: tuple = ()
) -> Union[str, "io.BytesIO"]:
) -> Optional[Union[str, "io.BytesIO", list[str], list["io.BytesIO"]]]:
"""Bound method *download* of :obj:`~pyrogram.types.Message`.
Use as a shortcut for:
Expand Down Expand Up @@ -5330,6 +5331,9 @@ async def download(
Blocks the code execution until the file has been downloaded.
Defaults to True.
idx (``int``, *optional*):
In case of a :obj:`~pyrogram.types.PaidMediaInfo` with more than one ``paid_media``, the zero based index of the :obj:`~pyrogram.types.PaidMedia` to download. Raises ``IndexError`` if the index specified does not exist in the original ``message``.
progress (``Callable``, *optional*):
Pass a callback function to view the file transmission progress.
The function must take *(current, total)* as positional arguments (look at Other Parameters below for a
Expand Down Expand Up @@ -5357,9 +5361,11 @@ async def download(
otherwise, in case the download failed or was deliberately stopped with
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
If the message is a :obj:`~pyrogram.types.PaidMediaInfo` with more than one ``paid_media`` containing ``minithumbnail`` and ``idx`` is not specified, then a list of paths or binary file-like objects is returned.
Raises:
RPCError: In case of a Telegram RPC error.
IndexError: In case of wrong value of ``idx``.
ValueError: If the message doesn't contain any downloadable media.
"""
Expand All @@ -5368,6 +5374,7 @@ async def download(
file_name=file_name,
in_memory=in_memory,
block=block,
idx=idx,
progress=progress,
progress_args=progress_args,
)
Expand Down
10 changes: 7 additions & 3 deletions pyrogram/types/messages_and_media/story.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.

import io
from datetime import datetime
from typing import Union, Callable
from typing import Callable, Optional, Union

import pyrogram
from pyrogram import raw, utils, types, enums
Expand Down Expand Up @@ -453,7 +454,7 @@ async def download(
block: bool = True,
progress: Callable = None,
progress_args: tuple = ()
) -> str:
) -> Optional[Union[str, "io.BytesIO"]]:
"""Bound method *download* of :obj:`~pyrogram.types.Story`.
Use as a shortcut for:
Expand Down Expand Up @@ -506,7 +507,10 @@ async def download(
You can either keep ``*args`` or add every single extra argument in your function signature.
Returns:
On success, the absolute path of the downloaded file as string is returned, None otherwise.
``str`` | ``None`` | :obj:`io.BytesIO`: On success, the absolute path of the downloaded file is returned,
otherwise, in case the download failed or was deliberately stopped with
:meth:`~pyrogram.Client.stop_transmission`, None is returned.
Otherwise, in case ``in_memory=True``, a binary file-like object with its attribute ".name" set is returned.
Raises:
RPCError: In case of a Telegram RPC error.
Expand Down
Loading

1 comment on commit 4ca8aff

@KurimuzonAkuma
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why

Please sign in to comment.