Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vecindex: redistribute vectors across level during split #135506

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 201 additions & 13 deletions pkg/sql/vecindex/fixup_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/internal"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/num32"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/vector"
"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -339,13 +340,35 @@ func (fp *fixupProcessor) splitPartition(
if parentPartition != nil {
// De-link the splitting partition from its parent partition.
childKey := vecstore.ChildKey{PartitionKey: partitionKey}
_, err = fp.index.removeFromPartition(ctx, txn, parentPartitionKey, childKey)
count, err := fp.index.removeFromPartition(ctx, txn, parentPartitionKey, childKey)
if err != nil {
return errors.Wrapf(err, "removing splitting partition %d from its parent %d",
partitionKey, parentPartitionKey)
}

// TODO(andyk): Move vectors to/from split partition.
if count != 0 {
// Move any vectors to sibling partitions that have closer centroids.
var parentVectors vector.Set
err = fp.moveVectorsToSiblings(
ctx, txn, parentPartitionKey, parentPartition, &parentVectors, partitionKey, &leftSplit)
if err != nil {
return err
}
err = fp.moveVectorsToSiblings(
ctx, txn, parentPartitionKey, parentPartition, &parentVectors, partitionKey, &rightSplit)
if err != nil {
return err
}

// Move any vectors at the same level that are closer to the new split
// centroids than they are to their own centroids.
if err = fp.linkNearbyVectors(ctx, txn, partitionKey, leftSplit.Partition); err != nil {
return err
}
if err = fp.linkNearbyVectors(ctx, txn, partitionKey, rightSplit.Partition); err != nil {
return err
}
}
}

