Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure embeddings #14266

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions frigate/api/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
status_code=404,
)

thumb_result = context.embeddings.search_thumbnail(search_event)
thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
Expand All @@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
search_types = search_type.split(",")

if "thumbnail" in search_types:
thumb_result = context.embeddings.search_thumbnail(query)
thumb_result = context.search_thumbnail(query)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
Expand All @@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
)

if "description" in search_types:
desc_result = context.embeddings.search_description(query)
desc_result = context.search_description(query)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
Expand Down Expand Up @@ -944,9 +944,9 @@ def set_description(
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.upsert_description(
event_id=event_id,
description=new_description,
context.update_description(
event_id,
new_description,
)

response_message = (
Expand Down Expand Up @@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id])
context.db.delete_embeddings_thumbnail(id=[event_id])
context.db.delete_embeddings_description(id=[event_id])
return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200,
Expand Down
4 changes: 2 additions & 2 deletions frigate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def check_db_data_migrations(self) -> None:
def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled:
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.config, self.db)
self.embeddings = EmbeddingsContext(self.db)

def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config)
Expand Down Expand Up @@ -699,7 +699,7 @@ def stop(self) -> None:

# Save embeddings stats to disk
if self.embeddings:
self.embeddings.save_stats()
self.embeddings.stop()

# Stop Communicators
self.inter_process_communicator.stop()
Expand Down
62 changes: 62 additions & 0 deletions frigate/comms/embeddings_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Facilitates communication between processes."""

from enum import Enum
from typing import Callable

import zmq

SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"


class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"


class EmbeddingsResponder:
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(SOCKET_REP_REQ)

def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1)

if not has_message:
break

try:
(topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)

response = process(topic, value)

if response is not None:
self.socket.send_json(response)
else:
self.socket.send_json([])
except zmq.ZMQError:
break

def stop(self) -> None:
self.socket.close()
self.context.destroy()


class EmbeddingsRequestor:
"""Simplifies sending data to EmbeddingsResponder and getting a reply."""

def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(SOCKET_REP_REQ)

def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()

def stop(self) -> None:
self.socket.close()
self.context.destroy()
8 changes: 8 additions & 0 deletions frigate/db/sqlitevecq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ def _load_vec_extension(self, conn: sqlite3.Connection) -> None:
conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False)

def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids)

def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
108 changes: 104 additions & 4 deletions frigate/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import signal
import threading
from types import FrameType
from typing import Optional
from typing import Optional, Union

from setproctitle import setproctitle

from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.services import listen

from .embeddings import Embeddings
Expand Down Expand Up @@ -70,10 +72,11 @@ def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:


class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(config.semantic_search, db)
def __init__(self, db: SqliteVecQueueDatabase):
self.db = db
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
self.requestor = EmbeddingsRequestor()

# load stats from disk
try:
Expand All @@ -84,11 +87,108 @@ def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
except FileNotFoundError:
pass

def save_stats(self):
def stop(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(),
}
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f)
self.requestor.stop()

def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
) -> list[tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)

row = cursor.fetchone() if cursor else None

if row:
query_embedding = row[0]
else:
# If no embedding found, generate it and return it
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": query.id, "thumbnail": query.thumbnail},
)
)
else:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
)

sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""

# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))

# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"

parameters = [query_embedding] + event_ids if event_ids else [query_embedding]

results = self.db.execute_sql(sql_query, parameters).fetchall()

return results

def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
)

# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""

# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))

# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"

parameters = [query_embedding] + event_ids if event_ids else [query_embedding]

results = self.db.execute_sql(sql_query, parameters).fetchall()

return results

def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description.value,
{"id": event_id, "description": description},
)
Loading