diff --git a/go/go.mod b/go/go.mod index 48e8433d66b..8d558149955 100644 --- a/go/go.mod +++ b/go/go.mod @@ -111,6 +111,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 // indirect github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 // indirect github.com/dolthub/maphash v0.0.0-20221220182448-74e1e1ea1577 // indirect + github.com/esote/minmaxheap v1.0.0 // indirect github.com/go-fonts/liberation v0.2.0 // indirect github.com/go-kit/kit v0.10.0 // indirect github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81 // indirect diff --git a/go/go.sum b/go/go.sum index 22f0b10fa16..0c250dbb11c 100644 --- a/go/go.sum +++ b/go/go.sum @@ -213,6 +213,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/esote/minmaxheap v1.0.0 h1:rgA7StnXXpZG6qlM0S7pUmEv1KpWe32rYT4x8J8ntaA= +github.com/esote/minmaxheap v1.0.0/go.mod h1:Ln8+i7fS1k3PLgZI2JAo0iA1as95QnIYiGCrqSJ5FZk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= diff --git a/go/store/prolly/proximity_map_test.go b/go/store/prolly/proximity_map_test.go index 014ee8a5f9c..c694cdf68cd 100644 --- a/go/store/prolly/proximity_map_test.go +++ b/go/store/prolly/proximity_map_test.go @@ -16,6 +16,7 @@ package prolly import ( "context" + "fmt" "os" "testing" @@ -154,6 +155,40 @@ func TestDoubleEntryProximityMapGetClosest(t *testing.T) { require.Equal(t, matches, 1) } +func TestProximityMapGetManyClosest(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + vectors := []interface{}{ + "[0.0, 0.0]", + "[0.0, 10.0]", + "[10.0, 10.0]", + "[10.0, 0.0]", + } + queryVector := "[3.0, 1.0]" + sortOrder := []int{0, 3, 1, 2} // indexes in sorted order: [0.0, 0.0], [10.0, 0.0], [0.0, 10.0], [10.0, 10.0] + distances := []float64{10.0, 50.0, 90.0, 130.0} + m, keys, values := createProximityMap(t, ctx, ns, vectors, []int64{1, 2, 3, 4}, 2) + + for limit := 0; limit <= 4; limit++ { + t.Run(fmt.Sprintf("limit %d", limit), func(t *testing.T) { + matches := 0 + + cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error { + require.Equal(t, val.Tuple(keys[sortOrder[matches]]), foundKey) + require.Equal(t, val.Tuple(values[sortOrder[matches]]), foundValue) + require.InDelta(t, distance, distances[matches], 0.1) + matches++ + return nil + } + + err := m.GetClosest(ctx, newJsonValue(t, queryVector), cb, limit) + require.NoError(t, err) + require.Equal(t, matches, limit) + }) + } + +} + func TestMultilevelProximityMap(t *testing.T) { ctx := context.Background() ns := tree.NewTestNodeStore() diff --git a/go/store/prolly/tree/proximity_map.go b/go/store/prolly/tree/proximity_map.go index f3b6f75150e..a61a0a825f5 100644 --- a/go/store/prolly/tree/proximity_map.go +++ b/go/store/prolly/tree/proximity_map.go @@ -15,14 +15,16 @@ package tree import ( + "container/heap" "context" "fmt" "math" - - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" + "sort" "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/go-mysql-server/sql" + "github.com/esote/minmaxheap" ) type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error @@ -105,11 +107,60 @@ func (t ProximityMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err e return ok, err } +type DistancePriorityHeapElem struct { + key Item + value Item + distance float64 +} + +type DistancePriorityHeap []DistancePriorityHeapElem + +var _ heap.Interface = (*DistancePriorityHeap)(nil) + +func newNodePriorityHeap(capacity int) DistancePriorityHeap { + // Allocate one extra slot: whenever this fills we remove the max element. + return make(DistancePriorityHeap, 0, capacity+1) +} + +func (n DistancePriorityHeap) Len() int { + return len(n) +} + +func (n DistancePriorityHeap) Less(i, j int) bool { + return n[i].distance < n[j].distance +} + +func (n DistancePriorityHeap) Swap(i, j int) { + n[i], n[j] = n[j], n[i] +} + +func (n *DistancePriorityHeap) Push(x any) { + *n = append(*n, x.(DistancePriorityHeapElem)) +} + +func (n *DistancePriorityHeap) Pop() any { + length := len(*n) + last := (*n)[length-1] + *n = (*n)[:length-1] + return last +} + +func (n *DistancePriorityHeap) Insert(key Item, value Item, distance float64) { + minmaxheap.Push(n, DistancePriorityHeapElem{ + key: key, + value: value, + distance: distance, + }) + if len(*n) == cap(*n) { + minmaxheap.PopMax(n) + } +} + // GetClosest performs an approximate nearest neighbors search. It finds |limit| vectors that are close to the query vector, // and calls |cb| with the matching key-value pairs. func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}, cb KeyValueDistanceFn[K, V], limit int) (err error) { - if limit != 1 { - return fmt.Errorf("currently only limit = 1 (find single closest vector) is supported for ProximityMap") + if limit == 0 { + return nil } queryVector, err := sql.ConvertToVector(query) @@ -117,35 +168,51 @@ func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{} return err } - nd := t.Root + // |nodes| holds the current candidates for closest vectors, up to |limit| + nodes := newNodePriorityHeap(limit) - var closestKey K - var closestIdx int - distance := math.Inf(1) + for i := 0; i < int(t.Root.count); i++ { + k := t.Root.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + nodes.Insert(k, t.Root.GetValue(i), newDistance) + } - for { - for i := 0; i < int(nd.count); i++ { - k := nd.GetKey(i) - newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + for level := t.Root.Level() - 1; level >= 0; level-- { + // visit each candidate node at the current level, building a priority list of candidates for the next level. + nextLevelNodes := newNodePriorityHeap(limit) + + for _, keyAndDistance := range nodes { + address := keyAndDistance.value + + node, err := fetchChild(ctx, t.NodeStore, hash.New(address)) if err != nil { return err } - if newDistance < distance { - closestIdx = i - distance = newDistance - closestKey = []byte(k) + nextLevelNodes.Insert(keyAndDistance.key, node.GetValue(0), keyAndDistance.distance) + for i := 1; i < int(node.count); i++ { + k := node.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + nextLevelNodes.Insert(k, node.GetValue(i), newDistance) } } + nodes = nextLevelNodes + } - if nd.IsLeaf() { - return cb(closestKey, []byte(nd.GetValue(closestIdx)), distance) - } - - nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + for nodes.Len() > 0 { + node := minmaxheap.Pop(&nodes).(DistancePriorityHeapElem) + err := cb([]byte(node.key), []byte(node.value), node.distance) if err != nil { return err } } + + return nil } func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) {