Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vecindex: add recall test command #135233

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/sql/vecindex/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ go_test(
"//pkg/util/num32",
"//pkg/util/vector",
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_stretchr_testify//require",
"@org_gonum_v1_gonum//floats/scalar",
"@org_gonum_v1_gonum//stat",
Expand Down
21 changes: 21 additions & 0 deletions pkg/sql/vecindex/testdata/search-features.ddt
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,24 @@ vec220: 0.7957 (centroid=0.4226)
vec387: 0.8038 (centroid=0.4652)
vec637: 0.8039 (centroid=0.5211)
356 leaf vectors, 567 vectors, 97 full vectors, 103 partitions

# Test recall at different beam sizes.
recall topk=10 beam-size=8 samples=50
----
53.60% recall@10
46.62 leaf vectors, 86.08 vectors, 20.18 full vectors, 15.00 partitions

recall topk=10 beam-size=16 samples=50
----
76.40% recall@10
94.02 leaf vectors, 168.58 vectors, 24.84 full vectors, 29.00 partitions

recall topk=10 beam-size=32 samples=50
----
91.80% recall@10
188.30 leaf vectors, 317.30 vectors, 28.52 full vectors, 55.00 partitions

recall topk=10 beam-size=64 samples=50
----
97.40% recall@10
371.40 leaf vectors, 585.00 vectors, 31.60 full vectors, 103.00 partitions
13 changes: 13 additions & 0 deletions pkg/sql/vecindex/vecstore/in_memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,19 @@ func (s *InMemoryStore) DeleteVector(txn Txn, key PrimaryKey) {
delete(s.mu.vectors, string(key))
}

// GetAllVectors returns all vectors that have been added to the store as key
// and vector pairs. This is used for testing.
func (s *InMemoryStore) GetAllVectors() []VectorWithKey {
s.mu.Lock()
defer s.mu.Unlock()

refs := make([]VectorWithKey, 0, len(s.mu.vectors))
for key, vec := range s.mu.vectors {
refs = append(refs, VectorWithKey{Key: ChildKey{PrimaryKey: PrimaryKey(key)}, Vector: vec})
}
return refs
}

// MarshalBinary saves the in-memory store as a bytes. This allows the store to
// be saved and later loaded without needing to rebuild it from scratch.
func (s *InMemoryStore) MarshalBinary() (data []byte, err error) {
Expand Down
10 changes: 10 additions & 0 deletions pkg/sql/vecindex/vecstore/in_memory_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package vecstore
import (
"context"
"runtime"
"slices"
"sync"
"testing"

Expand Down Expand Up @@ -53,6 +54,15 @@ func TestInMemoryStore(t *testing.T) {
require.Nil(t, results[1].Vector)
require.Equal(t, vec2, results[2].Vector)
require.Nil(t, results[3].Vector)

vectors := store.GetAllVectors()
slices.SortFunc(vectors, func(a, b VectorWithKey) int {
return a.Key.Compare(b.Key)
})
require.Equal(t, []VectorWithKey{
{Key: ChildKey{PrimaryKey: PrimaryKey{11}}, Vector: vector.T{100, 200}},
{Key: ChildKey{PrimaryKey: PrimaryKey{12}}, Vector: vector.T{300, 400}},
}, vectors)
})

t.Run("insert empty root partition into the store", func(t *testing.T) {
Expand Down
118 changes: 118 additions & 0 deletions pkg/sql/vecindex/vector_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ package vecindex

import (
"bytes"
"cmp"
"context"
"fmt"
"sort"
"strconv"
"strings"
"testing"
Expand All @@ -17,8 +19,10 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/testutils"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore"
"github.com/cockroachdb/cockroach/pkg/util/num32"
"github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -46,6 +50,9 @@ func TestDataDriven(t *testing.T) {

case "delete":
return state.Delete(d)

case "recall":
return state.Recall(d)
}

t.Fatalf("unknown cmd: %s", d.Cmd)
Expand Down Expand Up @@ -289,6 +296,93 @@ func (s *testState) Delete(d *datadriven.TestData) string {
return tree
}

func (s *testState) Recall(d *datadriven.TestData) string {
searchSet := vecstore.SearchSet{MaxResults: 1}
options := SearchOptions{}
samples := 50
var err error
for _, arg := range d.CmdArgs {
switch arg.Key {
case "samples":
require.Len(s.T, arg.Vals, 1)
samples, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)

case "topk":
require.Len(s.T, arg.Vals, 1)
searchSet.MaxResults, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)

case "beam-size":
require.Len(s.T, arg.Vals, 1)
options.BaseBeamSize, err = strconv.Atoi(arg.Vals[0])
require.NoError(s.T, err)
}
}

txn := beginTransaction(s.Ctx, s.T, s.InMemStore)
defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn)

// calcTruth calculates the true nearest neighbors for the query vector.
calcTruth := func(queryVector vector.T, data []vecstore.VectorWithKey) []vecstore.PrimaryKey {
distances := make([]float32, len(data))
offsets := make([]int, len(data))
for i := 0; i < len(data); i++ {
distances[i] = num32.L2SquaredDistance(queryVector, data[i].Vector)
offsets[i] = i
}
sort.SliceStable(offsets, func(i int, j int) bool {
res := cmp.Compare(distances[offsets[i]], distances[offsets[j]])
if res != 0 {
return res < 0
}
return data[offsets[i]].Key.Compare(data[offsets[j]].Key) < 0
})

truth := make([]vecstore.PrimaryKey, searchSet.MaxResults)
for i := 0; i < len(truth); i++ {
truth[i] = data[offsets[i]].Key.PrimaryKey
}
return truth
}

data := s.InMemStore.GetAllVectors()

// Search for last "samples" features.
var sumMAP float64
for feature := s.Features.Count - samples; feature < s.Features.Count; feature++ {
// Calculate truth set for the vector.
queryVector := s.Features.At(feature)
truth := calcTruth(queryVector, data)

// Calculate prediction set for the vector.
err = s.Index.Search(s.Ctx, txn, queryVector, &searchSet, options)
require.NoError(s.T, err)
results := searchSet.PopResults()

prediction := make([]vecstore.PrimaryKey, searchSet.MaxResults)
for res := 0; res < len(results); res++ {
prediction[res] = results[res].ChildKey.PrimaryKey
}

sumMAP += findMAP(prediction, truth)
}

recall := sumMAP / float64(samples) * 100
quantizedLeafVectors := float64(searchSet.Stats.QuantizedLeafVectorCount) / float64(samples)
quantizedVectors := float64(searchSet.Stats.QuantizedVectorCount) / float64(samples)
fullVectors := float64(searchSet.Stats.FullVectorCount) / float64(samples)
partitions := float64(searchSet.Stats.PartitionCount) / float64(samples)

var buf bytes.Buffer
buf.WriteString(fmt.Sprintf("%.2f%% recall@%d\n", recall, searchSet.MaxResults))
buf.WriteString(fmt.Sprintf("%.2f leaf vectors, ", quantizedLeafVectors))
buf.WriteString(fmt.Sprintf("%.2f vectors, ", quantizedVectors))
buf.WriteString(fmt.Sprintf("%.2f full vectors, ", fullVectors))
buf.WriteString(fmt.Sprintf("%.2f partitions", partitions))
return buf.String()
}

// parseVector parses a vector string in this form: (1.5, 6, -4).
func (s *testState) parseVector(str string) vector.T {
// Remove parentheses and split by commas.
Expand Down Expand Up @@ -328,3 +422,27 @@ func commitTransaction(ctx context.Context, t *testing.T, store vecstore.Store,
err := store.CommitTransaction(ctx, txn)
require.NoError(t, err)
}

// findMAP returns mean average precision, which compares a set of predicted
// results with the true set of results. Both sets are expected to be of equal
// length. It returns the percentage overlap of the predicted set with the truth
// set.
func findMAP(prediction, truth []vecstore.PrimaryKey) float64 {
if len(prediction) != len(truth) {
panic(errors.AssertionFailedf("prediction and truth sets are not same length"))
}

predictionMap := make(map[string]bool, len(prediction))
for _, p := range prediction {
predictionMap[string(p)] = true
}

var intersect float64
for _, t := range truth {
_, ok := predictionMap[string(t)]
if ok {
intersect++
}
}
return intersect / float64(len(truth))
}