Skip to content

Commit

Permalink
vecindex: add recall test command
Browse files Browse the repository at this point in the history
The new "recall" test command computes the percentage of search results
that are true nearest results. This is a standard measure of index and
search quality.

Epic: CRDB-42943

Release note: None
  • Loading branch information
andy-kimball committed Nov 15, 2024
1 parent 89df825 commit 6942d48
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions pkg/sql/vecindex/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ go_test(
"//pkg/util/num32",
"//pkg/util/vector",
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_gonum_v1_gonum//floats/scalar",
"@org_gonum_v1_gonum//stat",
Expand Down
21 changes: 21 additions & 0 deletions pkg/sql/vecindex/testdata/search-features.ddt
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,24 @@ vec220: 0.7957 (centroid=0.4226)
vec387: 0.8038 (centroid=0.4652)
vec637: 0.8039 (centroid=0.5211)
356 leaf vectors, 567 vectors, 97 full vectors, 103 partitions

# Test recall at different beam sizes.
recall topk=10 beam-size=8 samples=50
----
53.60% recall@10
46.62 leaf vectors, 86.08 vectors, 20.18 full vectors, 15.00 partitions

recall topk=10 beam-size=16 samples=50
----
76.40% recall@10
94.02 leaf vectors, 168.58 vectors, 24.84 full vectors, 29.00 partitions

recall topk=10 beam-size=32 samples=50
----
91.80% recall@10
188.30 leaf vectors, 317.30 vectors, 28.52 full vectors, 55.00 partitions

recall topk=10 beam-size=64 samples=50
----
97.40% recall@10
371.40 leaf vectors, 585.00 vectors, 31.60 full vectors, 103.00 partitions
13 changes: 13 additions & 0 deletions pkg/sql/vecindex/vecstore/in_memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,19 @@ func (s *InMemoryStore) DeleteVector(txn Txn, key PrimaryKey) {
delete(s.mu.vectors, string(key))
}

// GetAllVectors returns all vectors that have been added to the store as key
// and vector pairs. This is used for testing.
func (s *InMemoryStore) GetAllVectors() []VectorWithKey {
s.mu.Lock()
defer s.mu.Unlock()

refs := make([]VectorWithKey, 0, len(s.mu.vectors))
for key, vec := range s.mu.vectors {
refs = append(refs, VectorWithKey{Key: ChildKey{PrimaryKey: PrimaryKey(key)}, Vector: vec})
}
return refs
}

// MarshalBinary saves the in-memory store as a bytes. This allows the store to
// be saved and later loaded without needing to rebuild it from scratch.
func (s *InMemoryStore) MarshalBinary() (data []byte, err error) {
Expand Down
10 changes: 10 additions & 0 deletions pkg/sql/vecindex/vecstore/in_memory_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package vecstore
import (
"context"
"runtime"
"slices"
"sync"
"testing"

Expand Down Expand Up @@ -53,6 +54,15 @@ func TestInMemoryStore(t *testing.T) {
require.Nil(t, results[1].Vector)
require.Equal(t, vec2, results[2].Vector)
require.Nil(t, results[3].Vector)

vectors := store.GetAllVectors()
slices.SortFunc(vectors, func(a, b VectorWithKey) int {
return a.Key.Compare(b.Key)
})
require.Equal(t, []VectorWithKey{
{Key: ChildKey{PrimaryKey: PrimaryKey{11}}, Vector: vector.T{100, 200}},
{Key: ChildKey{PrimaryKey: PrimaryKey{12}}, Vector: vector.T{300, 400}},
}, vectors)
})

t.Run("insert empty root partition into the store", func(t *testing.T) {
Expand Down
118 changes: 118 additions & 0 deletions pkg/sql/vecindex/vector_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package vecindex

import (
"bytes"
"cmp"
"context"
"fmt"
"sort"
"strconv"
"strings"
"testing"
Expand All @@ -17,8 +19,10 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/testutils"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore"
"github.com/cockroachdb/cockroach/pkg/util/num32"
"github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -46,6 +50,9 @@ func TestDataDriven(t *testing.T) {

case "delete":
return state.Delete(d)

case "recall":
return state.Recall(d)
}

t.Fatalf("unknown cmd: %s", d.Cmd)
Expand Down Expand Up @@ -289,6 +296,93 @@ func (s *testState) Delete(d *datadriven.TestData) string {
return tree
}

func (s *testState) Recall(d *datadriven.TestData) string {
searchSet := vecstore.SearchSet{MaxResults: 1}
options := SearchOptions{}
samples := 50
var err error
for _, arg := range d.CmdArgs {
switch arg.Key {
case "samples":
require.Len(s.T, arg.Vals, 1)
samples, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)

case "topk":
require.Len(s.T, arg.Vals, 1)
searchSet.MaxResults, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)

case "beam-size":
require.Len(s.T, arg.Vals, 1)
options.BaseBeamSize, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)
}
}

txn := beginTransaction(s.Ctx, s.T, s.InMemStore)
defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn)

// calcTruth calculates the true nearest neighbors for the query vector.
calcTruth := func(queryVector vector.T, data []vecstore.VectorWithKey) []vecstore.PrimaryKey {
distances := make([]float32, len(data))
offsets := make([]int, len(data))
for i := 0; i < len(data); i++ {
distances[i] = num32.L2SquaredDistance(queryVector, data[i].Vector)
offsets[i] = i
}
sort.SliceStable(offsets, func(i int, j int) bool {
res := cmp.Compare(distances[offsets[i]], distances[offsets[j]])
if res != 0 {
return res < 0
}
return data[offsets[i]].Key.Compare(data[offsets[j]].Key) < 0
})

truth := make([]vecstore.PrimaryKey, searchSet.MaxResults)
for i := 0; i < len(truth); i++ {
truth[i] = data[offsets[i]].Key.PrimaryKey
}
return truth
}

data := s.InMemStore.GetAllVectors()

// Search for last "samples" features.
var sumMAP float64
for feature := s.Features.Count - samples; feature < s.Features.Count; feature++ {
// Calculate truth set for the vector.
queryVector := s.Features.At(feature)
truth := calcTruth(queryVector, data)

// Calculate prediction set for the vector.
err = s.Index.Search(s.Ctx, txn, queryVector, &searchSet, options)
require.NoError(s.T, err)
results := searchSet.PopResults()

prediction := make([]vecstore.PrimaryKey, searchSet.MaxResults)
for res := 0; res < len(results); res++ {
prediction[res] = results[res].ChildKey.PrimaryKey
}

sumMAP += findMAP(prediction, truth)
}

recall := sumMAP / float64(samples) * 100
quantizedLeafVectors := float64(searchSet.Stats.QuantizedLeafVectorCount) / float64(samples)
quantizedVectors := float64(searchSet.Stats.QuantizedVectorCount) / float64(samples)
fullVectors := float64(searchSet.Stats.FullVectorCount) / float64(samples)
partitions := float64(searchSet.Stats.PartitionCount) / float64(samples)

var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%.2f%% recall@%d\n", recall, searchSet.MaxResults))
buf.WriteString(fmt.Sprintf("%.2f leaf vectors, ", quantizedLeafVectors))
buf.WriteString(fmt.Sprintf("%.2f vectors, ", quantizedVectors))
buf.WriteString(fmt.Sprintf("%.2f full vectors, ", fullVectors))
buf.WriteString(fmt.Sprintf("%.2f partitions", partitions))
return buf.String()
}

// parseVector parses a vector string in this form: (1.5, 6, -4).
func (s *testState) parseVector(str string) vector.T {
// Remove parentheses and split by commas.
Expand Down Expand Up @@ -328,3 +422,27 @@ func commitTransaction(ctx context.Context, t *testing.T, store vecstore.Store,
err := store.CommitTransaction(ctx, txn)
require.NoError(t, err)
}

// findMAP returns mean average precision, which compares a set of predicted
// results with the true set of results. Both sets are expected to be of equal
// length. It returns the percentage overlap of the predicted set with the truth
// set.
func findMAP(prediction, truth []vecstore.PrimaryKey) float64 {
if len(prediction) != len(truth) {
panic(errors.AssertionFailedf("prediction and truth sets are not same length"))
}

predictionMap := make(map[string]bool, len(prediction))
for _, p := range prediction {
predictionMap[string(p)] = true
}

var intersect float64
for _, t := range truth {
_, ok := predictionMap[string(t)]
if ok {
intersect++
}
}
return intersect / float64(len(truth))
}

0 comments on commit 6942d48

Please sign in to comment.