Skip to content

Commit

Permalink
Apply some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-misuk-valor committed Jan 25, 2025
1 parent 068fab9 commit 4940804
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 43 deletions.
19 changes: 15 additions & 4 deletions src/hope_dedup_engine/apps/api/models/deduplication.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import chain
from typing import Any, Final, override
from uuid import uuid4

Expand All @@ -15,6 +16,7 @@
ImageEmbedding,
ImageEmbeddingError,
Score,
SortedTuple,
)

REFERENCE_PK_LENGTH: Final[int] = 100
Expand Down Expand Up @@ -86,10 +88,19 @@ def get_findings(self) -> list[Finding]:
)
)

def get_ignored_pairs(self) -> list[IgnoredPair]:
return list(
self.ignoredreferencepkpair_set.values_list("first", "second")
) + list(self.ignoredfilenamepair_set.values_list("first", "second"))
def get_ignored_pairs(self) -> set[IgnoredPair]:
return set(
chain(
map(
SortedTuple,
self.ignoredreferencepkpair_set.values_list("first", "second"),
),
map(
SortedTuple,
list(self.ignoredfilenamepair_set.values_list("first", "second")),
),
)
)

def update_encodings(self, encodings: list[ImageEmbedding]) -> None:
with transaction.atomic():
Expand Down
20 changes: 12 additions & 8 deletions src/hope_dedup_engine/apps/faces/celery/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,32 @@

from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.faces.celery.tasks.deduplication import (
deduplication_set_embedding_pairs,
deduplication_set_image_files,
encode_images,
filter_ignored_pairs,
find_duplicates,
get_deduplication_set_embedding_pairs,
get_deduplication_set_image_files,
save_encoding_errors_in_findings,
)
from hope_dedup_engine.utils.celery.utility_tasks import parallelize

IMAGE_ENCODING_BATCH_SIZE = 50
DUPLICATE_FINDING_BATCH_SIZE = 200


def image_pipeline(
deduplication_set: DeduplicationSet, config: dict[str, Any]
) -> Signature:
encode_images_pipeline = parallelize.si(
deduplication_set_image_files.s(deduplication_set.id),
encode_images.s(config),
100,
get_deduplication_set_image_files.s(deduplication_set.id),
encode_images.s(deduplication_set.id, config.get("encoding")),
IMAGE_ENCODING_BATCH_SIZE,
)
find_duplicates_pipeline = parallelize.si(
deduplication_set_embedding_pairs.s(deduplication_set.id),
filter_ignored_pairs.s(deduplication_set.id) | find_duplicates.s(config),
100,
get_deduplication_set_embedding_pairs.s(deduplication_set.id),
filter_ignored_pairs.s(deduplication_set.id)
| find_duplicates.s(deduplication_set.id, config.get("deduplicate", {})),
DUPLICATE_FINDING_BATCH_SIZE,
)

return (
Expand Down
31 changes: 14 additions & 17 deletions src/hope_dedup_engine/apps/faces/celery/tasks/deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
)
from hope_dedup_engine.config.celery import app
from hope_dedup_engine.constants import FacialError
from hope_dedup_engine.types import EntityEmbedding, Filename
from hope_dedup_engine.types import EntityEmbedding, Filename, SortedTuple
from hope_dedup_engine.utils.celery.task_result import wrapped


@app.task
@wrapped
def deduplication_set_image_files(deduplication_set_id: str) -> list[Filename]:
def get_deduplication_set_image_files(deduplication_set_id: str) -> list[Filename]:
# TODO: optimize it calculating on DB side
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
pk=deduplication_set_id
Expand All @@ -29,20 +29,21 @@ def deduplication_set_image_files(deduplication_set_id: str) -> list[Filename]:
@wrapped
def encode_images(
images: list[str],
config: dict[str, Any],
deduplication_set_id: str,
encoding_config: dict[str, Any],
) -> None:
"""Encode faces in a chunk of files."""
encodings, errors = encode_faces(images, config.get("encoding"))
encodings, errors = encode_faces(images, encoding_config)
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
pk=config.get("deduplication_set_id")
pk=deduplication_set_id
)
deduplication_set.update_encodings(encodings)
deduplication_set.update_encoding_errors(errors)


