diff --git a/vectorstores/mongovector/mongovector.go b/vectorstores/mongovector/mongovector.go index 36afc6c21..71fe128e5 100644 --- a/vectorstores/mongovector/mongovector.go +++ b/vectorstores/mongovector/mongovector.go @@ -26,6 +26,8 @@ var ( ErrInvalidScoreThreshold = errors.New("score threshold must be between 0 and 1") ) +// Store wraps a Mongo collection for writing to and searching an Atlas +// vector database. type Store struct { coll mongo.Collection embedder embeddings.Embedder @@ -35,6 +37,7 @@ type Store struct { var _ vectorstores.VectorStore = &Store{} +// New returns a Store that can read and write to the vector store. func New(coll mongo.Collection, embedder embeddings.Embedder, opts ...Option) Store { store := Store{ coll: coll, @@ -50,6 +53,8 @@ func New(coll mongo.Collection, embedder embeddings.Embedder, opts ...Option) St return store } +// AddDocuments will create embeddings for the given documents using the +// user-specified embedding model, then insert that data into a vector store. func (store *Store) AddDocuments( ctx context.Context, docs []schema.Document, @@ -113,8 +118,8 @@ func (store *Store) AddDocuments( return ids, nil } -// SimilaritySearch creates a vector embedding from the query using the embedder -// and queries to find the most similar documents. +// SimilaritySearch searches a vector store from the vector transformed from the +// query by the user-specified embedding model. // // This method searches the store-wrapped collection with an optionally // provided index at instantiation, with a default index of "vector_index". diff --git a/vectorstores/mongovector/mongovector_test.go b/vectorstores/mongovector/mongovector_test.go index c7bd67172..ce808e223 100644 --- a/vectorstores/mongovector/mongovector_test.go +++ b/vectorstores/mongovector/mongovector_test.go @@ -2,9 +2,13 @@ package mongovector import ( "context" + "errors" + "flag" + "fmt" "os" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,16 +20,48 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" ) +var testWithoutSetup = flag.Bool("no-atlas-setup", false, "don't create required indexes") + const ( - testDB = "langchaingo-test" - testColl = "vstore" - testIndexDP1536 = "vector_index_dotProduct_1536" - testIndexDP1536NoFilters = "vector_index_dotProduct_1536_no_filters" - testIndexSize1536 = 1536 - testIndexDP3 = "vector_index_dotProduct_3" - testIndexSize3 = 3 + testDB = "langchaingo-test" + testColl = "vstore" + testIndexDP1536 = "vector_index_dotProduct_1536" + testIndexDP1536WithFilter = "vector_index_dotProduct_1536_w_filters" + testIndexDP3 = "vector_index_dotProduct_3" + testIndexSize1536 = 1536 + testIndexSize3 = 3 ) +func TestMain(m *testing.M) { + flag.Parse() + + defer func() { + os.Exit(m.Run()) + }() + + if *testWithoutSetup { + return + } + + // Create the requires vector search indexes for the tests. + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + if err := resetForE2E(ctx, testIndexDP1536, testIndexSize1536, nil); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 1536: %v\n", err) + } + + filters := []string{"pageContent"} + if err := resetForE2E(ctx, testIndexDP1536WithFilter, testIndexSize1536, filters); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 1536 w filter: %v\n", err) + } + + if err := resetForE2E(ctx, testIndexDP3, testIndexSize3, nil); err != nil { + fmt.Fprintf(os.Stderr, "setup failed for 3: %v\n", err) + } +} + func TestNew(t *testing.T) { t.Parallel() @@ -86,7 +122,7 @@ func TestNew(t *testing.T) { } // resetVectorStore will reset the vector space defined by the given collection. -func resetVectorStore(t *testing.T, coll mongo.Collection, pageContentName string) { +func resetVectorStore(t *testing.T, coll mongo.Collection) { t.Helper() filter := bson.D{{Key: pageContentName, Value: bson.D{{Key: "$exists", Value: true}}}} @@ -97,7 +133,7 @@ func resetVectorStore(t *testing.T, coll mongo.Collection, pageContentName strin // setupTest will prepare the Atlas vector search for adding to and searching // a vector space. -func setupTest(t *testing.T, dim int, index string) (Store, *mockEmbedder) { +func setupTest(t *testing.T, dim int, index string) Store { uri := os.Getenv("MONGODB_URI") if uri == "" { t.Skip("Must set MONGODB_URI to run test") @@ -113,16 +149,16 @@ func setupTest(t *testing.T, dim int, index string) (Store, *mockEmbedder) { assert.NoError(t, err, "failed to create collection") coll := client.Database(testDB).Collection(testColl) - resetVectorStore(t, *coll, pageContentName) + resetVectorStore(t, *coll) emb := newMockEmbedder(dim, "") store := New(*coll, emb, WithIndex(index)) - return store, emb + return store } func TestStore_AddDocuments(t *testing.T) { - store, _ := setupTest(t, 0, testIndexDP1536) + store := setupTest(t, testIndexSize1536, testIndexDP1536) tests := []struct { name string @@ -164,7 +200,7 @@ func TestStore_AddDocuments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - resetVectorStore(t, store.coll, pageContentName) + resetVectorStore(t, store.coll) ids, err := store.AddDocuments(context.Background(), test.docs, test.options...) if len(test.wantErr) > 0 { @@ -195,10 +231,14 @@ type simSearchTest struct { wantErr string } -func runSimilaritySearchTest(t *testing.T, test simSearchTest) { +func runSimilaritySearchTest(t *testing.T, store Store, test simSearchTest) { t.Helper() - store, emb := setupTest(t, testIndexSize1536, testIndexDP1536) + resetVectorStore(t, store.coll) + + semb := store.embedder.(*mockEmbedder) + + emb := newMockEmbedder(semb.dim, semb.query) for _, doc := range test.seed { emb.addDocument(doc) } @@ -246,6 +286,8 @@ func runSimilaritySearchTest(t *testing.T, test simSearchTest) { } func TestStore_SimilaritySearch_ExactQuery(t *testing.T) { + store := setupTest(t, testIndexSize1536, testIndexDP1536) + seed := []schema.Document{ {PageContent: "v1", Score: 1}, {PageContent: "v090", Score: 0.90}, @@ -254,7 +296,7 @@ func TestStore_SimilaritySearch_ExactQuery(t *testing.T) { } t.Run("numDocuments=1 of 4", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: seed, @@ -265,7 +307,7 @@ func TestStore_SimilaritySearch_ExactQuery(t *testing.T) { }) t.Run("numDocuments=3 of 4", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: seed, @@ -279,6 +321,8 @@ func TestStore_SimilaritySearch_ExactQuery(t *testing.T) { } func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { + store := setupTest(t, testIndexSize1536, testIndexDP1536) + seed := []schema.Document{ {PageContent: "v090", Score: 0.90}, {PageContent: "v051", Score: 0.51}, @@ -286,7 +330,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { } t.Run("numDocuments=1 of 3", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: seed, @@ -295,7 +339,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("numDocuments=3 of 4", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: seed, @@ -304,7 +348,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with score threshold", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: seed, @@ -314,7 +358,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with invalid score threshold", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: seed, @@ -330,7 +374,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { } t.Run("with metadata", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: metadataSeed, @@ -339,7 +383,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with metadata and score threshold", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 3, seed: metadataSeed, @@ -354,7 +398,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { doc := schema.Document{PageContent: "v090", Score: 0.90, Metadata: map[string]any{"phi": 1.618}} emb.addDocument(doc) - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: []schema.Document{doc}, @@ -367,7 +411,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with non-existant namespace", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: metadataSeed, @@ -378,13 +422,14 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with filter", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: metadataSeed, want: metadataSeed[len(metadataSeed)-1:], options: []vectorstores.Option{ vectorstores.WithFilters(bson.D{{Key: "pageContent", Value: "v0001"}}), + vectorstores.WithNameSpace(testIndexDP1536WithFilter), }, }) }) @@ -395,13 +440,12 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { doc := schema.Document{PageContent: "v090", Score: 0.90, Metadata: map[string]any{"phi": 1.618}} emb.addDocument(doc) - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: metadataSeed, options: []vectorstores.Option{ vectorstores.WithFilters(bson.D{{Key: "pageContent", Value: "v0001"}}), - vectorstores.WithNameSpace(testIndexDP1536NoFilters), vectorstores.WithEmbedder(emb), }, wantErr: "'pageContent' needs to be indexed as token", @@ -409,7 +453,7 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) t.Run("with deduplicator", func(t *testing.T) { - runSimilaritySearchTest(t, + runSimilaritySearchTest(t, store, simSearchTest{ numDocuments: 1, seed: metadataSeed, @@ -420,3 +464,136 @@ func TestStore_SimilaritySearch_NonExactQuery(t *testing.T) { }) }) } + +// dropVectorSearchIndex will attempt to drop the search index by name, awaiting +// that it has been dropped. This function blocks until the index has been +// dropped. +func dropVectorSearchIndex(ctx context.Context, coll *mongo.Collection, idxName string) error { + if coll == nil { + return fmt.Errorf("collection must not be nil") + } + + view := coll.SearchIndexes() + + if err := view.DropOne(ctx, idxName); err != nil { + return fmt.Errorf("failed to drop index: %w", err) + } + + // Await the drop of the index. + for { + cursor, err := view.List(ctx, options.SearchIndexes().SetName(idxName)) + if err != nil { + return fmt.Errorf("failed to list search indexes: %w", err) + } + + if !cursor.Next(ctx) { + break + } else { + time.Sleep(5 * time.Second) + } + } + + return nil +} + +// vectorField defines the fields of an index used for vector search. +type vectorField struct { + Type string `bson:"type,omitempty"` + Path string `bson:"path,omityempty"` + NumDimensions int `bson:"numDimensions,omitempty"` + Similarity string `bson:"similarity,omitempty"` +} + +// createVectorSearchIndex will create a vector search index on the "db.vstore" +// collection named "vector_index" with the provided field. This function blocks +// until the index has been created. +func createVectorSearchIndex( + ctx context.Context, + coll *mongo.Collection, + idxName string, + fields ...vectorField, +) (string, error) { + def := struct { + Fields []vectorField `bson:"fields"` + }{ + Fields: fields, + } + + view := coll.SearchIndexes() + + siOpts := options.SearchIndexes().SetName(idxName).SetType("vectorSearch") + searchName, err := view.CreateOne(ctx, mongo.SearchIndexModel{Definition: def, Options: siOpts}) + if err != nil { + return "", fmt.Errorf("failed to create the search index: %w", err) + } + + // Await the creation of the index. + var doc bson.Raw + for doc == nil { + cursor, err := view.List(ctx, options.SearchIndexes().SetName(searchName)) + if err != nil { + return "", fmt.Errorf("failed to list search indexes: %w", err) + } + + if !cursor.Next(ctx) { + break + } + + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + if name == searchName && queryable { + doc = cursor.Current + } else { + time.Sleep(5 * time.Second) + } + } + + return searchName, nil +} + +func resetForE2E(ctx context.Context, idx string, dim int, filters []string) error { + uri := os.Getenv("MONGODB_URI") + if uri == "" { + return errors.New("MONGODB_URI required") + } + + client, err := mongo.Connect(options.Client().ApplyURI(uri)) + if err != nil { + return fmt.Errorf("failed to connect to server: %v", err) + } + + defer func() { _ = client.Disconnect(context.Background()) }() + + // Create the vectorstore collection + err = client.Database(testDB).CreateCollection(ctx, testColl) + if err != nil { + return fmt.Errorf("failed to create vector store collection: %v", err) + } + + coll := client.Database(testDB).Collection(testColl) + + _ = dropVectorSearchIndex(ctx, coll, idx) + + fields := []vectorField{} + + fields = append(fields, vectorField{ + Type: "vector", + Path: "plot_embedding", + NumDimensions: dim, + Similarity: "dotProduct", + }) + + for _, filter := range filters { + fields = append(fields, vectorField{ + Type: "filter", + Path: filter, + }) + } + + _, err = createVectorSearchIndex(ctx, coll, idx, fields...) + if err != nil { + return fmt.Errorf("faield to create index: %v", err) + } + + return nil +}