Skip to content

Commit

Permalink
Restructure embeddings (#14266)
Browse files Browse the repository at this point in the history
* Restructure embeddings

* Use ZMQ to proxy embeddings requests

* Handle serialization

* Formatting

* Remove unused
  • Loading branch information
NickM-27 authored Oct 10, 2024
1 parent a2ca18a commit 8ade85e
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 142 deletions.
16 changes: 8 additions & 8 deletions frigate/api/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
status_code=404,
)

thumb_result = context.embeddings.search_thumbnail(search_event)
thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
Expand All @@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
search_types = search_type.split(",")

if "thumbnail" in search_types:
thumb_result = context.embeddings.search_thumbnail(query)
thumb_result = context.search_thumbnail(query)
thumb_ids = dict(
zip(
[result[0] for result in thumb_result],
Expand All @@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
)

if "description" in search_types:
desc_result = context.embeddings.search_description(query)
desc_result = context.search_description(query)
desc_ids = dict(
zip(
[result[0] for result in desc_result],
Expand Down Expand Up @@ -944,9 +944,9 @@ def set_description(
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.upsert_description(
event_id=event_id,
description=new_description,
context.update_description(
event_id,
new_description,
)

response_message = (
Expand Down Expand Up @@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings
context.embeddings.delete_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id])
context.db.delete_embeddings_thumbnail(id=[event_id])
context.db.delete_embeddings_description(id=[event_id])
return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200,
Expand Down
4 changes: 2 additions & 2 deletions frigate/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def check_db_data_migrations(self) -> None:
def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled:
# Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.config, self.db)
self.embeddings = EmbeddingsContext(self.db)

def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config)
Expand Down Expand Up @@ -699,7 +699,7 @@ def stop(self) -> None:

# Save embeddings stats to disk
if self.embeddings:
self.embeddings.save_stats()
self.embeddings.stop()

# Stop Communicators
self.inter_process_communicator.stop()
Expand Down
62 changes: 62 additions & 0 deletions frigate/comms/embeddings_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Facilitates communication between processes."""

from enum import Enum
from typing import Callable

import zmq

SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"


class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"


class EmbeddingsResponder:
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(SOCKET_REP_REQ)

def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1)

if not has_message:
break

try:
(topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)

response = process(topic, value)

if response is not None:
self.socket.send_json(response)
else:
self.socket.send_json([])
except zmq.ZMQError:
break

def stop(self) -> None:
self.socket.close()
self.context.destroy()


class EmbeddingsRequestor:
"""Simplifies sending data to EmbeddingsResponder and getting a reply."""

def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(SOCKET_REP_REQ)

def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()

def stop(self) -> None:
self.socket.close()
self.context.destroy()
8 changes: 8 additions & 0 deletions frigate/db/sqlitevecq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ def _load_vec_extension(self, conn: sqlite3.Connection) -> None:
conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False)

def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids)

def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)
108 changes: 104 additions & 4 deletions frigate/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import signal
import threading
from types import FrameType
from typing import Optional
from typing import Optional, Union

from setproctitle import setproctitle

from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.services import listen

from .embeddings import Embeddings
Expand Down Expand Up @@ -70,10 +72,11 @@ def receiveSignal(signalNumber: int, frame: Optional[FrameType]) -> None:


class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(config.semantic_search, db)
def __init__(self, db: SqliteVecQueueDatabase):
self.db = db
self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization()
self.requestor = EmbeddingsRequestor()

# load stats from disk
try:
Expand All @@ -84,11 +87,108 @@ def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase):
except FileNotFoundError:
pass

def save_stats(self):
def stop(self):
"""Write the stats to disk as JSON on exit."""
contents = {
"thumb_stats": self.thumb_stats.to_dict(),
"desc_stats": self.desc_stats.to_dict(),
}
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f)
self.requestor.stop()

def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
) -> list[tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)

row = cursor.fetchone() if cursor else None

if row:
query_embedding = row[0]
else:
# If no embedding found, generate it and return it
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": query.id, "thumbnail": query.thumbnail},
)
)
else:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
)

sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""

# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))

# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"

parameters = [query_embedding] + event_ids if event_ids else [query_embedding]

results = self.db.execute_sql(sql_query, parameters).fetchall()

return results

def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
)

# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""

# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))

# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"

parameters = [query_embedding] + event_ids if event_ids else [query_embedding]

results = self.db.execute_sql(sql_query, parameters).fetchall()

return results

def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description.value,
{"id": event_id, "description": description},
)
Loading

0 comments on commit 8ade85e

Please sign in to comment.