forked from openai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 1
/
datastore.py
86 lines (76 loc) · 2.93 KB
/
datastore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import asyncio
from models.models import (
Document,
DocumentChunk,
DocumentMetadataFilter,
Query,
QueryResult,
QueryWithEmbedding,
)
from services.chunks import get_document_chunks
from services.openai import get_embeddings
class DataStore(ABC):
async def upsert(
self, documents: List[Document], chunk_token_size: Optional[int] = None
) -> List[str]:
"""
Takes in a list of documents and inserts them into the database.
First deletes all the existing vectors with the document id (if necessary, depends on the vector db), then inserts the new ones.
Return a list of document ids.
"""
# Delete any existing vectors for documents with the input document ids
await asyncio.gather(
*[
self.delete(
filter=DocumentMetadataFilter(
document_id=document.id,
),
delete_all=False,
)
for document in documents
if document.id
]
)
chunks = get_document_chunks(documents, chunk_token_size)
return await self._upsert(chunks)
@abstractmethod
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of list of document chunks and inserts them into the database.
Return a list of document ids.
"""
raise NotImplementedError
async def query(self, queries: List[Query]) -> List[QueryResult]:
"""
Takes in a list of queries and filters and returns a list of query results with matching document chunks and scores.
"""
# get a list of of just the queries from the Query list
query_texts = [query.query for query in queries]
query_embeddings = get_embeddings(query_texts)
# hydrate the queries with embeddings
queries_with_embeddings = [
QueryWithEmbedding(**query.dict(), embedding=embedding)
for query, embedding in zip(queries, query_embeddings)
]
return await self._query(queries_with_embeddings)
@abstractmethod
async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores.
"""
raise NotImplementedError
@abstractmethod
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes vectors by ids, filter, or everything in the datastore.
Multiple parameters can be used at once.
Returns whether the operation was successful.
"""
raise NotImplementedError