diff --git a/src/datachain/lib/models/__init__.py b/src/datachain/lib/models/__init__.py index b4335237b..398bd1a8a 100644 --- a/src/datachain/lib/models/__init__.py +++ b/src/datachain/lib/models/__init__.py @@ -1,5 +1,6 @@ -from . import yolo -from .bbox import BBox +from . import ultralytics +from .bbox import BBox, OBBox from .pose import Pose, Pose3D +from .segment import Segments -__all__ = ["BBox", "Pose", "Pose3D", "yolo"] +__all__ = ["BBox", "OBBox", "Pose", "Pose3D", "Segments", "ultralytics"] diff --git a/src/datachain/lib/models/bbox.py b/src/datachain/lib/models/bbox.py index 501ed6296..cefc5181d 100644 --- a/src/datachain/lib/models/bbox.py +++ b/src/datachain/lib/models/bbox.py @@ -1,5 +1,3 @@ -from typing import Optional - from pydantic import Field from datachain.lib.data_model import DataModel @@ -11,10 +9,7 @@ class BBox(DataModel): Attributes: title (str): The title of the bounding box. - x1 (float): The x-coordinate of the top-left corner of the bounding box. - y1 (float): The y-coordinate of the top-left corner of the bounding box. - x2 (float): The x-coordinate of the bottom-right corner of the bounding box. - y2 (float): The y-coordinate of the bottom-right corner of the bounding box. + coords (list[int]): The coordinates of the bounding box. The bounding box is defined by two points: - (x1, y1): The top-left corner of the box. @@ -22,24 +17,100 @@ class BBox(DataModel): """ title: str = Field(default="") - x1: float = Field(default=0) - y1: float = Field(default=0) - x2: float = Field(default=0) - y2: float = Field(default=0) + coords: list[int] = Field(default=None) + + @staticmethod + def from_list(coords: list[float], title: str = "") -> "BBox": + assert len(coords) == 4, "Bounding box coordinates must be a list of 4 floats." + assert all( + isinstance(value, (int, float)) for value in coords + ), "Bounding box coordinates must be integers or floats." + return BBox( + title=title, + coords=[round(c) for c in coords], + ) + + @staticmethod + def from_dict(coords: dict[str, float], title: str = "") -> "BBox": + assert ( + len(coords) == 4 + ), "Bounding box coordinates must be a dictionary of 4 floats." + assert set(coords) == { + "x1", + "y1", + "x2", + "y2", + }, "Bounding box coordinates must contain keys with coordinates." + assert all( + isinstance(value, (int, float)) for value in coords.values() + ), "Bounding box coordinates must be integers or floats." + return BBox( + title=title, + coords=[ + round(coords["x1"]), + round(coords["y1"]), + round(coords["x2"]), + round(coords["y2"]), + ], + ) + + +class OBBox(DataModel): + """ + A data model for representing oriented bounding boxes. + + Attributes: + title (str): The title of the oriented bounding box. + coords (list[int]): The coordinates of the oriented bounding box. + + The oriented bounding box is defined by four points: + - (x1, y1): The first corner of the box. + - (x2, y2): The second corner of the box. + - (x3, y3): The third corner of the box. + - (x4, y4): The fourth corner of the box. + """ + + title: str = Field(default="") + coords: list[int] = Field(default=None) + + @staticmethod + def from_list(coords: list[float], title: str = "") -> "OBBox": + assert ( + len(coords) == 8 + ), "Oriented bounding box coordinates must be a list of 8 floats." + assert all( + isinstance(value, (int, float)) for value in coords + ), "Oriented bounding box coordinates must be integers or floats." + return OBBox( + title=title, + coords=[round(c) for c in coords], + ) @staticmethod - def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox": - """ - Converts a bounding box in (x, y, width, height) format - to a BBox data model instance. - - Args: - bbox (list[float]): A bounding box, represented as a list - of four floats [x, y, width, height]. - - Returns: - BBox2D: An instance of the BBox data model. - """ - assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}" - x, y, w, h = bbox - return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h) + def from_dict(coords: dict[str, float], title: str = "") -> "OBBox": + assert set(coords) == { + "x1", + "y1", + "x2", + "y2", + "x3", + "y3", + "x4", + "y4", + }, "Oriented bounding box coordinates must contain keys with coordinates." + assert all( + isinstance(value, (int, float)) for value in coords.values() + ), "Oriented bounding box coordinates must be integers or floats." + return OBBox( + title=title, + coords=[ + round(coords["x1"]), + round(coords["y1"]), + round(coords["x2"]), + round(coords["y2"]), + round(coords["x3"]), + round(coords["y3"]), + round(coords["x4"]), + round(coords["y4"]), + ], + ) diff --git a/src/datachain/lib/models/pose.py b/src/datachain/lib/models/pose.py index 5cb95a29b..c4926b463 100644 --- a/src/datachain/lib/models/pose.py +++ b/src/datachain/lib/models/pose.py @@ -8,15 +8,48 @@ class Pose(DataModel): A data model for representing pose keypoints. Attributes: - x (list[float]): The x-coordinates of the keypoints. - y (list[float]): The y-coordinates of the keypoints. + x (list[int]): The x-coordinates of the keypoints. + y (list[int]): The y-coordinates of the keypoints. The keypoints are represented as lists of x and y coordinates, where each index corresponds to a specific body part. """ - x: list[float] = Field(default=None) - y: list[float] = Field(default=None) + x: list[int] = Field(default=None) + y: list[int] = Field(default=None) + + @staticmethod + def from_list(points: list[list[float]]) -> "Pose": + assert len(points) == 2, "Pose coordinates must be a list of 2 lists." + points_x, points_y = points + assert ( + len(points_x) == len(points_y) == 17 + ), "Pose x and y coordinates must have the same length of 17." + assert all( + isinstance(value, (int, float)) for value in [*points_x, *points_y] + ), "Pose coordinates must be integers or floats." + return Pose( + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + ) + + @staticmethod + def from_dict(points: dict[str, list[float]]) -> "Pose": + assert set(points) == { + "x", + "y", + }, "Pose coordinates must contain keys 'x' and 'y'." + points_x, points_y = points["x"], points["y"] + assert ( + len(points_x) == len(points_y) == 17 + ), "Pose x and y coordinates must have the same length of 17." + assert all( + isinstance(value, (int, float)) for value in [*points_x, *points_y] + ), "Pose coordinates must be integers or floats." + return Pose( + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + ) class Pose3D(DataModel): @@ -24,14 +57,52 @@ class Pose3D(DataModel): A data model for representing 3D pose keypoints. Attributes: - x (list[float]): The x-coordinates of the keypoints. - y (list[float]): The y-coordinates of the keypoints. + x (list[int]): The x-coordinates of the keypoints. + y (list[int]): The y-coordinates of the keypoints. visible (list[float]): The visibility of the keypoints. The keypoints are represented as lists of x, y, and visibility values, where each index corresponds to a specific body part. """ - x: list[float] = Field(default=None) - y: list[float] = Field(default=None) + x: list[int] = Field(default=None) + y: list[int] = Field(default=None) visible: list[float] = Field(default=None) + + @staticmethod + def from_list(points: list[list[float]]) -> "Pose3D": + assert len(points) == 3, "Pose coordinates must be a list of 3 lists." + points_x, points_y, points_v = points + assert ( + len(points_x) == len(points_y) == len(points_v) == 17 + ), "Pose x, y, and visibility coordinates must have the same length of 17." + assert all( + isinstance(value, (int, float)) + for value in [*points_x, *points_y, *points_v] + ), "Pose coordinates must be integers or floats." + return Pose3D( + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + visible=points_v, + ) + + @staticmethod + def from_dict(points: dict[str, list[float]]) -> "Pose3D": + assert set(points) == { + "x", + "y", + "visible", + }, "Pose coordinates must contain keys 'x', 'y', and 'visible'." + points_x, points_y, points_v = points["x"], points["y"], points["visible"] + assert ( + len(points_x) == len(points_y) == len(points_v) == 17 + ), "Pose x, y, and visibility coordinates must have the same length of 17." + assert all( + isinstance(value, (int, float)) + for value in [*points_x, *points_y, *points_v] + ), "Pose coordinates must be integers or floats." + return Pose3D( + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + visible=points_v, + ) diff --git a/src/datachain/lib/models/segment.py b/src/datachain/lib/models/segment.py new file mode 100644 index 000000000..04e4dec05 --- /dev/null +++ b/src/datachain/lib/models/segment.py @@ -0,0 +1,53 @@ +from pydantic import Field + +from datachain.lib.data_model import DataModel + + +class Segments(DataModel): + """ + A data model for representing segments. + + Attributes: + title (str): The title of the segments. + x (list[int]): The x-coordinates of the segments. + y (list[int]): The y-coordinates of the segments. + + The segments are represented as lists of x and y coordinates, where each index + corresponds to a specific segment. + """ + + title: str = Field(default="") + x: list[int] = Field(default=None) + y: list[int] = Field(default=None) + + @staticmethod + def from_list(points: list[list[float]], title: str = "") -> "Segments": + assert len(points) == 2, "Segments coordinates must be a list of 2 lists." + points_x, points_y = points + assert len(points_x) == len( + points_y + ), "Segments x and y coordinates must have the same length." + assert all( + isinstance(value, (int, float)) for value in [*points_x, *points_y] + ), "Segments coordinates must be integers or floats." + return Segments( + title=title, + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + ) + + @staticmethod + def from_dict(points: dict[str, list[float]], title: str = "") -> "Segments": + assert set(points) == { + "x", + "y", + }, "Segments coordinates must contain keys 'x' and 'y'." + points_x, points_y = points["x"], points["y"] + assert all( + isinstance(value, (int, float)) for value in [*points_x, *points_y] + ), "Segments coordinates must be integers or floats." + return Segments( + title=title, + x=[round(coord) for coord in points_x], + y=[round(coord) for coord in points_y], + ) diff --git a/src/datachain/lib/models/ultralytics/__init__.py b/src/datachain/lib/models/ultralytics/__init__.py new file mode 100644 index 000000000..504dccc91 --- /dev/null +++ b/src/datachain/lib/models/ultralytics/__init__.py @@ -0,0 +1,14 @@ +from .bbox import YoloBBox, YoloBBoxes, YoloOBBox, YoloOBBoxes +from .pose import YoloPose, YoloPoses +from .segment import YoloSegment, YoloSegments + +__all__ = [ + "YoloBBox", + "YoloBBoxes", + "YoloOBBox", + "YoloOBBoxes", + "YoloPose", + "YoloPoses", + "YoloSegment", + "YoloSegments", +] diff --git a/src/datachain/lib/models/ultralytics/bbox.py b/src/datachain/lib/models/ultralytics/bbox.py new file mode 100644 index 000000000..b5fa939ad --- /dev/null +++ b/src/datachain/lib/models/ultralytics/bbox.py @@ -0,0 +1,189 @@ +""" +This module contains the YOLO models. + +YOLO stands for "You Only Look Once", a family of object detection models that +are designed to be fast and accurate. The models are trained to detect objects +in images by dividing the image into a grid and predicting the bounding boxes +and class probabilities for each grid cell. + +More information about YOLO can be found here: +- https://pjreddie.com/darknet/yolo/ +- https://docs.ultralytics.com/ +""" + +from io import BytesIO +from typing import TYPE_CHECKING + +from PIL import Image +from pydantic import Field + +from datachain.lib.data_model import DataModel +from datachain.lib.models.bbox import BBox, OBBox + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + from ultralytics.models import YOLO + + from datachain.lib.file import File + + +class YoloBBox(DataModel): + """ + A class representing a bounding box detected by a YOLO model. + + Attributes: + cls: The class of the detected object. + name: The name of the detected object. + confidence: The confidence score of the detection. + box: The bounding box of the detected object + """ + + cls: int = Field(default=-1) + name: str = Field(default="") + confidence: float = Field(default=0) + box: BBox = Field(default=None) + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloBBox": + results = yolo(Image.open(BytesIO(file.read()))) + if len(results) == 0: + return YoloBBox() + return YoloBBox.from_result(results[0]) + + @staticmethod + def from_result(result: "Results") -> "YoloBBox": + summary = result.summary() + if not summary: + return YoloBBox() + name = summary[0].get("name", "") + box = ( + BBox.from_dict(summary[0]["box"], title=name) + if "box" in summary[0] + else BBox() + ) + return YoloBBox( + cls=summary[0]["class"], + name=name, + confidence=summary[0]["confidence"], + box=box, + ) + + +class YoloBBoxes(DataModel): + """ + A class representing a list of bounding boxes detected by a YOLO model. + + Attributes: + cls: A list of classes of the detected objects. + name: A list of names of the detected objects. + confidence: A list of confidence scores of the detections. + box: A list of bounding boxes of the detected objects + """ + + cls: list[int] + name: list[str] + confidence: list[float] + box: list[BBox] + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloBBoxes": + results = yolo(Image.open(BytesIO(file.read()))) + return YoloBBoxes.from_results(results) + + @staticmethod + def from_results(results: list["Results"]) -> "YoloBBoxes": + cls, names, confidence, box = [], [], [], [] + for r in results: + for s in r.summary(): + name = s.get("name", "") + cls.append(s["class"]) + names.append(name) + confidence.append(s["confidence"]) + box.append(BBox.from_dict(s.get("box", {}), title=name)) + return YoloBBoxes( + cls=cls, + name=names, + confidence=confidence, + box=box, + ) + + +class YoloOBBox(DataModel): + """ + A class representing an oriented bounding box detected by a YOLO model. + + Attributes: + cls: The class of the detected object. + name: The name of the detected object. + confidence: The confidence score of the detection. + box: The oriented bounding box of the detected object. + """ + + cls: int = Field(default=-1) + name: str = Field(default="") + confidence: float = Field(default=0) + box: OBBox = Field(default=None) + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloOBBox": + results = yolo(Image.open(BytesIO(file.read()))) + if len(results) == 0: + return YoloOBBox() + return YoloOBBox.from_result(results[0]) + + @staticmethod + def from_result(result: "Results") -> "YoloOBBox": + summary = result.summary() + if not summary: + return YoloOBBox() + name = summary[0].get("name", "") + box = ( + OBBox.from_dict(summary[0]["box"], title=name) + if "box" in summary[0] + else OBBox() + ) + return YoloOBBox( + cls=summary[0]["class"], + name=name, + confidence=summary[0]["confidence"], + box=box, + ) + + +class YoloOBBoxes(DataModel): + """ + A class representing a list of oriented bounding boxes detected by a YOLO model. + + Attributes: + cls: A list of classes of the detected objects. + name: A list of names of the detected objects. + confidence: A list of confidence scores of the detections. + box: A list of oriented bounding boxes of the detected objects. + """ + + cls: list[int] + name: list[str] + confidence: list[float] + box: list[OBBox] + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloOBBoxes": + results = yolo(Image.open(BytesIO(file.read()))) + return YoloOBBoxes.from_results(results) + + @staticmethod + def from_results(results: list["Results"]) -> "YoloOBBoxes": + cls, names, confidence, box = [], [], [], [] + for r in results: + for s in r.summary(): + name = s.get("name", "") + cls.append(s["class"]) + names.append(name) + confidence.append(s["confidence"]) + box.append(OBBox.from_dict(s.get("box", {}), title=name)) + return YoloOBBoxes( + cls=cls, + name=names, + confidence=confidence, + box=box, + ) diff --git a/src/datachain/lib/models/ultralytics/pose.py b/src/datachain/lib/models/ultralytics/pose.py new file mode 100644 index 000000000..e97d48c2f --- /dev/null +++ b/src/datachain/lib/models/ultralytics/pose.py @@ -0,0 +1,126 @@ +""" +This module contains the YOLO models. + +YOLO stands for "You Only Look Once", a family of object detection models that +are designed to be fast and accurate. The models are trained to detect objects +in images by dividing the image into a grid and predicting the bounding boxes +and class probabilities for each grid cell. + +More information about YOLO can be found here: +- https://pjreddie.com/darknet/yolo/ +- https://docs.ultralytics.com/ +""" + +from typing import TYPE_CHECKING + +from pydantic import Field + +from datachain.lib.data_model import DataModel +from datachain.lib.models.bbox import BBox +from datachain.lib.models.pose import Pose3D + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + + +class YoloPoseBodyPart: + """An enumeration of body parts for YOLO pose keypoints.""" + + nose = 0 + left_eye = 1 + right_eye = 2 + left_ear = 3 + right_ear = 4 + left_shoulder = 5 + right_shoulder = 6 + left_elbow = 7 + right_elbow = 8 + left_wrist = 9 + right_wrist = 10 + left_hip = 11 + right_hip = 12 + left_knee = 13 + right_knee = 14 + left_ankle = 15 + right_ankle = 16 + + +class YoloPose(DataModel): + """ + A data model for YOLO pose keypoints. + + Attributes: + cls: The class of the pose. + name: The name of the pose. + confidence: The confidence score of the pose. + box: The bounding box of the pose. + keypoints: The 3D pose keypoints. + """ + + cls: int = Field(default=-1) + name: str = Field(default="") + confidence: float = Field(default=0) + box: BBox = Field(default=None) + keypoints: Pose3D = Field(default=None) + + @staticmethod + def from_result(result: "Results") -> "YoloPose": + summary = result.summary() + if not summary: + return YoloPose() + name = summary[0].get("name", "") + box = ( + BBox.from_dict(summary[0]["box"], title=name) + if "box" in summary[0] + else BBox() + ) + keypoints = ( + Pose3D.from_dict(summary[0]["keypoints"]) + if "keypoints" in summary[0] + else Pose3D() + ) + return YoloPose( + cls=summary[0]["class"], + name=name, + confidence=summary[0]["confidence"], + box=box, + keypoints=keypoints, + ) + + +class YoloPoses(DataModel): + """ + A data model for a list of YOLO pose keypoints. + + Attributes: + cls: The classes of the poses. + name: The names of the poses. + confidence: The confidence scores of the poses. + box: The bounding boxes of the poses. + keypoints: The 3D pose keypoints of the poses. + """ + + cls: list[int] + name: list[str] + confidence: list[float] + box: list[BBox] + keypoints: list[Pose3D] + + @staticmethod + def from_results(results: list["Results"]) -> "YoloPoses": + cls, names, confidence, box, keypoints = [], [], [], [], [] + for r in results: + for s in r.summary(): + name = s.get("name", "") + cls.append(s["class"]) + names.append(name) + confidence.append(s["confidence"]) + box.append(BBox.from_dict(s.get("box", {}), title=name)) + keypoints.append(Pose3D.from_dict(s.get("keypoints", {}))) + return YoloPoses( + cls=cls, + name=names, + confidence=confidence, + box=box, + keypoints=keypoints, + ) diff --git a/src/datachain/lib/models/ultralytics/segment.py b/src/datachain/lib/models/ultralytics/segment.py new file mode 100644 index 000000000..0614e7ed8 --- /dev/null +++ b/src/datachain/lib/models/ultralytics/segment.py @@ -0,0 +1,121 @@ +""" +This module contains the YOLO models. + +YOLO stands for "You Only Look Once", a family of object detection models that +are designed to be fast and accurate. The models are trained to detect objects +in images by dividing the image into a grid and predicting the bounding boxes +and class probabilities for each grid cell. + +More information about YOLO can be found here: +- https://pjreddie.com/darknet/yolo/ +- https://docs.ultralytics.com/ +""" + +from io import BytesIO +from typing import TYPE_CHECKING + +from PIL import Image +from pydantic import Field + +from datachain.lib.data_model import DataModel +from datachain.lib.models.bbox import BBox +from datachain.lib.models.segment import Segments + +if TYPE_CHECKING: + from ultralytics.engine.results import Results + from ultralytics.models import YOLO + + from datachain.lib.file import File + + +class YoloSegment(DataModel): + """ + A data model for a single YOLO segment. + + Attributes: + cls (int): The class of the segment. + name (str): The name of the segment. + confidence (float): The confidence of the segment. + box (BBox): The bounding box of the segment. + segments (Segments): The segments of the segment. + """ + + cls: int = Field(default=-1) + name: str = Field(default="") + confidence: float = Field(default=0) + box: BBox = Field(default=None) + segments: Segments = Field(default=None) + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloSegment": + results = yolo(Image.open(BytesIO(file.read()))) + if len(results) == 0: + return YoloSegment() + return YoloSegment.from_result(results[0]) + + @staticmethod + def from_result(result: "Results") -> "YoloSegment": + summary = result.summary() + if not summary: + return YoloSegment() + name = summary[0].get("name", "") + box = ( + BBox.from_dict(summary[0]["box"], title=name) + if "box" in summary[0] + else BBox() + ) + segments = ( + Segments.from_dict(summary[0]["segments"], title=name) + if "segments" in summary[0] + else Segments() + ) + return YoloSegment( + cls=summary[0]["class"], + name=summary[0]["name"], + confidence=summary[0]["confidence"], + box=box, + segments=segments, + ) + + +class YoloSegments(DataModel): + """ + A data model for a list of YOLO segments. + + Attributes: + cls (list[int]): The classes of the segments. + name (list[str]): The names of the segments. + confidence (list[float]): The confidences of the segments. + box (list[BBox]): The bounding boxes of the segments. + segments (list[Segments]): The segments of the segments. + """ + + cls: list[int] + name: list[str] + confidence: list[float] + box: list[BBox] + segments: list[Segments] + + @staticmethod + def from_file(yolo: "YOLO", file: "File") -> "YoloSegments": + results = yolo(Image.open(BytesIO(file.read()))) + return YoloSegments.from_results(results) + + @staticmethod + def from_results(results: list["Results"]) -> "YoloSegments": + cls, names, confidence, box, segments = [], [], [], [], [] + for r in results: + for s in r.summary(): + name = s.get("name", "") + cls.append(s["class"]) + names.append(name) + confidence.append(s["confidence"]) + box.append(BBox.from_dict(s.get("box", {}), title=name)) + segments.append(Segments.from_dict(s.get("segments", {}), title=name)) + return YoloSegments( + cls=cls, + name=names, + confidence=confidence, + box=box, + segments=segments, + ) diff --git a/src/datachain/lib/models/yolo.py b/src/datachain/lib/models/yolo.py deleted file mode 100644 index 4231240a6..000000000 --- a/src/datachain/lib/models/yolo.py +++ /dev/null @@ -1,39 +0,0 @@ -""" -This module contains the YOLO models. - -YOLO stands for "You Only Look Once", a family of object detection models that -are designed to be fast and accurate. The models are trained to detect objects -in images by dividing the image into a grid and predicting the bounding boxes -and class probabilities for each grid cell. - -More information about YOLO can be found here: -- https://pjreddie.com/darknet/yolo/ -- https://docs.ultralytics.com/ -""" - - -class PoseBodyPart: - """ - An enumeration of body parts for YOLO pose keypoints. - - More information about the body parts can be found here: - https://docs.ultralytics.com/tasks/pose/ - """ - - nose = 0 - left_eye = 1 - right_eye = 2 - left_ear = 3 - right_ear = 4 - left_shoulder = 5 - right_shoulder = 6 - left_elbow = 7 - right_elbow = 8 - left_wrist = 9 - right_wrist = 10 - left_hip = 11 - right_hip = 12 - left_knee = 13 - right_knee = 14 - left_ankle = 15 - right_ankle = 16 diff --git a/tests/unit/lib/test_models.py b/tests/unit/lib/test_models.py index c3ea2b463..f090f31dd 100644 --- a/tests/unit/lib/test_models.py +++ b/tests/unit/lib/test_models.py @@ -1,50 +1,142 @@ +import pytest + from datachain.lib import models +from datachain.lib.models.ultralytics.pose import YoloPoseBodyPart -def test_bbox(): - bbox = models.BBox(title="BBox", x1=0.5, y1=1.5, x2=2.5, y2=3.5) +@pytest.mark.parametrize( + "bbox", + [ + models.BBox(title="BBox", coords=[0, 1, 2, 3]), + models.BBox.from_list([0.3, 1.1, 1.7, 3.4], title="BBox"), + models.BBox.from_dict({"x1": 0, "y1": 0.8, "x2": 2.2, "y2": 2.9}, title="BBox"), + ], +) +def test_bbox(bbox): assert bbox.model_dump() == { "title": "BBox", - "x1": 0.5, - "y1": 1.5, - "x2": 2.5, - "y2": 3.5, + "coords": [0, 1, 2, 3], } -def test_bbox_from_xywh(): - bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5]) - assert bbox.model_dump() == {"title": "", "x1": 0.5, "y1": 1.5, "x2": 3, "y2": 5} +@pytest.mark.parametrize( + "obbox", + [ + models.OBBox(title="OBBox", coords=[0, 1, 2, 3, 4, 5, 6, 7]), + models.OBBox.from_list([0.3, 1.1, 1.7, 3.4, 4.0, 4.9, 5.6, 7.0], title="OBBox"), + models.OBBox.from_dict( + { + "x1": 0, + "y1": 0.8, + "x2": 2.2, + "y2": 2.9, + "x3": 3.9, + "y3": 5.4, + "x4": 6.0, + "y4": 7.4, + }, + title="OBBox", + ), + ], +) +def test_obbox(obbox): + assert obbox.model_dump() == { + "title": "OBBox", + "coords": [0, 1, 2, 3, 4, 5, 6, 7], + } - bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5], title="BBox") - assert bbox.model_dump() == { - "title": "BBox", - "x1": 0.5, - "y1": 1.5, - "x2": 3, - "y2": 5, + +@pytest.mark.parametrize( + "pose", + [ + models.Pose(x=list(range(17)), y=[y * 2 for y in range(17)]), + models.Pose.from_list([list(range(17)), [y * 2 for y in range(17)]]), + models.Pose.from_dict({"x": list(range(17)), "y": [y * 2 for y in range(17)]}), + ], +) +def test_pose(pose): + assert pose.model_dump() == { + "x": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "y": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32], + } + + +@pytest.mark.parametrize( + "pose", + [ + models.Pose3D( + x=list(range(17)), y=[y * 2 for y in range(17)], visible=[0.2] * 17 + ), + models.Pose3D.from_list( + [list(range(17)), [y * 2 for y in range(17)], [0.2] * 17] + ), + models.Pose3D.from_dict( + { + "x": list(range(17)), + "y": [y * 2 for y in range(17)], + "visible": [0.2] * 17, + } + ), + ], +) +def test_pose3d(pose): + assert pose.model_dump() == { + "x": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "y": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32], + "visible": [ + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + 0.2, + ], + } + + +@pytest.mark.parametrize( + "segments", + [ + models.Segments(x=[0, 1, 2], y=[2, 3, 5], title="Segments"), + models.Segments.from_list([[0, 1, 2], [2, 3, 5]], title="Segments"), + models.Segments.from_dict({"x": [0, 1, 2], "y": [2, 3, 5]}, title="Segments"), + ], +) +def test_segments(segments): + assert segments.model_dump() == { + "title": "Segments", + "x": [0, 1, 2], + "y": [2, 3, 5], } -def test_pose(): - x = [x * 0.5 for x in range(17)] - y = [y * 1.5 for y in range(17)] - pose = models.Pose(x=x, y=y) - assert pose.model_dump() == {"x": x, "y": y} - assert pose.x[models.yolo.PoseBodyPart.nose] == 0 - assert pose.x[models.yolo.PoseBodyPart.left_eye] == 0.5 - assert pose.x[models.yolo.PoseBodyPart.right_eye] == 1 - assert pose.x[models.yolo.PoseBodyPart.left_ear] == 1.5 - assert pose.x[models.yolo.PoseBodyPart.right_ear] == 2 - assert pose.x[models.yolo.PoseBodyPart.left_shoulder] == 2.5 - assert pose.x[models.yolo.PoseBodyPart.right_shoulder] == 3 - assert pose.x[models.yolo.PoseBodyPart.left_elbow] == 3.5 - assert pose.x[models.yolo.PoseBodyPart.right_elbow] == 4 - assert pose.x[models.yolo.PoseBodyPart.left_wrist] == 4.5 - assert pose.x[models.yolo.PoseBodyPart.right_wrist] == 5 - assert pose.x[models.yolo.PoseBodyPart.left_hip] == 5.5 - assert pose.x[models.yolo.PoseBodyPart.right_hip] == 6 - assert pose.x[models.yolo.PoseBodyPart.left_knee] == 6.5 - assert pose.x[models.yolo.PoseBodyPart.right_knee] == 7 - assert pose.x[models.yolo.PoseBodyPart.left_ankle] == 7.5 - assert pose.x[models.yolo.PoseBodyPart.right_ankle] == 8 +def test_yolo_pose_body_parts(): + pose = models.Pose(x=list(range(17)), y=list(range(17))) + assert pose.x[YoloPoseBodyPart.nose] == 0 + assert pose.x[YoloPoseBodyPart.left_eye] == 1 + assert pose.x[YoloPoseBodyPart.right_eye] == 2 + assert pose.x[YoloPoseBodyPart.left_ear] == 3 + assert pose.x[YoloPoseBodyPart.right_ear] == 4 + assert pose.x[YoloPoseBodyPart.left_shoulder] == 5 + assert pose.x[YoloPoseBodyPart.right_shoulder] == 6 + assert pose.x[YoloPoseBodyPart.left_elbow] == 7 + assert pose.x[YoloPoseBodyPart.right_elbow] == 8 + assert pose.x[YoloPoseBodyPart.left_wrist] == 9 + assert pose.x[YoloPoseBodyPart.right_wrist] == 10 + assert pose.x[YoloPoseBodyPart.left_hip] == 11 + assert pose.x[YoloPoseBodyPart.right_hip] == 12 + assert pose.x[YoloPoseBodyPart.left_knee] == 13 + assert pose.x[YoloPoseBodyPart.right_knee] == 14 + assert pose.x[YoloPoseBodyPart.left_ankle] == 15 + assert pose.x[YoloPoseBodyPart.right_ankle] == 16