From 66d7710b6328b2d13f2c970f70d0887d859e6f00 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 12 Sep 2024 21:01:07 -0600 Subject: [PATCH] vectorstores: add mongovector (#1005) * mongovector: Add mongo vectorstore implementation --------- Co-authored-by: Travis Cline --- go.mod | 1 + go.sum | 2 + vectorstores/mongovector/doc.go | 46 ++ vectorstores/mongovector/mock_embedder.go | 207 +++++++ vectorstores/mongovector/mock_llm.go | 38 ++ vectorstores/mongovector/mongovector.go | 250 ++++++++ vectorstores/mongovector/mongovector_test.go | 592 +++++++++++++++++++ vectorstores/mongovector/option.go | 35 ++ 8 files changed, 1171 insertions(+) create mode 100644 vectorstores/mongovector/doc.go create mode 100644 vectorstores/mongovector/mock_embedder.go create mode 100644 vectorstores/mongovector/mock_llm.go create mode 100644 vectorstores/mongovector/mongovector.go create mode 100644 vectorstores/mongovector/mongovector_test.go create mode 100644 vectorstores/mongovector/option.go diff --git a/go.mod b/go.mod index 6401ef780..f3270f4c2 100644 --- a/go.mod +++ b/go.mod @@ -158,6 +158,7 @@ require ( gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect + go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect diff --git a/go.sum b/go.sum index 5ad429213..1579d1c52 100644 --- a/go.sum +++ b/go.sum @@ -783,6 +783,8 @@ go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4x go.mongodb.org/mongo-driver v1.10.0/go.mod h1:wsihk0Kdgv8Kqu1Anit4sfK+22vSFbUrAVEYRhCXrA8= go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd80= go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= +go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 h1:vwKMYa9FCX1OW7efPaH0FUaD6o+WC0kiC7VtHtNX7UU= +go.mongodb.org/mongo-driver/v2 v2.0.0-beta1/go.mod h1:pfndQmffp38kKjbwVfoavadsdC0Nsg/qb+INK01PNaM= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 h1:A3SayB3rNyt+1S6qpI9mHPkeHTZbD7XILEqWnYZb2l0= diff --git a/vectorstores/mongovector/doc.go b/vectorstores/mongovector/doc.go new file mode 100644 index 000000000..ea3611e8a --- /dev/null +++ b/vectorstores/mongovector/doc.go @@ -0,0 +1,46 @@ +// Package mongovector implements a vector store using MongoDB as the backend. +// +// The mongovector package provides a way to store and retrieve document embeddings +// using MongoDB's vector search capabilities. It implements the VectorStore +// interface from the vectorstores package, allowing it to be used interchangeably +// with other vector store implementations. +// +// Key features: +// - Store document embeddings in MongoDB +// - Perform similarity searches on stored embeddings +// - Configurable index and path settings +// - Support for custom embedding functions +// +// Main types: +// - Store: The main type that implements the VectorStore interface +// - Option: A function type for configuring the Store +// +// Usage: +// +// import ( +// "github.com/tmc/langchaingo/vectorstores/mongovector" +// "go.mongodb.org/mongo-driver/mongo" +// ) +// +// // Create a new Store +// coll := // ... obtain a *mongo.Collection +// embedder := // ... obtain an embeddings.Embedder +// store := mongovector.New(coll, embedder) +// +// // Add documents +// docs := []schema.Document{ +// {PageContent: "Document 1"}, +// {PageContent: "Document 2"}, +// } +// ids, err := store.AddDocuments(context.Background(), docs) +// +// // Perform similarity search +// results, err := store.SimilaritySearch(context.Background(), "query", 5) +// +// The package also provides options for customizing the Store: +// - WithIndex: Set a custom index name +// - WithPath: Set a custom path for the vector field +// - WithNumCandidates: Set the number of candidates for similarity search +// +// For more detailed information, see the documentation for individual types and functions. +package mongovector diff --git a/vectorstores/mongovector/mock_embedder.go b/vectorstores/mongovector/mock_embedder.go new file mode 100644 index 000000000..a1e40ffb6 --- /dev/null +++ b/vectorstores/mongovector/mock_embedder.go @@ -0,0 +1,207 @@ +package mongovector + +import ( + "context" + "crypto/rand" + "fmt" + "math/big" + "time" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" +) + +type mockEmbedder struct { + queryVector []float32 + docs map[string]schema.Document + docVectors map[string][]float32 +} + +var _ embeddings.Embedder = &mockEmbedder{} + +func newMockEmbedder(dim int) *mockEmbedder { + emb := &mockEmbedder{ + queryVector: newNormalizedVector(dim), + docs: make(map[string]schema.Document), + docVectors: make(map[string][]float32), + } + + return emb +} + +// mockDocuments will add the given documents to the embedder, assigning each +// a vector such that similarity score = 0.5 * ( 1 + vector * queryVector). +func (emb *mockEmbedder) mockDocuments(doc ...schema.Document) { + for _, d := range doc { + emb.docs[d.PageContent] = d + } +} + +// existingVectors returns all the vectors that have been added to the embedder. +// The query vector is included in the list to maintain orthogonality. +func (emb *mockEmbedder) existingVectors() [][]float32 { + vectors := make([][]float32, 0, len(emb.docs)+1) + for _, vec := range emb.docVectors { + vectors = append(vectors, vec) + } + + return append(vectors, emb.queryVector) +} + +// EmbedDocuments will return the embedded vectors for the given texts. If the +// text does not exist in the document set, a zero vector will be returned. +func (emb *mockEmbedder) EmbedDocuments(_ context.Context, texts []string) ([][]float32, error) { + vectors := make([][]float32, len(texts)) + for i := range vectors { + // If the text does not exist in the document set, return a zero vector. + doc, ok := emb.docs[texts[i]] + if !ok { + vectors[i] = make([]float32, len(emb.queryVector)) + } + + // If the vector exists, use it. + existing, ok := emb.docVectors[texts[i]] + if ok { + vectors[i] = existing + + continue + } + + // If it does not exist, make a linearly independent vector. + newVectorBasis := newOrthogonalVector(len(emb.queryVector), emb.existingVectors()...) + + // Update the newVector to be scaled by the score. + newVector := dotProductNormFn(doc.Score, emb.queryVector, newVectorBasis) + + vectors[i] = newVector + emb.docVectors[texts[i]] = newVector + } + + return vectors, nil +} + +// EmbedQuery returns the query vector. +func (emb *mockEmbedder) EmbedQuery(context.Context, string) ([]float32, error) { + return emb.queryVector, nil +} + +// Insert all of the mock documents collected by the embedder. +func flushMockDocuments(ctx context.Context, store Store, emb *mockEmbedder) error { + docs := make([]schema.Document, 0, len(emb.docs)) + for _, doc := range emb.docs { + docs = append(docs, doc) + } + + _, err := store.AddDocuments(ctx, docs, vectorstores.WithEmbedder(emb)) + if err != nil { + return err + } + + // Consistency on indexes is not synchronous. + // nolint:mnd + time.Sleep(10 * time.Second) + + return nil +} + +// newNormalizedFloat32 will generate a random float32 in [-1, 1]. +// nolint:mnd +func newNormalizedFloat32() (float32, error) { + max := big.NewInt(1 << 24) + + n, err := rand.Int(rand.Reader, max) + if err != nil { + return 0.0, fmt.Errorf("failed to normalize float32") + } + + return 2.0*(float32(n.Int64())/float32(1<<24)) - 1.0, nil +} + +// dotProduct will return the dot product between two slices of f32. +func dotProduct(v1, v2 []float32) float32 { + var sum float32 + + for i := range v1 { + sum += v1[i] * v2[i] + } + + return sum +} + +// linearlyIndependent true if the vectors are linearly independent. +func linearlyIndependent(v1, v2 []float32) bool { + var ratio float32 + + for i := range v1 { + if v1[i] != 0 { + r := v2[i] / v1[i] + + if ratio == 0 { + ratio = r + + continue + } + + if r == ratio { + continue + } + + return true + } + + if v2[i] != 0 { + return true + } + } + + return false +} + +// Create a vector of values between [-1, 1] of the specified size. +func newNormalizedVector(dim int) []float32 { + vector := make([]float32, dim) + for i := range vector { + vector[i], _ = newNormalizedFloat32() + } + + return vector +} + +// Use Gram Schmidt to return a vector orthogonal to the basis, so long as +// the vectors in the basis are linearly independent. +func newOrthogonalVector(dim int, basis ...[]float32) []float32 { + candidate := newNormalizedVector(dim) + + for _, b := range basis { + dp := dotProduct(candidate, b) + basisNorm := dotProduct(b, b) + + for i := range candidate { + candidate[i] -= (dp / basisNorm) * b[i] + } + } + + return candidate +} + +// return a new vector such that v1 * v2 = 2S - 1. +func dotProductNormFn(score float32, qvector, basis []float32) []float32 { + var sum float32 + + // Populate v2 upto dim-1. + for i := range qvector[:len(qvector)-1] { + sum += qvector[i] * basis[i] + } + + // Calculate v_{2, dim} such that v1 * v2 = 2S - 1: + basis[len(basis)-1] = (2*score - 1 - sum) / qvector[len(qvector)-1] + + // If the vectors are linearly independent, regenerate the dim-1 elements + // of v2. + if !linearlyIndependent(qvector, basis) { + return dotProductNormFn(score, qvector, basis) + } + + return basis +} diff --git a/vectorstores/mongovector/mock_llm.go b/vectorstores/mongovector/mock_llm.go new file mode 100644 index 000000000..d597b3c04 --- /dev/null +++ b/vectorstores/mongovector/mock_llm.go @@ -0,0 +1,38 @@ +package mongovector + +import ( + "context" + + "github.com/tmc/langchaingo/embeddings" +) + +// mockLLM will create consistent text embeddings mocking the OpenAI +// text-embedding-3-small algorithm. +type mockLLM struct { + seen map[string][]float32 + dim int +} + +var _ embeddings.EmbedderClient = &mockLLM{} + +// createEmbedding will return vector embeddings for the mock LLM, maintaining +// consistency. +func (emb *mockLLM) CreateEmbedding(_ context.Context, texts []string) ([][]float32, error) { + if emb.seen == nil { + emb.seen = map[string][]float32{} + } + + vectors := make([][]float32, len(texts)) + for i, text := range texts { + if f32s := emb.seen[text]; len(f32s) > 0 { + vectors[i] = f32s + + continue + } + + vectors[i] = newNormalizedVector(emb.dim) + emb.seen[text] = vectors[i] // ensure consistency + } + + return vectors, nil +} diff --git a/vectorstores/mongovector/mongovector.go b/vectorstores/mongovector/mongovector.go new file mode 100644 index 000000000..067b5a4cf --- /dev/null +++ b/vectorstores/mongovector/mongovector.go @@ -0,0 +1,250 @@ +package mongovector + +import ( + "context" + "errors" + "fmt" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +const ( + defaultIndex = "vector_index" + pageContentName = "pageContent" + defaultPath = "plot_embedding" + metadataName = "metadata" + scoreName = "score" + defaultNumCandidatesScalar = 10 +) + +var ( + ErrEmbedderWrongNumberVectors = errors.New("number of vectors from embedder does not match number of documents") + ErrUnsupportedOptions = errors.New("unsupported options") + ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1") +) + +// Store wraps a Mongo collection for writing to and searching an Atlas +// vector database. +type Store struct { + coll *mongo.Collection + embedder embeddings.Embedder + index string + path string + numCandidates int +} + +var _ vectorstores.VectorStore = &Store{} + +// New returns a Store that can read and write to the vector store. +func New(coll *mongo.Collection, embedder embeddings.Embedder, opts ...Option) Store { + store := Store{ + coll: coll, + embedder: embedder, + index: defaultIndex, + path: defaultPath, + } + + for _, opt := range opts { + opt(&store) + } + + return store +} + +func mergeAddOpts(store *Store, opts ...vectorstores.Option) (*vectorstores.Options, error) { + mopts := &vectorstores.Options{} + for _, set := range opts { + set(mopts) + } + + if mopts.ScoreThreshold != 0 || mopts.Filters != nil || mopts.NameSpace != "" || mopts.Deduplicater != nil { + return nil, ErrUnsupportedOptions + } + + if mopts.Embedder == nil { + mopts.Embedder = store.embedder + } + + if mopts.Embedder == nil { + return nil, fmt.Errorf("embedder is unset") + } + + return mopts, nil +} + +// AddDocuments will create embeddings for the given documents using the +// user-specified embedding model, then insert that data into a vector store. +func (store *Store) AddDocuments( + ctx context.Context, + docs []schema.Document, + opts ...vectorstores.Option, +) ([]string, error) { + cfg, err := mergeAddOpts(store, opts...) + if err != nil { + return nil, err + } + + // Collect the page contents for embedding. + texts := make([]string, 0, len(docs)) + for _, doc := range docs { + texts = append(texts, doc.PageContent) + } + + vectors, err := cfg.Embedder.EmbedDocuments(ctx, texts) + if err != nil { + return nil, err + } + + if len(vectors) != len(docs) { + return nil, ErrEmbedderWrongNumberVectors + } + + bdocs := []bson.D{} + for i := range vectors { + bdocs = append(bdocs, bson.D{ + {Key: pageContentName, Value: docs[i].PageContent}, + {Key: store.path, Value: vectors[i]}, + {Key: metadataName, Value: docs[i].Metadata}, + }) + } + + res, err := store.coll.InsertMany(ctx, bdocs) + if err != nil { + return nil, err + } + + // Since we don't allow user-defined ids, the InsertedIDs list will always + // be primitive objects. + ids := make([]string, 0, len(docs)) + for _, id := range res.InsertedIDs { + id, ok := id.(fmt.Stringer) + if !ok { + return nil, fmt.Errorf("expected id for embedding to be a stringer") + } + + ids = append(ids, id.String()) + } + + return ids, nil +} + +func mergeSearchOpts(store *Store, opts ...vectorstores.Option) (*vectorstores.Options, error) { + mopts := &vectorstores.Options{} + for _, set := range opts { + set(mopts) + } + + // Validate that the score threshold is in [0, 1] + if mopts.ScoreThreshold > 1 || mopts.ScoreThreshold < 0 { + return nil, ErrInvalidScoreThreshold + } + + if mopts.Deduplicater != nil { + return nil, ErrUnsupportedOptions + } + + // Created an llm-specific-n-dimensional vector for searching the vector + // space + if mopts.Embedder == nil { + mopts.Embedder = store.embedder + } + + if mopts.Embedder == nil { + return nil, fmt.Errorf("embedder is unset") + } + + // If the user provides a method-level index, update. + if mopts.NameSpace == "" { + mopts.NameSpace = store.index + } + + // If filters are unset, use an empty document. + if mopts.Filters == nil { + mopts.Filters = bson.D{} + } + + return mopts, nil +} + +// SimilaritySearch searches a vector store from the vector transformed from the +// query by the user-specified embedding model. +// +// This method searches the store-wrapped collection with an optionally +// provided index at instantiation, with a default index of "vector_index". +// Since multiple indexes can be defined for a collection, the options.NameSpace +// value can be used here to change the search index. The priority is +// options.NameSpace > Store.index > defaultIndex. +func (store *Store) SimilaritySearch( + ctx context.Context, + query string, + numDocuments int, + opts ...vectorstores.Option, +) ([]schema.Document, error) { + cfg, err := mergeSearchOpts(store, opts...) + if err != nil { + return nil, err + } + + vector, err := cfg.Embedder.EmbedQuery(ctx, query) + if err != nil { + return nil, err + } + + numCandidates := defaultNumCandidatesScalar * numDocuments + if store.numCandidates == 0 { + numCandidates = numDocuments + } + + // Create the pipeline for performing the similarity search. + stage := struct { + Index string `bson:"index"` // Name of Atlas Vector Search Index tied to Collection + Path string `bson:"path"` // Field in Collection containing embedding vectors + QueryVector []float32 `bson:"queryVector"` // Query as vector + NumCandidates int `bson:"numCandidates"` // Number of nearest neighbors to use during the search. + Limit int `bson:"limit"` // Number of docments to return + Filter any `bson:"filter"` // MQL matching expression + }{ + Index: cfg.NameSpace, + Path: store.path, + QueryVector: vector, + NumCandidates: numCandidates, + Limit: numDocuments, + Filter: cfg.Filters, + } + + pipeline := mongo.Pipeline{ + bson.D{ + {Key: "$vectorSearch", Value: stage}, + }, + bson.D{ + {Key: "$set", Value: bson.D{{Key: scoreName, Value: bson.D{{Key: "$meta", Value: "vectorSearchScore"}}}}}, + }, + } + + // Execute the aggregation. + cur, err := store.coll.Aggregate(ctx, pipeline) + if err != nil { + return nil, err + } + + found := []schema.Document{} + for cur.Next(ctx) { + doc := schema.Document{} + err := cur.Decode(&doc) + if err != nil { + return nil, err + } + + if doc.Score < cfg.ScoreThreshold { + continue + } + + found = append(found, doc) + } + + return found, nil +} diff --git a/vectorstores/mongovector/mongovector_test.go b/vectorstores/mongovector/mongovector_test.go new file mode 100644 index 000000000..8084aec18 --- /dev/null +++ b/vectorstores/mongovector/mongovector_test.go @@ -0,0 +1,592 @@ +package mongovector + +import ( + "context" + "errors" + "flag" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/schema" + "github.com/tmc/langchaingo/vectorstores" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" + "go.mongodb.org/mongo-driver/v2/mongo/options" +) + +// Run the test without setting up the test space. +// +//nolint:gochecknoglobals +var testWithoutSetup = flag.Bool("no-atlas-setup", false, "don't create required indexes") + +const ( + testDB = "langchaingo-test" + testColl = "vstore" + testIndexDP1536 = "vector_index_dotProduct_1536" + testIndexDP1536WithFilter = "vector_index_dotProduct_1536_w_filters" + testIndexDP3 = "vector_index_dotProduct_3" + testIndexSize1536 = 1536 + testIndexSize3 = 3 +) + +func TestMain(m *testing.M) { + flag.Parse() + + defer func() { + os.Exit(m.Run()) + }() + + if *testWithoutSetup { + return + } + + // Create the required vector search indexes for the tests. + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + if err := resetForE2E(ctx, testIndexDP1536, testIndexSize1536, nil); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 1536: %v\n", err) + } + + filters := []string{"pageContent"} + if err := resetForE2E(ctx, testIndexDP1536WithFilter, testIndexSize1536, filters); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 1536 w filter: %v\n", err) + } + + if err := resetForE2E(ctx, testIndexDP3, testIndexSize3, nil); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 3: %v\n", err) + } +} + +func TestNew(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts []Option + wantIndex string + wantPageContentName string + wantPath string + }{ + { + name: "nil options", + opts: nil, + wantIndex: "vector_index", + wantPageContentName: "page_content", + wantPath: "plot_embedding", + }, + { + name: "no options", + opts: []Option{}, + wantIndex: "vector_index", + wantPageContentName: "page_content", + wantPath: "plot_embedding", + }, + { + name: "mixed custom options", + opts: []Option{WithIndex("custom_vector_index")}, + wantIndex: "custom_vector_index", + wantPageContentName: "page_content", + wantPath: "plot_embedding", + }, + { + name: "all custom options", + opts: []Option{ + WithIndex("custom_vector_index"), + WithPath("custom_plot_embedding"), + }, + wantIndex: "custom_vector_index", + wantPath: "custom_plot_embedding", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + embedder, err := embeddings.NewEmbedder(&mockLLM{}) + require.NoError(t, err, "failed to construct embedder") + + store := New(&mongo.Collection{}, embedder, test.opts...) + + assert.Equal(t, test.wantIndex, store.index) + assert.Equal(t, test.wantPath, store.path) + }) + } +} + +// resetVectorStore will reset the vector space defined by the given collection. +func resetVectorStore(t *testing.T, coll *mongo.Collection) { + t.Helper() + + filter := bson.D{{Key: pageContentName, Value: bson.D{{Key: "$exists", Value: true}}}} + + _, err := coll.DeleteMany(context.Background(), filter) + assert.NoError(t, err, "failed to reset vector store") +} + +// setupTest will prepare the Atlas vector search for adding to and searching +// a vector space. +func setupTest(t *testing.T, dim int, index string) Store { + t.Helper() + + uri := os.Getenv("MONGODB_URI") + if uri == "" { + t.Skip("Must set MONGODB_URI to run test") + } + + require.NotEmpty(t, uri, "MONGODB_URI required") + + client, err := mongo.Connect(options.Client().ApplyURI(uri)) + require.NoError(t, err, "failed to connect to MongoDB server") + + // Create the vectorstore collection + err = client.Database(testDB).CreateCollection(context.Background(), testColl) + require.NoError(t, err, "failed to create collection") + + coll := client.Database(testDB).Collection(testColl) + resetVectorStore(t, coll) + + emb := newMockEmbedder(dim) + store := New(coll, emb, WithIndex(index)) + + return store +} + +//nolint:paralleltest +func TestStore_AddDocuments(t *testing.T) { + store := setupTest(t, testIndexSize1536, testIndexDP1536) + + tests := []struct { + name string + docs []schema.Document + options []vectorstores.Option + wantErr []string + }{ + { + name: "nil docs", + docs: nil, + wantErr: []string{"must provide at least one element in input slice"}, + options: []vectorstores.Option{}, + }, + { + name: "no docs", + docs: []schema.Document{}, + wantErr: []string{"must provide at least one element in input slice"}, + options: []vectorstores.Option{}, + }, + { + name: "single empty doc", + docs: []schema.Document{{}}, + wantErr: []string{}, // May vary by embedder + options: []vectorstores.Option{}, + }, + { + name: "single non-empty doc", + docs: []schema.Document{{PageContent: "foo"}}, + wantErr: []string{}, + options: []vectorstores.Option{}, + }, + { + name: "one non-empty doc and one empty doc", + docs: []schema.Document{{PageContent: "foo"}, {}}, + wantErr: []string{}, // May vary by embedder + options: []vectorstores.Option{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resetVectorStore(t, store.coll) + + ids, err := store.AddDocuments(context.Background(), test.docs, test.options...) + if len(test.wantErr) > 0 { + require.Error(t, err) + for _, want := range test.wantErr { + if strings.Contains(err.Error(), want) { + return + } + } + + t.Errorf("expected error %q to contain of %v", err.Error(), test.wantErr) + } else { + require.NoError(t, err) + } + + assert.Equal(t, len(test.docs), len(ids)) + }) + } +} + +type simSearchTest struct { + ctx context.Context //nolint:containedctx + seed []schema.Document + numDocuments int // Number of documents to return + options []vectorstores.Option // Search query options + want []schema.Document + wantErr string +} + +func runSimilaritySearchTest(t *testing.T, store Store, test simSearchTest) { + t.Helper() + + resetVectorStore(t, store.coll) + + // Merge options + opts := vectorstores.Options{} + for _, opt := range test.options { + opt(&opts) + } + + var emb *mockEmbedder + if opts.Embedder != nil { + var ok bool + + emb, ok = opts.Embedder.(*mockEmbedder) + require.True(t, ok) + } else { + semb, ok := store.embedder.(*mockEmbedder) + require.True(t, ok) + + emb = newMockEmbedder(len(semb.queryVector)) + emb.mockDocuments(test.seed...) + + test.options = append(test.options, vectorstores.WithEmbedder(emb)) + } + + err := flushMockDocuments(context.Background(), store, emb) + require.NoError(t, err, "failed to flush mock embedder") + + raw, err := store.SimilaritySearch(test.ctx, "", test.numDocuments, test.options...) + if test.wantErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.wantErr) + } else { + require.NoError(t, err) + } + + assert.Len(t, raw, len(test.want)) + + got := make(map[string]schema.Document) + for _, g := range raw { + got[g.PageContent] = g + } + + for _, w := range test.want { + got := got[w.PageContent] + if w.Score != 0 { + assert.InDelta(t, w.Score, got.Score, 1e-4, "score out of bounds for %w", w.PageContent) + } + + assert.Equal(t, w.PageContent, got.PageContent, "page contents differ") + assert.Equal(t, w.Metadata, got.Metadata, "metadata differs") + } +} + +//nolint:paralleltest +func TestStore_SimilaritySearch_ExactQuery(t *testing.T) { + store := setupTest(t, testIndexSize3, testIndexDP3) + + seed := []schema.Document{ + {PageContent: "v1", Score: 1}, + {PageContent: "v090", Score: 0.90}, + {PageContent: "v051", Score: 0.51}, + {PageContent: "v0001", Score: 0.001}, + } + + t.Run("numDocuments=1 of 4", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: seed, + want: []schema.Document{ + {PageContent: "v1", Score: 1}, + }, + }) + }) + + t.Run("numDocuments=3 of 4", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: seed, + want: []schema.Document{ + {PageContent: "v1", Score: 1}, + {PageContent: "v090", Score: 0.90}, + {PageContent: "v051", Score: 0.51}, + }, + }) + }) +} + +//nolint:funlen,paralleltest +func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { + store := setupTest(t, testIndexSize1536, testIndexDP1536) + + seed := []schema.Document{ + {PageContent: "v090", Score: 0.90}, + {PageContent: "v051", Score: 0.51}, + {PageContent: "v0001", Score: 0.001}, + } + + t.Run("numDocuments=1 of 3", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: seed, + want: seed[:1], + }) + }) + + t.Run("numDocuments=3 of 4", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: seed, + want: seed, + }) + }) + + t.Run("with score threshold", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: seed, + options: []vectorstores.Option{vectorstores.WithScoreThreshold(0.50)}, + want: seed[:2], + }) + }) + + t.Run("with invalid score threshold", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: seed, + options: []vectorstores.Option{vectorstores.WithScoreThreshold(-0.50)}, + wantErr: ErrInvalidScoreThreshold.Error(), + }) + }) + + metadataSeed := []schema.Document{ + {PageContent: "v090", Score: 0.90}, + {PageContent: "v051", Score: 0.51, Metadata: map[string]any{"pi": 3.14}}, + {PageContent: "v0001", Score: 0.001}, + } + + t.Run("with metadata", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: metadataSeed, + want: metadataSeed, + }) + }) + + t.Run("with metadata and score threshold", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 3, + seed: metadataSeed, + want: metadataSeed[:2], + options: []vectorstores.Option{vectorstores.WithScoreThreshold(0.50)}, + }) + }) + + t.Run("with namespace", func(t *testing.T) { + emb := newMockEmbedder(testIndexSize3) + + doc := schema.Document{PageContent: "v090", Score: 0.90, Metadata: map[string]any{"phi": 1.618}} + emb.mockDocuments(doc) + + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: []schema.Document{doc}, + want: []schema.Document{doc}, + options: []vectorstores.Option{ + vectorstores.WithNameSpace(testIndexDP3), + vectorstores.WithEmbedder(emb), + }, + }) + }) + + t.Run("with non-existent namespace", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: metadataSeed, + options: []vectorstores.Option{ + vectorstores.WithNameSpace("some-non-existent-index-name"), + }, + }) + }) + + t.Run("with filter", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: metadataSeed, + want: metadataSeed[len(metadataSeed)-1:], + options: []vectorstores.Option{ + vectorstores.WithFilters(bson.D{{Key: "pageContent", Value: "v0001"}}), + vectorstores.WithNameSpace(testIndexDP1536WithFilter), + }, + }) + }) + + t.Run("with non-tokenized filter", func(t *testing.T) { + emb := newMockEmbedder(testIndexSize1536) + + doc := schema.Document{PageContent: "v090", Score: 0.90, Metadata: map[string]any{"phi": 1.618}} + emb.mockDocuments(doc) + + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: metadataSeed, + options: []vectorstores.Option{ + vectorstores.WithFilters(bson.D{{Key: "pageContent", Value: "v0001"}}), + vectorstores.WithEmbedder(emb), + }, + wantErr: "'pageContent' needs to be indexed as token", + }) + }) + + t.Run("with deduplicator", func(t *testing.T) { + runSimilaritySearchTest(t, store, + simSearchTest{ + numDocuments: 1, + seed: metadataSeed, + options: []vectorstores.Option{ + vectorstores.WithDeduplicater(func(context.Context, schema.Document) bool { return true }), + }, + wantErr: ErrUnsupportedOptions.Error(), + }) + }) +} + +// vectorField defines the fields of an index used for vector search. +type vectorField struct { + Type string `bson:"type,omitempty"` + Path string `bson:"path,omityempty"` + NumDimensions int `bson:"numDimensions,omitempty"` + Similarity string `bson:"similarity,omitempty"` +} + +// createVectorSearchIndex will create a vector search index on the "db.vstore" +// collection named "vector_index" with the provided field. This function blocks +// until the index has been created. +func createVectorSearchIndex( + ctx context.Context, + coll *mongo.Collection, + idxName string, + fields ...vectorField, +) (string, error) { + def := struct { + Fields []vectorField `bson:"fields"` + }{ + Fields: fields, + } + + view := coll.SearchIndexes() + + siOpts := options.SearchIndexes().SetName(idxName).SetType("vectorSearch") + searchName, err := view.CreateOne(ctx, mongo.SearchIndexModel{Definition: def, Options: siOpts}) + if err != nil { + return "", fmt.Errorf("failed to create the search index: %w", err) + } + + // Await the creation of the index. + var doc bson.Raw + for doc == nil { + cursor, err := view.List(ctx, options.SearchIndexes().SetName(searchName)) + if err != nil { + return "", fmt.Errorf("failed to list search indexes: %w", err) + } + + if !cursor.Next(ctx) { + break + } + + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + if name == searchName && queryable { + doc = cursor.Current + } else { + time.Sleep(5 * time.Second) + } + } + + return searchName, nil +} + +func searchIndexExists(ctx context.Context, coll *mongo.Collection, idx string) (bool, error) { + view := coll.SearchIndexes() + + siOpts := options.SearchIndexes().SetName(idx).SetType("vectorSearch") + cursor, err := view.List(ctx, siOpts) + if err != nil { + return false, fmt.Errorf("failed to list search indexes: %w", err) + } + + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + + return name == idx && queryable, nil +} + +func resetForE2E(ctx context.Context, idx string, dim int, filters []string) error { + uri := os.Getenv("MONGODB_URI") + if uri == "" { + return errors.New("MONGODB_URI required") + } + + client, err := mongo.Connect(options.Client().ApplyURI(uri)) + if err != nil { + return fmt.Errorf("failed to connect to server: %w", err) + } + + defer func() { _ = client.Disconnect(ctx) }() + + // Create the vectorstore collection + err = client.Database(testDB).CreateCollection(ctx, testColl) + if err != nil { + return fmt.Errorf("failed to create vector store collection: %w", err) + } + + coll := client.Database(testDB).Collection(testColl) + + if ok, _ := searchIndexExists(ctx, coll, idx); ok { + return nil + } + + fields := []vectorField{} + + fields = append(fields, vectorField{ + Type: "vector", + Path: "plot_embedding", + NumDimensions: dim, + Similarity: "dotProduct", + }) + + for _, filter := range filters { + fields = append(fields, vectorField{ + Type: "filter", + Path: filter, + }) + } + + _, err = createVectorSearchIndex(ctx, coll, idx, fields...) + if err != nil { + return fmt.Errorf("faield to create index: %w", err) + } + + return nil +} diff --git a/vectorstores/mongovector/option.go b/vectorstores/mongovector/option.go new file mode 100644 index 000000000..0d13fb9e9 --- /dev/null +++ b/vectorstores/mongovector/option.go @@ -0,0 +1,35 @@ +package mongovector + +// Option sets mongovector-specific options when constructing a Store. +type Option func(p *Store) + +// WithIndex will set the default index to use when adding or searching +// documents with the store. +// +// Atlas Vector Search doesn't return results if you misspell the index name or +// if the specified index doesn't already exist on the cluster. +// +// The index can be update at the operation level with the NameSpace +// vectorstores option. +func WithIndex(index string) Option { + return func(p *Store) { + p.index = index + } +} + +// WithPath will set the path parameter used by the Atlas Search operators to +// specify the field or fields to be searched. +func WithPath(path string) Option { + return func(p *Store) { + p.path = path + } +} + +// WithNumCandidates sets the number of nearest neighbors to use during a +// similarity search. By default this value is 10 times the number of documents +// (or limit) passed as an argument to SimilaritySearch. +func WithNumCandidates(numCandidates int) Option { + return func(p *Store) { + p.numCandidates = numCandidates + } +}