@app.task
@wrapped
def deduplication_set_embedding_pairs(
def get_deduplication_set_embedding_pairs(
deduplication_set_id: str,
) -> Iterator[tuple[EntityEmbedding, EntityEmbedding]]:
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
Expand All @@ -69,16 +70,13 @@ def filter_ignored_pairs(
deduplication_set: DeduplicationSet = DeduplicationSet.objects.get(
pk=deduplication_set_id
)
ignored_pairs = set(deduplication_set.get_ignored_pairs())
ignored_pairs = deduplication_set.get_ignored_pairs()
filtered = []
for embedding_pair in embedding_pairs:
first, second = embedding_pair
first_reference_pk, _ = first
second_reference_pk, _ = second
if (first_reference_pk, second_reference_pk) not in ignored_pairs and (
second_reference_pk,
first_reference_pk,
) not in ignored_pairs:
if SortedTuple((first_reference_pk, second_reference_pk)) not in ignored_pairs:
filtered.append(embedding_pair)

return filtered
Expand All @@ -88,16 +86,15 @@ def filter_ignored_pairs(
@wrapped
def find_duplicates(
embedding_pairs: list[tuple[EntityEmbedding, EntityEmbedding]],
config: dict[str, Any],
deduplication_set_id: str,
deduplicate_config: dict[str, Any],
) -> None:
"""Deduplicate faces in a chunk of files."""
deduplication_set = DeduplicationSet.objects.get(
pk=config.get("deduplication_set_id")
)
deduplication_set = DeduplicationSet.objects.get(pk=deduplication_set_id)
findings = find_similar_faces(
embedding_pairs,
dedupe_threshold=config.get("deduplicate", {}).get("threshold"),
options=config.get("deduplicate"),
dedupe_threshold=deduplicate_config.get("threshold"),
options=deduplicate_config,
)
deduplication_set.update_findings(findings)

Expand Down
Empty file.
11 changes: 10 additions & 1 deletion src/hope_dedup_engine/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Iterable
from typing import Self

from hope_dedup_engine.constants import FacialError

ReferencePK = str
Expand All @@ -12,4 +15,10 @@
ImageEmbeddingError = tuple[Filename, FacialError]
Finding = tuple[ReferencePK, ReferencePK, Score]

IgnoredPair = tuple[ReferencePK, ReferencePK]

class SortedTuple(tuple):
def __new__(cls, iterable: Iterable) -> Self:
return tuple.__new__(cls, sorted(iterable))


IgnoredPair = SortedTuple[ReferencePK, ReferencePK]
15 changes: 2 additions & 13 deletions src/hope_dedup_engine/utils/celery/utility_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections.abc import Callable
from itertools import batched
from typing import Any, NoReturn

Expand All @@ -11,23 +10,13 @@
SerializedTask = dict[str, Any]


@app.task(bind=True)
@wrapped
def map_[
T, P
](self: celery.Task, results: list[T], serialize_task: SerializedTask) -> list[P]:
"""Celery map/starmap/xmap cannot be used in chain"""
signature: Callable[[T], P] = self.app.signature(serialize_task)
return list(map(signature, results))


@app.task(bind=True)
@wrapped
def parallelize(
self: celery.Task,
producer: SerializedTask,
task: SerializedTask,
size: int,
batch_size: int,
end_task: SerializedTask | None = None,
) -> NoReturn:
producer_signature = self.app.signature(producer)
Expand All @@ -36,7 +25,7 @@ def parallelize(
signature: canvas.Signature = self.app.signature(task)

signatures = []
for batch in batched(data, size):
for batch in batched(data, batch_size):
args = (batch,)
if isinstance(signature, canvas._chain):
clone = signature.clone()
Expand Down

0 comments on commit 4940804

Please sign in to comment.