// Insert the two new partitions into the index. This only adds their data
Expand Down Expand Up @@ -392,23 +415,19 @@ func (fp *fixupProcessor) splitPartition(
// Link the two new partitions into the K-means tree by inserting them
// into the parent level. This can trigger a further split, this time of
// the parent level.
fp.searchCtx = searchContext{
Ctx: ctx,
Workspace: fp.workspace,
Txn: txn,
Level: parentPartition.Level() + 1,
}
searchCtx := fp.reuseSearchContext(ctx, txn)
searchCtx.Level = parentPartition.Level() + 1

fp.searchCtx.Randomized = leftSplit.Partition.Centroid()
searchCtx.Randomized = leftSplit.Partition.Centroid()
childKey := vecstore.ChildKey{PartitionKey: leftPartitionKey}
err = fp.index.insertHelper(&fp.searchCtx, childKey, true /* allowRetry */)
err = fp.index.insertHelper(searchCtx, childKey, true /* allowRetry */)
if err != nil {
return errors.Wrapf(err, "inserting left partition for split of partition %d", partitionKey)
}

fp.searchCtx.Randomized = rightSplit.Partition.Centroid()
searchCtx.Randomized = rightSplit.Partition.Centroid()
childKey = vecstore.ChildKey{PartitionKey: rightPartitionKey}
err = fp.index.insertHelper(&fp.searchCtx, childKey, true /* allowRetry */)
err = fp.index.insertHelper(searchCtx, childKey, true /* allowRetry */)
if err != nil {
return errors.Wrapf(err, "inserting right partition for split of partition %d", partitionKey)
}
Expand Down Expand Up @@ -461,7 +480,8 @@ func (fp *fixupProcessor) splitPartitionData(

right := int(rightOffsets[ri])
if right >= len(leftOffsets) {
panic("expected equal number of left and right offsets that need to be swapped")
panic(errors.AssertionFailedf(
"expected equal number of left and right offsets that need to be swapped"))
}

// Swap vectors.
Expand Down Expand Up @@ -496,6 +516,160 @@ func (fp *fixupProcessor) splitPartitionData(
return leftSplit, rightSplit
}

// moveVectorsToSiblings checks each vector in the new split partition to see if
// it's now closer to a sibling partition's centroid than it is to its own
// centroid. If that's true, then move the vector to the sibling partition.
// NOTE: This method has the possible side effect of setting the parentVectors
// set to the full vectors for the parent partition's children.
func (fp *fixupProcessor) moveVectorsToSiblings(
ctx context.Context,
txn vecstore.Txn,
parentPartitionKey vecstore.PartitionKey,
parentPartition *vecstore.Partition,
parentVectors *vector.Set,
oldPartitionKey vecstore.PartitionKey,
split *splitData,
) error {
for i := 0; i < split.Vectors.Count; i++ {
if split.Vectors.Count == 1 {
// Don't allow so many vectors to be moved that the partition ends
// up empty.
break
}

vector := split.Vectors.At(i)

// If distance to new centroid is <= distance to old centroid, then skip.
newCentroidDistance := split.Partition.QuantizedSet().GetCentroidDistances()[i]
if newCentroidDistance <= split.OldCentroidDistances[i] {
continue
}

// Get the full vectors for the parent partition's children, if they have
// not already been fetched.
if parentVectors.Dims == 0 {
fullVectors, err := fp.getFullVectorsForPartition(
ctx, txn, parentPartitionKey, parentPartition)
if err != nil {
return err
}
*parentVectors = fullVectors
}

// Check whether the vector is closer to a sibling centroid than its own
// new centroid.
minDistanceOffset := -1
for parent := 0; parent < parentVectors.Count; parent++ {
squaredDistance := num32.L2Distance(parentVectors.At(parent), vector)
if squaredDistance < newCentroidDistance {
newCentroidDistance = squaredDistance
minDistanceOffset = parent
}
}
if minDistanceOffset == -1 {
continue
}

siblingPartitionKey := parentPartition.ChildKeys()[minDistanceOffset].PartitionKey
log.VEventf(ctx, 3, "moving vector from splitting partition %d to sibling partition %d",
oldPartitionKey, siblingPartitionKey)

// Found a sibling child partition that's closer, so insert the vector
// there instead.
childKey := split.Partition.ChildKeys()[i]
_, err := fp.index.addToPartition(ctx, txn, parentPartitionKey, siblingPartitionKey, vector, childKey)
if err != nil {
return errors.Wrapf(err, "moving vector to partition %d", siblingPartitionKey)
}

// Remove the vector's data from the new partition. The remove operation
// backfills data at the current index with data from the last index.
// Therefore, don't increment the iteration index, since the next item
// is in the same location as the last.
split.ReplaceWithLast(i)
i--
}

return nil
}

// linkNearbyVectors searches for vectors at the same level that are close to
// the given split partition's centroid. If they are closer than they are to
// their own centroid, then move them to the split partition.
func (fp *fixupProcessor) linkNearbyVectors(
ctx context.Context,
txn vecstore.Txn,
oldPartitionKey vecstore.PartitionKey,
partition *vecstore.Partition,
) error {
// TODO(andyk): Add way to filter search set in order to skip vectors deeper
// down in the search rather than afterwards.
searchCtx := fp.reuseSearchContext(ctx, txn)
searchCtx.Options = SearchOptions{ReturnVectors: true}
searchCtx.Level = partition.Level()
searchCtx.Randomized = partition.Centroid()

// Don't link more vectors than the number of remaining slots in the split
// partition, to avoid triggering another split.
maxResults := fp.index.options.MaxPartitionSize - partition.Count()
if maxResults < 1 {
return nil
}
searchSet := vecstore.SearchSet{MaxResults: maxResults}
err := fp.index.searchHelper(searchCtx, &searchSet, true /* allowRetry */)
if err != nil {
return err
}

tempVector := fp.workspace.AllocVector(fp.index.quantizer.GetRandomDims())
defer fp.workspace.FreeVector(tempVector)

// Filter the results.
results := searchSet.PopResults()
for i := range results {
result := &results[i]

// Skip vectors that are closer to their own centroid than they are to
// the split partition's centroid.
if result.QuerySquaredDistance >= result.CentroidDistance*result.CentroidDistance {
continue
}

log.VEventf(ctx, 3, "linking vector from partition %d to splitting partition %d",
result.ChildKey.PartitionKey, oldPartitionKey)

// Leaf vectors from the primary index need to be randomized.
vector := result.Vector
if partition.Level() == vecstore.LeafLevel {
fp.index.quantizer.RandomizeVector(ctx, vector, tempVector, false /* invert */)
vector = tempVector
}

// Remove the vector from the other partition.
count, err := fp.index.removeFromPartition(ctx, txn, result.ParentPartitionKey, result.ChildKey)
if err != nil {
return err
}
if count == 0 && partition.Level() > vecstore.LeafLevel {
// Removing the vector will result in an empty non-leaf partition, which
// is not allowed, as the K-means tree would not be fully balanced. Add
// the vector back to the partition. This is a very rare case and that
// partition is likely to be merged away regardless.
_, err = fp.index.store.AddToPartition(
ctx, txn, result.ParentPartitionKey, vector, result.ChildKey)
if err != nil {
return err
}
continue
}

// Add the vector to the split partition.
partition.Add(ctx, vector, result.ChildKey)
}

return nil
}

// getFullVectorsForPartition fetches the full-size vectors (potentially
// randomized by the quantizer) that are quantized by the given partition.
func (fp *fixupProcessor) getFullVectorsForPartition(
Expand Down Expand Up @@ -543,3 +717,17 @@ func (fp *fixupProcessor) getFullVectorsForPartition(

return vectors, nil
}

// reuseSearchContext initializes the reusable search context, including reusing
// its temp slices.
func (fp *fixupProcessor) reuseSearchContext(ctx context.Context, txn vecstore.Txn) *searchContext {
fp.searchCtx = searchContext{
Ctx: ctx,
Workspace: fp.workspace,
Txn: txn,
tempKeys: fp.searchCtx.tempKeys,
tempCounts: fp.searchCtx.tempCounts,
tempVectorsWithKeys: fp.searchCtx.tempVectorsWithKeys,
}
return &fp.searchCtx
}
Loading
Loading