diff --git a/go/store/prolly/proximity_map.go b/go/store/prolly/proximity_map.go new file mode 100644 index 00000000000..6683cc93273 --- /dev/null +++ b/go/store/prolly/proximity_map.go @@ -0,0 +1,472 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "io" +) + +// ProximityMap wraps a tree.ProximityMap but operates on typed Tuples instead of raw bytestrings. +type ProximityMap struct { + tuples tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc] + keyDesc val.TupleDesc + valDesc val.TupleDesc +} + +// NewProximityMap creates a new ProximityMap from a supplied root node. +func NewProximityMap(ctx context.Context, ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType expression.DistanceType) ProximityMap { + tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{ + Root: node, + NodeStore: ns, + Order: keyDesc, + DistanceType: distanceType, + Convert: func(bytes []byte) []float64 { + h, _ := keyDesc.GetJSONAddr(0, bytes) + doc := tree.NewJSONDoc(h, ns) + jsonWrapper, err := doc.ToIndexedJSONDocument(ctx) + if err != nil { + panic(err) + } + floats, err := sql.ConvertToVector(jsonWrapper) + if err != nil { + panic(err) + } + return floats + }, + } + return ProximityMap{ + tuples: tuples, + keyDesc: keyDesc, + valDesc: valDesc, + } +} + +func getJsonValueFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) (interface{}, error) { + return tree.NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} + +func NewProximityMapFromTupleIter(ctx context.Context, ns tree.NodeStore, distanceType expression.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, keys [][]byte, values [][]byte, logChunkSize uint8) (ProximityMap, error) { + // The algorithm for building a ProximityMap's tree requires us to start at the root and build out to the leaf nodes. + // Given that our trees are Merkle Trees, this presents an obvious problem. + // Our solution is to create the final tree by applying a series of transformations to intermediate trees. + + // Note: when talking about tree levels, we use "level" when counting from the leaves, and "depth" when counting + // from the root. In a tree with 5 levels, the root is level 4 (and depth 0), while the leaves are level 0 (and depth 4) + + // The process looks like this: + // Step 1: Create `levelMap`, a map from (indexLevel, keyBytes) -> values + // - indexLevel: the minimum level in which the vector appears + // - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) + // - values: the ProximityMap value tuple + // + // Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap + // The pathMap at depth `i` has the schema (vectorAddrs[1]...vectorAddr[i], keyBytes) -> value + // and contains a row for every vector whose maximum depth is i. + // - vectorAddrs: the path of vectors visited when walking from the root to the maximum depth where the vector appears. + // - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) + // - values: the ProximityMap value tuple + // + // These maps must be built in order, from shallowest to deepest. + // + // Step 3: Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, + // generating Nodes as we go. + // + // Currently, the intermediate trees are created using the standard NodeStore. This means that the nodes of these + // trees will inevitably be written out to disk when the NodeStore flushes, despite the fact that we know they + // won't be needed once we finish building the ProximityMap. This could potentially be avoided by creating a + // separate in-memory NodeStore for these values. + + vectorIndexSerializer := message.NewVectorIndexSerializer(ns.Pool()) + + makeRootNode := func(keys, values [][]byte, subtrees []uint64, level int) (ProximityMap, error) { + rootMsg := vectorIndexSerializer.Serialize(keys, values, subtrees, level) + rootNode, err := tree.NodeFromBytes(rootMsg) + if err != nil { + return ProximityMap{}, err + } + _, err = ns.Write(ctx, rootNode) + if err != nil { + return ProximityMap{}, err + } + + return NewProximityMap(ctx, ns, rootNode, keyDesc, valDesc, distanceType), nil + } + + // Check if index is empty. + if len(keys) == 0 { + return makeRootNode(nil, nil, nil, 0) + } + + // Step 1: Create `levelMap`, a map from (indexLevel, keyBytes) -> values + // We want the index to be sorted first by level (descending), so currently we store the level in the map as + // 255 - the actual level. TODO: Use a reverse iterator instead. + mutableLevelMap, err := makeLevelMap(ctx, ns, keys, values, valDesc, logChunkSize) + if err != nil { + return ProximityMap{}, err + } + levelMapIter, err := mutableLevelMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + + // Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap + + // The first element of levelMap tells us the height of the tree. + levelMapKey, levelMapValue, err := levelMapIter.Next(ctx) + if err != nil { + return ProximityMap{}, err + } + maxLevel, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + maxLevel = 255 - maxLevel + + if maxLevel == 0 { + // index is a single node. + // assuming that the keys are already sorted, we can return them unmodified. + return makeRootNode(keys, values, nil, 0) + } + + // Create every val.TupleBuilder and MutableMap that we will need + // pathMaps[i] is the pathMap for depth i (and level maxLevel - i) + pathMaps, keyTupleBuilders, prefixTupleBuilders, err := createInitialPathMaps(ctx, ns, valDesc, maxLevel) + + // Next, visit each key-value pair in decreasing order of level / increasing order of depth. + // When visiting a pair from depth `i`, we use each of the previous `i` pathMaps to compute a path of `i` index keys. + // This path dictate's that pair's location in the final ProximityMap. + for { + level, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + level = 255 - level // we currently store the level as 255 - the actual level for sorting purposes. + depth := int(maxLevel - level) + + keyTupleBuilder := keyTupleBuilders[level] + // Compute the path that this row will have in the vector index, starting with the keys with the highest levels. + // If the highest level is N, then a key at level L will have a path consisting of N-L vector hashes. + // This path is computed in steps. + var hashPath []hash.Hash + keyToInsert, _ := mutableLevelMap.keyDesc.GetBytes(1, levelMapKey) + vectorHashToInsert, _ := keyDesc.GetJSONAddr(0, keyToInsert) + vectorToInsert, err := getVectorFromHash(ctx, ns, vectorHashToInsert) + if err != nil { + return ProximityMap{}, err + } + for pathColumn := 0; pathColumn < depth; pathColumn++ { + prefixTupleBuilder := prefixTupleBuilders[int(maxLevel)-pathColumn] + pathMap := pathMaps[int(maxLevel)-pathColumn] + for tupleElem := 0; tupleElem < pathColumn; tupleElem++ { + prefixTupleBuilder.PutJSONAddr(tupleElem, hashPath[tupleElem]) + } + prefixTuple := prefixTupleBuilder.Build(ns.Pool()) + + prefixRange := PrefixRange(prefixTuple, prefixTupleBuilder.Desc) + pathMapIter, err := pathMap.IterRange(ctx, prefixRange) + if err != nil { + return ProximityMap{}, err + } + var candidateVectorHash hash.Hash + if pathColumn == 0 { + pathMapKey, _, err := pathMapIter.Next(ctx) + if err != nil { + return ProximityMap{}, err + } + originalKey, _ := pathMap.keyDesc.GetBytes(pathColumn, pathMapKey) + candidateVectorHash, _ = keyDesc.GetJSONAddr(0, originalKey) + } else { + candidateVectorHash = hashPath[pathColumn-1] + } + + candidateVector, err := getVectorFromHash(ctx, ns, candidateVectorHash) + if err != nil { + return ProximityMap{}, err + } + closestVectorHash := candidateVectorHash + closestDistance, err := distanceType.Eval(vectorToInsert, candidateVector) + if err != nil { + return ProximityMap{}, err + } + + for { + pathMapKey, _, err := pathMapIter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return ProximityMap{}, err + } + originalKey, _ := pathMap.keyDesc.GetBytes(pathColumn, pathMapKey) + candidateVectorHash, _ := keyDesc.GetJSONAddr(0, originalKey) + candidateVector, err = getVectorFromHash(ctx, ns, candidateVectorHash) + if err != nil { + return ProximityMap{}, err + } + candidateDistance, err := distanceType.Eval(vectorToInsert, candidateVector) + if err != nil { + return ProximityMap{}, err + } + if candidateDistance < closestDistance { + closestVectorHash = candidateVectorHash + closestDistance = candidateDistance + } + } + + hashPath = append(hashPath, closestVectorHash) + + } + + for i, h := range hashPath { + keyTupleBuilder.PutJSONAddr(i, h) + } + keyTupleBuilder.PutByteString(depth, keyToInsert) + + err = pathMaps[level].Put(ctx, keyTupleBuilder.Build(ns.Pool()), levelMapValue) + if err != nil { + return ProximityMap{}, err + } + + levelMapKey, levelMapValue, err = levelMapIter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return ProximityMap{}, err + } + + } + + // Step 3: Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, + // generating Nodes as we go. + + var chunker *vectorIndexChunker + for i, pathMap := range pathMaps[:len(pathMaps)-1] { + chunker, err = newVectorIndexChunker(ctx, pathMap, int(maxLevel)-(i), chunker) + if err != nil { + return ProximityMap{}, err + } + } + rootPathMap := pathMaps[len(pathMaps)-1] + topLevelPathMapIter, err := rootPathMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + var topLevelKeys [][]byte + var topLevelValues [][]byte + var topLevelSubtrees []uint64 + for { + key, value, err := topLevelPathMapIter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return ProximityMap{}, err + } + originalKey, _ := rootPathMap.keyDesc.GetBytes(0, key) + path, _ := keyDesc.GetJSONAddr(0, originalKey) + _, nodeCount, nodeHash, err := chunker.Next(ctx, ns, vectorIndexSerializer, path, originalKey, value, int(maxLevel)-1, 1, keyDesc) + if err != nil { + return ProximityMap{}, err + } + topLevelKeys = append(topLevelKeys, originalKey) + topLevelValues = append(topLevelValues, nodeHash[:]) + topLevelSubtrees = append(topLevelSubtrees, nodeCount) + } + return makeRootNode(topLevelKeys, topLevelValues, topLevelSubtrees, int(maxLevel)) +} + +// makeLevelMap creates a prolly map where the key is prefixed by the maximum level of that row in the corresponding ProximityMap. +func makeLevelMap(ctx context.Context, ns tree.NodeStore, keys [][]byte, values [][]byte, valDesc val.TupleDesc, logChunkSize uint8) (*MutableMap, error) { + levelMapKeyDesc := val.NewTupleDescriptor( + val.Type{Enc: val.Uint8Enc, Nullable: false}, + val.Type{Enc: val.ByteStringEnc, Nullable: false}, + ) + + emptyLevelMap, err := NewMapFromTuples(ctx, ns, levelMapKeyDesc, valDesc) + if err != nil { + return nil, err + } + mutableLevelMap := newMutableMap(emptyLevelMap) + + for i := 0; i < len(keys); i++ { + key := keys[i] + keyLevel := tree.DeterministicHashLevel(logChunkSize, []byte(key)) + + levelMapKeyBuilder := val.NewTupleBuilder(levelMapKeyDesc) + levelMapKeyBuilder.PutUint8(0, 255-keyLevel) + levelMapKeyBuilder.PutByteString(1, key) + err = mutableLevelMap.Put(ctx, levelMapKeyBuilder.Build(ns.Pool()), values[i]) + if err != nil { + return nil, err + } + } + + return mutableLevelMap, nil +} + +// createInitialPathMaps creates a list of MutableMaps that will eventually store a single level of the corresponding ProximityMap +func createInitialPathMaps(ctx context.Context, ns tree.NodeStore, valDesc val.TupleDesc, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilders, prefixTupleBuilders []*val.TupleBuilder, err error) { + keyTupleBuilders = make([]*val.TupleBuilder, maxLevel+1) + prefixTupleBuilders = make([]*val.TupleBuilder, maxLevel+1) + pathMaps = make([]*MutableMap, maxLevel+1) + + // Make a type slice for the maximum depth pathMap: each other slice we need is a subslice of this one. + pathMapKeyDescTypes := make([]val.Type, maxLevel+1) + for i := uint8(0); i < maxLevel; i++ { + pathMapKeyDescTypes[i] = val.Type{Enc: val.JSONAddrEnc, Nullable: false} + } + pathMapKeyDescTypes[maxLevel] = val.Type{Enc: val.ByteStringEnc, Nullable: false} + + for i := uint8(0); i <= maxLevel; i++ { + pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes[i:]...) + + emptyPathMap, err := NewMapFromTuples(ctx, ns, pathMapKeyDesc, valDesc) + if err != nil { + return nil, nil, nil, err + } + pathMaps[i] = newMutableMap(emptyPathMap) + + keyTupleBuilders[i] = val.NewTupleBuilder(pathMapKeyDesc) + prefixTupleBuilders[i] = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes[i:maxLevel]...)) + } + + return pathMaps, keyTupleBuilders, prefixTupleBuilders, nil +} + +// vectorIndexChunker is a stateful chunker that iterates over |pathMap|, a map that contains an element +// for every key-value pair for a given level of a ProximityMap, and provides the path of keys to reach +// that pair from the root. It uses this iterator to build each of the ProximityMap nodes for that level. +type vectorIndexChunker struct { + pathMap *MutableMap + pathMapIter MapIter + lastPathSegment hash.Hash + lastKey []byte + lastValue []byte + lastSubtreeCount uint64 + childChunker *vectorIndexChunker + atEnd bool +} + +func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, depth int, childChunker *vectorIndexChunker) (*vectorIndexChunker, error) { + pathMapIter, err := pathMap.IterAll(ctx) + if err != nil { + return nil, err + } + firstKey, firstValue, err := pathMapIter.Next(ctx) + if err == io.EOF { + // In rare situations, there aren't any vectors at a given level. + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + atEnd: true, + }, nil + } + if err != nil { + return nil, err + } + lastPathSegment, _ := pathMap.keyDesc.GetJSONAddr(depth-1, firstKey) + originalKey, _ := pathMap.keyDesc.GetBytes(depth, firstKey) + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + lastKey: originalKey, + lastValue: firstValue, + lastPathSegment: lastPathSegment, + atEnd: false, + }, nil +} + +func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment hash.Hash, parentKey val.Tuple, parentValue val.Tuple, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) { + indexMapKeys := [][]byte{parentKey} + var indexMapValues [][]byte + var indexMapSubtrees []uint64 + subtreeSum := uint64(0) + if c.childChunker != nil { + _, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, parentPathSegment, parentKey, parentValue, level-1, depth+1, originalKeyDesc) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + indexMapValues = append(indexMapValues, nodeHash[:]) + indexMapSubtrees = append(indexMapSubtrees, childCount) + subtreeSum += childCount + } else { + indexMapValues = append(indexMapValues, parentValue) + subtreeSum++ + } + + for { + if c.atEnd || c.lastPathSegment != parentPathSegment { + msg := serializer.Serialize(indexMapKeys, indexMapValues, indexMapSubtrees, level) + node, err := tree.NodeFromBytes(msg) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + nodeHash, err := ns.Write(ctx, node) + return node, subtreeSum, nodeHash, err + } + vectorHash, _ := originalKeyDesc.GetJSONAddr(0, c.lastKey) + if c.childChunker != nil { + _, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, vectorHash, c.lastKey, c.lastValue, level-1, depth+1, originalKeyDesc) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + c.lastValue = nodeHash[:] + indexMapSubtrees = append(indexMapSubtrees, childCount) + subtreeSum += childCount + } else { + subtreeSum++ + } + indexMapKeys = append(indexMapKeys, c.lastKey) + indexMapValues = append(indexMapValues, c.lastValue) + + nextKey, nextValue, err := c.pathMapIter.Next(ctx) + if err == io.EOF { + c.atEnd = true + } else if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } else { + c.lastPathSegment, _ = c.pathMap.keyDesc.GetJSONAddr(depth-1, nextKey) + c.lastKey, _ = c.pathMap.keyDesc.GetBytes(depth, nextKey) + c.lastValue = nextValue + } + } +} + +// Count returns the number of key-value pairs in the Map. +func (m ProximityMap) Count() (int, error) { + return m.tuples.Count() +} + +// Get searches for the key-value pair keyed by |key| and passes the results to the callback. +// If |key| is not present in the map, a nil key-value pair are passed. +func (m ProximityMap) Get(ctx context.Context, query interface{}, cb tree.KeyValueFn[val.Tuple, val.Tuple]) (err error) { + return m.tuples.GetExact(ctx, query, cb) +} + +func (m ProximityMap) GetClosest(ctx context.Context, query interface{}, cb tree.KeyValueDistanceFn[val.Tuple, val.Tuple], limit int) (err error) { + return m.tuples.GetClosest(ctx, query, cb, limit) +} diff --git a/go/store/prolly/proximity_map_test.go b/go/store/prolly/proximity_map_test.go new file mode 100644 index 00000000000..684bf41e671 --- /dev/null +++ b/go/store/prolly/proximity_map_test.go @@ -0,0 +1,197 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/stretchr/testify/require" + "os" + "testing" +) + +func newJsonValue(t *testing.T, v interface{}) sql.JSONWrapper { + doc, _, err := types.JSON.Convert(v) + require.NoError(t, err) + return doc.(sql.JSONWrapper) +} + +// newJsonDocument creates a JSON value from a provided value. +func newJsonDocument(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) hash.Hash { + doc := newJsonValue(t, v) + root, err := tree.SerializeJsonToAddr(ctx, ns, doc) + require.NoError(t, err) + return root.HashOf() +} + +func createProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, vectors []interface{}, pks []int64, logChunkSize uint8) (ProximityMap, [][]byte, [][]byte) { + bp := pool.NewBuffPool() + + count := len(vectors) + require.Equal(t, count, len(pks)) + + kd := val.NewTupleDescriptor( + val.Type{Enc: val.JSONAddrEnc, Nullable: true}, + ) + + vd := val.NewTupleDescriptor( + val.Type{Enc: val.Int64Enc, Nullable: true}, + ) + + distanceType := expression.DistanceL2Squared{} + + keys := make([][]byte, count) + keyBuilder := val.NewTupleBuilder(kd) + for i, vector := range vectors { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, vector)) + keys[i] = keyBuilder.Build(bp) + } + + valueBuilder := val.NewTupleBuilder(vd) + values := make([][]byte, count) + for i, pk := range pks { + valueBuilder.PutInt64(0, pk) + values[i] = valueBuilder.Build(bp) + } + + m, err := NewProximityMapFromTupleIter(ctx, ns, distanceType, kd, vd, keys, values, logChunkSize) + require.NoError(t, err) + mapCount, err := m.Count() + require.NoError(t, err) + require.Equal(t, count, mapCount) + + return m, keys, values +} + +func TestEmptyProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + createProximityMap(t, ctx, ns, nil, nil, 10) +} + +func TestSingleEntryProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[1.0]"}, []int64{1}, 10) + matches := 0 + vectorHash, _ := m.keyDesc.GetJSONAddr(0, keys[0]) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + require.NoError(t, err) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(keys[0]), foundKey) + require.Equal(t, val.Tuple(values[0]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + require.Equal(t, matches, 1) +} + +func TestDoubleEntryProximityMapGetExact(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) + matches := 0 + for i, key := range keys { + vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestDoubleEntryProximityMapGetClosest(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) + matches := 0 + + cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error { + require.Equal(t, val.Tuple(keys[1]), foundKey) + require.Equal(t, val.Tuple(values[1]), foundValue) + require.InDelta(t, distance, 25.0, 0.1) + matches++ + return nil + } + + err := m.GetClosest(ctx, newJsonValue(t, "[0.0, 0.0]"), cb, 1) + require.NoError(t, err) + require.Equal(t, matches, 1) +} + +func TestMultilevelProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + keyStrings := []interface{}{ + "[0.0, 1.0]", + "[3.0, 4.0]", + "[5.0, 6.0]", + "[7.0, 8.0]", + } + valueStrings := []int64{1, 2, 3, 4} + m, keys, values := createProximityMap(t, ctx, ns, keyStrings, valueStrings, 1) + matches := 0 + for i, key := range keys { + vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + require.NoError(t, err) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestInsertOrderIndependence(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + keyStrings1 := []interface{}{ + "[0.0, 1.0]", + "[3.0, 4.0]", + "[5.0, 6.0]", + "[7.0, 8.0]", + } + valueStrings1 := []int64{1, 2, 3, 4} + keyStrings2 := []interface{}{ + "[7.0, 8.0]", + "[5.0, 6.0]", + "[3.0, 4.0]", + "[0.0, 1.0]", + } + valueStrings2 := []int64{4, 3, 2, 1} + m1, _, _ := createProximityMap(t, ctx, ns, keyStrings1, valueStrings1, 1) + _, _ = keyStrings1, valueStrings1 + m2, _, _ := createProximityMap(t, ctx, ns, keyStrings2, valueStrings2, 1) + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m1.tuples.Root)) + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m2.tuples.Root)) + + require.Equal(t, m1.tuples.Root.HashOf(), m2.tuples.Root.HashOf()) +} diff --git a/go/store/prolly/tree/node_splitter.go b/go/store/prolly/tree/node_splitter.go index 05509efd767..874aecbcb27 100644 --- a/go/store/prolly/tree/node_splitter.go +++ b/go/store/prolly/tree/node_splitter.go @@ -25,6 +25,7 @@ import ( "crypto/sha512" "encoding/binary" "math" + "math/bits" "github.com/kch42/buzhash" "github.com/zeebo/xxh3" @@ -264,3 +265,8 @@ func saltFromLevel(level uint8) (salt uint64) { full := sha512.Sum512([]byte{level}) return binary.LittleEndian.Uint64(full[:8]) } + +func DeterministicHashLevel(leadingZerosPerLevel uint8, key Item) uint8 { + h := xxHash32(key, levelSalt[1]) + return uint8(bits.LeadingZeros32(h)) / leadingZerosPerLevel +} diff --git a/go/store/prolly/tree/proximity_map.go b/go/store/prolly/tree/proximity_map.go new file mode 100644 index 00000000000..783bd20aef8 --- /dev/null +++ b/go/store/prolly/tree/proximity_map.go @@ -0,0 +1,423 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "bytes" + "context" + "fmt" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "math" + "sort" +) + +type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error + +// ProximityMap is a static Prolly Tree where the position of a key in the tree is based on proximity, as opposed to a traditional ordering. +// O provides the ordering only within a node. +type ProximityMap[K, V ~[]byte, O Ordering[K]] struct { + Root Node + NodeStore NodeStore + DistanceType expression.DistanceType + Convert func([]byte) []float64 + Order O +} + +func (t ProximityMap[K, V, O]) Count() (int, error) { + return t.Root.TreeCount() +} + +func (t ProximityMap[K, V, O]) Height() int { + return t.Root.Level() + 1 +} + +func (t ProximityMap[K, V, O]) HashOf() hash.Hash { + return t.Root.HashOf() +} + +func (t ProximityMap[K, V, O]) WalkAddresses(ctx context.Context, cb AddressCb) error { + return WalkAddresses(ctx, t.Root, t.NodeStore, cb) +} + +func (t ProximityMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error { + return WalkNodes(ctx, t.Root, t.NodeStore, cb) +} + +// GetExact searches for an exact vector in the index, calling |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) GetExact(ctx context.Context, query interface{}, cb KeyValueFn[K, V]) (err error) { + nd := t.Root + + queryVector, err := sql.ConvertToVector(query) + if err != nil { + return err + } + + // Find the child with the minimum distance. + + for { + var closestKey K + var closestIdx int + distance := math.Inf(1) + + for i := 0; i < int(nd.count); i++ { + k := nd.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + if newDistance < distance { + closestIdx = i + distance = newDistance + closestKey = []byte(k) + } + } + + if nd.IsLeaf() { + return cb(closestKey, []byte(nd.GetValue(closestIdx))) + } + + nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + if err != nil { + return err + } + } +} + +func (t ProximityMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err error) { + err = t.GetExact(ctx, query, func(_ K, _ V) error { + ok = true + return nil + }) + return ok, err +} + +// GetClosest performs an approximate nearest neighbors search. It finds |limit| vectors that are close to the query vector, +// and calls |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}, cb KeyValueDistanceFn[K, V], limit int) (err error) { + if limit != 1 { + return fmt.Errorf("currently only limit = 1 (find single closest vector) is supported for ProximityMap") + } + + queryVector, err := sql.ConvertToVector(query) + if err != nil { + return err + } + + nd := t.Root + + var closestKey K + var closestIdx int + distance := math.Inf(1) + + for { + for i := 0; i < int(nd.count); i++ { + k := nd.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + if newDistance < distance { + closestIdx = i + distance = newDistance + closestKey = []byte(k) + } + } + + if nd.IsLeaf() { + return cb(closestKey, []byte(nd.GetValue(closestIdx)), distance) + } + + nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + if err != nil { + return err + } + } +} + +func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) { + c, err := newCursorAtStart(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + s, err := newCursorPastEnd(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + stop := func(curr *cursor) bool { + return curr.compare(s) >= 0 + } + + if stop(c) { + // empty range + return &OrderedTreeIter[K, V]{curr: nil}, nil + } + + return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil +} + +func getJsonValueFromHash(ctx context.Context, ns NodeStore, h hash.Hash) (interface{}, error) { + return NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} + +// Building/inserting into a ProximityMap requires a Fixup step, which reorganizes part (in the worst case, all) of the +// tree such that each node is a child of the closest node in the previous row. +// Currently, this is a brute force approach that visits the entire affected region of the tree in level-order, builds +// a new tree structure in memory, and then serializes the new tree to disk. There is room to improvement here. + +// An in-memory representation of a Vector Index node. +// It stores a list of (vectorHash, key, value OR address) tuples. +// The first element is always the same vector used as the parent key. +// The remaining elements are sorted prior to serialization. +type memoryNode struct { + vectorHashes []hash.Hash + keys [][]byte + addresses []memoryNode + values [][]byte +} + +type memoryNodeSort[K ~[]byte, O Ordering[K]] struct { + *memoryNode + order O + isRoot bool +} + +var _ sort.Interface = (*memoryNodeSort[[]byte, Ordering[[]byte]])(nil) + +func (m memoryNodeSort[K, O]) Len() int { + keys := m.keys[1:] + if m.isRoot { + keys = m.keys + } + return len(keys) +} + +func (m memoryNodeSort[K, O]) Less(i, j int) bool { + keys := m.keys[1:] + if m.isRoot { + keys = m.keys + } + return m.order.Compare(keys[i], keys[j]) < 0 +} + +func (m memoryNodeSort[K, O]) Swap(i, j int) { + + vectorHashes := m.vectorHashes[1:] + if m.isRoot { + vectorHashes = m.vectorHashes + } + vectorHashes[i], vectorHashes[j] = vectorHashes[j], vectorHashes[i] + keys := m.keys[1:] + if m.isRoot { + keys = m.keys + } + keys[i], keys[j] = keys[j], keys[i] + if m.addresses != nil { + addresses := m.addresses[1:] + if m.isRoot { + addresses = m.addresses + } + addresses[i], addresses[j] = addresses[j], addresses[i] + } + if m.values != nil { + values := m.values[1:] + if m.isRoot { + values = m.values + } + values[i], values[j] = values[j], values[i] + } +} + +func serializeAndWriteNode(ctx context.Context, ns NodeStore, s message.Serializer, level int, subtrees []uint64, keys [][]byte, values [][]byte) (node Node, err error) { + msg := s.Serialize(keys, values, subtrees, level) + node, err = NodeFromBytes(msg) + if err != nil { + return Node{}, err + } + _, err = ns.Write(ctx, node) + return node, err +} + +func serializeMemoryNode[K ~[]byte, O Ordering[K]](ctx context.Context, m memoryNode, ns NodeStore, s message.Serializer, level int, isRoot bool, order O) (node Node, err error) { + sort.Sort(memoryNodeSort[K, O]{ + memoryNode: &m, + isRoot: isRoot, + order: order, + }) + if level == 0 { + return serializeAndWriteNode(ctx, ns, s, 0, nil, m.keys, m.values) + } + values := make([][]byte, 0, len(m.addresses)) + subTrees := make([]uint64, 0, len(m.addresses)) + for _, address := range m.addresses { + child, err := serializeMemoryNode(ctx, address, ns, s, level-1, false, order) + if err != nil { + return Node{}, err + } + childHash := child.HashOf() + values = append(values, childHash[:]) + childCount, err := message.GetTreeCount(child.msg) + if err != nil { + return Node{}, err + } + subTrees = append(subTrees, uint64(childCount)) + } + return serializeAndWriteNode(ctx, ns, s, level, subTrees, m.keys, values) +} + +func (m *memoryNode) insert(ctx context.Context, ns NodeStore, distanceType expression.DistanceType, vectorHash hash.Hash, key Item, value Item, vector []float64, level int, isLeaf bool) error { + if level == 0 { + if isLeaf { + if bytes.Equal(m.keys[0], key) { + m.values[0] = value + } else { + m.vectorHashes = append(m.vectorHashes, vectorHash) + m.keys = append(m.keys, key) + m.values = append(m.values, value) + } + return nil + } + // We're inserting into the row that's currently the bottom of the in-memory representation, + // but this isn't the leaf row of the final tree: more rows will be added afterward. + if bytes.Equal(m.keys[0], key) { + m.addresses[0] = memoryNode{ + vectorHashes: []hash.Hash{vectorHash}, + keys: [][]byte{key}, + addresses: []memoryNode{{}}, + values: [][]byte{nil}, + } + } else { + m.vectorHashes = append(m.vectorHashes, vectorHash) + m.keys = append(m.keys, key) + m.addresses = append(m.addresses, memoryNode{ + vectorHashes: []hash.Hash{vectorHash}, + keys: [][]byte{key}, + addresses: []memoryNode{{}}, + values: [][]byte{nil}, + }) + } + return nil + } + closestIdx := 0 + otherVector, err := getVectorFromHash(ctx, ns, m.vectorHashes[0]) + if err != nil { + return err + } + distance, err := distanceType.Eval(vector, otherVector) + if err != nil { + return err + } + for i := 1; i < len(m.keys); i++ { + candidateVector, err := getVectorFromHash(ctx, ns, m.vectorHashes[i]) + if err != nil { + return err + } + candidateDistance, err := distanceType.Eval(vector, candidateVector) + if err != nil { + return err + } + if candidateDistance < distance { + distance = candidateDistance + closestIdx = i + } + } + return m.addresses[closestIdx].insert(ctx, ns, distanceType, vectorHash, key, value, vector, level-1, isLeaf) +} + +func levelTraversal(ctx context.Context, nd Node, ns NodeStore, level int, cb func(nd Node) error) error { + if level == 0 { + return cb(nd) + } + for i := 0; i < int(nd.count); i++ { + child, err := ns.Read(ctx, nd.getAddress(i)) + if err != nil { + return err + } + err = levelTraversal(ctx, child, ns, level-1, cb) + if err != nil { + return err + } + } + return nil +} + +// FixupProximityMap takes the root not of a vector index which may not be in the correct order, and moves and reorders +// nodes to make it correct. It ensures the following invariants: +// - In any node except the root node, the first key is the same as the key in the edge pointing to that node. +// (This is the node's "defining key") +// - All other keys within a node are sorted. +// - Each non-root node contains only the keys (including transitively) that are closer to that node's defining key than +// any other key in that node's parent. +func FixupProximityMap[K ~[]byte, O Ordering[K]](ctx context.Context, ns NodeStore, distanceType expression.DistanceType, n Node, getHash func([]byte) hash.Hash, order O) (Node, error) { + if n.Level() == 0 { + return n, nil + } + // Iterate over the keys, starting at the level 1 nodes (with root as level 0) + result := memoryNode{ + vectorHashes: make([]hash.Hash, n.Count()), + keys: make([][]byte, n.Count()), + addresses: make([]memoryNode, n.Count()), + } + for i := 0; i < n.Count(); i++ { + keyItem := n.GetKey(i) + result.keys[i] = keyItem + vectorHash := getHash(keyItem) + result.vectorHashes[i] = vectorHash + result.addresses[i] = memoryNode{ + vectorHashes: []hash.Hash{vectorHash}, + keys: [][]byte{keyItem}, + addresses: []memoryNode{{}}, + values: [][]byte{nil}, + } + } + + for level := 1; level <= n.Level(); level++ { + // Insert each key into the appropriate place in the result. + err := levelTraversal(ctx, n, ns, level, func(nd Node) error { + for i := 0; i < nd.Count(); i++ { + key := nd.GetKey(i) + vecHash := getHash(key) + vector, err := getVectorFromHash(ctx, ns, vecHash) + if err != nil { + return err + } + isLeaf := level == n.Level() + err = result.insert(ctx, ns, distanceType, vecHash, key, nd.GetValue(i), vector, level, isLeaf) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return Node{}, err + } + } + // Convert the in-memory representation back into a Node. + serializer := message.NewVectorIndexSerializer(ns.Pool()) + return serializeMemoryNode[K, O](ctx, result, ns, serializer, n.Level(), true, order) +}