Skip to content

Commit

Permalink
Support getting the N closest vectors in a ProximityMap (Adds depende…
Browse files Browse the repository at this point in the history
…ncy on github.com/esote/minmaxheap)
  • Loading branch information
nicktobey committed Oct 28, 2024
1 parent d539f18 commit 0af74b8
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 22 deletions.
1 change: 1 addition & 0 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 // indirect
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 // indirect
github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 // indirect
github.com/esote/minmaxheap v1.0.0 // indirect
github.com/go-fonts/liberation v0.2.0 // indirect
github.com/go-kit/kit v0.10.0 // indirect
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/esote/minmaxheap v1.0.0 h1:rgA7StnXXpZG6qlM0S7pUmEv1KpWe32rYT4x8J8ntaA=
github.com/esote/minmaxheap v1.0.0/go.mod h1:Ln8+i7fS1k3PLgZI2JAo0iA1as95QnIYiGCrqSJ5FZk=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
Expand Down
35 changes: 35 additions & 0 deletions go/store/prolly/proximity_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package prolly

import (
"context"
"fmt"
"os"
"testing"

Expand Down Expand Up @@ -154,6 +155,40 @@ func TestDoubleEntryProximityMapGetClosest(t *testing.T) {
require.Equal(t, matches, 1)
}

func TestProximityMapGetManyClosest(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
vectors := []interface{}{
"[0.0, 0.0]",
"[0.0, 10.0]",
"[10.0, 10.0]",
"[10.0, 0.0]",
}
queryVector := "[3.0, 1.0]"
sortOrder := []int{0, 3, 1, 2} // indexes in sorted order: [0.0, 0.0], [10.0, 0.0], [0.0, 10.0], [10.0, 10.0]
distances := []float64{10.0, 50.0, 90.0, 130.0}
m, keys, values := createProximityMap(t, ctx, ns, vectors, []int64{1, 2, 3, 4}, 2)

for limit := 0; limit <= 4; limit++ {
t.Run(fmt.Sprintf("limit %d", limit), func(t *testing.T) {
matches := 0

cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error {
require.Equal(t, val.Tuple(keys[sortOrder[matches]]), foundKey)
require.Equal(t, val.Tuple(values[sortOrder[matches]]), foundValue)
require.InDelta(t, distance, distances[matches], 0.1)
matches++
return nil
}

err := m.GetClosest(ctx, newJsonValue(t, queryVector), cb, limit)
require.NoError(t, err)
require.Equal(t, matches, limit)
})
}

}

func TestMultilevelProximityMap(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
Expand Down
111 changes: 89 additions & 22 deletions go/store/prolly/tree/proximity_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
package tree

import (
"container/heap"
"context"
"fmt"
"math"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"sort"

"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/prolly/message"
"github.com/dolthub/go-mysql-server/sql"
"github.com/esote/minmaxheap"
)

type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error
Expand Down Expand Up @@ -105,47 +107,112 @@ func (t ProximityMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err e
return ok, err
}

type DistancePriorityHeapElem struct {
key Item
value Item
distance float64
}

type DistancePriorityHeap []DistancePriorityHeapElem

var _ heap.Interface = (*DistancePriorityHeap)(nil)

func newNodePriorityHeap(capacity int) DistancePriorityHeap {
// Allocate one extra slot: whenever this fills we remove the max element.
return make(DistancePriorityHeap, 0, capacity+1)
}

func (n DistancePriorityHeap) Len() int {
return len(n)
}

func (n DistancePriorityHeap) Less(i, j int) bool {
return n[i].distance < n[j].distance
}

func (n DistancePriorityHeap) Swap(i, j int) {
n[i], n[j] = n[j], n[i]
}

func (n *DistancePriorityHeap) Push(x any) {
*n = append(*n, x.(DistancePriorityHeapElem))
}

func (n *DistancePriorityHeap) Pop() any {
length := len(*n)
last := (*n)[length-1]
*n = (*n)[:length-1]
return last
}

func (n *DistancePriorityHeap) Insert(key Item, value Item, distance float64) {
minmaxheap.Push(n, DistancePriorityHeapElem{
key: key,
value: value,
distance: distance,
})
if len(*n) == cap(*n) {
minmaxheap.PopMax(n)
}
}

// GetClosest performs an approximate nearest neighbors search. It finds |limit| vectors that are close to the query vector,
// and calls |cb| with the matching key-value pairs.
func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}, cb KeyValueDistanceFn[K, V], limit int) (err error) {
if limit != 1 {
return fmt.Errorf("currently only limit = 1 (find single closest vector) is supported for ProximityMap")
if limit == 0 {
return nil
}

queryVector, err := sql.ConvertToVector(query)
if err != nil {
return err
}

nd := t.Root
// |nodes| holds the current candidates for closest vectors, up to |limit|
nodes := newNodePriorityHeap(limit)

var closestKey K
var closestIdx int
distance := math.Inf(1)
for i := 0; i < int(t.Root.count); i++ {
k := t.Root.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector)
if err != nil {
return err
}
nodes.Insert(k, t.Root.GetValue(i), newDistance)
}

for {
for i := 0; i < int(nd.count); i++ {
k := nd.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector)
for level := t.Root.Level() - 1; level >= 0; level-- {
// visit each candidate node at the current level, building a priority list of candidates for the next level.
nextLevelNodes := newNodePriorityHeap(limit)

for _, keyAndDistance := range nodes {
address := keyAndDistance.value

node, err := fetchChild(ctx, t.NodeStore, hash.New(address))
if err != nil {
return err
}
if newDistance < distance {
closestIdx = i
distance = newDistance
closestKey = []byte(k)
nextLevelNodes.Insert(keyAndDistance.key, node.GetValue(0), keyAndDistance.distance)
for i := 1; i < int(node.count); i++ {
k := node.GetKey(i)
newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector)
if err != nil {
return err
}
nextLevelNodes.Insert(k, node.GetValue(i), newDistance)
}
}
nodes = nextLevelNodes
}

if nd.IsLeaf() {
return cb(closestKey, []byte(nd.GetValue(closestIdx)), distance)
}

nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx))
for nodes.Len() > 0 {
node := minmaxheap.Pop(&nodes).(DistancePriorityHeapElem)
err := cb([]byte(node.key), []byte(node.value), node.distance)
if err != nil {
return err
}
}

return nil
}

func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) {
Expand Down

0 comments on commit 0af74b8

Please sign in to comment.