diff --git a/go/Godeps/LICENSES b/go/Godeps/LICENSES index f5221763117..0bef2af10de 100644 --- a/go/Godeps/LICENSES +++ b/go/Godeps/LICENSES @@ -3705,6 +3705,134 @@ SOFTWARE. = LICENSE c2d1663aa29baf618aacf2a16cd5d14a2c6922e13388c82d850cb47a = ================================================================================ +================================================================================ += github.com/esote/minmaxheap licensed under: = + +Creative Commons Legal Code + +CC0 1.0 Universal + + CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE + LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN + ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS + INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES + REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS + PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM + THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED + HEREUNDER. + +Statement of Purpose + +The laws of most jurisdictions throughout the world automatically confer +exclusive Copyright and Related Rights (defined below) upon the creator +and subsequent owner(s) (each and all, an "owner") of an original work of +authorship and/or a database (each, a "Work"). + +Certain owners wish to permanently relinquish those rights to a Work for +the purpose of contributing to a commons of creative, cultural and +scientific works ("Commons") that the public can reliably and without fear +of later claims of infringement build upon, modify, incorporate in other +works, reuse and redistribute as freely as possible in any form whatsoever +and for any purposes, including without limitation commercial purposes. +These owners may contribute to the Commons to promote the ideal of a free +culture and the further production of creative, cultural and scientific +works, or to gain reputation or greater distribution for their Work in +part through the use and efforts of others. + +For these and/or other purposes and motivations, and without any +expectation of additional consideration or compensation, the person +associating CC0 with a Work (the "Affirmer"), to the extent that he or she +is an owner of Copyright and Related Rights in the Work, voluntarily +elects to apply CC0 to the Work and publicly distribute the Work under its +terms, with knowledge of his or her Copyright and Related Rights in the +Work and the meaning and intended legal effect of CC0 on those rights. + +1. Copyright and Related Rights. A Work made available under CC0 may be +protected by copyright and related or neighboring rights ("Copyright and +Related Rights"). Copyright and Related Rights include, but are not +limited to, the following: + + i. the right to reproduce, adapt, distribute, perform, display, + communicate, and translate a Work; + ii. moral rights retained by the original author(s) and/or performer(s); +iii. publicity and privacy rights pertaining to a person's image or + likeness depicted in a Work; + iv. rights protecting against unfair competition in regards to a Work, + subject to the limitations in paragraph 4(a), below; + v. rights protecting the extraction, dissemination, use and reuse of data + in a Work; + vi. database rights (such as those arising under Directive 96/9/EC of the + European Parliament and of the Council of 11 March 1996 on the legal + protection of databases, and under any national implementation + thereof, including any amended or successor version of such + directive); and +vii. other similar, equivalent or corresponding rights throughout the + world based on applicable law or treaty, and any national + implementations thereof. + +2. Waiver. To the greatest extent permitted by, but not in contravention +of, applicable law, Affirmer hereby overtly, fully, permanently, +irrevocably and unconditionally waives, abandons, and surrenders all of +Affirmer's Copyright and Related Rights and associated claims and causes +of action, whether now known or unknown (including existing as well as +future claims and causes of action), in the Work (i) in all territories +worldwide, (ii) for the maximum duration provided by applicable law or +treaty (including future time extensions), (iii) in any current or future +medium and for any number of copies, and (iv) for any purpose whatsoever, +including without limitation commercial, advertising or promotional +purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each +member of the public at large and to the detriment of Affirmer's heirs and +successors, fully intending that such Waiver shall not be subject to +revocation, rescission, cancellation, termination, or any other legal or +equitable action to disrupt the quiet enjoyment of the Work by the public +as contemplated by Affirmer's express Statement of Purpose. + +3. Public License Fallback. Should any part of the Waiver for any reason +be judged legally invalid or ineffective under applicable law, then the +Waiver shall be preserved to the maximum extent permitted taking into +account Affirmer's express Statement of Purpose. In addition, to the +extent the Waiver is so judged Affirmer hereby grants to each affected +person a royalty-free, non transferable, non sublicensable, non exclusive, +irrevocable and unconditional license to exercise Affirmer's Copyright and +Related Rights in the Work (i) in all territories worldwide, (ii) for the +maximum duration provided by applicable law or treaty (including future +time extensions), (iii) in any current or future medium and for any number +of copies, and (iv) for any purpose whatsoever, including without +limitation commercial, advertising or promotional purposes (the +"License"). The License shall be deemed effective as of the date CC0 was +applied by Affirmer to the Work. Should any part of the License for any +reason be judged legally invalid or ineffective under applicable law, such +partial invalidity or ineffectiveness shall not invalidate the remainder +of the License, and in such case Affirmer hereby affirms that he or she +will not (i) exercise any of his or her remaining Copyright and Related +Rights in the Work or (ii) assert any associated claims and causes of +action with respect to the Work, in either case contrary to Affirmer's +express Statement of Purpose. + +4. Limitations and Disclaimers. + + a. No trademark or patent rights held by Affirmer are waived, abandoned, + surrendered, licensed or otherwise affected by this document. + b. Affirmer offers the Work as-is and makes no representations or + warranties of any kind concerning the Work, express, implied, + statutory or otherwise, including without limitation warranties of + title, merchantability, fitness for a particular purpose, non + infringement, or the absence of latent or other defects, accuracy, or + the present or absence of errors, whether or not discoverable, all to + the greatest extent permissible under applicable law. + c. Affirmer disclaims responsibility for clearing rights of other persons + that may apply to the Work or any use thereof, including without + limitation any person's Copyright and Related Rights in the Work. + Further, Affirmer disclaims responsibility for obtaining any necessary + consents, permissions or other rights required for any use of the + Work. + d. Affirmer understands and acknowledges that Creative Commons is not a + party to this document and has no duty or obligation with respect to + this CC0 or use of the Work. + += LICENSE ee1aaed0069c3ab9317a520603cf7b5a38cdc60177f40862e1ce4d98 = +================================================================================ + ================================================================================ = github.com/fatih/color licensed under: = diff --git a/go/cmd/dolt/commands/indexcmds/cat.go b/go/cmd/dolt/commands/indexcmds/cat.go index ea0bde18747..cfc3310167e 100644 --- a/go/cmd/dolt/commands/indexcmds/cat.go +++ b/go/cmd/dolt/commands/indexcmds/cat.go @@ -151,7 +151,7 @@ func (cmd CatCmd) prettyPrintResults(ctx context.Context, doltSch schema.Schema, sqlCtx := sql.NewEmptyContext() - rowItr, err := table.NewTableIterator(ctx, doltSch, idx, 0) + rowItr, err := table.NewTableIterator(ctx, doltSch, idx) if err != nil { return err } diff --git a/go/gen/fb/serial/fileidentifiers.go b/go/gen/fb/serial/fileidentifiers.go index 628cc4f99b8..1f39f839c28 100644 --- a/go/gen/fb/serial/fileidentifiers.go +++ b/go/gen/fb/serial/fileidentifiers.go @@ -42,6 +42,7 @@ const StashFileID = "STSH" const StatisticFileID = "STAT" const DoltgresRootValueFileID = "DGRV" const TupleFileID = "TUPL" +const VectorIndexNodeFileID = "IVFF" const MessageTypesKind int = 27 diff --git a/go/gen/fb/serial/schema.go b/go/gen/fb/serial/schema.go index 9568d2b2e07..c833c20ba6b 100644 --- a/go/gen/fb/serial/schema.go +++ b/go/gen/fb/serial/schema.go @@ -17,9 +17,35 @@ package serial import ( + "strconv" + flatbuffers "github.com/dolthub/flatbuffers/v23/go" ) +type DistanceType byte + +const ( + DistanceTypeNull DistanceType = 0 + DistanceTypeL2_Squared DistanceType = 1 +) + +var EnumNamesDistanceType = map[DistanceType]string{ + DistanceTypeNull: "Null", + DistanceTypeL2_Squared: "L2_Squared", +} + +var EnumValuesDistanceType = map[string]DistanceType{ + "Null": DistanceTypeNull, + "L2_Squared": DistanceTypeL2_Squared, +} + +func (v DistanceType) String() string { + if s, ok := EnumNamesDistanceType[v]; ok { + return s + } + return "DistanceType(" + strconv.FormatInt(int64(v), 10) + ")" +} + type TableSchema struct { _tab flatbuffers.Table } @@ -667,7 +693,35 @@ func (rcv *Index) TryFulltextInfo(obj *FulltextInfo) (*FulltextInfo, error) { return nil, nil } -const IndexNumFields = 12 +func (rcv *Index) VectorKey() bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(28)) + if o != 0 { + return rcv._tab.GetBool(o + rcv._tab.Pos) + } + return false +} + +func (rcv *Index) MutateVectorKey(n bool) bool { + return rcv._tab.MutateBoolSlot(28, n) +} + +func (rcv *Index) TryVectorInfo(obj *VectorInfo) (*VectorInfo, error) { + o := flatbuffers.UOffsetT(rcv._tab.Offset(30)) + if o != 0 { + x := rcv._tab.Indirect(o + rcv._tab.Pos) + if obj == nil { + obj = new(VectorInfo) + } + obj.Init(rcv._tab.Bytes, x) + if VectorInfoNumFields < obj.Table().NumFields() { + return nil, flatbuffers.ErrTableHasUnknownFields + } + return obj, nil + } + return nil, nil +} + +const IndexNumFields = 14 func IndexStart(builder *flatbuffers.Builder) { builder.StartObject(IndexNumFields) @@ -720,6 +774,12 @@ func IndexAddFulltextKey(builder *flatbuffers.Builder, fulltextKey bool) { func IndexAddFulltextInfo(builder *flatbuffers.Builder, fulltextInfo flatbuffers.UOffsetT) { builder.PrependUOffsetTSlot(11, flatbuffers.UOffsetT(fulltextInfo), 0) } +func IndexAddVectorKey(builder *flatbuffers.Builder, vectorKey bool) { + builder.PrependBoolSlot(12, vectorKey, false) +} +func IndexAddVectorInfo(builder *flatbuffers.Builder, vectorInfo flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(13, flatbuffers.UOffsetT(vectorInfo), 0) +} func IndexEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } @@ -878,6 +938,62 @@ func FulltextInfoEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } +type VectorInfo struct { + _tab flatbuffers.Table +} + +func InitVectorInfoRoot(o *VectorInfo, buf []byte, offset flatbuffers.UOffsetT) error { + n := flatbuffers.GetUOffsetT(buf[offset:]) + return o.Init(buf, n+offset) +} + +func TryGetRootAsVectorInfo(buf []byte, offset flatbuffers.UOffsetT) (*VectorInfo, error) { + x := &VectorInfo{} + return x, InitVectorInfoRoot(x, buf, offset) +} + +func TryGetSizePrefixedRootAsVectorInfo(buf []byte, offset flatbuffers.UOffsetT) (*VectorInfo, error) { + x := &VectorInfo{} + return x, InitVectorInfoRoot(x, buf, offset+flatbuffers.SizeUint32) +} + +func (rcv *VectorInfo) Init(buf []byte, i flatbuffers.UOffsetT) error { + rcv._tab.Bytes = buf + rcv._tab.Pos = i + if VectorInfoNumFields < rcv.Table().NumFields() { + return flatbuffers.ErrTableHasUnknownFields + } + return nil +} + +func (rcv *VectorInfo) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *VectorInfo) DistanceType() DistanceType { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return DistanceType(rcv._tab.GetByte(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *VectorInfo) MutateDistanceType(n DistanceType) bool { + return rcv._tab.MutateByteSlot(4, byte(n)) +} + +const VectorInfoNumFields = 1 + +func VectorInfoStart(builder *flatbuffers.Builder) { + builder.StartObject(VectorInfoNumFields) +} +func VectorInfoAddDistanceType(builder *flatbuffers.Builder, distanceType DistanceType) { + builder.PrependByteSlot(0, byte(distanceType), 0) +} +func VectorInfoEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} + type CheckConstraint struct { _tab flatbuffers.Table } diff --git a/go/gen/fb/serial/vectorindexnode.go b/go/gen/fb/serial/vectorindexnode.go new file mode 100644 index 00000000000..9928b29b4d6 --- /dev/null +++ b/go/gen/fb/serial/vectorindexnode.go @@ -0,0 +1,346 @@ +// Copyright 2022-2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package serial + +import ( + flatbuffers "github.com/dolthub/flatbuffers/v23/go" +) + +type VectorIndexNode struct { + _tab flatbuffers.Table +} + +func InitVectorIndexNodeRoot(o *VectorIndexNode, buf []byte, offset flatbuffers.UOffsetT) error { + n := flatbuffers.GetUOffsetT(buf[offset:]) + return o.Init(buf, n+offset) +} + +func TryGetRootAsVectorIndexNode(buf []byte, offset flatbuffers.UOffsetT) (*VectorIndexNode, error) { + x := &VectorIndexNode{} + return x, InitVectorIndexNodeRoot(x, buf, offset) +} + +func TryGetSizePrefixedRootAsVectorIndexNode(buf []byte, offset flatbuffers.UOffsetT) (*VectorIndexNode, error) { + x := &VectorIndexNode{} + return x, InitVectorIndexNodeRoot(x, buf, offset+flatbuffers.SizeUint32) +} + +func (rcv *VectorIndexNode) Init(buf []byte, i flatbuffers.UOffsetT) error { + rcv._tab.Bytes = buf + rcv._tab.Pos = i + if VectorIndexNodeNumFields < rcv.Table().NumFields() { + return flatbuffers.ErrTableHasUnknownFields + } + return nil +} + +func (rcv *VectorIndexNode) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *VectorIndexNode) KeyItems(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyItemsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyItemsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateKeyItems(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) KeyOffsets(j int) uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetUint32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyOffsetsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateKeyOffsets(j int, n uint32) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateUint32(a+flatbuffers.UOffsetT(j*4), n) + } + return false +} + +func (rcv *VectorIndexNode) ValueItems(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueItemsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueItemsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateValueItems(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) ValueOffsets(j int) uint32 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetUint32(a + flatbuffers.UOffsetT(j*4)) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueOffsetsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateValueOffsets(j int, n uint32) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateUint32(a+flatbuffers.UOffsetT(j*4), n) + } + return false +} + +func (rcv *VectorIndexNode) AddressArray(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) AddressArrayLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) AddressArrayBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateAddressArray(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) SubtreeCounts(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) SubtreeCountsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) SubtreeCountsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateSubtreeCounts(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) TreeCount() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateTreeCount(n uint64) bool { + return rcv._tab.MutateUint64Slot(16, n) +} + +func (rcv *VectorIndexNode) TreeLevel() byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.GetByte(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateTreeLevel(n byte) bool { + return rcv._tab.MutateByteSlot(18, n) +} + +func (rcv *VectorIndexNode) LogChunkSize() byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(20)) + if o != 0 { + return rcv._tab.GetByte(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateLogChunkSize(n byte) bool { + return rcv._tab.MutateByteSlot(20, n) +} + +func (rcv *VectorIndexNode) DistanceType() DistanceType { + o := flatbuffers.UOffsetT(rcv._tab.Offset(22)) + if o != 0 { + return DistanceType(rcv._tab.GetByte(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateDistanceType(n DistanceType) bool { + return rcv._tab.MutateByteSlot(22, byte(n)) +} + +const VectorIndexNodeNumFields = 10 + +func VectorIndexNodeStart(builder *flatbuffers.Builder) { + builder.StartObject(VectorIndexNodeNumFields) +} +func VectorIndexNodeAddKeyItems(builder *flatbuffers.Builder, keyItems flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(keyItems), 0) +} +func VectorIndexNodeStartKeyItemsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddKeyOffsets(builder *flatbuffers.Builder, keyOffsets flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(keyOffsets), 0) +} +func VectorIndexNodeStartKeyOffsetsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func VectorIndexNodeAddValueItems(builder *flatbuffers.Builder, valueItems flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(valueItems), 0) +} +func VectorIndexNodeStartValueItemsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddValueOffsets(builder *flatbuffers.Builder, valueOffsets flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(valueOffsets), 0) +} +func VectorIndexNodeStartValueOffsetsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(4, numElems, 4) +} +func VectorIndexNodeAddAddressArray(builder *flatbuffers.Builder, addressArray flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(addressArray), 0) +} +func VectorIndexNodeStartAddressArrayVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddSubtreeCounts(builder *flatbuffers.Builder, subtreeCounts flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(subtreeCounts), 0) +} +func VectorIndexNodeStartSubtreeCountsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddTreeCount(builder *flatbuffers.Builder, treeCount uint64) { + builder.PrependUint64Slot(6, treeCount, 0) +} +func VectorIndexNodeAddTreeLevel(builder *flatbuffers.Builder, treeLevel byte) { + builder.PrependByteSlot(7, treeLevel, 0) +} +func VectorIndexNodeAddLogChunkSize(builder *flatbuffers.Builder, logChunkSize byte) { + builder.PrependByteSlot(8, logChunkSize, 0) +} +func VectorIndexNodeAddDistanceType(builder *flatbuffers.Builder, distanceType DistanceType) { + builder.PrependByteSlot(9, byte(distanceType), 0) +} +func VectorIndexNodeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/go/go.mod b/go/go.mod index d5ff754e1b7..3a836fd08f8 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,9 +56,10 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.19.1-0.20250115230045-115c98b242ba + github.com/dolthub/go-mysql-server v0.19.1-0.20250116005911-204fe88cb899 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 + github.com/esote/minmaxheap v1.0.0 github.com/goccy/go-json v0.10.2 github.com/google/btree v1.1.2 github.com/google/go-github/v57 v57.0.0 diff --git a/go/go.sum b/go/go.sum index d9e6e674ad3..db0db490ddd 100644 --- a/go/go.sum +++ b/go/go.sum @@ -179,8 +179,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90 h1:Sni8jrP0sy/w9ZYXoff4g/ixe+7bFCZlfCqXKJSU+zM= github.com/dolthub/go-icu-regex v0.0.0-20241215010122-db690dd53c90/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250115230045-115c98b242ba h1:+eSlrGeIq8Xh36nCtt9auq0DJIs+Rle6UJBjaCmijnU= -github.com/dolthub/go-mysql-server v0.19.1-0.20250115230045-115c98b242ba/go.mod h1:5HtKnb+IAiv+27bo50KGANbUB4HAzGEF9rlFF2ZBLZg= +github.com/dolthub/go-mysql-server v0.19.1-0.20250116005911-204fe88cb899 h1:cjntkeERfjYl7Zi+RGWDwhiGk+SmS6AUXBMYkIfrLHc= +github.com/dolthub/go-mysql-server v0.19.1-0.20250116005911-204fe88cb899/go.mod h1:5HtKnb+IAiv+27bo50KGANbUB4HAzGEF9rlFF2ZBLZg= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= @@ -209,6 +209,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/esote/minmaxheap v1.0.0 h1:rgA7StnXXpZG6qlM0S7pUmEv1KpWe32rYT4x8J8ntaA= +github.com/esote/minmaxheap v1.0.0/go.mod h1:Ln8+i7fS1k3PLgZI2JAo0iA1as95QnIYiGCrqSJ5FZk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= diff --git a/go/go.work.sum b/go/go.work.sum index 7203a8e7f6f..71f195420ad 100644 --- a/go/go.work.sum +++ b/go/go.work.sum @@ -387,6 +387,7 @@ github.com/dolthub/go-mysql-server v0.18.2-0.20240918194055-f75a63f64679 h1:O6eS github.com/dolthub/go-mysql-server v0.18.2-0.20240918194055-f75a63f64679/go.mod h1:m88EMm9OthVVa6qIhbpnRDpj/eYUXuNpvY/+0YWKVwc= github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683 h1:2/RJeUfNAXS7mbBnEr9C36htiCJHk5XldDPzhxtEsME= github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20241231200706-18992bb25fdc/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dylanmei/iso8601 v0.1.0/go.mod h1:w9KhXSgIyROl1DefbMYIE7UVSIvELTbMrCfx+QkYnoQ= github.com/dylanmei/winrmtest v0.0.0-20210303004826-fbc9ae56efb6/go.mod h1:6BLLhzn1VEiJ4veuAGhINBTrBlV889Wd+aU4auxKOww= github.com/eapache/go-resiliency v1.1.0 h1:1NtRmCAqadE2FN4ZcN6g90TP3uk8cg9rn9eNK2197aU= @@ -403,6 +404,8 @@ github.com/envoyproxy/protoc-gen-validate v0.10.1 h1:c0g45+xCJhdgFGw7a5QAfdS4byA github.com/envoyproxy/protoc-gen-validate v0.10.1/go.mod h1:DRjgyB0I43LtJapqN6NiRwroiAU2PaFuvk/vjgh61ss= github.com/envoyproxy/protoc-gen-validate v1.0.2 h1:QkIBuU5k+x7/QXPvPPnWXWlCdaBFApVqftFV6k087DA= github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7QWXpgx6x8QiMKdmN72jogE= +github.com/esote/minmaxheap v1.0.0 h1:rgA7StnXXpZG6qlM0S7pUmEv1KpWe32rYT4x8J8ntaA= +github.com/esote/minmaxheap v1.0.0/go.mod h1:Ln8+i7fS1k3PLgZI2JAo0iA1as95QnIYiGCrqSJ5FZk= github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db h1:gb2Z18BhTPJPpLQWj4T+rfKHYCHxRHCtRxhKKjRidVw= @@ -482,7 +485,6 @@ github.com/hashicorp/go-slug v0.15.0/go.mod h1:THWVTAXwJEinbsp4/bBRcmbaO5EYNLTqx github.com/hashicorp/go-sockaddr v1.0.0 h1:GeH6tui99pF4NJgfnhp+L6+FfobzVW3Ah46sLo0ICXs= github.com/hashicorp/go-syslog v1.0.0 h1:KaodqZuhUoZereWVIYmpUgZysurB1kBLX2j0MwMrUAE= github.com/hashicorp/go-tfe v1.58.0/go.mod h1:XnTtBj3tVQ4uFkcFsv8Grn+O1CVcIcceL1uc2AgUcaU= -github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.0 h1:3vNe/fWF5CBgRIguda1meWhsZHy3m8gCJ5wx+dIzX/E= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= @@ -692,6 +694,7 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= golang.org/x/oauth2 v0.11.0/go.mod h1:LdF7O/8bLR/qWK9DrpXmbHLTouvRHK0SgJl0GmDBchk= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/go/libraries/doltcore/doltdb/durable/index.go b/go/libraries/doltcore/doltdb/durable/index.go index 9340d6608bf..726e2038154 100644 --- a/go/libraries/doltcore/doltdb/durable/index.go +++ b/go/libraries/doltcore/doltdb/durable/index.go @@ -21,6 +21,8 @@ import ( "io" "strings" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/prolly" @@ -112,7 +114,7 @@ func indexFromAddr(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeS return IndexFromNomsMap(v.(types.Map), vrw, ns), nil case types.Format_DOLT: - m, err := shim.MapInterfaceFromValue(v, sch, ns, isKeylessTable) + m, err := shim.MapInterfaceFromValue(ctx, v, sch, ns, isKeylessTable) if err != nil { return nil, err } @@ -125,23 +127,23 @@ func indexFromAddr(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeS // NewEmptyPrimaryIndex creates a new empty Index for use as the primary index in a table. func NewEmptyPrimaryIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, indexSchema schema.Schema) (Index, error) { - return newEmptyIndex(ctx, vrw, ns, indexSchema, false) + return newEmptyIndex(ctx, vrw, ns, indexSchema, false, false) } // NewEmptyForeignKeyIndex creates a new empty Index for use as a foreign key index. // Foreign keys cannot appear on keyless tables. func NewEmptyForeignKeyIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, indexSchema schema.Schema) (Index, error) { - return newEmptyIndex(ctx, vrw, ns, indexSchema, false) + return newEmptyIndex(ctx, vrw, ns, indexSchema, false, false) } // NewEmptyIndexFromTableSchema creates a new empty Index described by a schema.Index. func NewEmptyIndexFromTableSchema(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, idx schema.Index, tableSchema schema.Schema) (Index, error) { indexSchema := idx.Schema() - return newEmptyIndex(ctx, vrw, ns, indexSchema, schema.IsKeyless(tableSchema)) + return newEmptyIndex(ctx, vrw, ns, indexSchema, idx.IsVector(), schema.IsKeyless(tableSchema)) } // newEmptyIndex returns an index with no rows. -func newEmptyIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, sch schema.Schema, isKeylessSecondary bool) (Index, error) { +func newEmptyIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, sch schema.Schema, isVector bool, isKeylessSecondary bool) (Index, error) { switch vrw.Format() { case types.Format_LD_1: m, err := types.NewMap(ctx, vrw) @@ -155,7 +157,11 @@ func newEmptyIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeS if isKeylessSecondary { kd = prolly.AddHashToSchema(kd) } - return NewEmptyProllyIndex(ctx, ns, kd, vd) + if isVector { + return NewEmptyProximityIndex(ctx, ns, kd, vd) + } else { + return NewEmptyProllyIndex(ctx, ns, kd, vd) + } default: return nil, errNbfUnknown @@ -170,6 +176,18 @@ func NewEmptyProllyIndex(ctx context.Context, ns tree.NodeStore, kd, vd val.Tupl return IndexFromProllyMap(m), nil } +func NewEmptyProximityIndex(ctx context.Context, ns tree.NodeStore, kd, vd val.TupleDesc) (Index, error) { + proximityMapBuilder, err := prolly.NewProximityMapBuilder(ctx, ns, vector.DistanceL2Squared{}, kd, vd, prolly.DefaultLogChunkSize) + if err != nil { + return nil, err + } + m, err := proximityMapBuilder.Flush(ctx) + if err != nil { + return nil, err + } + return IndexFromProximityMap(m), nil +} + type nomsIndex struct { index types.Map vrw types.ValueReadWriter @@ -264,6 +282,8 @@ func MapFromIndex(i Index) prolly.MapInterfaceWithMutable { switch indexType := i.(type) { case prollyIndex: return indexType.index + case proximityIndex: + return indexType.index } return i.(prollyIndex).index } @@ -278,6 +298,8 @@ func IndexFromMapInterface(m prolly.MapInterface) Index { switch m := m.(type) { case prolly.Map: return IndexFromProllyMap(m) + case prolly.ProximityMap: + return IndexFromProximityMap(m) default: panic("unknown map type") } diff --git a/go/libraries/doltcore/doltdb/durable/proximity_index.go b/go/libraries/doltcore/doltdb/durable/proximity_index.go new file mode 100644 index 00000000000..465106b081c --- /dev/null +++ b/go/libraries/doltcore/doltdb/durable/proximity_index.go @@ -0,0 +1,113 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package durable + +import ( + "bytes" + "context" + "strings" + + "github.com/dolthub/dolt/go/libraries/doltcore/schema" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly" + "github.com/dolthub/dolt/go/store/prolly/shim" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/types" +) + +type proximityIndex struct { + index prolly.ProximityMap +} + +var _ Index = proximityIndex{} + +// ProximityMapFromIndex unwraps the Index and returns the underlying prolly.ProximityMap. +func ProximityMapFromIndex(i Index) prolly.ProximityMap { + return i.(proximityIndex).index +} + +// IndexFromProximityMap wraps a prolly.ProximityMap and returns it as an Index. +func IndexFromProximityMap(m prolly.ProximityMap) Index { + return proximityIndex{index: m} +} + +// HashOf implements Index. +func (i proximityIndex) HashOf() (hash.Hash, error) { + return i.index.HashOf(), nil +} + +// Count implements Index. +func (i proximityIndex) Count() (uint64, error) { + c, err := i.index.Count() + return uint64(c), err +} + +// Empty implements Index. +func (i proximityIndex) Empty() (bool, error) { + c, err := i.index.Count() + if err != nil { + return false, err + } + return c == 0, nil +} + +// Format implements Index. +func (i proximityIndex) Format() *types.NomsBinFormat { + return types.Format_DOLT +} + +// bytes implements Index. +func (i proximityIndex) bytes() ([]byte, error) { + return shim.ValueFromMap(i.index).(types.SerialMessage), nil +} + +var _ Index = proximityIndex{} + +func (i proximityIndex) AddColumnToRows(ctx context.Context, newCol string, newSchema schema.Schema) (Index, error) { + var last bool + colIdx, iCol := 0, 0 + err := newSchema.GetNonPKCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { + last = false + if strings.EqualFold(col.Name, newCol) { + last = true + colIdx = iCol + } + iCol++ + return false, nil + }) + if err != nil { + return nil, err + } + + _ = colIdx + // If the column we added was last among non-primary key columns we can skip this step + if last { + return i, nil + } + + // If not, then we have to iterate over this table's rows and update all the offsets for the new column + rowMap := ProximityMapFromIndex(i) + // TODO: Allow for mutation of ProximityMaps + + return IndexFromProximityMap(rowMap), nil +} + +func (i proximityIndex) DebugString(ctx context.Context, ns tree.NodeStore, schema schema.Schema) string { + var b bytes.Buffer + i.index.WalkNodes(ctx, func(ctx context.Context, nd tree.Node) error { + return tree.OutputProllyNode(ctx, &b, nd, ns, schema) + }) + return b.String() +} diff --git a/go/libraries/doltcore/doltdb/durable/table.go b/go/libraries/doltcore/doltdb/durable/table.go index 9f91f6dea60..eba4b079476 100644 --- a/go/libraries/doltcore/doltdb/durable/table.go +++ b/go/libraries/doltcore/doltdb/durable/table.go @@ -850,7 +850,7 @@ func (t doltDevTable) GetTableRows(ctx context.Context) (Index, error) { if err != nil { return nil, err } - m, err := shim.MapInterfaceFromValue(types.SerialMessage(rowbytes), sch, t.ns, false) + m, err := shim.MapInterfaceFromValue(ctx, types.SerialMessage(rowbytes), sch, t.ns, false) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/schema/encoding/schema_marshaling.go b/go/libraries/doltcore/schema/encoding/schema_marshaling.go index c5c9aee5686..4cd078f48a2 100644 --- a/go/libraries/doltcore/schema/encoding/schema_marshaling.go +++ b/go/libraries/doltcore/schema/encoding/schema_marshaling.go @@ -332,6 +332,7 @@ func (sd encodedSchemaData) decodeSchema() (schema.Schema, error) { IsUnique: encodedIndex.Unique, IsSpatial: encodedIndex.Spatial, IsFullText: encodedIndex.FullText, + IsVector: false, // noms encoding does not support vector indexes IsUserDefined: !encodedIndex.IsSystemDefined, Comment: encodedIndex.Comment, FullTextProperties: schema.FullTextProperties{ diff --git a/go/libraries/doltcore/schema/encoding/serialization.go b/go/libraries/doltcore/schema/encoding/serialization.go index 3eb53784b85..c62ba28155b 100644 --- a/go/libraries/doltcore/schema/encoding/serialization.go +++ b/go/libraries/doltcore/schema/encoding/serialization.go @@ -20,6 +20,7 @@ import ( fb "github.com/dolthub/flatbuffers/v23/go" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" "github.com/dolthub/go-mysql-server/sql/planbuilder" sqltypes "github.com/dolthub/go-mysql-server/sql/types" @@ -415,6 +416,11 @@ func serializeSecondaryIndexes(b *fb.Builder, sch schema.Schema, indexes []schem ftInfo = serializeFullTextInfo(b, idx) } + var vectorInfo fb.UOffsetT + if idx.IsVector() { + vectorInfo = serializeVectorInfo(b, idx) + } + serial.IndexStart(b) serial.IndexAddName(b, no) serial.IndexAddComment(b, co) @@ -429,6 +435,10 @@ func serializeSecondaryIndexes(b *fb.Builder, sch schema.Schema, indexes []schem if idx.IsFullText() { serial.IndexAddFulltextInfo(b, ftInfo) } + if idx.IsVector() { + serial.IndexAddVectorKey(b, true) + serial.IndexAddVectorInfo(b, vectorInfo) + } offs[i] = serial.IndexEnd(b) } @@ -454,14 +464,21 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error return err } + vi, err := deserializeVectorInfo(&idx) + if err != nil { + return err + } + name := string(idx.Name()) props := schema.IndexProperties{ IsUnique: idx.UniqueKey(), IsSpatial: idx.SpatialKey(), IsFullText: idx.FulltextKey(), + IsVector: idx.VectorKey(), IsUserDefined: !idx.SystemDefined(), Comment: string(idx.Comment()), FullTextProperties: fti, + VectorProperties: vi, } tags := make([]uint64, idx.IndexColumnsLength()) @@ -555,6 +572,19 @@ func serializeFullTextInfo(b *fb.Builder, idx schema.Index) fb.UOffsetT { return serial.FulltextInfoEnd(b) } +func serializeVectorInfo(b *fb.Builder, idx schema.Index) fb.UOffsetT { + props := idx.VectorProperties() + + serial.VectorInfoStart(b) + + switch props.DistanceType { + case vector.DistanceL2Squared{}: + serial.VectorInfoAddDistanceType(b, serial.DistanceTypeL2_Squared) + } + + return serial.VectorInfoEnd(b) +} + func deserializeFullTextInfo(idx *serial.Index) (schema.FullTextProperties, error) { fulltext := serial.FulltextInfo{} has, err := idx.TryFulltextInfo(&fulltext) @@ -586,6 +616,25 @@ func deserializeFullTextInfo(idx *serial.Index) (schema.FullTextProperties, erro }, nil } +func deserializeVectorInfo(idx *serial.Index) (schema.VectorProperties, error) { + vectorInfo := serial.VectorInfo{} + has, err := idx.TryVectorInfo(&vectorInfo) + if err != nil { + return schema.VectorProperties{}, err + } + if has == nil { + return schema.VectorProperties{}, nil + } + + switch vectorInfo.DistanceType() { + case serial.DistanceTypeL2_Squared: + return schema.VectorProperties{ + DistanceType: vector.DistanceL2Squared{}, + }, nil + } + return schema.VectorProperties{}, fmt.Errorf("unknown distance type in vector index info: %s", vectorInfo.DistanceType()) +} + func keylessSerialSchema(s *serial.TableSchema) (bool, error) { n := s.ColumnsLength() if n < 2 { diff --git a/go/libraries/doltcore/schema/index.go b/go/libraries/doltcore/schema/index.go index 13b8a02876e..942b5d753d4 100644 --- a/go/libraries/doltcore/schema/index.go +++ b/go/libraries/doltcore/schema/index.go @@ -47,6 +47,8 @@ type Index interface { IsSpatial() bool // IsFullText returns whether the given index has the FULLTEXT constraint. IsFullText() bool + // IsVector returns whether the given index has the VECTOR constraint. + IsVector() bool // IsUserDefined returns whether the given index was created by a user or automatically generated. IsUserDefined() bool // Name returns the name of the index. @@ -62,22 +64,26 @@ type Index interface { PrefixLengths() []uint16 // FullTextProperties returns all properties belonging to a Full-Text index. FullTextProperties() FullTextProperties + // VectorProperties returns all properties belonging to a vector index. + VectorProperties() VectorProperties } var _ Index = (*indexImpl)(nil) type indexImpl struct { - name string - tags []uint64 - allTags []uint64 - indexColl *indexCollectionImpl - isUnique bool - isSpatial bool - isFullText bool - isUserDefined bool - comment string - prefixLengths []uint16 - fullTextProps FullTextProperties + name string + tags []uint64 + allTags []uint64 + indexColl *indexCollectionImpl + isUnique bool + isSpatial bool + isFullText bool + isVector bool + isUserDefined bool + comment string + prefixLengths []uint16 + fullTextProps FullTextProperties + vectorProperties VectorProperties } func NewIndex(name string, tags, allTags []uint64, indexColl IndexCollection, props IndexProperties) Index { @@ -87,16 +93,18 @@ func NewIndex(name string, tags, allTags []uint64, indexColl IndexCollection, pr } return &indexImpl{ - name: name, - tags: tags, - allTags: allTags, - indexColl: indexCollImpl, - isUnique: props.IsUnique, - isSpatial: props.IsSpatial, - isFullText: props.IsFullText, - isUserDefined: props.IsUserDefined, - comment: props.Comment, - fullTextProps: props.FullTextProperties, + name: name, + tags: tags, + allTags: allTags, + indexColl: indexCollImpl, + isUnique: props.IsUnique, + isSpatial: props.IsSpatial, + isFullText: props.IsFullText, + isVector: props.IsVector, + isUserDefined: props.IsUserDefined, + comment: props.Comment, + fullTextProps: props.FullTextProperties, + vectorProperties: props.VectorProperties, } } @@ -209,6 +217,11 @@ func (ix *indexImpl) IsFullText() bool { return ix.isFullText } +// IsVector implements Index. +func (ix *indexImpl) IsVector() bool { + return ix.isVector +} + // IsUserDefined implements Index. func (ix *indexImpl) IsUserDefined() bool { return ix.isUserDefined @@ -309,6 +322,11 @@ func (ix *indexImpl) FullTextProperties() FullTextProperties { return ix.fullTextProps } +// VectorProperties implements Index. +func (ix *indexImpl) VectorProperties() VectorProperties { + return ix.vectorProperties +} + // copy returns an exact copy of the calling index. func (ix *indexImpl) copy() *indexImpl { newIx := *ix diff --git a/go/libraries/doltcore/schema/index_coll.go b/go/libraries/doltcore/schema/index_coll.go index b27b1d25653..c61844c7f99 100644 --- a/go/libraries/doltcore/schema/index_coll.go +++ b/go/libraries/doltcore/schema/index_coll.go @@ -18,6 +18,8 @@ import ( "fmt" "sort" "strings" + + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" ) type IndexCollection interface { @@ -83,6 +85,8 @@ type IndexProperties struct { IsUserDefined bool Comment string FullTextProperties + IsVector bool + VectorProperties } type FullTextProperties struct { @@ -96,6 +100,10 @@ type FullTextProperties struct { KeyPositions []uint16 } +type VectorProperties struct { + DistanceType vector.DistanceType +} + type indexCollectionImpl struct { colColl *ColCollection indexes map[string]*indexImpl @@ -210,17 +218,19 @@ func (ixc *indexCollectionImpl) AddIndexByColTags(indexName string, tags []uint6 } index := &indexImpl{ - indexColl: ixc, - name: indexName, - tags: tags, - allTags: combineAllTags(tags, ixc.pks), - isUnique: props.IsUnique, - isSpatial: props.IsSpatial, - isFullText: props.IsFullText, - isUserDefined: props.IsUserDefined, - comment: props.Comment, - prefixLengths: prefixLengths, - fullTextProps: props.FullTextProperties, + indexColl: ixc, + name: indexName, + tags: tags, + allTags: combineAllTags(tags, ixc.pks), + isUnique: props.IsUnique, + isSpatial: props.IsSpatial, + isFullText: props.IsFullText, + isVector: props.IsVector, + isUserDefined: props.IsUserDefined, + comment: props.Comment, + prefixLengths: prefixLengths, + fullTextProps: props.FullTextProperties, + vectorProperties: props.VectorProperties, } ixc.indexes[lowerName] = index for _, tag := range tags { @@ -243,6 +253,7 @@ func (ixc *indexCollectionImpl) UnsafeAddIndexByColTags(indexName string, tags [ isUnique: props.IsUnique, isSpatial: props.IsSpatial, isFullText: props.IsFullText, + isVector: props.IsVector, isUserDefined: props.IsUserDefined, comment: props.Comment, prefixLengths: prefixLengths, @@ -410,6 +421,7 @@ func (ixc *indexCollectionImpl) Merge(indexes ...Index) { isUnique: index.IsUnique(), isSpatial: index.IsSpatial(), isFullText: index.IsFullText(), + isVector: index.IsVector(), isUserDefined: index.IsUserDefined(), comment: index.Comment(), prefixLengths: index.PrefixLengths(), diff --git a/go/libraries/doltcore/sqle/alterschema.go b/go/libraries/doltcore/sqle/alterschema.go index 7defa7011e4..e11d5d63484 100755 --- a/go/libraries/doltcore/sqle/alterschema.go +++ b/go/libraries/doltcore/sqle/alterschema.go @@ -268,9 +268,11 @@ func replaceColumnInSchema(sch schema.Schema, oldCol schema.Column, newCol schem IsUnique: index.IsUnique(), IsSpatial: index.IsSpatial(), IsFullText: index.IsFullText(), + IsVector: index.IsVector(), IsUserDefined: index.IsUserDefined(), Comment: index.Comment(), FullTextProperties: index.FullTextProperties(), + VectorProperties: index.VectorProperties(), }) if err != nil { return nil, err diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 20043743b26..d8410141b4b 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -852,6 +852,18 @@ func TestIndexes(t *testing.T) { enginetest.TestIndexes(t, harness) } +func TestVectorIndexes(t *testing.T) { + harness := newDoltHarness(t) + defer harness.Close() + enginetest.TestVectorIndexes(t, harness) +} + +func TestVectorFunctions(t *testing.T) { + harness := newDoltHarness(t) + defer harness.Close() + enginetest.TestVectorFunctions(t, harness) +} + func TestIndexPrefix(t *testing.T) { skipOldFormat(t) harness := newDoltHarness(t) diff --git a/go/libraries/doltcore/sqle/index/dolt_index.go b/go/libraries/doltcore/sqle/index/dolt_index.go index 8362facb818..11f2c8b1d86 100644 --- a/go/libraries/doltcore/sqle/index/dolt_index.go +++ b/go/libraries/doltcore/sqle/index/dolt_index.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" "github.com/dolthub/go-mysql-server/sql/fulltext" sqltypes "github.com/dolthub/go-mysql-server/sql/types" @@ -394,6 +395,7 @@ func getSecondaryIndex(ctx context.Context, db, tbl string, t *doltdb.Table, sch unique: idx.IsUnique(), spatial: idx.IsSpatial(), fulltext: idx.IsFullText(), + vector: idx.IsVector(), isPk: false, comment: idx.Comment(), vrw: vrw, @@ -404,6 +406,7 @@ func getSecondaryIndex(ctx context.Context, db, tbl string, t *doltdb.Table, sch doltBinFormat: types.IsFormat_DOLT(vrw.Format()), prefixLengths: idx.PrefixLengths(), fullTextProps: idx.FullTextProperties(), + vectorProps: idx.VectorProperties(), }, nil } @@ -425,6 +428,7 @@ func ConvertFullTextToSql(ctx context.Context, db, tbl string, sch schema.Schema unique: idx.IsUnique(), spatial: idx.IsSpatial(), fulltext: idx.IsFullText(), + vector: idx.IsVector(), isPk: false, comment: idx.Comment(), vrw: nil, @@ -435,6 +439,7 @@ func ConvertFullTextToSql(ctx context.Context, db, tbl string, sch schema.Schema doltBinFormat: true, prefixLengths: idx.PrefixLengths(), fullTextProps: idx.FullTextProperties(), + vectorProps: idx.VectorProperties(), }, nil } @@ -556,6 +561,7 @@ type doltIndex struct { unique bool spatial bool fulltext bool + vector bool isPk bool comment string order sql.IndexOrder @@ -571,6 +577,7 @@ type doltIndex struct { prefixLengths []uint16 fullTextProps schema.FullTextProperties + vectorProps schema.VectorProperties } type LookupMeta struct { @@ -619,8 +626,12 @@ func (di *doltIndex) CanSupport(...sql.Range) bool { } // CanSupportOrderBy implements the interface sql.Index. -func (di *doltIndex) CanSupportOrderBy(_ sql.Expression) bool { - return false +func (di *doltIndex) CanSupportOrderBy(expr sql.Expression) bool { + distance, ok := expr.(*vector.Distance) + if !ok { + return false + } + return di.vector && di.vectorProps.DistanceType.CanEval(distance.DistanceType) } // ColumnExpressionTypes implements the interface sql.Index. @@ -994,7 +1005,7 @@ func (di *doltIndex) IsFullText() bool { // IsVector implements sql.Index func (di *doltIndex) IsVector() bool { - return false + return di.vector } // IsPrimaryKey implements DoltIndex. diff --git a/go/libraries/doltcore/sqle/index/index_reader.go b/go/libraries/doltcore/sqle/index/index_reader.go index 45da12ba0b5..1ebfd09595e 100644 --- a/go/libraries/doltcore/sqle/index/index_reader.go +++ b/go/libraries/doltcore/sqle/index/index_reader.go @@ -146,7 +146,7 @@ func newPointPartitionIter(ctx *sql.Context, lookup sql.IndexLookup, idx *doltIn } var _ sql.PartitionIter = (*pointPartition)(nil) -var _ sql.Partition = (*pointPartition)(nil) +var _ sql.Partition = pointPartition{} type pointPartition struct { r prolly.Range @@ -246,6 +246,42 @@ func GetDurableIndex(ctx *sql.Context, return s.Secondary, nil } +// vectorPartitionIter is the sql.PartitionIter for vector indexes. +// Because it only ever has one partition, it also implements sql.Partition +// and returns itself in calls to Next. +type vectorPartitionIter struct { + Column sql.Expression + sql.OrderAndLimit + used bool +} + +var _ sql.PartitionIter = (*vectorPartitionIter)(nil) +var _ sql.Partition = vectorPartitionIter{} + +// Key returns the key used to distinguish partitions. Since it only ever has one partition, +// this value is unused. +func (v vectorPartitionIter) Key() []byte { + return nil +} + +func (v *vectorPartitionIter) Close(_ *sql.Context) error { + return nil +} + +func (v *vectorPartitionIter) Next(_ *sql.Context) (sql.Partition, error) { + if v.used { + return nil, io.EOF + } + v.used = true + return *v, nil +} + +func NewVectorPartitionIter(lookup sql.IndexLookup) (sql.PartitionIter, error) { + return &vectorPartitionIter{ + OrderAndLimit: lookup.VectorOrderAndLimit, + }, nil +} + // IndexScanBuilder generates secondary lookups for partitions and // encapsulates fast path optimizations for certain point lookups. type IndexScanBuilder interface { @@ -296,10 +332,18 @@ func NewIndexReaderBuilder( } if isDoltFormat { - base.sec = durable.ProllyMapFromIndex(s.Secondary) - base.secKd, base.secVd = base.sec.Descriptors() - base.ns = base.sec.NodeStore() + secondaryIndex := durable.MapFromIndex(s.Secondary) + base.ns = secondaryIndex.NodeStore() + base.secKd, base.secVd = secondaryIndex.Descriptors() base.prefDesc = base.secKd.PrefixDesc(len(di.columns)) + switch si := secondaryIndex.(type) { + case prolly.Map: + base.sec = si + case prolly.ProximityMap: + base.proximitySecondary = si + default: + return nil, fmt.Errorf("unknown index type %v", secondaryIndex) + } } switch { @@ -382,10 +426,12 @@ type baseIndexImplBuilder struct { sch sql.PrimaryKeySchema projections []uint64 - sec prolly.Map - secKd, secVd val.TupleDesc - prefDesc val.TupleDesc - ns tree.NodeStore + isProximity bool + sec prolly.Map + proximitySecondary prolly.ProximityMap + secKd, secVd val.TupleDesc + prefDesc val.TupleDesc + ns tree.NodeStore } func (ib *baseIndexImplBuilder) Key() doltdb.DataCacheKey { @@ -414,6 +460,10 @@ func (ib *baseIndexImplBuilder) NewSecondaryIter(strict bool, cnt int, nullSafe // every subsequent point lookup. Note that equality joins can have a mix of // point lookups on concrete values, and range lookups for null matches. func (ib *baseIndexImplBuilder) newPointLookup(ctx *sql.Context, rang prolly.Range) (iter prolly.MapIter, err error) { + if ib.isProximity { + // TODO: It should be possible to do a point lookup with a proximity index. + return nil, fmt.Errorf("can't perform point lookup with a proximity index") + } err = ib.sec.GetPrefix(ctx, rang.Tup, ib.prefDesc, func(key val.Tuple, value val.Tuple) (err error) { if key != nil && rang.Matches(key) { iter = prolly.NewPointLookup(key, value) @@ -430,6 +480,9 @@ func (ib *baseIndexImplBuilder) rangeIter(ctx *sql.Context, part sql.Partition) case pointPartition: return ib.newPointLookup(ctx, p.r) case rangePartition: + if ib.isProximity { + return nil, fmt.Errorf("range iter not allowed for vector index") + } if p.isReverse { return ib.sec.IterRangeReverse(ctx, p.prollyRange) } else { @@ -437,11 +490,25 @@ func (ib *baseIndexImplBuilder) rangeIter(ctx *sql.Context, part sql.Partition) } case DoltgresPartition: return doltgresProllyMapIterator(ctx, ib.secKd, ib.ns, ib.sec.Node(), p.rang) + case vectorPartitionIter: + return nil, fmt.Errorf("ranger iter not allowed for vector partition") default: panic(fmt.Sprintf("unexpected prolly partition type: %T", part)) } } +func (ib *baseIndexImplBuilder) proximityIter(ctx *sql.Context, part vectorPartitionIter) (prolly.MapIter, error) { + candidateVector, err := part.Literal.Eval(ctx, nil) + if err != nil { + return nil, err + } + limit, err := part.Limit.Eval(ctx, nil) + if err != nil { + return nil, err + } + return ib.proximitySecondary.GetClosest(ctx, candidateVector, int(limit.(int64))) +} + // coveringIndexImplBuilder constructs row iters for covering lookups, // where we only need to cursor seek on a single index to both identify // target keys and fill all requested projections @@ -524,6 +591,9 @@ func (ib *coveringIndexImplBuilder) OutputSchema() schema.Schema { // NewRangeMapIter implements IndexScanBuilder func (ib *coveringIndexImplBuilder) NewRangeMapIter(ctx context.Context, r prolly.Range, reverse bool) (prolly.MapIter, error) { + if ib.isProximity { + return nil, fmt.Errorf("range map iter not allowed for vector index") + } if reverse { return ib.sec.IterRangeReverse(ctx, r) } else { @@ -533,13 +603,19 @@ func (ib *coveringIndexImplBuilder) NewRangeMapIter(ctx context.Context, r proll // NewPartitionRowIter implements IndexScanBuilder func (ib *coveringIndexImplBuilder) NewPartitionRowIter(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { - rangeIter, err := ib.rangeIter(ctx, part) + var indexIter prolly.MapIter + var err error + if proximityPartition, ok := part.(vectorPartitionIter); ok { + indexIter, err = ib.proximityIter(ctx, proximityPartition) + } else { + indexIter, err = ib.rangeIter(ctx, part) + } if err != nil { return nil, err } return prollyCoveringIndexIter{ idx: ib.idx, - indexIter: rangeIter, + indexIter: indexIter, keyDesc: ib.secKd, valDesc: ib.secVd, keyMap: ib.keyMap, @@ -611,6 +687,9 @@ func (ib *nonCoveringIndexImplBuilder) OutputSchema() schema.Schema { // NewRangeMapIter implements IndexScanBuilder func (ib *nonCoveringIndexImplBuilder) NewRangeMapIter(ctx context.Context, r prolly.Range, reverse bool) (prolly.MapIter, error) { + if ib.isProximity { + return nil, fmt.Errorf("range map iter not allowed for vector index") + } var secIter prolly.MapIter var err error if reverse { @@ -632,13 +711,19 @@ func (ib *nonCoveringIndexImplBuilder) NewRangeMapIter(ctx context.Context, r pr // NewPartitionRowIter implements IndexScanBuilder func (ib *nonCoveringIndexImplBuilder) NewPartitionRowIter(ctx *sql.Context, part sql.Partition) (sql.RowIter, error) { - rangeIter, err := ib.rangeIter(ctx, part) + var indexIter prolly.MapIter + var err error + if proximityPartition, ok := part.(vectorPartitionIter); ok { + indexIter, err = ib.proximityIter(ctx, proximityPartition) + } else { + indexIter, err = ib.rangeIter(ctx, part) + } if err != nil { return nil, err } return prollyIndexIter{ idx: ib.idx, - indexIter: rangeIter, + indexIter: indexIter, primary: ib.pri, pkBld: ib.pkBld, pkMap: ib.pkMap, diff --git a/go/libraries/doltcore/sqle/index/prolly_row_iter.go b/go/libraries/doltcore/sqle/index/prolly_row_iter.go index 932a9a5723c..54b82b91a43 100644 --- a/go/libraries/doltcore/sqle/index/prolly_row_iter.go +++ b/go/libraries/doltcore/sqle/index/prolly_row_iter.go @@ -39,7 +39,7 @@ type prollyRowIter struct { var _ sql.RowIter = prollyRowIter{} -func NewProllyRowIterForMap(sch schema.Schema, rows prolly.Map, iter prolly.MapIter, projections []uint64) sql.RowIter { +func NewProllyRowIterForMap(sch schema.Schema, rows prolly.MapInterface, iter prolly.MapIter, projections []uint64) sql.RowIter { if projections == nil { projections = sch.GetAllCols().Tags } diff --git a/go/libraries/doltcore/sqle/indexed_dolt_table.go b/go/libraries/doltcore/sqle/indexed_dolt_table.go index a58320ea5a5..03df8870d5e 100644 --- a/go/libraries/doltcore/sqle/indexed_dolt_table.go +++ b/go/libraries/doltcore/sqle/indexed_dolt_table.go @@ -148,6 +148,9 @@ func (t *WritableIndexedDoltTable) LookupBuilder(ctx *sql.Context) (index.IndexS } func (t *WritableIndexedDoltTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) { + if lookup.VectorOrderAndLimit.OrderBy != nil { + return index.NewVectorPartitionIter(lookup) + } return index.NewRangePartitionIter(ctx, t.DoltTable, lookup, t.isDoltFormat) } diff --git a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go index fd174dbfc84..93c262f621e 100644 --- a/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go +++ b/go/libraries/doltcore/sqle/sqlfmt/schema_fmt.go @@ -221,7 +221,7 @@ func GenerateCreateTableIndentedColumnDefinition(col schema.Column, tableCollati // GenerateCreateTableIndexDefinition returns index definition for CREATE TABLE statement with indentation of 2 spaces func GenerateCreateTableIndexDefinition(index schema.Index) string { - return sql.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), index.IsFullText(), false, index.Name(), + return sql.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), index.IsFullText(), index.IsVector(), index.Name(), sql.QuoteIdentifiers(index.ColumnNames()), index.Comment()) } diff --git a/go/libraries/doltcore/sqle/sqlutil/schema.go b/go/libraries/doltcore/sqle/sqlutil/schema.go index b7c8a646325..4cd462ad5ff 100644 --- a/go/libraries/doltcore/sqle/sqlutil/schema.go +++ b/go/libraries/doltcore/sqle/sqlutil/schema.go @@ -56,6 +56,7 @@ func ParseCreateTableStatement(ctx *sql.Context, root doltdb.RootValue, engine * IsUnique: idx.IsUnique(), IsSpatial: idx.IsSpatial(), IsFullText: idx.IsFullText(), + IsVector: idx.IsVector(), Comment: idx.Comment, } name := getIndexName(idx) diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index 7eac3dadf61..e8fb46ea5d1 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -30,6 +30,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" "github.com/dolthub/go-mysql-server/sql/fulltext" sqltypes "github.com/dolthub/go-mysql-server/sql/types" @@ -1974,6 +1975,7 @@ func modifyFulltextIndexesForRewrite(ctx *sql.Context, keyCols fulltext.KeyColum IsUnique: idx.IsUnique(), IsSpatial: idx.IsSpatial(), IsFullText: true, + IsVector: false, IsUserDefined: true, Comment: idx.Comment(), FullTextProperties: ft, @@ -2031,6 +2033,7 @@ func modifyFulltextIndexForColumnDrop(index schema.Index, newSch schema.Schema, IsUnique: index.IsUnique(), IsSpatial: false, IsFullText: true, + IsVector: false, IsUserDefined: index.IsUserDefined(), Comment: index.Comment(), FullTextProperties: index.FullTextProperties(), @@ -2085,9 +2088,11 @@ func modifyIndexesForTableRewrite(ctx *sql.Context, oldSch schema.Schema, oldCol IsUnique: index.IsUnique(), IsSpatial: index.IsSpatial(), IsFullText: index.IsFullText(), + IsVector: index.IsVector(), IsUserDefined: index.IsUserDefined(), Comment: index.Comment(), FullTextProperties: index.FullTextProperties(), + VectorProperties: index.VectorProperties(), }) } @@ -2459,11 +2464,17 @@ func (t *AlterableDoltTable) CreateIndex(ctx *sql.Context, idx sql.IndexDef) err if err := dsess.CheckAccessForDb(ctx, t.db, branch_control.Permissions_Write); err != nil { return err } - if idx.Constraint != sql.IndexConstraint_None && idx.Constraint != sql.IndexConstraint_Unique && idx.Constraint != sql.IndexConstraint_Spatial { + if idx.Constraint != sql.IndexConstraint_None && idx.Constraint != sql.IndexConstraint_Unique && idx.Constraint != sql.IndexConstraint_Spatial && idx.Constraint != sql.IndexConstraint_Vector { return fmt.Errorf("only the following types of index constraints are supported: none, unique, spatial") } - return t.createIndex(ctx, idx, fulltext.KeyColumns{}, fulltext.IndexTableNames{}) + var vectorProperties schema.VectorProperties + if idx.Constraint == sql.IndexConstraint_Vector { + vectorProperties = schema.VectorProperties{ + DistanceType: vector.DistanceL2Squared{}, + } + } + return t.createIndex(ctx, idx, fulltext.KeyColumns{}, fulltext.IndexTableNames{}, vectorProperties) } // DropIndex implements sql.IndexAlterableTable @@ -2547,11 +2558,11 @@ func (t *AlterableDoltTable) CreateFulltextIndex(ctx *sql.Context, idx sql.Index return fmt.Errorf("attempted to create non-FullText index through FullText interface") } - return t.createIndex(ctx, idx, keyCols, tableNames) + return t.createIndex(ctx, idx, keyCols, tableNames, schema.VectorProperties{}) } // createIndex handles the common functionality between CreateIndex and CreateFulltextIndex. -func (t *AlterableDoltTable) createIndex(ctx *sql.Context, idx sql.IndexDef, keyCols fulltext.KeyColumns, tableNames fulltext.IndexTableNames) error { +func (t *AlterableDoltTable) createIndex(ctx *sql.Context, idx sql.IndexDef, keyCols fulltext.KeyColumns, tableNames fulltext.IndexTableNames, vectorProperties schema.VectorProperties) error { columns := make([]string, len(idx.Columns)) for i, indexCol := range idx.Columns { columns[i] = indexCol.Name @@ -2574,6 +2585,7 @@ func (t *AlterableDoltTable) createIndex(ctx *sql.Context, idx sql.IndexDef, key IsUnique: idx.Constraint == sql.IndexConstraint_Unique, IsSpatial: idx.Constraint == sql.IndexConstraint_Spatial, IsFullText: idx.Constraint == sql.IndexConstraint_Fulltext, + IsVector: idx.Constraint == sql.IndexConstraint_Vector, IsUserDefined: true, Comment: idx.Comment, FullTextProperties: schema.FullTextProperties{ @@ -2586,6 +2598,7 @@ func (t *AlterableDoltTable) createIndex(ctx *sql.Context, idx sql.IndexDef, key KeyName: keyCols.Name, KeyPositions: keyPositions, }, + VectorProperties: vectorProperties, }, t.opts) if err != nil { return err @@ -2888,7 +2901,7 @@ func (t *WritableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, sq // CreateIndexForForeignKey implements sql.ForeignKeyTable func (t *AlterableDoltTable) CreateIndexForForeignKey(ctx *sql.Context, idx sql.IndexDef) error { if idx.Constraint != sql.IndexConstraint_None && idx.Constraint != sql.IndexConstraint_Unique && idx.Constraint != sql.IndexConstraint_Spatial { - return fmt.Errorf("only the following types of index constraints are supported: none, unique, spatial") + return fmt.Errorf("only the following types of index constraints are supported for foreign keys: none, unique, spatial") } columns := make([]string, len(idx.Columns)) for i, indexCol := range idx.Columns { @@ -2903,7 +2916,8 @@ func (t *AlterableDoltTable) CreateIndexForForeignKey(ctx *sql.Context, idx sql. ret, err := creation.CreateIndex(ctx, table, t.Name(), idx.Name, columns, allocatePrefixLengths(idx.Columns), schema.IndexProperties{ IsUnique: idx.Constraint == sql.IndexConstraint_Unique, IsSpatial: idx.Constraint == sql.IndexConstraint_Spatial, - IsFullText: idx.Constraint == sql.IndexConstraint_Fulltext, + IsFullText: false, + IsVector: false, IsUserDefined: false, Comment: "", }, t.opts) diff --git a/go/libraries/doltcore/sqle/temp_table.go b/go/libraries/doltcore/sqle/temp_table.go index 6a07a804d02..6b0a21f21d7 100644 --- a/go/libraries/doltcore/sqle/temp_table.go +++ b/go/libraries/doltcore/sqle/temp_table.go @@ -298,7 +298,8 @@ func (t *TempTable) CreateIndex(ctx *sql.Context, idx sql.IndexDef) error { ret, err := creation.CreateIndex(ctx, t.table, t.Name(), idx.Name, cols, allocatePrefixLengths(idx.Columns), schema.IndexProperties{ IsUnique: idx.Constraint == sql.IndexConstraint_Unique, IsSpatial: idx.Constraint == sql.IndexConstraint_Spatial, - IsFullText: idx.Constraint == sql.IndexConstraint_Fulltext, + IsFullText: false, + IsVector: false, IsUserDefined: true, Comment: idx.Comment, }, t.opts) diff --git a/go/libraries/doltcore/table/editor/creation/external_build_index.go b/go/libraries/doltcore/table/editor/creation/external_build_index.go index 878cb456602..285965ae827 100644 --- a/go/libraries/doltcore/table/editor/creation/external_build_index.go +++ b/go/libraries/doltcore/table/editor/creation/external_build_index.go @@ -41,26 +41,29 @@ const ( // single prolly tree materialization by presorting the index keys in an // intermediate file format. func BuildProllyIndexExternal(ctx *sql.Context, vrw types.ValueReadWriter, ns tree.NodeStore, sch schema.Schema, tableName string, idx schema.Index, primary prolly.Map, uniqCb DupEntryCb) (durable.Index, error) { - empty, err := durable.NewEmptyIndexFromTableSchema(ctx, vrw, ns, idx, sch) - if err != nil { - return nil, err - } - secondary := durable.ProllyMapFromIndex(empty) - iter, err := primary.IterAll(ctx) if err != nil { return nil, err } p := primary.Pool() - prefixDesc := secondary.KeyDesc().PrefixDesc(idx.Count()) - secondaryBld, err := index.NewSecondaryKeyBuilder(ctx, tableName, sch, idx, secondary.KeyDesc(), p, secondary.NodeStore()) + keyDesc, _ := idx.Schema().GetMapDescriptors() + if schema.IsKeyless(sch) { + keyDesc = prolly.AddHashToSchema(keyDesc) + } + + prefixDesc := keyDesc.PrefixDesc(idx.Count()) + secondaryBld, err := index.NewSecondaryKeyBuilder(ctx, tableName, sch, idx, keyDesc, p, ns) if err != nil { return nil, err } + if idx.IsVector() { + return BuildProximityIndex(ctx, ns, idx, keyDesc, prefixDesc, iter, secondaryBld, uniqCb) + } + sorter := sort.NewTupleSorter(batchSize, fileMax, func(t1, t2 val.Tuple) bool { - return secondary.KeyDesc().Compare(t1, t2) < 0 + return keyDesc.Compare(t1, t2) < 0 }, tempfiles.MovableTempFileProvider) defer sorter.Close() @@ -98,6 +101,9 @@ func BuildProllyIndexExternal(ctx *sql.Context, vrw types.ValueReadWriter, ns tr } defer it.Close() + empty, err := durable.NewEmptyIndexFromTableSchema(ctx, vrw, ns, idx, sch) + secondary := durable.ProllyMapFromIndex(empty) + tupIter := &tupleIterWithCb{iter: it, prefixDesc: prefixDesc, uniqCb: uniqCb} ret, err := prolly.MutateMapWithTupleIter(ctx, secondary, tupIter) if err != nil { @@ -110,6 +116,48 @@ func BuildProllyIndexExternal(ctx *sql.Context, vrw types.ValueReadWriter, ns tr return durable.IndexFromProllyMap(ret), nil } +// func BuildProximityIndexExternal(ctx *sql.Context, vrw types.ValueReadWriter, ns tree.NodeStore, sch schema.Schema, tableName string, idx schema.Index, primary prolly.Map, uniqCb DupEntryCb) (durable.Index, error) { +func BuildProximityIndex( + ctx *sql.Context, + ns tree.NodeStore, + idx schema.Index, + keyDesc val.TupleDesc, + prefixDesc val.TupleDesc, + iter prolly.MapIter, + secondaryBld index.SecondaryKeyBuilder, + uniqCb DupEntryCb, +) (durable.Index, error) { + // Secondary indexes have no non-key columns + valDesc := val.NewTupleDescriptor() + proximityMapBuilder, err := prolly.NewProximityMapBuilder(ctx, ns, idx.VectorProperties().DistanceType, keyDesc, valDesc, prolly.DefaultLogChunkSize) + if err != nil { + return nil, err + } + for { + k, v, err := iter.Next(ctx) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + + idxKey, err := secondaryBld.SecondaryKeyFromRow(ctx, k, v) + if err != nil { + return nil, err + } + + if uniqCb != nil && prefixDesc.HasNulls(idxKey) { + continue + } + + if err := proximityMapBuilder.Insert(ctx, idxKey, val.EmptyTuple); err != nil { + return nil, err + } + } + proximityMap, err := proximityMapBuilder.Flush(ctx) + return durable.IndexFromProximityMap(proximityMap), nil +} + type tupleIterWithCb struct { iter sort.KeyIter err error diff --git a/go/libraries/doltcore/table/table_iterator.go b/go/libraries/doltcore/table/table_iterator.go index e771539345e..4d95dc3b017 100644 --- a/go/libraries/doltcore/table/table_iterator.go +++ b/go/libraries/doltcore/table/table_iterator.go @@ -59,15 +59,11 @@ func (i rowIterImpl) Close(ctx context.Context) error { // NewTableIterator creates a RowIter that iterates sql.Row's from |idx|. // |offset| can be supplied to read at some start point in |idx|. -func NewTableIterator(ctx context.Context, sch schema.Schema, idx durable.Index, offset uint64) (RowIter, error) { +func NewTableIterator(ctx context.Context, sch schema.Schema, idx durable.Index) (RowIter, error) { var rowItr sql.RowIter if types.IsFormat_DOLT(idx.Format()) { - m := durable.ProllyMapFromIndex(idx) - c, err := m.Count() - if err != nil { - return nil, err - } - itr, err := m.IterOrdinalRange(ctx, offset, uint64(c)) + m := durable.MapFromIndex(idx) + itr, err := m.IterAll(ctx) if err != nil { return nil, err } @@ -78,7 +74,7 @@ func NewTableIterator(ctx context.Context, sch schema.Schema, idx durable.Index, } else { noms := durable.NomsMapFromIndex(idx) - itr, err := noms.IteratorAt(ctx, offset) + itr, err := noms.IteratorAt(ctx, 0) if err != nil { return nil, err } diff --git a/go/libraries/doltcore/table/table_iterator_test.go b/go/libraries/doltcore/table/table_iterator_test.go index 486cce37fc7..7bb5af0f4fb 100644 --- a/go/libraries/doltcore/table/table_iterator_test.go +++ b/go/libraries/doltcore/table/table_iterator_test.go @@ -38,15 +38,12 @@ var testRand = rand.New(rand.NewSource(1)) func TestTableIteratorProlly(t *testing.T) { n := 100 - for i := 0; i < 10; i++ { - offset := testRand.Intn(n) - m, tups := mustMakeProllyMap(t, n) - idx := durable.IndexFromProllyMap(m) - itr, err := NewTableIterator(context.Background(), sch, idx, uint64(offset)) - require.NoError(t, err) - expectedRows := tuplesToRows(t, tups[offset:]) - testIterator(t, itr, expectedRows) - } + m, tups := mustMakeProllyMap(t, n) + idx := durable.IndexFromProllyMap(m) + itr, err := NewTableIterator(context.Background(), sch, idx) + require.NoError(t, err) + expectedRows := tuplesToRows(t, tups) + testIterator(t, itr, expectedRows) } func testIterator(t *testing.T, iter RowIter, expected []sql.Row) { diff --git a/go/serial/fileidentifiers.go b/go/serial/fileidentifiers.go index 628cc4f99b8..1f39f839c28 100644 --- a/go/serial/fileidentifiers.go +++ b/go/serial/fileidentifiers.go @@ -42,6 +42,7 @@ const StashFileID = "STSH" const StatisticFileID = "STAT" const DoltgresRootValueFileID = "DGRV" const TupleFileID = "TUPL" +const VectorIndexNodeFileID = "IVFF" const MessageTypesKind int = 27 diff --git a/go/serial/generate.sh b/go/serial/generate.sh index 91ae999a07f..367f8a857e6 100755 --- a/go/serial/generate.sh +++ b/go/serial/generate.sh @@ -38,7 +38,8 @@ fi table.fbs \ tag.fbs \ tuple.fbs \ - workingset.fbs + workingset.fbs \ + vectorindexnode.fbs # prefix files with copyright header for FILE in $GEN_DIR/*.go; diff --git a/go/serial/schema.fbs b/go/serial/schema.fbs index 05598b0ca3d..449290c698c 100644 --- a/go/serial/schema.fbs +++ b/go/serial/schema.fbs @@ -17,6 +17,10 @@ include "collation.fbs"; namespace serial; +enum DistanceType : uint8 { + Null = 0, + L2_Squared = 1, +} table TableSchema { columns:[Column] (required); @@ -104,6 +108,11 @@ table Index { // fulltext information fulltext_key:bool; fulltext_info:FulltextInfo; + + // vector information + // these fields should be set for vector indexes and otherwise omitted, for backwards compatibility + vector_key:bool; + vector_info:VectorInfo; } table FulltextInfo { @@ -117,6 +126,10 @@ table FulltextInfo { key_positions:[uint16]; } +table VectorInfo { + distance_type:DistanceType; +} + table CheckConstraint { name:string; expression:string; diff --git a/go/serial/vectorindexnode.fbs b/go/serial/vectorindexnode.fbs new file mode 100644 index 00000000000..6705257df19 --- /dev/null +++ b/go/serial/vectorindexnode.fbs @@ -0,0 +1,65 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include "schema.fbs"; + +namespace serial; + +// VectorIndexNode is a node that makes up a vector index. Every key contains a vector value, +// and keys are organized according to their proximity to their parent node. +table VectorIndexNode { + // sorted array of key items + key_items:[ubyte] (required); + // item offsets for |key_items| + // first offset is 0, last offset is len(key_items) + key_offsets:[uint32] (required); + + // array of values items, ordered by paired key + value_items:[ubyte]; + // item offsets for |value_items| + // first offset is 0, last offset is len(value_items) + value_offsets:[uint32]; + + // array of chunk addresses + // - subtree addresses for internal prolly tree nodes + // - value addresses for AddressMap leaf nodes + // note that while the keys in this index are addresses to JSON chunks, we don't store those in the address_array + // because we are guaranteed to have other references to those chunks in the primary index. + address_array:[ubyte] (required); + + // array of varint encoded subtree counts + // see: go/store/prolly/message/varint.go + subtree_counts:[ubyte]; + // total count of prolly tree + tree_count:uint64; + // prolly tree level, 0 for leaf nodes + tree_level:uint8; + + // the base-2 log of the average (geometric mean) number of vectors stored in each node. + // currently this is always set to 8, but other numbers are used in testing, and future versions of dolt + // may choose to use a different size, or even select the best size for each index. + // all nodes in an index must use the same size, and when modifying an existing index, we must use this value. + log_chunk_size:uint8; + + // each node encodes the distance function used for the index. This allows lookups without needing to retrieve the + // distance function from the schema. + distance_type:DistanceType; +} + + +// KEEP THIS IN SYNC WITH fileidentifiers.go +file_identifier "IVFF"; + +root_type VectorIndexNode; + diff --git a/go/store/prolly/message/message.go b/go/store/prolly/message/message.go index 8b9202cd963..21923a6241a 100644 --- a/go/store/prolly/message/message.go +++ b/go/store/prolly/message/message.go @@ -41,6 +41,9 @@ func UnpackFields(msg serial.Message) (fileId string, keys, values ItemAccess, l case serial.ProllyTreeNodeFileID: keys, values, level, count, err = getProllyMapKeysAndValues(msg) return + case serial.VectorIndexNodeFileID: + keys, values, level, count, err = getVectorIndexKeysAndValues(msg) + return case serial.AddressMapFileID: keys, values, level, count, err = getAddressMapKeysAndValues(msg) return @@ -75,6 +78,8 @@ func WalkAddresses(ctx context.Context, msg serial.Message, cb func(ctx context. switch id { case serial.ProllyTreeNodeFileID: return walkProllyMapAddresses(ctx, msg, cb) + case serial.VectorIndexNodeFileID: + return walkVectorIndexAddresses(ctx, msg, cb) case serial.AddressMapFileID: return walkAddressMapAddresses(ctx, msg, cb) case serial.MergeArtifactsFileID: @@ -93,6 +98,8 @@ func GetTreeCount(msg serial.Message) (int, error) { switch id { case serial.ProllyTreeNodeFileID: return getProllyMapTreeCount(msg) + case serial.VectorIndexNodeFileID: + return getVectorIndexTreeCount(msg) case serial.AddressMapFileID: return getAddressMapTreeCount(msg) case serial.MergeArtifactsFileID: diff --git a/go/store/prolly/message/vector_index.go b/go/store/prolly/message/vector_index.go new file mode 100644 index 00000000000..811b8e0b38a --- /dev/null +++ b/go/store/prolly/message/vector_index.go @@ -0,0 +1,229 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "context" + "encoding/binary" + "fmt" + "math" + + fb "github.com/dolthub/flatbuffers/v23/go" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + + "github.com/dolthub/dolt/go/gen/fb/serial" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" +) + +const ( + // These constants are mirrored from serial.VectorIndexNode + // They are only as stable as the flatbuffers schema that define them. + vectorIvfKeyItemBytesVOffset fb.VOffsetT = 4 + vectorIvfKeyOffsetsVOffset fb.VOffsetT = 6 + vectorIvfValueItemBytesVOffset fb.VOffsetT = 8 + vectorIvfValueOffsetsVOffset fb.VOffsetT = 10 + vectorIvfAddressArrayBytesVOffset fb.VOffsetT = 12 +) + +var vectorIvfFileID = []byte(serial.VectorIndexNodeFileID) + +func distanceTypeToEnum(distanceType vector.DistanceType) serial.DistanceType { + switch distanceType.(type) { + case vector.DistanceL2Squared: + return serial.DistanceTypeL2_Squared + } + return serial.DistanceTypeNull +} + +func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8, distanceType vector.DistanceType) VectorIndexSerializer { + return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize, distanceType: distanceType} +} + +type VectorIndexSerializer struct { + pool pool.BuffPool + logChunkSize uint8 + distanceType vector.DistanceType +} + +var _ Serializer = VectorIndexSerializer{} + +func (s VectorIndexSerializer) Serialize(keys, values [][]byte, subtrees []uint64, level int) serial.Message { + var ( + keyTups, keyOffs fb.UOffsetT + valTups, valOffs fb.UOffsetT + refArr, cardArr fb.UOffsetT + ) + + keySz, valSz, bufSz := estimateVectorIndexSize(keys, values, subtrees) + b := getFlatbufferBuilder(s.pool, bufSz) + + // serialize keys and offStart + keyTups = writeItemBytes(b, keys, keySz) + serial.VectorIndexNodeStartKeyOffsetsVector(b, len(keys)+1) + keyOffs = writeItemOffsets32(b, keys, keySz) + + if level == 0 { + // serialize value tuples for leaf nodes + valTups = writeItemBytes(b, values, valSz) + serial.VectorIndexNodeStartValueOffsetsVector(b, len(values)+1) + valOffs = writeItemOffsets32(b, values, valSz) + } else { + // serialize child refs and subtree counts for internal nodes + refArr = writeItemBytes(b, values, valSz) + cardArr = writeCountArray(b, subtrees) + } + + // populate the node's vtable + serial.VectorIndexNodeStart(b) + serial.VectorIndexNodeAddKeyItems(b, keyTups) + serial.VectorIndexNodeAddKeyOffsets(b, keyOffs) + if level == 0 { + serial.VectorIndexNodeAddValueItems(b, valTups) + serial.VectorIndexNodeAddValueOffsets(b, valOffs) + serial.VectorIndexNodeAddTreeCount(b, uint64(len(keys))) + } else { + serial.VectorIndexNodeAddAddressArray(b, refArr) + serial.VectorIndexNodeAddSubtreeCounts(b, cardArr) + serial.VectorIndexNodeAddTreeCount(b, sumSubtrees(subtrees)) + } + serial.VectorIndexNodeAddTreeLevel(b, uint8(level)) + serial.VectorIndexNodeAddLogChunkSize(b, s.logChunkSize) + serial.VectorIndexNodeAddDistanceType(b, distanceTypeToEnum(s.distanceType)) + + return serial.FinishMessage(b, serial.VectorIndexNodeEnd(b), vectorIvfFileID) +} + +func getVectorIndexKeysAndValues(msg serial.Message) (keys, values ItemAccess, level, count uint16, err error) { + keys.offsetSize = OFFSET_SIZE_32 + values.offsetSize = OFFSET_SIZE_32 + var pm serial.VectorIndexNode + err = serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return + } + keys.bufStart = lookupVectorOffset(vectorIvfKeyItemBytesVOffset, pm.Table()) + keys.bufLen = uint32(pm.KeyItemsLength()) + keys.offStart = lookupVectorOffset(vectorIvfKeyOffsetsVOffset, pm.Table()) + keys.offLen = uint32(pm.KeyOffsetsLength() * uint16Size) + + count = uint16(keys.offLen/2) - 1 + level = uint16(pm.TreeLevel()) + + vv := pm.ValueItemsBytes() + if vv != nil { + values.bufStart = lookupVectorOffset(vectorIvfValueItemBytesVOffset, pm.Table()) + values.bufLen = uint32(pm.ValueItemsLength()) + values.offStart = lookupVectorOffset(vectorIvfValueOffsetsVOffset, pm.Table()) + values.offLen = uint32(pm.ValueOffsetsLength() * uint16Size) + } else { + values.bufStart = lookupVectorOffset(vectorIvfAddressArrayBytesVOffset, pm.Table()) + values.bufLen = uint32(pm.AddressArrayLength()) + values.itemWidth = hash.ByteLen + } + return +} + +func walkVectorIndexAddresses(ctx context.Context, msg serial.Message, cb func(ctx context.Context, addr hash.Hash) error) error { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return err + } + arr := pm.AddressArrayBytes() + for i := 0; i < len(arr)/hash.ByteLen; i++ { + addr := hash.New(arr[i*addrSize : (i+1)*addrSize]) + if err := cb(ctx, addr); err != nil { + return err + } + } + + return nil +} + +func getVectorIndexCount(msg serial.Message) (uint16, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, err + } + return uint16(pm.KeyOffsetsLength() - 1), nil +} + +func getVectorIndexTreeLevel(msg serial.Message) (int, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, fb.ErrTableHasUnknownFields + } + return int(pm.TreeLevel()), nil +} + +func getVectorIndexTreeCount(msg serial.Message) (int, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, fb.ErrTableHasUnknownFields + } + return int(pm.TreeCount()), nil +} + +func getVectorIndexSubtrees(msg serial.Message) ([]uint64, error) { + sz, err := getVectorIndexCount(msg) + if err != nil { + return nil, err + } + + var pm serial.VectorIndexNode + n := fb.GetUOffsetT(msg[serial.MessagePrefixSz:]) + err = pm.Init(msg, serial.MessagePrefixSz+n) + if err != nil { + return nil, err + } + + counts := make([]uint64, sz) + + return decodeVarints(pm.SubtreeCountsBytes(), counts), nil +} + +// estimateVectorIndexSize returns the exact Size of the tuple vectors for keys and values, +// and an estimate of the overall Size of the final flatbuffer. +func estimateVectorIndexSize(keys, values [][]byte, subtrees []uint64) (int, int, int) { + var keySz, valSz, bufSz int + for i := range keys { + keySz += len(keys[i]) + valSz += len(values[i]) + } + subtreesSz := len(subtrees) * binary.MaxVarintLen64 + + // constraints enforced upstream + if keySz > math.MaxUint32 { + panic(fmt.Sprintf("key vector exceeds Size limit ( %d > %d )", keySz, math.MaxUint32)) + } + if valSz > math.MaxUint32 { + panic(fmt.Sprintf("value vector exceeds Size limit ( %d > %d )", valSz, math.MaxUint32)) + } + + // The following estimates the final size of the message based on the expected size of the flatbuffer components. + bufSz += keySz + valSz // tuples + bufSz += subtreesSz // subtree counts + bufSz += len(keys)*4 + len(values)*4 // offStart + bufSz += 8 + 1 + 1 + 1 // metadata + bufSz += 72 // vtable (approx) + bufSz += 100 // padding? + bufSz += serial.MessagePrefixSz + + return keySz, valSz, bufSz +} diff --git a/go/store/prolly/proximity_map.go b/go/store/prolly/proximity_map.go new file mode 100644 index 00000000000..467196fbf16 --- /dev/null +++ b/go/store/prolly/proximity_map.go @@ -0,0 +1,589 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "io" + "iter" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" +) + +// ProximityMap wraps a tree.ProximityMap but operates on typed Tuples instead of raw bytestrings. +// A ProximityMap is like a Map, except that walking the tree does not produce a sorted order. Instead, each key +// is stored such that it is closer to its parent key than any of its uncle keys, according to a distance function +// defined on the tree.ProximityMap +type ProximityMap struct { + tuples tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc] + keyDesc val.TupleDesc + valDesc val.TupleDesc + logChunkSize uint8 +} + +// MutateInterface converts the map to a MutableMapInterface +func (m ProximityMap) MutateInterface() MutableMapInterface { + return newProximityMutableMap(m) +} + +func (m ProximityMap) WalkNodes(ctx context.Context, cb tree.NodeCb) error { + return m.tuples.WalkNodes(ctx, cb) +} + +func (m ProximityMap) Node() tree.Node { + return m.tuples.Root +} + +func (m ProximityMap) HashOf() hash.Hash { + return m.tuples.HashOf() +} + +var _ MapInterface = ProximityMap{} + +// Count returns the number of key-value pairs in the Map. +func (m ProximityMap) Count() (int, error) { + return m.tuples.Count() +} + +func (m ProximityMap) Descriptors() (val.TupleDesc, val.TupleDesc) { + return m.keyDesc, m.valDesc +} + +func (m ProximityMap) NodeStore() tree.NodeStore { + return m.tuples.NodeStore +} + +func (m ProximityMap) ValDesc() val.TupleDesc { + return m.valDesc +} + +func (m ProximityMap) KeyDesc() val.TupleDesc { + return m.keyDesc +} + +func (m ProximityMap) Pool() pool.BuffPool { + return m.tuples.NodeStore.Pool() +} + +func (m ProximityMap) IterAll(ctx context.Context) (MapIter, error) { + return m.tuples.IterAll(ctx) +} + +// Get searches for key-value pairs keyed by |query| and passes the results to the callback. +// If |query| is not present in the map, a nil key-value pair are passed. +func (m ProximityMap) Get(ctx context.Context, query val.Tuple, cb tree.KeyValueFn[val.Tuple, val.Tuple]) (err error) { + return m.tuples.Get(ctx, query, cb) +} + +// Has returns true is |key| is present in the Map. +func (m ProximityMap) Has(ctx context.Context, key val.Tuple) (ok bool, err error) { + return m.tuples.Has(ctx, key) +} + +// GetClosest returns a MapIter that produces the |limit| closest key-value pairs to the provided query key. +func (m ProximityMap) GetClosest(ctx context.Context, query interface{}, limit int) (mapIter MapIter, err error) { + kvPairs := make([]kvPair, 0, limit) + cb := func(key val.Tuple, value val.Tuple, distance float64) error { + kvPairs = append(kvPairs, kvPair{key, value}) + return nil + } + err = m.tuples.GetClosest(ctx, query, cb, limit) + if err != nil { + return nil, err + } + return &proximityMapIter{ + m.keyDesc, m.valDesc, kvPairs, 0, + }, nil +} + +type kvPair struct { + key, value val.Tuple +} + +type proximityMapIter struct { + keyDesc, valueDesc val.TupleDesc + kvPairs []kvPair + i int +} + +var _ MapIter = (*proximityMapIter)(nil) + +func (p *proximityMapIter) Next(ctx context.Context) (k val.Tuple, v val.Tuple, err error) { + if p.i >= len(p.kvPairs) { + return nil, nil, io.EOF + } + pair := p.kvPairs[p.i] + k = pair.key + v = pair.value + p.i++ + return +} + +// NewProximityMap creates a new ProximityMap from a supplied root node. +func NewProximityMap(ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType vector.DistanceType, logChunkSize uint8) ProximityMap { + tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{ + Root: node, + NodeStore: ns, + Order: keyDesc, + DistanceType: distanceType, + Convert: func(ctx context.Context, bytes []byte) []float64 { + h, _ := keyDesc.GetJSONAddr(0, bytes) + doc := tree.NewJSONDoc(h, ns) + jsonWrapper, err := doc.ToIndexedJSONDocument(ctx) + if err != nil { + panic(err) + } + floats, err := sql.ConvertToVector(jsonWrapper) + if err != nil { + panic(err) + } + return floats + }, + } + return ProximityMap{ + tuples: tuples, + keyDesc: keyDesc, + valDesc: valDesc, + logChunkSize: logChunkSize, + } +} + +var proximitylevelMapKeyDesc = val.NewTupleDescriptor( + val.Type{Enc: val.Uint8Enc, Nullable: false}, + val.Type{Enc: val.ByteStringEnc, Nullable: false}, +) + +// NewProximityMapBuilder creates a new ProximityMap from a given list of key-value pairs. +func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType vector.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, logChunkSize uint8) (ProximityMapBuilder, error) { + + emptyLevelMap, err := NewMapFromTuples(ctx, ns, proximitylevelMapKeyDesc, valDesc) + if err != nil { + return ProximityMapBuilder{}, err + } + mutableLevelMap := newMutableMap(emptyLevelMap) + return ProximityMapBuilder{ + ns: ns, + vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType), + distanceType: distanceType, + keyDesc: keyDesc, + valDesc: valDesc, + logChunkSize: logChunkSize, + maxLevel: 0, + levelMap: mutableLevelMap, + }, nil +} + +// ProximityMapBuilder is used to create a ProximityMap. +// +// Each node has an average of 2^|logChunkSize| key-value pairs. +// +// The algorithm for building a ProximityMap's tree requires us to start at the root and build out to the leaf nodes. +// Given that our trees are Merkle Trees, this presents an obvious problem. +// Our solution is to create the final tree by applying a series of transformations to intermediate trees. +// +// Note: when talking about tree levels, we use "level" when counting from the leaves, and "depth" when counting +// from the root. In a tree with 5 levels, the root is level 4 (and depth 0), while the leaves are level 0 (and depth 4) +// +// The process looks like this: +// Step 1: Create `levelMap`, a map from (indexLevel, keyBytes) -> values +// - indexLevel: the minimum level in which the vector appears +// - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) +// - values: the ProximityMap value tuple +// +// Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap +// +// The pathMap at depth `i` has the schema (vectorAddrs[1]...vectorAddr[i], keyBytes) -> value +// and contains a row for every vector whose maximum depth is i. +// - vectorAddrs: the path of vectors visited when walking from the root to the maximum depth where the vector appears. +// - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) +// - values: the ProximityMap value tuple +// +// These maps must be built in order, from shallowest to deepest. +// +// Step 3: Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, +// generating Nodes as we go. +// +// Step 1 is accomplished via repeated calls to the Insert method. Steps 2 and 3 are performed when Flush is called. +// +// Currently, the intermediate trees are created using the standard NodeStore. This means that the nodes of these +// trees will inevitably be written out to disk when the NodeStore flushes, despite the fact that we know they +// won't be needed once we finish building the ProximityMap. This could potentially be avoided by creating a +// separate in-memory NodeStore for these values. +type ProximityMapBuilder struct { + ns tree.NodeStore + vectorIndexSerializer message.VectorIndexSerializer + distanceType vector.DistanceType + keyDesc, valDesc val.TupleDesc + logChunkSize uint8 + + maxLevel uint8 + levelMap *MutableMap +} + +// Insert adds a new key-value pair to the ProximityMap under construction. +// It computes the key's level in the proximity map and adds an entry to the |levelMap|. +func (b *ProximityMapBuilder) Insert(ctx context.Context, key, value []byte) error { + keyLevel := tree.DeterministicHashLevel(b.logChunkSize, key) + if keyLevel > b.maxLevel { + b.maxLevel = keyLevel + } + + // We want the index to be sorted by level (descending), so currently we store the level in the map as + // 255 - the actual level. + + // In the future, if MutableMap supports a ReverseIter function, we can use that instead. + levelMapKeyBuilder := val.NewTupleBuilder(proximitylevelMapKeyDesc) + levelMapKeyBuilder.PutUint8(0, 255-keyLevel) + levelMapKeyBuilder.PutByteString(1, key) + return b.levelMap.Put(ctx, levelMapKeyBuilder.Build(b.ns.Pool()), value) +} + +// When set to true, enables an additional check in ProximityMapBuilder.InsertAtLevel. +// This should always be false in production. +const assertProximityMapLevels = false + +// InsertAtLevel inserts into a proximity map when the level for a key is already known +// This is called when an existing tree is being modified, and can skip the level calculation. +func (b *ProximityMapBuilder) InsertAtLevel(ctx context.Context, key, value []byte, keyLevel uint8) error { + if assertProximityMapLevels { + if keyLevel != tree.DeterministicHashLevel(b.logChunkSize, key) { + panic("wrong level") + } + } + + if keyLevel > b.maxLevel { + b.maxLevel = keyLevel + } + levelMapKeyBuilder := val.NewTupleBuilder(proximitylevelMapKeyDesc) + levelMapKeyBuilder.PutUint8(0, 255-keyLevel) + levelMapKeyBuilder.PutByteString(1, key) + return b.levelMap.Put(ctx, levelMapKeyBuilder.Build(b.ns.Pool()), value) +} + +// makeRootNode creates a ProximityMap with a root node constructed from the provided parameters. +func (b *ProximityMapBuilder) makeRootNode(ctx context.Context, keys, values [][]byte, subtrees []uint64, level int) (ProximityMap, error) { + rootMsg := b.vectorIndexSerializer.Serialize(keys, values, subtrees, level) + rootNode, _, err := tree.NodeFromBytes(rootMsg) + if err != nil { + return ProximityMap{}, err + } + _, err = b.ns.Write(ctx, rootNode) + if err != nil { + return ProximityMap{}, err + } + + return NewProximityMap(b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType, b.logChunkSize), nil +} + +// Flush finishes constructing a ProximityMap. Call this after all calls to Insert. +func (b *ProximityMapBuilder) Flush(ctx context.Context) (ProximityMap, error) { + + flushedLevelMap, err := b.levelMap.Map(ctx) + if err != nil { + return ProximityMap{}, err + } + + levelMapSize, err := flushedLevelMap.Count() + if err != nil { + return ProximityMap{}, err + } + + if levelMapSize == 0 { + // Index is empty. + return b.makeRootNode(ctx, nil, nil, nil, 0) + } + + if b.maxLevel == 0 { + // index is a single node. + // assuming that the keys are already sorted, we can return them unmodified. + levelMapIter, err := b.levelMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + var keys, values [][]byte + for { + key, value, err := levelMapIter.Next(ctx) + if err == io.EOF { + break + } + originalKey, _ := proximitylevelMapKeyDesc.GetBytes(1, key) + if err != nil { + return ProximityMap{}, err + } + keys = append(keys, originalKey) + values = append(values, value) + } + return b.makeRootNode(ctx, keys, values, nil, 0) + } + + // Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap + pathMaps, err := b.makePathMaps(ctx, b.levelMap) + if err != nil { + return ProximityMap{}, err + } + + // Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, + // generating Nodes as we go. + return b.makeProximityMapFromPathMaps(ctx, pathMaps) +} + +// makePathMaps creates a set of prolly maps, each of which corresponds to a different level in the to-be-built ProximityMap +func (b *ProximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap *MutableMap) ([]*MutableMap, error) { + levelMapIter, err := mutableLevelMap.IterAll(ctx) + if err != nil { + return nil, err + } + + // The first element of levelMap tells us the height of the tree. + levelMapKey, levelMapValue, err := levelMapIter.Next(ctx) + if err != nil { + return nil, err + } + maxLevel, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + maxLevel = 255 - maxLevel + + // Create every val.TupleBuilder and MutableMap that we will need + // pathMaps[i] is the pathMap for level i (and depth maxLevel - i) + pathMaps, keyTupleBuilder, prefixTupleBuilder, err := b.createInitialPathMaps(ctx, maxLevel) + + // Next, visit each key-value pair in decreasing order of level / increasing order of depth. + // When visiting a pair from depth `i`, we use each of the previous `i` pathMaps to compute a path of `i` index keys. + // This path dictate's that pair's location in the final ProximityMap. + for { + level, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + level = 255 - level // we currently store the level as 255 - the actual level for sorting purposes. + depth := int(maxLevel - level) + + // hashPath is a list of concatenated hashes, representing the sequence of closest vectors at each level of the tree. + var hashPath []byte + keyToInsert, _ := mutableLevelMap.keyDesc.GetBytes(1, levelMapKey) + vectorHashToInsert, _ := b.keyDesc.GetJSONAddr(0, keyToInsert) + vectorToInsert, err := getVectorFromHash(ctx, b.ns, vectorHashToInsert) + if err != nil { + return nil, err + } + // Compute the path that this row will have in the vector index, starting at the root. + // A key-value pair at depth D will have a path D prior keys. + // This path is computed in steps, by performing a lookup in each of the prior pathMaps. + for pathDepth := 0; pathDepth < depth; pathDepth++ { + lookupLevel := int(maxLevel) - pathDepth + pathMap := pathMaps[lookupLevel] + + pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, prefixTupleBuilder, hashPath) + if err != nil { + return nil, err + } + + // Create an iterator that yields every candidate vector + nextCandidate, stopIter := iter.Pull2(func(yield func(hash.Hash, error) bool) { + for { + pathMapKey, _, err := pathMapIter.Next(ctx) + if err == io.EOF { + return + } + if err != nil { + yield(hash.Hash{}, err) + } + originalKey, _ := pathMap.keyDesc.GetBytes(1, pathMapKey) + candidateVectorHash, _ := b.keyDesc.GetJSONAddr(0, originalKey) + yield(candidateVectorHash, nil) + } + }) + defer stopIter() + + closestVectorHash, _ := b.getClosestVector(ctx, vectorToInsert, nextCandidate) + + hashPath = append(hashPath, closestVectorHash[:]...) + } + + // Once we have the path for this key, we turn it into a tuple and add it to the next pathMap. + keyTupleBuilder.PutByteString(0, hashPath) + keyTupleBuilder.PutByteString(1, keyToInsert) + + keyTuple := keyTupleBuilder.Build(b.ns.Pool()) + err = pathMaps[level].Put(ctx, keyTuple, levelMapValue) + if err != nil { + return nil, err + } + + // Since a key that appears at level N also appears at every previous level, we insert into those level maps too + // Since level is unsigned, we can't write `for childLevel > 0` here. + childLevel := level - 1 + if level > 0 { + for { + hashPath = append(hashPath, vectorHashToInsert[:]...) + keyTupleBuilder.PutByteString(0, hashPath) + keyTupleBuilder.PutByteString(1, keyToInsert) + + childKeyTuple := keyTupleBuilder.Build(b.ns.Pool()) + err = pathMaps[childLevel].Put(ctx, childKeyTuple, levelMapValue) + if err != nil { + return nil, err + } + + if childLevel == 0 { + break + } + childLevel-- + } + } + + levelMapKey, levelMapValue, err = levelMapIter.Next(ctx) + if err == io.EOF { + return pathMaps, nil + } + if err != nil { + return nil, err + } + } +} + +// createInitialPathMaps creates a list of MutableMaps that will eventually store a single level of the to-be-built ProximityMap +func (b *ProximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilder, prefixTupleBuilder *val.TupleBuilder, err error) { + pathMaps = make([]*MutableMap, maxLevel+1) + + pathMapKeyDescTypes := []val.Type{{Enc: val.ByteStringEnc, Nullable: false}, {Enc: val.ByteStringEnc, Nullable: false}} + + pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes...) + + emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc) + + keyTupleBuilder = val.NewTupleBuilder(pathMapKeyDesc) + prefixTupleBuilder = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes[0])) + + for i := uint8(0); i <= maxLevel; i++ { + + if err != nil { + return nil, nil, nil, err + } + pathMaps[i] = newMutableMap(emptyPathMap) + } + + return pathMaps, keyTupleBuilder, prefixTupleBuilder, nil +} + +// getNextPathSegmentCandidates takes a list of keys, representing a path into the ProximityMap from the root. +// It returns an iter over all possible keys that could be the next path segment. +func (b *ProximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleBuilder *val.TupleBuilder, currentPath []byte) (MapIter, error) { + prefixTupleBuilder.PutByteString(0, currentPath) + prefixTuple := prefixTupleBuilder.Build(b.ns.Pool()) + + prefixRange := PrefixRange(prefixTuple, prefixTupleBuilder.Desc) + return pathMap.IterRange(ctx, prefixRange) +} + +// getClosestVector iterates over a range of candidate vectors to determine which one is the closest to the target. +func (b *ProximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float64, nextCandidate func() (candidate hash.Hash, err error, valid bool)) (hash.Hash, error) { + // First call to nextCandidate is guaranteed to be valid because there's at least one vector in the set. + // (non-root nodes inherit the first vector from their parent) + candidateVectorHash, err, _ := nextCandidate() + if err != nil { + return hash.Hash{}, err + } + + candidateVector, err := getVectorFromHash(ctx, b.ns, candidateVectorHash) + if err != nil { + return hash.Hash{}, err + } + closestVectorHash := candidateVectorHash + closestDistance, err := b.distanceType.Eval(targetVector, candidateVector) + if err != nil { + return hash.Hash{}, err + } + + for { + candidateVectorHash, err, valid := nextCandidate() + if err != nil { + return hash.Hash{}, err + } + if !valid { + return closestVectorHash, nil + } + candidateVector, err = getVectorFromHash(ctx, b.ns, candidateVectorHash) + if err != nil { + return hash.Hash{}, err + } + candidateDistance, err := b.distanceType.Eval(targetVector, candidateVector) + if err != nil { + return hash.Hash{}, err + } + if candidateDistance < closestDistance { + closestVectorHash = candidateVectorHash + closestDistance = candidateDistance + } + } +} + +// makeProximityMapFromPathMaps builds a ProximityMap from a list of maps, each of which corresponds to a different tree level. +func (b *ProximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context, pathMaps []*MutableMap) (proximityMap ProximityMap, err error) { + maxLevel := len(pathMaps) - 1 + + // We create a chain of vectorIndexChunker objects, with the leaf row at the tail. + // Because the root node has no parent, the logic is slightly different. We don't make a vectorIndexChunker for it. + var chunker *vectorIndexChunker + for _, pathMap := range pathMaps[:maxLevel] { + chunker, err = newVectorIndexChunker(ctx, pathMap, chunker) + if err != nil { + return ProximityMap{}, err + } + } + + rootPathMap := pathMaps[maxLevel] + topLevelPathMapIter, err := rootPathMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + var topLevelKeys [][]byte + var topLevelValues [][]byte + var topLevelSubtrees []uint64 + for { + key, _, err := topLevelPathMapIter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return ProximityMap{}, err + } + originalKey, _ := rootPathMap.keyDesc.GetBytes(1, key) + path, _ := b.keyDesc.GetJSONAddr(0, originalKey) + _, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, path, maxLevel-1, 1, b.keyDesc) + if err != nil { + return ProximityMap{}, err + } + topLevelKeys = append(topLevelKeys, originalKey) + topLevelValues = append(topLevelValues, nodeHash[:]) + topLevelSubtrees = append(topLevelSubtrees, nodeCount) + } + return b.makeRootNode(ctx, topLevelKeys, topLevelValues, topLevelSubtrees, maxLevel) +} + +func getJsonValueFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) (sql.JSONWrapper, error) { + return tree.NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} diff --git a/go/store/prolly/proximity_map_test.go b/go/store/prolly/proximity_map_test.go new file mode 100644 index 00000000000..0b1ef475852 --- /dev/null +++ b/go/store/prolly/proximity_map_test.go @@ -0,0 +1,744 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "fmt" + "io" + "math" + "math/rand" + "os" + "strconv" + "strings" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" +) + +func newJsonValue(t *testing.T, v interface{}) sql.JSONWrapper { + doc, _, err := types.JSON.Convert(v) + require.NoError(t, err) + return doc.(sql.JSONWrapper) +} + +// newJsonDocument creates a JSON value from a provided value. +func newJsonDocument(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) hash.Hash { + doc := newJsonValue(t, v) + root, err := tree.SerializeJsonToAddr(ctx, ns, doc) + require.NoError(t, err) + return root.HashOf() +} + +var testKeyDesc = val.NewTupleDescriptor( + val.Type{Enc: val.JSONAddrEnc, Nullable: true}, +) + +var testValDesc = val.NewTupleDescriptor( + val.Type{Enc: val.Int64Enc, Nullable: true}, +) + +func buildTuple(t *testing.T, ctx context.Context, ns tree.NodeStore, pool pool.BuffPool, desc val.TupleDesc, row []interface{}) val.Tuple { + builder := val.NewTupleBuilder(desc) + for i, column := range row { + err := tree.PutField(ctx, ns, builder, i, column) + require.NoError(t, err) + } + return builder.Build(pool) +} + +func buildTuples(t *testing.T, ctx context.Context, ns tree.NodeStore, pool pool.BuffPool, desc val.TupleDesc, rows [][]interface{}) [][]byte { + result := make([][]byte, len(rows)) + for i, row := range rows { + result[i] = buildTuple(t, ctx, ns, pool, desc, row) + } + return result +} + +func createAndValidateProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, keyDesc val.TupleDesc, keyBytes [][]byte, valueDesc val.TupleDesc, valueBytes [][]byte, logChunkSize uint8) ProximityMap { + m := createProximityMap(t, ctx, ns, keyDesc, keyBytes, valueDesc, valueBytes, logChunkSize) + validateProximityMapSkipHistoryIndependenceCheck(t, ctx, ns, &m, testKeyDesc, testValDesc, keyBytes, valueBytes) + return m +} + +func createProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, keyDesc val.TupleDesc, keyBytes [][]byte, valueDesc val.TupleDesc, valueBytes [][]byte, logChunkSize uint8) ProximityMap { + count := len(keyBytes) + require.Equal(t, count, len(valueBytes)) + + distanceType := vector.DistanceL2Squared{} + + builder, err := NewProximityMapBuilder(ctx, ns, distanceType, keyDesc, valueDesc, logChunkSize) + require.NoError(t, err) + + for i, key := range keyBytes { + value := valueBytes[i] + err = builder.Insert(ctx, key, value) + require.NoError(t, err) + } + + m, err := builder.Flush(ctx) + require.NoError(t, err) + + mapCount, err := m.Count() + require.NoError(t, err) + require.Equal(t, count, mapCount) + + return m +} + +func validateProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, m *ProximityMap, keyDesc, valDesc val.TupleDesc, keys, values [][]byte, logChunkSize uint8) { + validateProximityMapSkipHistoryIndependenceCheck(t, ctx, ns, m, keyDesc, valDesc, keys, values) + validateHistoryIndependence(t, ctx, ns, m, keyDesc, keys, valDesc, values, logChunkSize) +} + +func validateProximityMapSkipHistoryIndependenceCheck(t *testing.T, ctx context.Context, ns tree.NodeStore, m *ProximityMap, keyDesc, valDesc val.TupleDesc, keys, values [][]byte) { + expectedSize := len(keys) + actualSize, err := m.Count() + require.NoError(t, err) + require.Equal(t, expectedSize, actualSize) + // Check that every key and value appears in the map exactly once. + matches := 0 + for i := 0; i < actualSize; i++ { + err = m.Get(ctx, keys[i], func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(keys[i]), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, expectedSize, matches) + + // Check that the invariant holds: each vector is closer to its parent than any of its uncles. + err = tree.WalkNodes(ctx, m.tuples.Root, ns, func(ctx context.Context, nd tree.Node) error { + validateProximityMapNode(t, ctx, ns, nd, vector.DistanceL2Squared{}, keyDesc, valDesc) + return nil + }) + require.NoError(t, err) + + // Finally, build a new map with the supplied keys and values and confirm that it has the same root hash. +} + +func validateHistoryIndependence(t *testing.T, ctx context.Context, ns tree.NodeStore, m *ProximityMap, keyDesc val.TupleDesc, keyBytes [][]byte, valueDesc val.TupleDesc, valueBytes [][]byte, logChunkSize uint8) { + // Build a new map with the supplied keys and values and confirm that it has the same root hash. + other := createProximityMap(t, ctx, ns, keyDesc, keyBytes, valueDesc, valueBytes, logChunkSize) + require.Equal(t, other.HashOf(), m.HashOf()) +} + +func vectorFromKey(t *testing.T, ctx context.Context, ns tree.NodeStore, keyDesc val.TupleDesc, key []byte) []float64 { + vectorHash, _ := keyDesc.GetJSONAddr(0, key) + jsonWrapper, err := getJsonValueFromHash(ctx, ns, vectorHash) + require.NoError(t, err) + floats, err := sql.ConvertToVector(jsonWrapper) + require.NoError(t, err) + return floats +} + +func validateProximityMapNode(t *testing.T, ctx context.Context, ns tree.NodeStore, nd tree.Node, distanceType vector.DistanceType, keyDesc val.TupleDesc, desc val.TupleDesc) { + // For each node, the node's grandchildren should be closer to their parent than the other children. + if nd.Level() == 0 { + // Leaf node + return + } + if nd.Count() <= 1 { + // A node with only one child is trivially valid. + return + } + // Get the vector in each key + vectors := make([][]float64, nd.Count()) + for vectorIdx := 0; vectorIdx < nd.Count(); vectorIdx++ { + vectorKey := nd.GetKey(vectorIdx) + vectors[vectorIdx] = vectorFromKey(t, ctx, ns, keyDesc, vectorKey) + } + for childIdx := 0; childIdx < nd.Count(); childIdx++ { + // Get the child node + childHash := hash.New(nd.GetValue(childIdx)) + childNode, err := ns.Read(ctx, childHash) + require.NoError(t, err) + for childKeyIdx := 0; childKeyIdx < childNode.Count(); childKeyIdx++ { + childVectorKey := childNode.GetKey(childKeyIdx) + childVector := vectorFromKey(t, ctx, ns, keyDesc, childVectorKey) + minDistance := math.MaxFloat64 + closestKeyIdx := -1 + for otherChildIdx := 0; otherChildIdx < nd.Count(); otherChildIdx++ { + distance, err := distanceType.Eval(childVector, vectors[otherChildIdx]) + require.NoError(t, err) + if distance < minDistance { + minDistance = distance + closestKeyIdx = otherChildIdx + } + } + require.Equal(t, closestKeyIdx, childIdx) + } + } +} + +func TestEmptyProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + createAndValidateProximityMap(t, ctx, ns, testKeyDesc, nil, testValDesc, nil, 10) +} + +func TestSingleEntryProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, [][]interface{}{{"[1.0]"}}) + values := buildTuples(t, ctx, ns, pb, testValDesc, [][]interface{}{{int64(1)}}) + createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 10) +} + +func TestDoubleEntryProximityMapGetExact(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows := [][]interface{}{{"[0.0, 6.0]"}, {"[3.0, 4.0]"}} + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + + valueRows := [][]interface{}{{int64(1)}, {int64(2)}} + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 10) + matches := 0 + for i, key := range keys { + err := m.Get(ctx, key, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestDoubleEntryProximityMapGetClosest(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows := [][]interface{}{{"[0.0, 6.0]"}, {"[3.0, 4.0]"}} + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + + valueRows := [][]interface{}{{int64(1)}, {int64(2)}} + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 10) + + matches := 0 + + mapIter, err := m.GetClosest(ctx, newJsonValue(t, "[0.0, 0.0]"), 1) + require.NoError(t, err) + for { + k, v, err := mapIter.Next(ctx) + if err == io.EOF { + break + } + require.NoError(t, err) + require.Equal(t, val.Tuple(keys[1]), k) + require.Equal(t, val.Tuple(values[1]), v) + matches++ + } + + require.NoError(t, err) + require.Equal(t, matches, 1) +} + +func TestProximityMapGetManyClosest(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows := [][]interface{}{ + {"[0.0, 0.0]"}, + {"[0.0, 10.0]"}, + {"[10.0, 10.0]"}, + {"[10.0, 0.0]"}, + } + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + + valueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 10) + + queryVector := "[3.0, 1.0]" + sortOrder := []int{0, 3, 1, 2} // indexes in sorted order: [0.0, 0.0], [10.0, 0.0], [0.0, 10.0], [10.0, 10.0] + + for limit := 0; limit <= 4; limit++ { + t.Run(fmt.Sprintf("limit %d", limit), func(t *testing.T) { + matches := 0 + + mapIter, err := m.GetClosest(ctx, newJsonValue(t, queryVector), limit) + require.NoError(t, err) + for { + k, v, err := mapIter.Next(ctx) + if err == io.EOF { + break + } + require.NoError(t, err) + require.Equal(t, val.Tuple(keys[sortOrder[matches]]), k) + require.Equal(t, val.Tuple(values[sortOrder[matches]]), v) + matches++ + } + require.NoError(t, err) + require.Equal(t, limit, matches) + }) + } +} + +func TestProximityMapWithOverflowNode(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + // Create an index with enough rows that it can't fit in a single physical chunk + keyRows := make([][]interface{}, 0, 4000) + valueRows := make([][]interface{}, 0, 4000) + + for i := int64(0); i < 4000; i++ { + keyRows = append(keyRows, []interface{}{fmt.Sprintf("[%d]", i)}) + valueRows = append(valueRows, []interface{}{i}) + } + + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + // Set logChunkSize to a high enough value that everything goes in a single chunk + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 16) + + count, err := m.Count() + require.NoError(t, err) + require.Equal(t, 4000, count) +} + +func TestMultilevelProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + + valueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 1) + matches := 0 + for i, key := range keys { + err := m.Get(ctx, key, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestLargerMultilevelProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + {"[9.0, 10.0]"}, + {"[11.0, 12.0]"}, + {"[13.0, 14.0]"}, + {"[15.0, 16.0]"}, + } + keys := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows) + + valueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}, {int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}} + values := buildTuples(t, ctx, ns, pb, testValDesc, valueRows) + + m := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys, testValDesc, values, 1) + matches := 0 + for i, key := range keys { + err := m.Get(ctx, key, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestInsertOrderIndependence(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyRows1 := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + keys1 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows1) + + valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) + + keyRows2 := [][]interface{}{ + {"[7.0, 8.0]"}, + {"[5.0, 6.0]"}, + {"[3.0, 4.0]"}, + {"[0.0, 1.0]"}, + } + keys2 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows2) + + valueRows2 := [][]interface{}{{int64(4)}, {int64(3)}, {int64(2)}, {int64(1)}} + values2 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows2) + + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1) + m2 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys2, testValDesc, values2, 1) + + if !assert.Equal(t, m1.tuples.Root.HashOf(), m2.tuples.Root.HashOf(), "trees have different hashes") { + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m1.tuples.Root)) + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m2.tuples.Root)) + } +} + +func TestIncrementalInserts(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} + keyRows1 := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + keys1 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows1) + + valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) + + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) + + l1 := m1.tuples.Root.Level() + _ = l1 + mutableMap := newProximityMutableMap(m1) + + keyRows2 := [][]interface{}{ + {"[9.0, 10.0]"}, + {"[11.0, 12.0]"}, + {"[13.0, 14.0]"}, + {"[15.0, 16.0]"}, + } + keys2 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows2) + + valueRows2 := [][]interface{}{{int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}} + values2 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows2) + + for i, key := range keys2 { + err := mutableMap.Put(ctx, key, values2[i]) + require.NoError(t, err) + } + + // Check that map looks how we expect. + newMap, err := flusher.Map(ctx, mutableMap) + require.NoError(t, err) + + l2 := m1.tuples.Root.Level() + _ = l2 + + combinedKeyRows := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + {"[9.0, 10.0]"}, + {"[11.0, 12.0]"}, + {"[13.0, 14.0]"}, + {"[15.0, 16.0]"}, + } + combinedKeys := buildTuples(t, ctx, ns, pb, testKeyDesc, combinedKeyRows) + + combinedValueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}, {int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}} + combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) + + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) +} + +func TestIncrementalUpdates(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} + keyRows1 := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + keys1 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows1) + + valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) + + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) + + mutableMap := newProximityMutableMap(m1) + + bp := pool.NewBuffPool() + + keyBuilder := val.NewTupleBuilder(testKeyDesc) + valueBuilder := val.NewTupleBuilder(testValDesc) + + // update leaf node + { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, "[0.0, 1.0]")) + nextKey := keyBuilder.Build(bp) + + valueBuilder.PutInt64(0, 5) + nextValue := valueBuilder.Build(bp) + + err := mutableMap.Put(ctx, nextKey, nextValue) + require.NoError(t, err) + + newMap, err := flusher.Map(ctx, mutableMap) + require.NoError(t, err) + + newCount, err := newMap.Count() + require.NoError(t, err) + + require.Equal(t, 4, newCount) + + // validate + + combinedKeyRows := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + combinedKeys := buildTuples(t, ctx, ns, pb, testKeyDesc, combinedKeyRows) + combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(3)}, {int64(4)}} + combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) + + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) + } + + // update root node + { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, "[5.0, 6.0]")) + nextKey := keyBuilder.Build(bp) + + valueBuilder.PutInt64(0, 6) + nextValue := valueBuilder.Build(bp) + + err := mutableMap.Put(ctx, nextKey, nextValue) + require.NoError(t, err) + + newMap, err := flusher.Map(ctx, mutableMap) + require.NoError(t, err) + + combinedKeyRows := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + combinedKeys := buildTuples(t, ctx, ns, pb, testKeyDesc, combinedKeyRows) + combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(6)}, {int64(4)}} + combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) + + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) + + } +} + +func TestIncrementalDeletes(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} + keyRows1 := [][]interface{}{ + {"[0.0, 1.0]"}, + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + keys1 := buildTuples(t, ctx, ns, pb, testKeyDesc, keyRows1) + + valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} + values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) + + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) + + mutableMap := newProximityMutableMap(m1) + + bp := pool.NewBuffPool() + + keyBuilder := val.NewTupleBuilder(testKeyDesc) + + // delete leaf node + { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, "[0.0, 1.0]")) + nextKey := keyBuilder.Build(bp) + + err := mutableMap.Put(ctx, nextKey, nil) + require.NoError(t, err) + + newMap, err := flusher.Map(ctx, mutableMap) + require.NoError(t, err) + + combinedKeyRows := [][]interface{}{ + {"[3.0, 4.0]"}, + {"[5.0, 6.0]"}, + {"[7.0, 8.0]"}, + } + combinedKeys := buildTuples(t, ctx, ns, pb, testKeyDesc, combinedKeyRows) + combinedValueRows := [][]interface{}{{int64(2)}, {int64(3)}, {int64(4)}} + combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) + + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) + + } + + // delete root node + { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, "[5.0, 6.0]")) + nextKey := keyBuilder.Build(bp) + + err := mutableMap.Put(ctx, nextKey, nil) + require.NoError(t, err) + + newMap, err := flusher.Map(ctx, mutableMap) + require.NoError(t, err) + + combinedKeyRows := [][]interface{}{ + {"[3.0, 4.0]"}, + {"[7.0, 8.0]"}, + } + combinedKeys := buildTuples(t, ctx, ns, pb, testKeyDesc, combinedKeyRows) + combinedValueRows := [][]interface{}{{int64(2)}, {int64(4)}} + combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) + + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) + + } +} + +// As part of the algorithm for building proximity maps, we store the map keys as bytestrings in a temporary table. +// The sorting order of a key is not always the same as the lexographic ordering of these bytestrings. +// This test makes sure that even when this is not the case we still generate correct output. +func TestNonlexographicKey(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + pb := pool.NewBuffPool() + + keyDesc := val.NewTupleDescriptor( + val.Type{Enc: val.JSONAddrEnc, Nullable: true}, + val.Type{Enc: val.Int64Enc, Nullable: true}, + ) + + valDesc := val.NewTupleDescriptor() + + keyRows := [][]interface{}{ + {"[0.0, 0.0]", int64(4 + 0*256)}, + {"[0.0, 0.0]", int64(3 + 1*256)}, + {"[0.0, 0.0]", int64(2 + 2*256)}, + {"[0.0, 0.0]", int64(1 + 3*256)}, + {"[0.0, 0.0]", int64(0 + 4*256)}, + } + keys := buildTuples(t, ctx, ns, pb, keyDesc, keyRows) + + valueRows := [][]interface{}{{}, {}, {}, {}, {}} + values := buildTuples(t, ctx, ns, pb, valDesc, valueRows) + + // The way the validation test is currently written it assumes that all vectors are unique, but this is not a + // requirement. Skip validation for now. + _ = createProximityMap(t, ctx, ns, keyDesc, keys, valDesc, values, 1) +} + +func TestManyDimensions(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + numRows := 50 + dimensions := 50 + testManyDimensions(ctx, t, ns, numRows, dimensions) +} + +func testManyDimensions(ctx context.Context, t *testing.T, ns tree.NodeStore, numRows int, dimensions int) { + pb := pool.NewBuffPool() + keyDesc := val.NewTupleDescriptor( + val.Type{Enc: val.JSONAddrEnc, Nullable: true}, + val.Type{Enc: val.Int64Enc, Nullable: true}, + ) + + valDesc := val.NewTupleDescriptor() + + t.Run(fmt.Sprintf("numRows = %d, dimensions = %d", numRows, dimensions), func(t *testing.T) { + keyRows := make([][]interface{}, numRows) + valueRows := make([][]interface{}, numRows) + for i := 0; i < numRows; i++ { + keyRows[i] = []interface{}{makeManyDimensionalVector(dimensions, int64(i)), i} + valueRows[i] = []interface{}{} + } + keys := buildTuples(t, ctx, ns, pb, keyDesc, keyRows) + values := buildTuples(t, ctx, ns, pb, keyDesc, valueRows) + + _ = createAndValidateProximityMap(t, ctx, ns, keyDesc, keys, valDesc, values, 3) + }) +} + +func makeManyDimensionalVector(dimensions int, seed int64) interface{} { + var builder strings.Builder + rng := rand.New(rand.NewSource(seed)) + + builder.WriteRune('[') + if dimensions > 0 { + + builder.WriteString(strconv.Itoa(rng.Int())) + for d := 1; d < dimensions; d++ { + builder.WriteRune(',') + builder.WriteString(strconv.Itoa(rng.Int())) + } + } + builder.WriteRune(']') + return builder.String() +} diff --git a/go/store/prolly/proximity_mutable_map.go b/go/store/prolly/proximity_mutable_map.go new file mode 100644 index 00000000000..961f53699d6 --- /dev/null +++ b/go/store/prolly/proximity_mutable_map.go @@ -0,0 +1,460 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + + "github.com/dolthub/dolt/go/gen/fb/serial" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/skip" + "github.com/dolthub/dolt/go/store/val" +) + +type ProximityMutableMap = GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]] + +type ProximityFlusher struct { + logChunkSize uint8 + distanceType vector.DistanceType +} + +var _ MutableMapFlusher[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]] = ProximityFlusher{} + +func (f ProximityFlusher) ApplyMutationsWithSerializer( + ctx context.Context, + serializer message.Serializer, + mutableMap *GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]], +) (tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc], error) { + // Identify what parts of the tree need to be rebuilt: + // For each edit, identify the node closest to the root that is affected. + // Then, walk the tree creating a new one. + // In order to skip walking parts of the tree that aren't modified, we need to know when a node + // has no edits in any of its children. + // We can have a cursor that fast-forwards to the affected value. + // - How does this work with inserts? + // Do it recursively, starting with the root. Sort each edit into the affected child node (or the current node). + // If the current node if affected, rebuild. + // Otherwise visit each child node. + keyDesc := mutableMap.keyDesc + valDesc := mutableMap.valDesc + ns := mutableMap.NodeStore() + convert := func(ctx context.Context, bytes []byte) []float64 { + h, _ := keyDesc.GetJSONAddr(0, bytes) + doc := tree.NewJSONDoc(h, ns) + jsonWrapper, err := doc.ToIndexedJSONDocument(ctx) + if err != nil { + panic(err) + } + floats, err := sql.ConvertToVector(jsonWrapper) + if err != nil { + panic(err) + } + return floats + } + edits := make([]VectorIndexKV, 0, mutableMap.tuples.Edits.Count()) + editIter := mutableMap.tuples.Mutations() + key, value := editIter.NextMutation(ctx) + maxEditLevel := uint8(0) + for key != nil { + keyLevel := tree.DeterministicHashLevel(f.logChunkSize, key) + if keyLevel > maxEditLevel { + maxEditLevel = keyLevel + } + edits = append(edits, VectorIndexKV{ + key: key, + value: value, + level: int(keyLevel), + }) + key, value = editIter.NextMutation(ctx) + } + var newRoot tree.Node + var err error + root := mutableMap.tuples.Static.Root + distanceType := mutableMap.tuples.Static.DistanceType + if root.Count() == 0 { + // Original index was empty. We need to make a new index based on the edits. + newRoot, err = makeNewProximityMap(ctx, ns, edits, distanceType, keyDesc, valDesc, f.logChunkSize) + } else if maxEditLevel >= uint8(root.Level()) { + // The root node has changed, or there may be a new level to the tree. We need to rebuild the tree. + newRoot, _, err = f.rebuildNode(ctx, ns, root, edits, distanceType, keyDesc, valDesc, maxEditLevel) + } else { + newRoot, _, err = f.visitNode(ctx, serializer, ns, root, edits, convert, distanceType, keyDesc, valDesc) + + } + if err != nil { + return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{}, err + } + return tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{ + Root: newRoot, + NodeStore: ns, + DistanceType: distanceType, + Convert: convert, + Order: keyDesc, + }, nil +} + +type VectorIndexKV struct { + key, value tree.Item + level int +} + +type childEditList struct { + edits []VectorIndexKV + mustRebuild bool +} + +func makeNewProximityMap( + ctx context.Context, + ns tree.NodeStore, + edits []VectorIndexKV, + distanceType vector.DistanceType, + keyDesc val.TupleDesc, + valDesc val.TupleDesc, + logChunkSize uint8, +) (newNode tree.Node, err error) { + proximityMapBuilder, err := NewProximityMapBuilder(ctx, ns, distanceType, keyDesc, valDesc, logChunkSize) + if err != nil { + return tree.Node{}, err + } + for _, edit := range edits { + // If the original index was empty, then all edits are inserts. + if edit.key != nil { + err = proximityMapBuilder.InsertAtLevel(ctx, edit.key, edit.value, uint8(edit.level)) + if err != nil { + return tree.Node{}, err + } + } + } + proximityMap, err := proximityMapBuilder.Flush(ctx) + if err != nil { + return tree.Node{}, err + } + + return proximityMap.Node(), nil +} + +// visitNode produces a new tree.Node that incorporates the provided edits to the provided node. +// As a precondition, we have confirmed that the keys in the provided node will not change, but the +// keys in children nodes might. If the keys in a child node would change, we call rebuildNode on that child. +// Otherwise, we recursively called visitNode on the children. +func (f ProximityFlusher) visitNode( + ctx context.Context, + serializer message.Serializer, + ns tree.NodeStore, + node tree.Node, + edits []VectorIndexKV, + convert func(context.Context, []byte) []float64, + distanceType vector.DistanceType, + keyDesc val.TupleDesc, + valDesc val.TupleDesc, +) (newNode tree.Node, subtrees int, err error) { + var keys [][]byte + var values [][]byte + var nodeSubtrees []uint64 + + if node.IsLeaf() { + keys, values, nodeSubtrees = f.rebuildLeafNodeWithEdits(node, edits, keyDesc) + } else { + // sort the list of edits based on which child node contains them. + childEdits := make(map[int]childEditList) + for _, edit := range edits { + key := edit.key + editVector := convert(ctx, key) + level := edit.level + // visit each child in the node to determine which is closest + closestIdx := 0 + childKey := node.GetKey(0) + closestDistance, err := distanceType.Eval(convert(ctx, childKey), editVector) + if err != nil { + return tree.Node{}, 0, err + } + for i := 1; i < node.Count(); i++ { + childKey = node.GetKey(i) + newDistance, err := distanceType.Eval(convert(ctx, childKey), editVector) + if err != nil { + return tree.Node{}, 0, err + } + if newDistance < closestDistance { + closestDistance = newDistance + closestIdx = i + } + } + childEditList := childEdits[closestIdx] + childEditList.edits = append(childEditList.edits, edit) + if level == node.Level()-1 { + childEditList.mustRebuild = true + } + childEdits[closestIdx] = childEditList + } + // Recursively build the new tree. + // We need keys, values, subtrees, and levels. + for i := 0; i < node.Count(); i++ { + childKey := node.GetKey(i) + keys = append(keys, childKey) + childValue := node.GetValue(i) + + childEditList := childEdits[i] + if len(childEditList.edits) == 0 { + // No edits affected this node, leave it as is. + values = append(values, childValue) + } else { + childNodeAddress := hash.New(childValue) + childNode, err := ns.Read(ctx, childNodeAddress) + if err != nil { + return tree.Node{}, 0, err + } + var newChildNode tree.Node + var childSubtrees int + if childEditList.mustRebuild { + newChildNode, childSubtrees, err = f.rebuildNode(ctx, ns, childNode, childEditList.edits, distanceType, keyDesc, valDesc, uint8(childNode.Level())) + } else { + newChildNode, childSubtrees, err = f.visitNode(ctx, serializer, ns, childNode, childEditList.edits, convert, distanceType, keyDesc, valDesc) + } + + if err != nil { + return tree.Node{}, 0, err + } + newChildAddress := newChildNode.HashOf() + + values = append(values, newChildAddress[:]) + nodeSubtrees = append(nodeSubtrees, uint64(childSubtrees)) + } + } + } + newNode, err = serializeVectorIndexNode(ctx, serializer, ns, keys, values, nodeSubtrees, node.Level()) + if err != nil { + return tree.Node{}, 0, err + } + subtrees, err = newNode.TreeCount() + if err != nil { + return tree.Node{}, 0, err + } + return newNode, subtrees, err +} + +func serializeVectorIndexNode( + ctx context.Context, + serializer message.Serializer, + ns tree.NodeStore, + keys [][]byte, + values [][]byte, + nodeSubtrees []uint64, + level int, +) (tree.Node, error) { + msg := serializer.Serialize(keys, values, nodeSubtrees, level) + newNode, fileId, err := tree.NodeFromBytes(msg) + if err != nil { + return tree.Node{}, err + } + + if fileId != serial.VectorIndexNodeFileID { + return tree.Node{}, fmt.Errorf("expected file id %s, received %s", serial.VectorIndexNodeFileID, fileId) + } + _, err = ns.Write(ctx, newNode) + return newNode, err +} + +// rebuildLeafNodeWithEdits creates a new leaf node by applying a list of edits to an existing node. +func (f ProximityFlusher) rebuildLeafNodeWithEdits( + originalNode tree.Node, + edits []VectorIndexKV, + keyDesc val.TupleDesc, +) (keys [][]byte, values [][]byte, nodeSubtrees []uint64) { + // combine edits with node keys. Use merge sort. + + editIdx := 0 + nodeIdx := 0 + for editIdx < len(edits) || nodeIdx < originalNode.Count() { + // Edit doesn't match an existing key: it must be an insert. + if editIdx >= len(edits) { + keys = append(keys, originalNode.GetKey(nodeIdx)) + values = append(values, originalNode.GetValue(nodeIdx)) + nodeSubtrees = append(nodeSubtrees, 0) + nodeIdx++ + continue + } + if nodeIdx >= originalNode.Count() { + keys = append(keys, edits[editIdx].key) + values = append(values, edits[editIdx].value) + nodeSubtrees = append(nodeSubtrees, 0) + editIdx++ + continue + } + editKey := val.Tuple(edits[editIdx].key) + nodeKey := val.Tuple(originalNode.GetKey(nodeIdx)) + cmp := keyDesc.Compare(editKey, nodeKey) + if cmp < 0 { + //edit comes first + // Edit doesn't match an existing key: it must be an insert. + keys = append(keys, edits[editIdx].key) + values = append(values, edits[editIdx].value) + nodeSubtrees = append(nodeSubtrees, 0) + editIdx++ + continue + } + if cmp > 0 { + // node comes first + keys = append(keys, originalNode.GetKey(nodeIdx)) + values = append(values, originalNode.GetValue(nodeIdx)) + nodeSubtrees = append(nodeSubtrees, 0) + nodeIdx++ + continue + } + // edit to an existing key. + newValue := edits[editIdx].value + editIdx++ + nodeIdx++ + if newValue == nil { + // This is a delete. We simply skip to the next key, excluding this key from the new node. + continue + } + keys = append(keys, editKey) + values = append(values, newValue) + nodeSubtrees = append(nodeSubtrees, 0) + } + return +} + +var DefaultLogChunkSize = uint8(8) + +func (f ProximityFlusher) rebuildNode(ctx context.Context, ns tree.NodeStore, node tree.Node, edits []VectorIndexKV, distanceType vector.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, maxLevel uint8) (newNode tree.Node, subtrees int, err error) { + + proximityMapBuilder, err := NewProximityMapBuilder(ctx, ns, distanceType, keyDesc, valDesc, f.logChunkSize) + if err != nil { + return tree.Node{}, 0, err + } + editSkipList := skip.NewSkipList(func(left, right []byte) int { + return keyDesc.Compare(left, right) + }) + for _, edit := range edits { + editSkipList.Put(edit.key, edit.value) + } + + insertFromNode := func(nd tree.Node, i int) error { + key := nd.GetKey(i) + value := nd.GetValue(i) + _, hasNewVal := editSkipList.Get(key) + if !hasNewVal { + // TODO: Is it faster if we fetch the level from the current tree? + keyLevel := tree.DeterministicHashLevel(f.logChunkSize, key) + if keyLevel > maxLevel { + keyLevel = maxLevel + } + err = proximityMapBuilder.InsertAtLevel(ctx, key, value, keyLevel) + if err != nil { + return err + } + } + return nil + } + + var walk func(nd tree.Node) error + walk = func(nd tree.Node) (err error) { + + if nd.IsLeaf() { + for i := 0; i < nd.Count(); i++ { + err = insertFromNode(nd, i) + if err != nil { + return err + } + } + } else { + + for i := 0; i < nd.Count(); i++ { + childAddr := hash.New(nd.GetValue(i)) + if i != 0 { + // walkLevel = nd.Level() + } + child, err := ns.Read(ctx, childAddr) + if err != nil { + return err + } + err = walk(child) + } + } + + return nil + } + + err = walk(node) + if err != nil { + return tree.Node{}, 0, err + } + for _, edit := range edits { + key := edit.key + value := edit.value + if value != nil { + err = proximityMapBuilder.Insert(ctx, key, value) + if err != nil { + return tree.Node{}, 0, err + } + } + } + newMap, err := proximityMapBuilder.Flush(ctx) + if err != nil { + return tree.Node{}, 0, err + } + newRoot := newMap.tuples.Root + newTreeCount, err := newRoot.TreeCount() + if err != nil { + return tree.Node{}, 0, err + } + return newRoot, newTreeCount, nil +} + +func (f ProximityFlusher) GetDefaultSerializer(ctx context.Context, mutableMap *GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]]) message.Serializer { + return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize, f.distanceType) +} + +// newMutableMap returns a new MutableMap. +func newProximityMutableMap(m ProximityMap) *ProximityMutableMap { + return &ProximityMutableMap{ + tuples: m.tuples.Mutate(), + keyDesc: m.keyDesc, + valDesc: m.valDesc, + maxPending: defaultMaxPending, + flusher: ProximityFlusher{logChunkSize: m.logChunkSize, distanceType: m.tuples.DistanceType}, + } +} + +func (f ProximityFlusher) MapInterface(ctx context.Context, mut *ProximityMutableMap) (MapInterface, error) { + return f.Map(ctx, mut) +} + +// TreeMap materializes all pending and applied mutations in the MutableMap. +func (f ProximityFlusher) TreeMap(ctx context.Context, mut *ProximityMutableMap) (tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc], error) { + s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize, f.distanceType) + return mut.flushWithSerializer(ctx, s) +} + +// TreeMap materializes all pending and applied mutations in the MutableMap. +func (f ProximityFlusher) Map(ctx context.Context, mut *ProximityMutableMap) (ProximityMap, error) { + treeMap, err := f.TreeMap(ctx, mut) + if err != nil { + return ProximityMap{}, err + } + return ProximityMap{ + tuples: treeMap, + keyDesc: mut.keyDesc, + valDesc: mut.valDesc, + logChunkSize: f.logChunkSize, + }, nil +} diff --git a/go/store/prolly/shim/shim.go b/go/store/prolly/shim/shim.go index c788fda3c12..f51b6f48db5 100644 --- a/go/store/prolly/shim/shim.go +++ b/go/store/prolly/shim/shim.go @@ -15,6 +15,12 @@ package shim import ( + "context" + "fmt" + + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + + "github.com/dolthub/dolt/go/gen/fb/serial" "github.com/dolthub/dolt/go/libraries/doltcore/schema" "github.com/dolthub/dolt/go/store/prolly" "github.com/dolthub/dolt/go/store/prolly/tree" @@ -31,7 +37,10 @@ func ValueFromMap(m prolly.MapInterface) types.Value { } func MapFromValue(v types.Value, sch schema.Schema, ns tree.NodeStore, isKeylessSecondary bool) (prolly.Map, error) { - root, _, err := NodeFromValue(v) + root, fileId, err := NodeFromValue(v) + if fileId == serial.VectorIndexNodeFileID { + return prolly.Map{}, fmt.Errorf("can't make a prolly.Map from a vector index node") + } if err != nil { return prolly.Map{}, err } @@ -43,8 +52,8 @@ func MapFromValue(v types.Value, sch schema.Schema, ns tree.NodeStore, isKeyless return prolly.NewMap(root, ns, kd, vd), nil } -func MapInterfaceFromValue(v types.Value, sch schema.Schema, ns tree.NodeStore, isKeylessSecondary bool) (prolly.MapInterface, error) { - root, _, err := NodeFromValue(v) +func MapInterfaceFromValue(ctx context.Context, v types.Value, sch schema.Schema, ns tree.NodeStore, isKeylessSecondary bool) (prolly.MapInterface, error) { + root, fileId, err := NodeFromValue(v) if err != nil { return nil, err } @@ -53,13 +62,29 @@ func MapInterfaceFromValue(v types.Value, sch schema.Schema, ns tree.NodeStore, kd = prolly.AddHashToSchema(kd) } vd := sch.GetValueDescriptor() - return prolly.NewMap(root, ns, kd, vd), nil + switch fileId { + case serial.VectorIndexNodeFileID: + // TODO: We should read the distance function and chunk size from the message. + // Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written, + // but this may not be true in the future. + return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil + default: + return prolly.NewMap(root, ns, kd, vd), nil + } } func MapFromValueWithDescriptors(v types.Value, kd, vd val.TupleDesc, ns tree.NodeStore) (prolly.MapInterface, error) { - root, _, err := NodeFromValue(v) + root, fileId, err := NodeFromValue(v) if err != nil { return prolly.Map{}, err } - return prolly.NewMap(root, ns, kd, vd), nil + switch fileId { + case serial.VectorIndexNodeFileID: + // TODO: We should read the distance function and chunk size from the message. + // Currently, vector.DistanceL2Squared{} and prolly.DefaultLogChunkSize are the only values that can be written, + // but this may not be true in the future. + return prolly.NewProximityMap(ns, root, kd, vd, vector.DistanceL2Squared{}, prolly.DefaultLogChunkSize), nil + default: + return prolly.NewMap(root, ns, kd, vd), nil + } } diff --git a/go/store/prolly/tree/node_splitter.go b/go/store/prolly/tree/node_splitter.go index 5714dc0bd2a..db569db8ff3 100644 --- a/go/store/prolly/tree/node_splitter.go +++ b/go/store/prolly/tree/node_splitter.go @@ -25,6 +25,7 @@ import ( "crypto/sha512" "encoding/binary" "math" + "math/bits" "github.com/kch42/buzhash" "github.com/zeebo/xxh3" @@ -264,3 +265,10 @@ func saltFromLevel(level uint8) (salt uint64) { full := sha512.Sum512([]byte{level}) return binary.LittleEndian.Uint64(full[:8]) } + +// DeterministicHashLevel takes a key and counts the number of leading zeros in the key's hash. +// This is used for computing the level that a key appears in, in a ProximityMap +func DeterministicHashLevel(leadingZerosPerLevel uint8, key Item) uint8 { + h := xxHash32(key, levelSalt[1]) + return uint8(bits.LeadingZeros32(h)) / leadingZerosPerLevel +} diff --git a/go/store/prolly/tree/proximity_map.go b/go/store/prolly/tree/proximity_map.go new file mode 100644 index 00000000000..db0906ccdc6 --- /dev/null +++ b/go/store/prolly/tree/proximity_map.go @@ -0,0 +1,279 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "container/heap" + "context" + "math" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" + "github.com/esote/minmaxheap" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/skip" +) + +type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error + +// ProximityMap is a static Prolly Tree where the position of a key in the tree is based on proximity, as opposed to a traditional ordering. +// O provides the ordering only within a node. +type ProximityMap[K, V ~[]byte, O Ordering[K]] struct { + Root Node + NodeStore NodeStore + DistanceType vector.DistanceType + Convert func(context.Context, []byte) []float64 + Order O +} + +func (t ProximityMap[K, V, O]) GetRoot() Node { + return t.Root +} + +func (t ProximityMap[K, V, O]) GetNodeStore() NodeStore { + return t.NodeStore +} + +func (t ProximityMap[K, V, O]) GetPrefix(ctx context.Context, query K, prefixOrder O, cb KeyValueFn[K, V]) (err error) { + //TODO implement me + panic("implement me") +} + +func (t ProximityMap[K, V, O]) HasPrefix(ctx context.Context, query K, prefixOrder O) (ok bool, err error) { + //TODO implement me + panic("implement me") +} + +func (t ProximityMap[K, V, O]) Mutate() MutableMap[K, V, O, ProximityMap[K, V, O]] { + return MutableMap[K, V, O, ProximityMap[K, V, O]]{ + Edits: skip.NewSkipList(func(left, right []byte) int { + return t.Order.Compare(left, right) + }), + Static: t, + } +} + +func (t ProximityMap[K, V, O]) IterKeyRange(ctx context.Context, start, stop K) (*OrderedTreeIter[K, V], error) { + panic("Not implemented") +} + +func (t ProximityMap[K, V, O]) Count() (int, error) { + return t.Root.TreeCount() +} + +func (t ProximityMap[K, V, O]) Height() int { + return t.Root.Level() + 1 +} + +func (t ProximityMap[K, V, O]) HashOf() hash.Hash { + return t.Root.HashOf() +} + +func (t ProximityMap[K, V, O]) WalkAddresses(ctx context.Context, cb AddressCb) error { + return WalkAddresses(ctx, t.Root, t.NodeStore, cb) +} + +func (t ProximityMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error { + return WalkNodes(ctx, t.Root, t.NodeStore, cb) +} + +// Get searches for an exact vector in the index, calling |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) Get(ctx context.Context, query K, cb KeyValueFn[K, V]) (err error) { + nd := t.Root + + queryVector := t.Convert(ctx, query) + + // Find the child with the minimum distance. + + for { + var closestKey K + var closestIdx int + distance := math.Inf(1) + + for i := 0; i < int(nd.count); i++ { + k := nd.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector) + if err != nil { + return err + } + if newDistance < distance { + closestIdx = i + distance = newDistance + closestKey = []byte(k) + } + } + + if nd.IsLeaf() { + return cb(closestKey, []byte(nd.GetValue(closestIdx))) + } + + nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + if err != nil { + return err + } + } +} + +func (t ProximityMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err error) { + err = t.Get(ctx, query, func(_ K, _ V) error { + ok = true + return nil + }) + return ok, err +} + +type DistancePriorityHeapElem struct { + key Item + value Item + distance float64 +} + +type DistancePriorityHeap []DistancePriorityHeapElem + +var _ heap.Interface = (*DistancePriorityHeap)(nil) + +func newNodePriorityHeap(capacity int) DistancePriorityHeap { + // Allocate one extra slot: whenever this fills we remove the max element. + return make(DistancePriorityHeap, 0, capacity+1) +} + +func (n DistancePriorityHeap) Len() int { + return len(n) +} + +func (n DistancePriorityHeap) Less(i, j int) bool { + return n[i].distance < n[j].distance +} + +func (n DistancePriorityHeap) Swap(i, j int) { + n[i], n[j] = n[j], n[i] +} + +func (n *DistancePriorityHeap) Push(x any) { + *n = append(*n, x.(DistancePriorityHeapElem)) +} + +func (n *DistancePriorityHeap) Pop() any { + length := len(*n) + last := (*n)[length-1] + *n = (*n)[:length-1] + return last +} + +func (n *DistancePriorityHeap) Insert(key Item, value Item, distance float64) { + minmaxheap.Push(n, DistancePriorityHeapElem{ + key: key, + value: value, + distance: distance, + }) + if len(*n) == cap(*n) { + minmaxheap.PopMax(n) + } +} + +// GetClosest performs an approximate nearest neighbors search. It finds |limit| vectors that are close to the query vector, +// and calls |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}, cb KeyValueDistanceFn[K, V], limit int) (err error) { + if limit == 0 { + return nil + } + + queryVector, err := sql.ConvertToVector(query) + if err != nil { + return err + } + + // |nodes| holds the current candidates for closest vectors, up to |limit| + nodes := newNodePriorityHeap(limit) + + for i := 0; i < int(t.Root.count); i++ { + k := t.Root.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector) + if err != nil { + return err + } + nodes.Insert(k, t.Root.GetValue(i), newDistance) + } + + for level := t.Root.Level() - 1; level >= 0; level-- { + // visit each candidate node at the current level, building a priority list of candidates for the next level. + nextLevelNodes := newNodePriorityHeap(limit) + + for _, keyAndDistance := range nodes { + address := keyAndDistance.value + + node, err := fetchChild(ctx, t.NodeStore, hash.New(address)) + if err != nil { + return err + } + // TODO: We don't need to recompute the distance when visiting the same key as the parent. + for i := 0; i < int(node.count); i++ { + k := node.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(ctx, k), queryVector) + if err != nil { + return err + } + nextLevelNodes.Insert(k, node.GetValue(i), newDistance) + } + } + nodes = nextLevelNodes + } + + for nodes.Len() > 0 { + node := minmaxheap.Pop(&nodes).(DistancePriorityHeapElem) + err := cb([]byte(node.key), []byte(node.value), node.distance) + if err != nil { + return err + } + } + + return nil +} + +func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) { + c, err := newCursorAtStart(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + s, err := newCursorPastEnd(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + stop := func(curr *cursor) bool { + return curr.compare(s) >= 0 + } + + if stop(c) { + // empty range + return &OrderedTreeIter[K, V]{curr: nil}, nil + } + + return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil +} + +func getJsonValueFromHash(ctx context.Context, ns NodeStore, h hash.Hash) (interface{}, error) { + return NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} diff --git a/go/store/prolly/vector_index_chunker.go b/go/store/prolly/vector_index_chunker.go new file mode 100644 index 00000000000..638c0497020 --- /dev/null +++ b/go/store/prolly/vector_index_chunker.go @@ -0,0 +1,116 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "io" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" +) + +// vectorIndexChunker is a stateful chunker that iterates over |pathMap|, a map that contains an element +// for every key-value pair for a given level of a ProximityMap, and provides the path of keys to reach +// that pair from the root. It uses this iterator to build each of the ProximityMap nodes for that level. +type vectorIndexChunker struct { + pathMap *MutableMap + pathMapIter MapIter + lastPathSegment hash.Hash + lastKey []byte + lastValue []byte + lastSubtreeCount uint64 + childChunker *vectorIndexChunker + atEnd bool +} + +func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, childChunker *vectorIndexChunker) (*vectorIndexChunker, error) { + pathMapIter, err := pathMap.IterAll(ctx) + if err != nil { + return nil, err + } + firstKey, firstValue, err := pathMapIter.Next(ctx) + if err == io.EOF { + // In rare situations, there aren't any vectors at a given level. + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + atEnd: true, + }, nil + } + if err != nil { + return nil, err + } + path, _ := pathMap.keyDesc.GetBytes(0, firstKey) + lastPathSegment := hash.New(path[len(path)-20:]) + originalKey, _ := pathMap.keyDesc.GetBytes(1, firstKey) + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + lastKey: originalKey, + lastValue: firstValue, + lastPathSegment: lastPathSegment, + atEnd: false, + }, nil +} + +func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment hash.Hash, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) { + var indexMapKeys [][]byte + var indexMapValues [][]byte + var indexMapSubtrees []uint64 + subtreeSum := uint64(0) + + for { + if c.atEnd || c.lastPathSegment != parentPathSegment { + msg := serializer.Serialize(indexMapKeys, indexMapValues, indexMapSubtrees, level) + node, _, err := tree.NodeFromBytes(msg) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + nodeHash, err := ns.Write(ctx, node) + return node, subtreeSum, nodeHash, err + } + vectorHash, _ := originalKeyDesc.GetJSONAddr(0, c.lastKey) + if c.childChunker != nil { + _, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, vectorHash, level-1, depth+1, originalKeyDesc) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + c.lastValue = nodeHash[:] + indexMapSubtrees = append(indexMapSubtrees, childCount) + subtreeSum += childCount + } else { + subtreeSum++ + } + indexMapKeys = append(indexMapKeys, c.lastKey) + indexMapValues = append(indexMapValues, c.lastValue) + + nextKey, nextValue, err := c.pathMapIter.Next(ctx) + if err == io.EOF { + c.atEnd = true + } else if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } else { + lastPath, _ := c.pathMap.keyDesc.GetBytes(0, nextKey) + c.lastPathSegment = hash.New(lastPath[len(lastPath)-20:]) + c.lastKey, _ = c.pathMap.keyDesc.GetBytes(1, nextKey) + c.lastValue = nextValue + } + } +} diff --git a/go/store/types/serial_message.go b/go/store/types/serial_message.go index 29c55e84241..50ee98c3261 100644 --- a/go/store/types/serial_message.go +++ b/go/store/types/serial_message.go @@ -342,6 +342,17 @@ func (sm SerialMessage) HumanReadableStringAtIndentationLevel(level int) string _ = OutputProllyNodeBytes(ret, serial.Message(sm)) + // deleting everything from a vector index turns it into a regular index? + level -= 1 + printWithIndendationLevel(level, ret, "}") + return ret.String() + case serial.VectorIndexNodeFileID: + ret := &strings.Builder{} + printWithIndendationLevel(level, ret, "Vector Index {\n") + level++ + + _ = OutputVectorIndexNodeBytes(ret, serial.Message(sm)) + level -= 1 printWithIndendationLevel(level, ret, "}") return ret.String() @@ -493,6 +504,66 @@ func OutputProllyNodeBytes(w io.Writer, msg serial.Message) error { return nil } +func OutputVectorIndexNodeBytes(w io.Writer, msg serial.Message) error { + fileId, keys, values, treeLevel, count, err := message.UnpackFields(msg) + if fileId != serial.VectorIndexNodeFileID { + return fmt.Errorf("unexpected file ID, expected %s, got %s", serial.VectorIndexNodeFileID, fileId) + } + if err != nil { + return err + } + isLeaf := treeLevel == 0 + + for i := 0; i < int(count); i++ { + k := keys.GetItem(i, msg) + kt := val.Tuple(k) + + w.Write([]byte("\n { key: ")) + // The first key of a vector index is always a vector, which right now is JSON and thus addressable. + // This may not always be true in the future. + + for j := 0; j < kt.Count(); j++ { + if j == 0 { + ref := hash.New(kt.GetField(0)) + + w.Write([]byte(" #")) + w.Write([]byte(ref.String())) + continue + } + if j > 0 { + w.Write([]byte(", ")) + } + + w.Write([]byte(hex.EncodeToString(kt.GetField(j)))) + } + + if isLeaf { + v := values.GetItem(i, msg) + vt := val.Tuple(v) + + w.Write([]byte(" value: ")) + for j := 0; j < vt.Count(); j++ { + if j > 0 { + w.Write([]byte(", ")) + } + field := vt.GetField(j) + w.Write([]byte(hex.EncodeToString(field))) + } + + w.Write([]byte(" }")) + } else { + ref := hash.New(values.GetItem(i, msg)) + + w.Write([]byte(" ref: #")) + w.Write([]byte(ref.String())) + w.Write([]byte(" }")) + } + } + + w.Write([]byte("\n")) + return nil +} + func (sm SerialMessage) Less(ctx context.Context, nbf *NomsBinFormat, other LesserValuable) (bool, error) { if v2, ok := other.(SerialMessage); ok { return bytes.Compare(sm, v2) == -1, nil diff --git a/integration-tests/bats/helper/local-remote.bash b/integration-tests/bats/helper/local-remote.bash index 0de595a44f0..2b686150a83 100644 --- a/integration-tests/bats/helper/local-remote.bash +++ b/integration-tests/bats/helper/local-remote.bash @@ -166,3 +166,10 @@ teardown_remote_server() { stop_sql_server fi } + +skip_if_remote() { + if [ "$SQL_ENGINE" = "remote-engine" ]; + then + skip + fi +} \ No newline at end of file diff --git a/integration-tests/bats/vector-index.bats b/integration-tests/bats/vector-index.bats new file mode 100644 index 00000000000..5693b7d8628 --- /dev/null +++ b/integration-tests/bats/vector-index.bats @@ -0,0 +1,432 @@ +#!/usr/bin/env bats +load $BATS_TEST_DIRNAME/helper/common.bash + +setup() { + setup_common + + dolt sql <"$[1]", v1->"$[0]") WHERE pk1 >= 3; +SQL + run dolt index cat onepk idx_v1 -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "v1,pk1" ]] || false + [[ "$output" =~ '"[11,55]",2' ]] || false + [[ "$output" =~ '"[54,22]",4' ]] || false + [[ "$output" =~ '"[53,77]",5' ]] || false + [[ "$output" =~ '"[52,88]",3' ]] || false + [[ "$output" =~ '"[99,51]",1' ]] || false + [[ "${#lines[@]}" = "6" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[53,77]') LIMIT 1" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + [[ "$output" =~ "5" ]] || false + [[ "${#lines[@]}" = "2" ]] || false +} + +@test "vector-index: INSERT then DELETE some" { + dolt sql <<'SQL' +CREATE VECTOR INDEX idx_v1 ON onepk(v1); +INSERT INTO onepk VALUES (1, '[99, 51]'), (2, '[11, 55]'), (3, '[88, 52]'), (4, '[22, 54]'), (5, '[77, 53]'); +DELETE FROM onepk WHERE v1->>"$[0]" % 2 = 0; +SQL + run dolt index cat onepk idx_v1 -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "v1,pk1" ]] || false + [[ "$output" =~ '"[99,51]",1' ]] || false + [[ "$output" =~ '"[11,55]",2' ]] || false + [[ "$output" =~ '"[77,53]",5' ]] || false + [[ "${#lines[@]}" = "4" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[77,53]') LIMIT 1;" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + [[ "$output" =~ "5" ]] || false + [[ "${#lines[@]}" = "2" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[22,54]') LIMIT 1;" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + ! [[ "$output" =~ "4" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "${#lines[@]}" = "2" ]] || false +} + +@test "vector-index: INSERT then DELETE all" { + dolt sql <<'SQL' +CREATE VECTOR INDEX idx_v1 ON onepk(v1); +INSERT INTO onepk VALUES (1, '[99, 51]'), (2, '[11, 55]'), (3, '[88, 52]'), (4, '[22, 54]'), (5, '[77, 53]'); +DELETE FROM onepk WHERE v1->>"$[0]" = 99; +SQL + run dolt index cat onepk idx_v1 -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "v1,pk1" ]] || false + [[ "$output" =~ '"[11,55]",2' ]] || false + [[ "$output" =~ '"[88,52]",3' ]] || false + [[ "$output" =~ '"[22,54]",4' ]] || false + [[ "$output" =~ '"[77,53]",5' ]] || false + [[ "${#lines[@]}" = "5" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[22,54]') LIMIT 1;" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + [[ "$output" =~ "4" ]] || false + [[ "${#lines[@]}" = "2" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[99,51]') LIMIT 1;" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "${#lines[@]}" = "2" ]] || false + + dolt sql <<'SQL' +DELETE FROM onepk; +SQL + run dolt index cat onepk idx_v1 -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "v1,pk1" ]] || false + [[ "${#lines[@]}" = "1" ]] || false + run dolt sql -q "SELECT pk1 FROM onepk ORDER BY VEC_DISTANCE(v1, '[99,51]') LIMIT 1;" -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "pk1" ]] || false + [[ "${#lines[@]}" = "1" ]] || false +} + +@test "vector-index: CREATE INDEX with same name" { + dolt sql <<'SQL' +INSERT INTO onepk VALUES (1, '[99, 51]'), (2, '[11, 55]'), (3, '[88, 52]'), (4, '[22, 54]'), (5, '[77, 53]'); +CREATE VECTOR INDEX idx_v1 ON onepk(v1); +SQL + run dolt sql -q "CREATE VECTOR INDEX idx_v1 ON onepk(v2)" + [ "$status" -eq "1" ] + run dolt index ls onepk + [ "$status" -eq "0" ] + [[ "$output" =~ "idx_v1(v1)" ]] || false + # Found bug where the above would error, yet somehow wipe the index table + run dolt index cat onepk idx_v1 -r=csv + [ "$status" -eq "0" ] + [[ "$output" =~ "v1,pk1" ]] || false + [[ "$output" =~ '"[88,52]",3' ]] || false + [[ "$output" =~ '"[99,51]",1' ]] || false + [[ "${#lines[@]}" = "6" ]] || false + run dolt schema show onepk + [ "$status" -eq "0" ] + [[ "$output" =~ 'VECTOR KEY `idx_v1` (`v1`)' ]] || false +} + +@test "vector-index: CREATE INDEX with same columns" { + dolt sql <<'SQL' +INSERT INTO onepk VALUES (1, '[99, 51]'), (2, '[11, 55]'), (3, '[88, 52]'), (4, '[22, 54]'), (5, '[77, 53]'); +CREATE VECTOR INDEX idx_v1 ON onepk(v1); +SQL + run dolt sql -q "CREATE VECTOR INDEX idx_v1_dup ON onepk(v1)" + [ "$status" -eq "0" ] + run dolt index ls onepk + [ "$status" -eq "0" ] + [[ "$output" =~ "idx_v1_dup(v1)" ]] || false + run dolt schema show onepk + [ "$status" -eq "0" ] + [[ "$output" =~ 'VECTOR KEY `idx_v1_dup` (`v1`)' ]] || false +} + +@test "vector-index: Disallow 'dolt_' name prefix" { + run dolt sql -q "CREATE VECTOR INDEX dolt_idx_v1 ON onepk(v1)" + [ "$status" -eq "1" ] + run dolt sql -q "ALTER TABLE onepk ADD INDEX dolt_idx_v1 (v1)" + [ "$status" -eq "1" ] +} + +@test "vector-index: DROP INDEX" { + dolt sql <