Skip to content

Commit

Permalink
Update ultralytics models (#592)
Browse files Browse the repository at this point in the history
Co-authored-by: Helio Machado <[email protected]>
dreadatour and 0x2b3bfa0 authored Nov 17, 2024
1 parent ebc19c6 commit b62d091
Showing 10 changed files with 812 additions and 113 deletions.
7 changes: 4 additions & 3 deletions src/datachain/lib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import yolo
from .bbox import BBox
from . import ultralytics
from .bbox import BBox, OBBox
from .pose import Pose, Pose3D
from .segment import Segments

__all__ = ["BBox", "Pose", "Pose3D", "yolo"]
__all__ = ["BBox", "OBBox", "Pose", "Pose3D", "Segments", "ultralytics"]
121 changes: 96 additions & 25 deletions src/datachain/lib/models/bbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from pydantic import Field

from datachain.lib.data_model import DataModel
@@ -11,35 +9,108 @@ class BBox(DataModel):
Attributes:
title (str): The title of the bounding box.
x1 (float): The x-coordinate of the top-left corner of the bounding box.
y1 (float): The y-coordinate of the top-left corner of the bounding box.
x2 (float): The x-coordinate of the bottom-right corner of the bounding box.
y2 (float): The y-coordinate of the bottom-right corner of the bounding box.
coords (list[int]): The coordinates of the bounding box.
The bounding box is defined by two points:
- (x1, y1): The top-left corner of the box.
- (x2, y2): The bottom-right corner of the box.
"""

title: str = Field(default="")
x1: float = Field(default=0)
y1: float = Field(default=0)
x2: float = Field(default=0)
y2: float = Field(default=0)
coords: list[int] = Field(default=None)

@staticmethod
def from_list(coords: list[float], title: str = "") -> "BBox":
assert len(coords) == 4, "Bounding box coordinates must be a list of 4 floats."
assert all(
isinstance(value, (int, float)) for value in coords
), "Bounding box coordinates must be integers or floats."
return BBox(
title=title,
coords=[round(c) for c in coords],
)

@staticmethod
def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
assert (
len(coords) == 4
), "Bounding box coordinates must be a dictionary of 4 floats."
assert set(coords) == {
"x1",
"y1",
"x2",
"y2",
}, "Bounding box coordinates must contain keys with coordinates."
assert all(
isinstance(value, (int, float)) for value in coords.values()
), "Bounding box coordinates must be integers or floats."
return BBox(
title=title,
coords=[
round(coords["x1"]),
round(coords["y1"]),
round(coords["x2"]),
round(coords["y2"]),
],
)


class OBBox(DataModel):
"""
A data model for representing oriented bounding boxes.
Attributes:
title (str): The title of the oriented bounding box.
coords (list[int]): The coordinates of the oriented bounding box.
The oriented bounding box is defined by four points:
- (x1, y1): The first corner of the box.
- (x2, y2): The second corner of the box.
- (x3, y3): The third corner of the box.
- (x4, y4): The fourth corner of the box.
"""

title: str = Field(default="")
coords: list[int] = Field(default=None)

@staticmethod
def from_list(coords: list[float], title: str = "") -> "OBBox":
assert (
len(coords) == 8
), "Oriented bounding box coordinates must be a list of 8 floats."
assert all(
isinstance(value, (int, float)) for value in coords
), "Oriented bounding box coordinates must be integers or floats."
return OBBox(
title=title,
coords=[round(c) for c in coords],
)

@staticmethod
def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox":
"""
Converts a bounding box in (x, y, width, height) format
to a BBox data model instance.
Args:
bbox (list[float]): A bounding box, represented as a list
of four floats [x, y, width, height].
Returns:
BBox2D: An instance of the BBox data model.
"""
assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}"
x, y, w, h = bbox
return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h)
def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
assert set(coords) == {
"x1",
"y1",
"x2",
"y2",
"x3",
"y3",
"x4",
"y4",
}, "Oriented bounding box coordinates must contain keys with coordinates."
assert all(
isinstance(value, (int, float)) for value in coords.values()
), "Oriented bounding box coordinates must be integers or floats."
return OBBox(
title=title,
coords=[
round(coords["x1"]),
round(coords["y1"]),
round(coords["x2"]),
round(coords["y2"]),
round(coords["x3"]),
round(coords["y3"]),
round(coords["x4"]),
round(coords["y4"]),
],
)
87 changes: 79 additions & 8 deletions src/datachain/lib/models/pose.py
Original file line number Diff line number Diff line change
@@ -8,30 +8,101 @@ class Pose(DataModel):
A data model for representing pose keypoints.
Attributes:
x (list[float]): The x-coordinates of the keypoints.
y (list[float]): The y-coordinates of the keypoints.
x (list[int]): The x-coordinates of the keypoints.
y (list[int]): The y-coordinates of the keypoints.
The keypoints are represented as lists of x and y coordinates, where each index
corresponds to a specific body part.
"""

x: list[float] = Field(default=None)
y: list[float] = Field(default=None)
x: list[int] = Field(default=None)
y: list[int] = Field(default=None)

@staticmethod
def from_list(points: list[list[float]]) -> "Pose":
assert len(points) == 2, "Pose coordinates must be a list of 2 lists."
points_x, points_y = points
assert (
len(points_x) == len(points_y) == 17
), "Pose x and y coordinates must have the same length of 17."
assert all(
isinstance(value, (int, float)) for value in [*points_x, *points_y]
), "Pose coordinates must be integers or floats."
return Pose(
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
)

@staticmethod
def from_dict(points: dict[str, list[float]]) -> "Pose":
assert set(points) == {
"x",
"y",
}, "Pose coordinates must contain keys 'x' and 'y'."
points_x, points_y = points["x"], points["y"]
assert (
len(points_x) == len(points_y) == 17
), "Pose x and y coordinates must have the same length of 17."
assert all(
isinstance(value, (int, float)) for value in [*points_x, *points_y]
), "Pose coordinates must be integers or floats."
return Pose(
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
)


class Pose3D(DataModel):
"""
A data model for representing 3D pose keypoints.
Attributes:
x (list[float]): The x-coordinates of the keypoints.
y (list[float]): The y-coordinates of the keypoints.
x (list[int]): The x-coordinates of the keypoints.
y (list[int]): The y-coordinates of the keypoints.
visible (list[float]): The visibility of the keypoints.
The keypoints are represented as lists of x, y, and visibility values,
where each index corresponds to a specific body part.
"""

x: list[float] = Field(default=None)
y: list[float] = Field(default=None)
x: list[int] = Field(default=None)
y: list[int] = Field(default=None)
visible: list[float] = Field(default=None)

@staticmethod
def from_list(points: list[list[float]]) -> "Pose3D":
assert len(points) == 3, "Pose coordinates must be a list of 3 lists."
points_x, points_y, points_v = points
assert (
len(points_x) == len(points_y) == len(points_v) == 17
), "Pose x, y, and visibility coordinates must have the same length of 17."
assert all(
isinstance(value, (int, float))
for value in [*points_x, *points_y, *points_v]
), "Pose coordinates must be integers or floats."
return Pose3D(
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
visible=points_v,
)

@staticmethod
def from_dict(points: dict[str, list[float]]) -> "Pose3D":
assert set(points) == {
"x",
"y",
"visible",
}, "Pose coordinates must contain keys 'x', 'y', and 'visible'."
points_x, points_y, points_v = points["x"], points["y"], points["visible"]
assert (
len(points_x) == len(points_y) == len(points_v) == 17
), "Pose x, y, and visibility coordinates must have the same length of 17."
assert all(
isinstance(value, (int, float))
for value in [*points_x, *points_y, *points_v]
), "Pose coordinates must be integers or floats."
return Pose3D(
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
visible=points_v,
)
53 changes: 53 additions & 0 deletions src/datachain/lib/models/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pydantic import Field

from datachain.lib.data_model import DataModel


class Segments(DataModel):
"""
A data model for representing segments.
Attributes:
title (str): The title of the segments.
x (list[int]): The x-coordinates of the segments.
y (list[int]): The y-coordinates of the segments.
The segments are represented as lists of x and y coordinates, where each index
corresponds to a specific segment.
"""

title: str = Field(default="")
x: list[int] = Field(default=None)
y: list[int] = Field(default=None)

@staticmethod
def from_list(points: list[list[float]], title: str = "") -> "Segments":
assert len(points) == 2, "Segments coordinates must be a list of 2 lists."
points_x, points_y = points
assert len(points_x) == len(
points_y
), "Segments x and y coordinates must have the same length."
assert all(
isinstance(value, (int, float)) for value in [*points_x, *points_y]
), "Segments coordinates must be integers or floats."
return Segments(
title=title,
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
)

@staticmethod
def from_dict(points: dict[str, list[float]], title: str = "") -> "Segments":
assert set(points) == {
"x",
"y",
}, "Segments coordinates must contain keys 'x' and 'y'."
points_x, points_y = points["x"], points["y"]
assert all(
isinstance(value, (int, float)) for value in [*points_x, *points_y]
), "Segments coordinates must be integers or floats."
return Segments(
title=title,
x=[round(coord) for coord in points_x],
y=[round(coord) for coord in points_y],
)
14 changes: 14 additions & 0 deletions src/datachain/lib/models/ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .bbox import YoloBBox, YoloBBoxes, YoloOBBox, YoloOBBoxes
from .pose import YoloPose, YoloPoses
from .segment import YoloSegment, YoloSegments

__all__ = [
"YoloBBox",
"YoloBBoxes",
"YoloOBBox",
"YoloOBBoxes",
"YoloPose",
"YoloPoses",
"YoloSegment",
"YoloSegments",
]
Loading

0 comments on commit b62d091

Please sign in to comment.