diff --git a/.github/workflows/tests-studio.yml b/.github/workflows/tests-studio.yml index 923731fef..76358ecca 100644 --- a/.github/workflows/tests-studio.yml +++ b/.github/workflows/tests-studio.yml @@ -75,6 +75,9 @@ jobs: path: './backend/datachain' fetch-depth: 0 + - name: Set up FFmpeg + uses: AnimMouse/setup-ffmpeg@v1 + - name: Set up Python ${{ matrix.pyv }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3b96a12fd..6a807e29f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -78,6 +78,9 @@ jobs: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha || github.ref }} + - name: Set up FFmpeg + uses: AnimMouse/setup-ffmpeg@v1 + - name: Set up Python ${{ matrix.pyv }} uses: actions/setup-python@v5 with: diff --git a/pyproject.toml b/pyproject.toml index 39fe6729f..098c70a1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,8 +77,16 @@ hf = [ "numba>=0.60.0", "datasets[audio,vision]>=2.21.0" ] +video = [ + # Use 'av<14' because of incompatibility with imageio + # See https://github.com/PyAV-Org/PyAV/discussions/1700 + "av<14", + "ffmpeg-python", + "imageio[ffmpeg]", + "opencv-python" +] tests = [ - "datachain[torch,remote,vector,hf]", + "datachain[torch,remote,vector,hf,video]", "pytest>=8,<9", "pytest-sugar>=0.9.6", "pytest-cov>=4.1.0", diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index e8bbc00bf..659b2ce5d 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -4,9 +4,14 @@ ArrowRow, File, FileError, + Image, ImageFile, TarVFile, TextFile, + Video, + VideoFile, + VideoFragment, + VideoFrame, ) from datachain.lib.model_store import ModelStore from datachain.lib.udf import Aggregator, Generator, Mapper @@ -27,6 +32,7 @@ "File", "FileError", "Generator", + "Image", "ImageFile", "Mapper", "ModelStore", @@ -34,6 +40,10 @@ "Sys", "TarVFile", "TextFile", + "Video", + "VideoFile", + "VideoFragment", + "VideoFrame", "is_chain_type", "metrics", "param", diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 599fa667e..65101e4ca 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -17,7 +17,7 @@ from urllib.request import url2pathname from fsspec.callbacks import DEFAULT_CALLBACK, Callback -from PIL import Image +from PIL import Image as PilImage from pydantic import Field, field_validator from datachain.client.fileslice import FileSlice @@ -27,6 +27,7 @@ from datachain.utils import TIME_ZERO if TYPE_CHECKING: + from numpy import ndarray from typing_extensions import Self from datachain.catalog import Catalog @@ -40,7 +41,7 @@ # how to create file path when exporting ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"] -FileType = Literal["binary", "text", "image"] +FileType = Literal["binary", "text", "image", "video"] class VFileError(DataChainError): @@ -193,7 +194,7 @@ def __init__(self, **kwargs): @classmethod def upload( cls, data: bytes, path: str, catalog: Optional["Catalog"] = None - ) -> "File": + ) -> "Self": if catalog is None: from datachain.catalog.loader import get_catalog @@ -203,6 +204,8 @@ def upload( client = catalog.get_client(parent) file = client.upload(data, name) + if not isinstance(file, cls): + file = cls(**file.model_dump()) file._set_stream(catalog) return file @@ -486,13 +489,217 @@ class ImageFile(File): def read(self): """Returns `PIL.Image.Image` object.""" fobj = super().read() - return Image.open(BytesIO(fobj)) + return PilImage.open(BytesIO(fobj)) def save(self, destination: str): """Writes it's content to destination""" self.read().save(destination) +class Image(DataModel): + """`DataModel` for image file meta information.""" + + width: int = Field(default=-1) + height: int = Field(default=-1) + format: str = Field(default="") + + +class VideoFile(File): + """`DataModel` for reading video files.""" + + def get_info(self) -> "Video": + """Returns video file information.""" + from .video import video_info + + return video_info(self) + + def get_frame_np(self, frame: int) -> "ndarray": + """ + Reads video frame from a file. + + Args: + frame (int): Frame number to read. + + Returns: + ndarray: Video frame. + """ + from .video import video_frame_np + + return video_frame_np(self, frame) + + def get_frame(self, frame: int, format: str = "jpg") -> bytes: + """ + Reads video frame from a file and returns as image bytes. + + Args: + frame (int): Frame number to read. + format (str): Image format (default: 'jpg'). + + Returns: + bytes: Video frame image as bytes. + """ + from .video import video_frame + + return video_frame(self, frame, format) + + def save_frame( + self, + frame: int, + output_file: str, + format: Optional[str] = None, + ) -> "VideoFrame": + """ + Saves video frame as an image file. + + Args: + frame (int): Frame number to read. + output_file (str): Output file path. + format (str): Image format (default: use output file extension). + + Returns: + VideoFrame: Video frame model. + """ + from .video import save_video_frame + + return save_video_frame(self, frame, output_file, format=format) + + def get_frames_np( + self, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + ) -> "Iterator[ndarray]": + """ + Reads video frames from a file. + + Args: + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + + Returns: + Iterator[ndarray]: Iterator of video frames. + """ + from .video import video_frames_np + + yield from video_frames_np(self, start, end, step) + + def get_frames( + self, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + format: str = "jpg", + ) -> "Iterator[bytes]": + """ + Reads video frames from a file and returns as bytes. + + Args: + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + format (str): Image format (default: 'jpg'). + + Returns: + Iterator[bytes]: Iterator of video frames. + """ + from .video import video_frames + + yield from video_frames(self, start, end, step, format) + + def save_frames( + self, + output_dir: str, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + format: str = "jpg", + ) -> "Iterator[VideoFrame]": + """ + Saves video frames as image files. + + Args: + output_dir (str): Output directory path. + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + format (str): Image format (default: 'jpg'). + + Returns: + Iterator[VideoFrame]: List of video frame models. + """ + from .video import save_video_frames + + yield from save_video_frames(self, output_dir, start, end, step, format) + + def save_fragment( + self, + start: float, + end: float, + output_file: str, + ) -> "VideoFragment": + """ + Saves video interval as a new video file. + + Args: + start (float): Start time in seconds. + end (float): End time in seconds. + output_file (str): Output file path. + + Returns: + VideoFragment: Video fragment model. + """ + from .video import save_video_fragment + + return save_video_fragment(self, start, end, output_file) + + def save_fragments( + self, + intervals: list[tuple[float, float]], + output_dir: str, + ) -> "Iterator[VideoFragment]": + """ + Saves video intervals as new video files. + + Args: + intervals (list[tuple[float, float]]): List of start and end times + in seconds. + output_dir (str): Output directory path. + + Returns: + Iterator[VideoFragment]: List of video fragment models. + """ + from .video import save_video_fragments + + yield from save_video_fragments(self, intervals, output_dir) + + +class VideoFragment(VideoFile): + """`DataModel` for reading video fragments.""" + + start: float = Field(default=-1.0) + end: float = Field(default=-1.0) + + +class VideoFrame(ImageFile): + """`DataModel` for reading video frames.""" + + frame: int = Field(default=-1) + timestamp: float = Field(default=-1.0) + + +class Video(DataModel): + """`DataModel` for video file meta information.""" + + width: int = Field(default=-1) + height: int = Field(default=-1) + fps: float = Field(default=-1.0) + duration: float = Field(default=-1.0) + frames: int = Field(default=-1) + format: str = Field(default="") + codec: str = Field(default="") + + class ArrowRow(DataModel): """`DataModel` for reading row from Arrow-supported file.""" @@ -528,5 +735,7 @@ def get_file_type(type_: FileType = "binary") -> type[File]: file = TextFile elif type_ == "image": file = ImageFile # type: ignore[assignment] + elif type_ == "video": + file = VideoFile return file diff --git a/src/datachain/lib/hf.py b/src/datachain/lib/hf.py index 66f4ee4fb..2e31c7f84 100644 --- a/src/datachain/lib/hf.py +++ b/src/datachain/lib/hf.py @@ -20,7 +20,7 @@ except ImportError as exc: raise ImportError( - "Missing dependencies for huggingface datasets:\n" + "Missing dependencies for huggingface datasets.\n" "To install run:\n\n" " pip install 'datachain[hf]'\n" ) from exc diff --git a/src/datachain/lib/vfile.py b/src/datachain/lib/vfile.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/datachain/lib/video.py b/src/datachain/lib/video.py new file mode 100644 index 000000000..0be3755dd --- /dev/null +++ b/src/datachain/lib/video.py @@ -0,0 +1,349 @@ +import posixpath +import shutil +import tempfile +from collections.abc import Iterator +from pathlib import PurePosixPath +from typing import Optional + +from numpy import ndarray + +from datachain.lib.file import ( + File, + FileError, + Video, + VideoFile, + VideoFragment, + VideoFrame, +) + +try: + import ffmpeg + import imageio.v3 as iio +except ImportError as exc: + raise ImportError( + "Missing dependencies for processing video.\n" + "To install run:\n\n" + " pip install 'datachain[video]'\n" + ) from exc + + +def _video_probe(file: VideoFile) -> tuple[dict, dict, float]: + """Probes video file for video stream, video format and fps.""" + try: + probe = ffmpeg.probe(file.get_local_path()) + except ffmpeg.Error as exc: + raise FileError(file, f"unable to probe video file: {exc.stderr}") from exc + except Exception as exc: + raise FileError(file, f"unable to probe video file: {exc}") from exc + + if not probe: + raise FileError(file, "unable to probe video file") + + all_streams = probe.get("streams") + video_format = probe.get("format") + if not all_streams or not video_format: + raise FileError(file, "unable to probe video file") + + video_streams = [s for s in all_streams if s["codec_type"] == "video"] + if len(video_streams) == 0: + raise FileError(file, "no video streams found in video file") + + video_stream = video_streams[0] + + r_frame_rate = video_stream.get("r_frame_rate", "0") + if "/" in r_frame_rate: + num, denom = r_frame_rate.split("/") + fps = float(num) / float(denom) + else: + fps = float(r_frame_rate) + + return video_stream, video_format, fps + + +def video_info(file: VideoFile) -> Video: + """ + Returns video file information. + + Args: + file (VideoFile): Video file object. + + Returns: + Video: Video file information. + """ + video_stream, video_format, fps = _video_probe(file) + + width = int(video_stream.get("width", 0)) + height = int(video_stream.get("height", 0)) + duration = float(video_format.get("duration", 0)) + start_time = float(video_format.get("start_time", 0)) + frames = round((duration - start_time) * fps) + format_name = video_format.get("format_name", "") + codec_name = video_stream.get("codec_name", "") + + return Video( + width=width, + height=height, + fps=fps, + duration=duration, + frames=frames, + format=format_name, + codec=codec_name, + ) + + +def video_frame_np(file: VideoFile, frame: int) -> "ndarray": + """ + Reads video frame from a file. + + Args: + file (VideoFile): Video file object. + frame (int): Frame number to read. + + Returns: + ndarray: Video frame. + """ + if frame < 0: + raise ValueError("frame must be a non-negative integer.") + + with file.open() as f: + return iio.imread(f, index=frame, plugin="pyav") # type: ignore[arg-type] + + +def video_frame(file: VideoFile, frame: int, format: str = "jpg") -> bytes: + """ + Reads video frame from a file and returns as image bytes. + + Args: + file (VideoFile): Video file object. + frame (int): Frame number to read. + format (str): Image format (default: 'jpg'). + + Returns: + bytes: Video frame image as bytes. + """ + img = video_frame_np(file, frame) + return iio.imwrite("", img, extension=f".{format}") + + +def save_video_frame( + file: VideoFile, + frame: int, + output_file: str, + format: Optional[str] = None, +) -> VideoFrame: + """ + Saves video frame as an image file. + + Args: + file (VideoFile): Video file object. + frame (int): Frame number to read. + output_file (str): Output file path. + format (str): Image format (default: use output file extension). + + Returns: + VideoFrame: Video frame model. + """ + _, _, fps = _video_probe(file) + + if format is None: + format = PurePosixPath(output_file).suffix.strip(".") + + img = video_frame(file, frame, format=format) + uploaded_file = File.upload(img, output_file) + + frame_file = VideoFrame( + **uploaded_file.model_dump(), + frame=frame, + timestamp=float(frame) / fps, + ) + frame_file._set_stream(uploaded_file._catalog) + return frame_file + + +def video_frames_np( + file: VideoFile, + start: int = 0, + end: Optional[int] = None, + step: int = 1, +) -> Iterator[ndarray]: + """ + Reads video frames from a file. + + Args: + file (VideoFile): Video file object. + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + + Returns: + Iterator[ndarray]: Iterator of video frames. + """ + if start < 0: + raise ValueError("start_frame must be a non-negative integer.") + if end is not None: + if end < 0: + raise ValueError("end_frame must be a non-negative integer.") + if start > end: + raise ValueError("start_frame must be less than or equal to end_frame.") + if step < 1: + raise ValueError("step must be a positive integer.") + + # Compute the frame shift to determine the number of frames to skip, + # considering the start frame and step size + frame_shift = start % step + + # Iterate over video frames and yield only those within the specified range and step + with file.open() as f: + for frame, img in enumerate(iio.imiter(f.read(), plugin="pyav")): # type: ignore[arg-type] + if frame < start: + continue + if (frame - frame_shift) % step != 0: + continue + if end is not None and frame > end: + break + yield img + + +def video_frames( + file: VideoFile, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + format: str = "jpg", +) -> Iterator[bytes]: + """ + Reads video frames from a file and returns as bytes. + + Args: + file (VideoFile): Video file object. + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + format (str): Image format (default: 'jpg'). + + Returns: + Iterator[bytes]: Iterator of video frames. + """ + for img in video_frames_np(file, start, end, step): + yield iio.imwrite("", img, extension=f".{format}") + + +def save_video_frames( + file: VideoFile, + output_dir: str, + start: int = 0, + end: Optional[int] = None, + step: int = 1, + format: str = "jpg", +) -> Iterator[VideoFrame]: + """ + Saves video frames as image files. + + Args: + file (VideoFile): Video file object. + output_dir (str): Output directory path. + start (int): Frame number to start reading from (default: 0). + end (int): Frame number to stop reading at (default: None). + step (int): Step size for reading frames (default: 1). + format (str): Image format (default: 'jpg'). + + Returns: + Iterator[VideoFrame]: List of video frame models. + """ + _, _, fps = _video_probe(file) + file_stem = file.get_file_stem() + + for i, img in enumerate(video_frames_np(file, start, end, step)): + frame = start + i * step + output_file = posixpath.join(output_dir, f"{file_stem}_{frame:06d}.{format}") + + raw = iio.imwrite("", img, extension=f".{format}") + uploaded_file = File.upload(raw, output_file) + + frame_file = VideoFrame( + **uploaded_file.model_dump(), + frame=frame, + timestamp=float(frame) / fps, + ) + frame_file._set_stream(uploaded_file._catalog) + yield frame_file + + +def save_video_fragment( + file: VideoFile, + start: float, + end: float, + output_file: str, +) -> VideoFragment: + """ + Saves video interval as a new video file. + + Args: + file (VideoFile): Video file object. + start (float): Start time in seconds. + end (float): End time in seconds. + output_file (str): Output file path. + + Returns: + VideoFragment: Video fragment model. + """ + if start < 0 or start >= end: + raise ValueError(f"Invalid time range: ({start}, {end}).") + + temp_dir = tempfile.mkdtemp() + try: + output_file_tmp = posixpath.join(temp_dir, posixpath.basename(output_file)) + ( + ffmpeg.input(file.get_local_path(), ss=start, to=end) + .output(output_file_tmp) + .run(quiet=True) + ) + + with open(output_file_tmp, "rb") as f: + uploaded_file = File.upload(f.read(), output_file) + finally: + shutil.rmtree(temp_dir) + + fragment = VideoFragment( + **uploaded_file.model_dump(), + start=start, + end=end, + ) + fragment._set_stream(uploaded_file._catalog) + return fragment + + +def save_video_fragments( + file: VideoFile, + intervals: list[tuple[float, float]], + output_dir: str, +) -> Iterator[VideoFragment]: + """ + Saves video intervals as new video files. + + Args: + file (VideoFile): Video file object. + intervals (list[tuple[float, float]]): List of start and end times in seconds. + output_dir (str): Output directory path. + + Returns: + Iterator[VideoFragment]: List of video fragment models. + """ + file_stem = file.get_file_stem() + file_ext = file.get_file_ext() + + for start, end in intervals: + if start < 0 or start >= end: + print(f"Invalid time range: ({start}, {end}). Skipping this segment.") + continue + + # Output file name + start_ms = int(start * 1000) + end_ms = int(end * 1000) + output_file = posixpath.join( + output_dir, + f"{file_stem}_{start_ms:03d}_{end_ms:03d}.{file_ext}", + ) + + # Write the video fragment to file and yield it + yield save_video_fragment(file, start, end, output_file) diff --git a/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 b/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 new file mode 100644 index 000000000..9b6d89da0 Binary files /dev/null and b/tests/unit/lib/data/Big_Buck_Bunny_360_10s_1MB.mp4 differ diff --git a/tests/unit/lib/test_video.py b/tests/unit/lib/test_video.py new file mode 100644 index 000000000..6d371cf7f --- /dev/null +++ b/tests/unit/lib/test_video.py @@ -0,0 +1,181 @@ +import io +import os +import posixpath + +import pytest +from numpy import ndarray +from PIL import Image + +from datachain.lib.file import FileError, VideoFile + + +@pytest.fixture(autouse=True) +def video_file(catalog) -> VideoFile: + data_path = os.path.join(os.path.dirname(__file__), "data") + file_name = "Big_Buck_Bunny_360_10s_1MB.mp4" + + with open(os.path.join(data_path, file_name), "rb") as f: + file = VideoFile.upload(f.read(), file_name) + + file.ensure_cached() + return file + + +def test_get_info(video_file): + info = video_file.get_info() + assert info.model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 10.0, + "frames": 300, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + } + + +def test_get_info_error(): + # upload current Python file as video file to get an error while getting video meta + with open(__file__, "rb") as f: + file = VideoFile.upload(f.read(), "test.mp4") + + file.ensure_cached() + with pytest.raises(FileError): + file.get_info() + + +def test_get_frame_np(video_file): + frame = video_file.get_frame_np(0) + assert isinstance(frame, ndarray) + assert frame.shape == (360, 640, 3) + + +def test_get_frame_np_error(video_file): + with pytest.raises(ValueError): + video_file.get_frame_np(-1) + + +@pytest.mark.parametrize( + "format,img_format,header", + [ + ("jpg", "JPEG", [b"\xff\xd8\xff\xe0"]), + ("png", "PNG", [b"\x89PNG\r\n\x1a\n"]), + ("gif", "GIF", [b"GIF87a", b"GIF89a"]), + ], +) +def test_get_frame(video_file, format, img_format, header): + frame = video_file.get_frame(0, format=format) + assert isinstance(frame, bytes) + assert any(frame.startswith(h) for h in header) + + img = Image.open(io.BytesIO(frame)) + assert img.format == img_format + assert img.size == (640, 360) + + +@pytest.mark.parametrize("use_format", [True, False]) +def test_save_frame_ext(tmp_path, video_file, use_format): + filename = "frame" if use_format else "frame.jpg" + format = "jpg" if use_format else None + output_file = posixpath.join(tmp_path, filename) + + frame_file = video_file.save_frame(3, str(output_file), format=format) + assert frame_file.frame == 3 + assert frame_file.timestamp == 3 / 30 + + frame_file.ensure_cached() + img = Image.open(frame_file.get_local_path()) + assert img.format == "JPEG" + assert img.size == (640, 360) + + +def test_get_frames_np(video_file): + frames = list(video_file.get_frames_np(10, 200, 5)) + assert len(frames) == 39 + assert all(isinstance(frame, ndarray) for frame in frames) + assert all(frame.shape == (360, 640, 3) for frame in frames) + + +@pytest.mark.parametrize( + "start_frame,end_frame,step", + [ + (-1, None, None), + (0, -1, None), + (1, 0, None), + (0, 1, -1), + ], +) +def test_get_frames_np_error(video_file, start_frame, end_frame, step): + with pytest.raises(ValueError): + list(video_file.get_frames_np(start_frame, end_frame, step)) + + +def test_get_frames(video_file): + frames = list(video_file.get_frames(10, 200, 5, format="jpg")) + assert len(frames) == 39 + assert all(isinstance(frame, bytes) for frame in frames) + assert all(Image.open(io.BytesIO(frame)).format == "JPEG" for frame in frames) + + +def test_save_frames(tmp_path, video_file): + frame_files = list(video_file.save_frames(str(tmp_path), 10, 200, 5, format="jpg")) + assert len(frame_files) == 39 + + for i, frame_file in enumerate(frame_files): + assert frame_file.frame == 10 + 5 * i + assert frame_file.timestamp == (10 + 5 * i) / 30 + + frame_file.ensure_cached() + img = Image.open(frame_file.get_local_path()) + assert img.format == "JPEG" + assert img.size == (640, 360) + + +def test_save_fragment(tmp_path, video_file): + output_file = posixpath.join(tmp_path, "fragment.mp4") + fragment = video_file.save_fragment(2.5, 5, str(output_file)) + assert fragment.start == 2.5 + assert fragment.end == 5 + + fragment.ensure_cached() + assert fragment.get_info().model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 2.5, + "frames": 75, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + } + + +def test_save_fragment_error(video_file): + with pytest.raises(ValueError): + video_file.save_fragment(5, 2.5, "fragment.mp4") + + +def test_save_fragments(tmp_path, video_file): + intervals = [(1, 2), (3, 4), (5, 6)] + + fragments = list(video_file.save_fragments(intervals, str(tmp_path))) + assert len(fragments) == 3 + + for i, fragment in enumerate(fragments): + assert fragment.start == 1 + 2 * i + assert fragment.end == 2 + 2 * i + + fragment.ensure_cached() + assert fragment.get_info().model_dump() == { + "width": 640, + "height": 360, + "fps": 30.0, + "duration": 1, + "frames": 30, + "format": "mov,mp4,m4a,3gp,3g2,mj2", + "codec": "h264", + } + + +def test_save_fragments_error(video_file): + fragments = list(video_file.save_fragments([(2, 1)], "fragments")) + assert len(fragments) == 0