Skip to content

Commit

Permalink
GODRIVER-3305 Cont. w/ tests
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez committed Sep 6, 2024
1 parent fb3cf39 commit ea4cd0b
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
)

const (
defaultIndex = "vector_index"
defaultPageContentName = "page_content"
defaultPath = "plot_embedding"
metadataName = "metadata"
scoreName = "score"
defaultIndex = "vector_index"
pageContentName = "pageContent"
defaultPath = "plot_embedding"
metadataName = "metadata"
scoreName = "score"
)

var (
Expand All @@ -27,22 +27,20 @@ var (
)

type Store struct {
coll mongo.Collection
embedder embeddings.Embedder
index string
pageContentName string
path string
coll mongo.Collection
embedder embeddings.Embedder
index string
path string
}

var _ vectorstores.VectorStore = &Store{}

func New(coll mongo.Collection, embedder embeddings.Embedder, opts ...Option) Store {
store := Store{
coll: coll,
embedder: embedder,
index: defaultIndex,
pageContentName: defaultPageContentName,
path: defaultPath,
coll: coll,
embedder: embedder,
index: defaultIndex,
path: defaultPath,
}

for _, opt := range opts {
Expand All @@ -62,7 +60,7 @@ func (store *Store) AddDocuments(
set(&opts)
}

if opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" {
if opts.ScoreThreshold != 0 || opts.Filters != nil || opts.NameSpace != "" || opts.Deduplicater != nil {
return nil, ErrUnsupportedOptions
}

Expand Down Expand Up @@ -94,7 +92,7 @@ func (store *Store) AddDocuments(
bdocs := []bson.D{}
for i := 0; i < len(vectors); i++ {
bdocs = append(bdocs, bson.D{
{Key: store.pageContentName, Value: docs[i].PageContent},
{Key: pageContentName, Value: docs[i].PageContent},
{Key: store.path, Value: vectors[i]},
{Key: metadataName, Value: docs[i].Metadata},
})
Expand Down Expand Up @@ -139,6 +137,10 @@ func (store *Store) SimilaritySearch(
return nil, ErrInvalidScoreThreshold
}

if opts.Deduplicater != nil {
return nil, ErrUnsupportedOptions
}

// Created an llm-specific-n-dimensional vector for searching the vector
// space
embedder := store.embedder
Expand All @@ -152,6 +154,11 @@ func (store *Store) SimilaritySearch(
index = opts.NameSpace
}

// If filters are unset, use an empty document.
if opts.Filters == nil {
opts.Filters = bson.D{}
}

vector, err := embedder.EmbedQuery(ctx, query)
if err != nil {
return nil, err
Expand All @@ -164,12 +171,14 @@ func (store *Store) SimilaritySearch(
QueryVector []float32 `bson:"queryVector"` // Query as vector
NumCandidates int `bson:"numCandidates"` // Number of nearest neighbors to use during the search.
Limit int `bson:"limit"` // Number of docments to return
Filter any `bson:"filter"` // MQL matching expression
}{
Index: index,
Path: store.path,
QueryVector: vector,
NumCandidates: 150,
Limit: numDocuments,
Filter: opts.Filters,
}

pipeline := mongo.Pipeline{
Expand All @@ -184,39 +193,23 @@ func (store *Store) SimilaritySearch(
// Execute the aggregation.
cur, err := store.coll.Aggregate(ctx, pipeline)
if err != nil {
fmt.Println("err")
return nil, err
}

found := []schema.Document{}
for cur.Next(ctx) {
doc := bson.M{}
doc := schema.Document{}
err := cur.Decode(&doc)
if err != nil {
return nil, err
}

schemaDoc := schema.Document{}

score, ok := doc[scoreName].(float64)
if ok {
if score < float64(opts.ScoreThreshold) {
continue
}

schemaDoc.Score = float32(score) // score ∈ [0, 1]
}

pageContent, ok := doc[store.pageContentName].(string)
if ok {
schemaDoc.PageContent = pageContent
}

metadata, ok := doc[metadataName].(map[string]any)
if ok {
schemaDoc.Metadata = metadata
if doc.Score < opts.ScoreThreshold {
continue
}

found = append(found, schemaDoc)
found = append(found, doc)
}

return found, nil
Expand Down
Loading

0 comments on commit ea4cd0b

Please sign in to comment.