From 6942d48add2347b149a2ae046138043d18f8eb8c Mon Sep 17 00:00:00 2001 From: Andrew Kimball Date: Thu, 14 Nov 2024 00:04:18 -0800 Subject: [PATCH] vecindex: add recall test command The new "recall" test command computes the percentage of search results that are true nearest results. This is a standard measure of index and search quality. Epic: CRDB-42943 Release note: None --- pkg/sql/vecindex/BUILD.bazel | 1 + pkg/sql/vecindex/testdata/search-features.ddt | 21 ++++ pkg/sql/vecindex/vecstore/in_memory_store.go | 13 ++ .../vecindex/vecstore/in_memory_store_test.go | 10 ++ pkg/sql/vecindex/vector_index_test.go | 118 ++++++++++++++++++ 5 files changed, 163 insertions(+) diff --git a/pkg/sql/vecindex/BUILD.bazel b/pkg/sql/vecindex/BUILD.bazel index c22a5a81d5c6..56d6486a96b0 100644 --- a/pkg/sql/vecindex/BUILD.bazel +++ b/pkg/sql/vecindex/BUILD.bazel @@ -46,6 +46,7 @@ go_test( "//pkg/util/num32", "//pkg/util/vector", "@com_github_cockroachdb_datadriven//:datadriven", + "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//require", "@org_gonum_v1_gonum//floats/scalar", "@org_gonum_v1_gonum//stat", diff --git a/pkg/sql/vecindex/testdata/search-features.ddt b/pkg/sql/vecindex/testdata/search-features.ddt index b29e706785ec..2b2c2008954a 100644 --- a/pkg/sql/vecindex/testdata/search-features.ddt +++ b/pkg/sql/vecindex/testdata/search-features.ddt @@ -73,3 +73,24 @@ vec220: 0.7957 (centroid=0.4226) vec387: 0.8038 (centroid=0.4652) vec637: 0.8039 (centroid=0.5211) 356 leaf vectors, 567 vectors, 97 full vectors, 103 partitions + +# Test recall at different beam sizes. +recall topk=10 beam-size=8 samples=50 +---- +53.60% recall@10 +46.62 leaf vectors, 86.08 vectors, 20.18 full vectors, 15.00 partitions + +recall topk=10 beam-size=16 samples=50 +---- +76.40% recall@10 +94.02 leaf vectors, 168.58 vectors, 24.84 full vectors, 29.00 partitions + +recall topk=10 beam-size=32 samples=50 +---- +91.80% recall@10 +188.30 leaf vectors, 317.30 vectors, 28.52 full vectors, 55.00 partitions + +recall topk=10 beam-size=64 samples=50 +---- +97.40% recall@10 +371.40 leaf vectors, 585.00 vectors, 31.60 full vectors, 103.00 partitions diff --git a/pkg/sql/vecindex/vecstore/in_memory_store.go b/pkg/sql/vecindex/vecstore/in_memory_store.go index 28956253d65d..cf86eaf2b3f4 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store.go @@ -352,6 +352,19 @@ func (s *InMemoryStore) DeleteVector(txn Txn, key PrimaryKey) { delete(s.mu.vectors, string(key)) } +// GetAllVectors returns all vectors that have been added to the store as key +// and vector pairs. This is used for testing. +func (s *InMemoryStore) GetAllVectors() []VectorWithKey { + s.mu.Lock() + defer s.mu.Unlock() + + refs := make([]VectorWithKey, 0, len(s.mu.vectors)) + for key, vec := range s.mu.vectors { + refs = append(refs, VectorWithKey{Key: ChildKey{PrimaryKey: PrimaryKey(key)}, Vector: vec}) + } + return refs +} + // MarshalBinary saves the in-memory store as a bytes. This allows the store to // be saved and later loaded without needing to rebuild it from scratch. func (s *InMemoryStore) MarshalBinary() (data []byte, err error) { diff --git a/pkg/sql/vecindex/vecstore/in_memory_store_test.go b/pkg/sql/vecindex/vecstore/in_memory_store_test.go index c4be8e3eb7c6..1edbb20da3cf 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store_test.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store_test.go @@ -8,6 +8,7 @@ package vecstore import ( "context" "runtime" + "slices" "sync" "testing" @@ -53,6 +54,15 @@ func TestInMemoryStore(t *testing.T) { require.Nil(t, results[1].Vector) require.Equal(t, vec2, results[2].Vector) require.Nil(t, results[3].Vector) + + vectors := store.GetAllVectors() + slices.SortFunc(vectors, func(a, b VectorWithKey) int { + return a.Key.Compare(b.Key) + }) + require.Equal(t, []VectorWithKey{ + {Key: ChildKey{PrimaryKey: PrimaryKey{11}}, Vector: vector.T{100, 200}}, + {Key: ChildKey{PrimaryKey: PrimaryKey{12}}, Vector: vector.T{300, 400}}, + }, vectors) }) t.Run("insert empty root partition into the store", func(t *testing.T) { diff --git a/pkg/sql/vecindex/vector_index_test.go b/pkg/sql/vecindex/vector_index_test.go index 85c4de85b9b5..75dc76383f19 100644 --- a/pkg/sql/vecindex/vector_index_test.go +++ b/pkg/sql/vecindex/vector_index_test.go @@ -7,8 +7,10 @@ package vecindex import ( "bytes" + "cmp" "context" "fmt" + "sort" "strconv" "strings" "testing" @@ -17,8 +19,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize" "github.com/cockroachdb/cockroach/pkg/sql/vecindex/testutils" "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore" + "github.com/cockroachdb/cockroach/pkg/util/num32" "github.com/cockroachdb/cockroach/pkg/util/vector" "github.com/cockroachdb/datadriven" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/require" ) @@ -46,6 +50,9 @@ func TestDataDriven(t *testing.T) { case "delete": return state.Delete(d) + + case "recall": + return state.Recall(d) } t.Fatalf("unknown cmd: %s", d.Cmd) @@ -289,6 +296,93 @@ func (s *testState) Delete(d *datadriven.TestData) string { return tree } +func (s *testState) Recall(d *datadriven.TestData) string { + searchSet := vecstore.SearchSet{MaxResults: 1} + options := SearchOptions{} + samples := 50 + var err error + for _, arg := range d.CmdArgs { + switch arg.Key { + case "samples": + require.Len(s.T, arg.Vals, 1) + samples, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "topk": + require.Len(s.T, arg.Vals, 1) + searchSet.MaxResults, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "beam-size": + require.Len(s.T, arg.Vals, 1) + options.BaseBeamSize, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + } + } + + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) + defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) + + // calcTruth calculates the true nearest neighbors for the query vector. + calcTruth := func(queryVector vector.T, data []vecstore.VectorWithKey) []vecstore.PrimaryKey { + distances := make([]float32, len(data)) + offsets := make([]int, len(data)) + for i := 0; i < len(data); i++ { + distances[i] = num32.L2SquaredDistance(queryVector, data[i].Vector) + offsets[i] = i + } + sort.SliceStable(offsets, func(i int, j int) bool { + res := cmp.Compare(distances[offsets[i]], distances[offsets[j]]) + if res != 0 { + return res < 0 + } + return data[offsets[i]].Key.Compare(data[offsets[j]].Key) < 0 + }) + + truth := make([]vecstore.PrimaryKey, searchSet.MaxResults) + for i := 0; i < len(truth); i++ { + truth[i] = data[offsets[i]].Key.PrimaryKey + } + return truth + } + + data := s.InMemStore.GetAllVectors() + + // Search for last "samples" features. + var sumMAP float64 + for feature := s.Features.Count - samples; feature < s.Features.Count; feature++ { + // Calculate truth set for the vector. + queryVector := s.Features.At(feature) + truth := calcTruth(queryVector, data) + + // Calculate prediction set for the vector. + err = s.Index.Search(s.Ctx, txn, queryVector, &searchSet, options) + require.NoError(s.T, err) + results := searchSet.PopResults() + + prediction := make([]vecstore.PrimaryKey, searchSet.MaxResults) + for res := 0; res < len(results); res++ { + prediction[res] = results[res].ChildKey.PrimaryKey + } + + sumMAP += findMAP(prediction, truth) + } + + recall := sumMAP / float64(samples) * 100 + quantizedLeafVectors := float64(searchSet.Stats.QuantizedLeafVectorCount) / float64(samples) + quantizedVectors := float64(searchSet.Stats.QuantizedVectorCount) / float64(samples) + fullVectors := float64(searchSet.Stats.FullVectorCount) / float64(samples) + partitions := float64(searchSet.Stats.PartitionCount) / float64(samples) + + var buf bytes.Buffer + buf.WriteString(fmt.Sprintf("%.2f%% recall@%d\n", recall, searchSet.MaxResults)) + buf.WriteString(fmt.Sprintf("%.2f leaf vectors, ", quantizedLeafVectors)) + buf.WriteString(fmt.Sprintf("%.2f vectors, ", quantizedVectors)) + buf.WriteString(fmt.Sprintf("%.2f full vectors, ", fullVectors)) + buf.WriteString(fmt.Sprintf("%.2f partitions", partitions)) + return buf.String() +} + // parseVector parses a vector string in this form: (1.5, 6, -4). func (s *testState) parseVector(str string) vector.T { // Remove parentheses and split by commas. @@ -328,3 +422,27 @@ func commitTransaction(ctx context.Context, t *testing.T, store vecstore.Store, err := store.CommitTransaction(ctx, txn) require.NoError(t, err) } + +// findMAP returns mean average precision, which compares a set of predicted +// results with the true set of results. Both sets are expected to be of equal +// length. It returns the percentage overlap of the predicted set with the truth +// set. +func findMAP(prediction, truth []vecstore.PrimaryKey) float64 { + if len(prediction) != len(truth) { + panic(errors.AssertionFailedf("prediction and truth sets are not same length")) + } + + predictionMap := make(map[string]bool, len(prediction)) + for _, p := range prediction { + predictionMap[string(p)] = true + } + + var intersect float64 + for _, t := range truth { + _, ok := predictionMap[string(t)] + if ok { + intersect++ + } + } + return intersect / float64(len(truth)) +}