diff --git a/firestore/document.go b/firestore/document.go index 8d63e23edd33..cc384db8409e 100644 --- a/firestore/document.go +++ b/firestore/document.go @@ -97,6 +97,7 @@ func (d *DocumentSnapshot) Data() map[string]interface{} { // Slices are resized to the incoming value's size, while arrays that are too // long have excess elements filled with zero values. If the array is too short, // excess incoming values will be dropped. +// - Vectors convert to []float64 // - Maps convert to map[string]interface{}. When setting a struct field, // maps of key type string and any value type are permitted, and are populated // recursively. diff --git a/firestore/from_value.go b/firestore/from_value.go index df68465a9e57..8ff05c7410bd 100644 --- a/firestore/from_value.go +++ b/firestore/from_value.go @@ -32,63 +32,63 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error { return setReflectFromProtoValue(v.Elem(), vproto, c) } -// setReflectFromProtoValue sets v from a Firestore Value. -// v must be a settable value. -func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error { +// setReflectFromProtoValue sets vDest from a Firestore Value. +// vDest must be a settable value. +func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error { typeErr := func() error { - return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto)) + return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc)) } - val := vproto.ValueType + valTypeSrc := vprotoSrc.ValueType // A Null value sets anything nullable to nil, and has no effect // on anything else. - if _, ok := val.(*pb.Value_NullValue); ok { - switch v.Kind() { + if _, ok := valTypeSrc.(*pb.Value_NullValue); ok { + switch vDest.Kind() { case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice: - v.Set(reflect.Zero(v.Type())) + vDest.Set(reflect.Zero(vDest.Type())) } return nil } // Handle special types first. - switch v.Type() { + switch vDest.Type() { case typeOfByteSlice: - x, ok := val.(*pb.Value_BytesValue) + x, ok := valTypeSrc.(*pb.Value_BytesValue) if !ok { return typeErr() } - v.SetBytes(x.BytesValue) + vDest.SetBytes(x.BytesValue) return nil case typeOfGoTime: - x, ok := val.(*pb.Value_TimestampValue) + x, ok := valTypeSrc.(*pb.Value_TimestampValue) if !ok { return typeErr() } if err := x.TimestampValue.CheckValid(); err != nil { return err } - v.Set(reflect.ValueOf(x.TimestampValue.AsTime())) + vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime())) return nil case typeOfProtoTimestamp: - x, ok := val.(*pb.Value_TimestampValue) + x, ok := valTypeSrc.(*pb.Value_TimestampValue) if !ok { return typeErr() } - v.Set(reflect.ValueOf(x.TimestampValue)) + vDest.Set(reflect.ValueOf(x.TimestampValue)) return nil case typeOfLatLng: - x, ok := val.(*pb.Value_GeoPointValue) + x, ok := valTypeSrc.(*pb.Value_GeoPointValue) if !ok { return typeErr() } - v.Set(reflect.ValueOf(x.GeoPointValue)) + vDest.Set(reflect.ValueOf(x.GeoPointValue)) return nil case typeOfDocumentRef: - x, ok := val.(*pb.Value_ReferenceValue) + x, ok := valTypeSrc.(*pb.Value_ReferenceValue) if !ok { return typeErr() } @@ -96,157 +96,172 @@ func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) erro if err != nil { return err } - v.Set(reflect.ValueOf(dr)) + vDest.Set(reflect.ValueOf(dr)) + return nil + + case typeOfVector32: + val, err := vector32FromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) + return nil + case typeOfVector64: + val, err := vector64FromProtoValue(vprotoSrc) + if err != nil { + return err + } + vDest.Set(reflect.ValueOf(val)) return nil } - switch v.Kind() { + switch vDest.Kind() { case reflect.Bool: - x, ok := val.(*pb.Value_BooleanValue) + x, ok := valTypeSrc.(*pb.Value_BooleanValue) if !ok { return typeErr() } - v.SetBool(x.BooleanValue) + vDest.SetBool(x.BooleanValue) case reflect.String: - x, ok := val.(*pb.Value_StringValue) + x, ok := valTypeSrc.(*pb.Value_StringValue) if !ok { return typeErr() } - v.SetString(x.StringValue) + vDest.SetString(x.StringValue) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: var i int64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_IntegerValue: i = x.IntegerValue case *pb.Value_DoubleValue: f := x.DoubleValue i = int64(f) if float64(i) != f { - return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type()) + return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type()) } default: return typeErr() } - if v.OverflowInt(i) { - return overflowErr(v, i) + if vDest.OverflowInt(i) { + return overflowErr(vDest, i) } - v.SetInt(i) + vDest.SetInt(i) case reflect.Uint8, reflect.Uint16, reflect.Uint32: var u uint64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_IntegerValue: u = uint64(x.IntegerValue) case *pb.Value_DoubleValue: f := x.DoubleValue u = uint64(f) if float64(u) != f { - return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type()) + return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type()) } default: return typeErr() } - if v.OverflowUint(u) { - return overflowErr(v, u) + if vDest.OverflowUint(u) { + return overflowErr(vDest, u) } - v.SetUint(u) + vDest.SetUint(u) case reflect.Float32, reflect.Float64: var f float64 - switch x := val.(type) { + switch x := valTypeSrc.(type) { case *pb.Value_DoubleValue: f = x.DoubleValue case *pb.Value_IntegerValue: f = float64(x.IntegerValue) if int64(f) != x.IntegerValue { - return overflowErr(v, x.IntegerValue) + return overflowErr(vDest, x.IntegerValue) } default: return typeErr() } - if v.OverflowFloat(f) { - return overflowErr(v, f) + if vDest.OverflowFloat(f) { + return overflowErr(vDest, f) } - v.SetFloat(f) + vDest.SetFloat(f) case reflect.Slice: - x, ok := val.(*pb.Value_ArrayValue) + x, ok := valTypeSrc.(*pb.Value_ArrayValue) if !ok { return typeErr() } vals := x.ArrayValue.Values - vlen := v.Len() + vlen := vDest.Len() xlen := len(vals) // Make a slice of the right size, avoiding allocation if possible. switch { case vlen < xlen: - v.Set(reflect.MakeSlice(v.Type(), xlen, xlen)) + vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen)) case vlen > xlen: - v.SetLen(xlen) + vDest.SetLen(xlen) } - return populateRepeated(v, vals, xlen, c) + return populateRepeated(vDest, vals, xlen, c) case reflect.Array: - x, ok := val.(*pb.Value_ArrayValue) + x, ok := valTypeSrc.(*pb.Value_ArrayValue) if !ok { return typeErr() } vals := x.ArrayValue.Values xlen := len(vals) - vlen := v.Len() + vlen := vDest.Len() minlen := vlen // Set extra elements to their zero value. if vlen > xlen { - z := reflect.Zero(v.Type().Elem()) + z := reflect.Zero(vDest.Type().Elem()) for i := xlen; i < vlen; i++ { - v.Index(i).Set(z) + vDest.Index(i).Set(z) } minlen = xlen } - return populateRepeated(v, vals, minlen, c) + return populateRepeated(vDest, vals, minlen, c) case reflect.Map: - x, ok := val.(*pb.Value_MapValue) + x, ok := valTypeSrc.(*pb.Value_MapValue) if !ok { return typeErr() } - return populateMap(v, x.MapValue.Fields, c) + return populateMap(vDest, x.MapValue.Fields, c) case reflect.Ptr: // If the pointer is nil, set it to a zero value. - if v.IsNil() { - v.Set(reflect.New(v.Type().Elem())) + if vDest.IsNil() { + vDest.Set(reflect.New(vDest.Type().Elem())) } - return setReflectFromProtoValue(v.Elem(), vproto, c) + return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c) case reflect.Struct: - x, ok := val.(*pb.Value_MapValue) + x, ok := valTypeSrc.(*pb.Value_MapValue) if !ok { return typeErr() } - return populateStruct(v, x.MapValue.Fields, c) + return populateStruct(vDest, x.MapValue.Fields, c) case reflect.Interface: - if v.NumMethod() == 0 { // empty interface + if vDest.NumMethod() == 0 { // empty interface // If v holds a pointer, set the pointer. - if !v.IsNil() && v.Elem().Kind() == reflect.Ptr { - return setReflectFromProtoValue(v.Elem(), vproto, c) + if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr { + return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c) } // Otherwise, create a fresh value. - x, err := createFromProtoValue(vproto, c) + x, err := createFromProtoValue(vprotoSrc, c) if err != nil { return err } - v.Set(reflect.ValueOf(x)) + vDest.Set(reflect.ValueOf(x)) return nil } // Any other kind of interface is an error. fallthrough default: - return fmt.Errorf("firestore: cannot set type %s", v.Type()) + return fmt.Errorf("firestore: cannot set type %s", vDest.Type()) } return nil } @@ -389,8 +404,15 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) { } ret[k] = r } - return ret, nil + typeVal, ok := ret[typeKey] + if !ok || typeVal != typeValVector { + // Map is not a vector. Return the map + return ret, nil + } + + // Special handling for vector + return vectorFromProtoValue(vproto) default: return nil, fmt.Errorf("firestore: unknown value type %T", v) } diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 36c710378b8e..bfa690d5524e 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -75,13 +75,13 @@ const ( ) var ( - iClient *Client - iAdminClient *apiv1.FirestoreAdminClient - iColl *CollectionRef - collectionIDs = uid.NewSpace("go-integration-test", nil) - wantDBPath string - indexNames []string - testParams map[string]interface{} + iClient *Client + iAdminClient *apiv1.FirestoreAdminClient + iColl *CollectionRef + collectionIDs = uid.NewSpace("go-integration-test", nil) + wantDBPath string + testParams map[string]interface{} + seededFirstIndex bool ) func initIntegrationTest() { @@ -146,16 +146,68 @@ func initIntegrationTest() { integrationTestStruct.Ref = refDoc } +type vectorIndex struct { + dimension int32 + fieldPath string +} + +func createVectorIndexes(ctx context.Context, t *testing.T, dbPath string, vectorModeIndexes []vectorIndex) []string { + collRef := integrationColl(t) + indexNames := make([]string, len(vectorModeIndexes)) + indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, collRef.ID) + + var wg sync.WaitGroup + + // create vectore mode indexes + for i, vectorModeIndex := range vectorModeIndexes { + wg.Add(1) + req := &adminpb.CreateIndexRequest{ + Parent: indexParent, + Index: &adminpb.Index{ + QueryScope: adminpb.Index_COLLECTION, + Fields: []*adminpb.Index_IndexField{ + { + FieldPath: vectorModeIndex.fieldPath, + ValueMode: &adminpb.Index_IndexField_VectorConfig_{ + VectorConfig: &adminpb.Index_IndexField_VectorConfig{ + Dimension: vectorModeIndex.dimension, + Type: &adminpb.Index_IndexField_VectorConfig_Flat{ + Flat: &adminpb.Index_IndexField_VectorConfig_FlatIndex{}, + }, + }, + }, + }, + }, + }, + } + op, createErr := iAdminClient.CreateIndex(ctx, req) + if createErr != nil { + log.Fatalf("CreateIndex vectorindexes: %v", createErr) + } + if i == 0 && !seededFirstIndex { + seededFirstIndex = true + handleCreateIndexResp(ctx, indexNames, &wg, i, op) + } else { + go handleCreateIndexResp(ctx, indexNames, &wg, i, op) + } + } + + wg.Wait() + return indexNames +} + // createIndexes creates composite indexes on provided Firestore database // Indexes are required to run queries with composite filters on multiple fields. // Without indexes, FailedPrecondition rpc error is seen with // desc 'The query requires multiple indexes'. -func createIndexes(ctx context.Context, dbPath string, indexFields [][]string) { - indexNames = make([]string, len(indexFields)) +func createIndexes(ctx context.Context, dbPath string, orderModeindexFields [][]string) []string { + indexNames := make([]string, len(orderModeindexFields)) indexParent := fmt.Sprintf("%s/collectionGroups/%s", dbPath, iColl.ID) var wg sync.WaitGroup - for i, fields := range indexFields { + + // Create order mode indexes + for i, fields := range orderModeindexFields { wg.Add(1) var adminPbIndexFields []*adminpb.Index_IndexField for _, field := range fields { @@ -177,17 +229,21 @@ func createIndexes(ctx context.Context, dbPath string, indexFields [][]string) { if createErr != nil { log.Fatalf("CreateIndex: %v", createErr) } - if i == 0 { + if i == 0 && !seededFirstIndex { + seededFirstIndex = true // Seed first index to prevent FirestoreMetadataWrite.BootstrapDatabase Concurrent access error - handleCreateIndexResp(ctx, &wg, i, op) + handleCreateIndexResp(ctx, indexNames, &wg, i, op) } else { - go handleCreateIndexResp(ctx, &wg, i, op) + go handleCreateIndexResp(ctx, indexNames, &wg, i, op) } } + wg.Wait() + return indexNames } -func handleCreateIndexResp(ctx context.Context, wg *sync.WaitGroup, i int, op *apiv1.CreateIndexOperation) { +// handleCreateIndexResp handles create index response and puts the created index name at index i in the indexNames array +func handleCreateIndexResp(ctx context.Context, indexNames []string, wg *sync.WaitGroup, i int, op *apiv1.CreateIndexOperation) { defer wg.Done() createdIndex, waitErr := op.Wait(ctx) if waitErr != nil { @@ -197,7 +253,7 @@ func handleCreateIndexResp(ctx context.Context, wg *sync.WaitGroup, i int, op *a } // deleteIndexes deletes composite indexes created in createIndexes function -func deleteIndexes(ctx context.Context) { +func deleteIndexes(ctx context.Context, indexNames []string) { for _, indexName := range indexNames { err := iAdminClient.DeleteIndex(ctx, &adminpb.DeleteIndexRequest{ Name: indexName, @@ -293,10 +349,6 @@ func deleteDocument(ctx context.Context, docRef *DocumentRef, bulkwriter *BulkWr func cleanupIntegrationTest() { if iClient != nil { - adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - deleteIndexes(adminCtx) - ctx := context.Background() deleteCollection(ctx, iColl) iClient.Close() @@ -346,44 +398,46 @@ var ( // Use this when writing a doc. integrationTestMap = map[string]interface{}{ - "int": 1, - "int8": int8(2), - "int16": int16(3), - "int32": int32(4), - "int64": int64(5), - "uint8": uint8(6), - "uint16": uint16(7), - "uint32": uint32(8), - "str": "two", - "bool": true, - "float": 3.14, - "null": nil, - "bytes": []byte("bytes"), - "*": map[string]interface{}{"`": 4}, - "time": integrationTime, - "geo": integrationGeo, - "ref": nil, // populated by initIntegrationTest + "int": 1, + "int8": int8(2), + "int16": int16(3), + "int32": int32(4), + "int64": int64(5), + "uint8": uint8(6), + "uint16": uint16(7), + "uint32": uint32(8), + "str": "two", + "bool": true, + "float": 3.14, + "null": nil, + "bytes": []byte("bytes"), + "*": map[string]interface{}{"`": 4}, + "time": integrationTime, + "geo": integrationGeo, + "ref": nil, // populated by initIntegrationTest + "embeddedField": Vector64{1.0, 2.0, 3.0}, } // The returned data is slightly different. wantIntegrationTestMap = map[string]interface{}{ - "int": int64(1), - "int8": int64(2), - "int16": int64(3), - "int32": int64(4), - "int64": int64(5), - "uint8": int64(6), - "uint16": int64(7), - "uint32": int64(8), - "str": "two", - "bool": true, - "float": 3.14, - "null": nil, - "bytes": []byte("bytes"), - "*": map[string]interface{}{"`": int64(4)}, - "time": wantIntegrationTime, - "geo": integrationGeo, - "ref": nil, // populated by initIntegrationTest + "int": int64(1), + "int8": int64(2), + "int16": int64(3), + "int32": int64(4), + "int64": int64(5), + "uint8": int64(6), + "uint16": int64(7), + "uint32": int64(8), + "str": "two", + "bool": true, + "float": 3.14, + "null": nil, + "bytes": []byte("bytes"), + "*": map[string]interface{}{"`": int64(4)}, + "time": wantIntegrationTime, + "geo": integrationGeo, + "ref": nil, // populated by initIntegrationTest + "embeddedField": Vector64{1.0, 2.0, 3.0}, } integrationTestStruct = integrationTestStructType{ @@ -873,7 +927,8 @@ func TestIntegration_QueryDocuments_WhereEntity(t *testing.T) { {"weight", "height"}} adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - createIndexes(adminCtx, wantDBPath, indexFields) + indexNames := createIndexes(adminCtx, wantDBPath, indexFields) + defer deleteIndexes(adminCtx, indexNames) h := testHelper{t} nowTime := time.Now() @@ -2462,10 +2517,13 @@ func TestIntegration_AggregationQueries(t *testing.T) { client := integrationClient(t) indexFields := [][]string{ - {"weight", "model"}} + {"weight", "model"}, + {"weight", "height"}, + } adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() - createIndexes(adminCtx, wantDBPath, indexFields) + indexNames := createIndexes(adminCtx, wantDBPath, indexFields) + defer deleteIndexes(adminCtx, indexNames) h := testHelper{t} docs := []map[string]interface{}{ @@ -2767,7 +2825,11 @@ func TestIntegration_ClientReadTime(t *testing.T) { } tm := time.Now().Add(-time.Minute) + oldReadSettings := *c.readSettings c.WithReadOptions(ReadTime(tm)) + t.Cleanup(func() { + c.readSettings = &oldReadSettings + }) ds, err := c.GetAll(ctx, docs) if err != nil { @@ -2782,3 +2844,94 @@ func TestIntegration_ClientReadTime(t *testing.T) { } } } + +func TestIntegration_FindNearest(t *testing.T) { + collRef := integrationColl(t) + adminCtx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + t.Cleanup(func() { + cancel() + }) + queryField := "EmbeddedField64" + indexNames := createVectorIndexes(adminCtx, t, wantDBPath, []vectorIndex{ + { + fieldPath: queryField, + dimension: 3, + }, + }) + t.Cleanup(func() { + deleteIndexes(adminCtx, indexNames) + }) + + type coffeeBean struct { + ID string + EmbeddedField64 Vector64 + EmbeddedField32 Vector32 + Float32s []float32 // When querying, saving and retrieving, this should be retrieved as []float32 and not Vector32 + } + + beans := []coffeeBean{ + { + ID: "Robusta", + EmbeddedField64: []float64{1, 2, 3}, + EmbeddedField32: []float32{1, 2, 3}, + Float32s: []float32{1, 2, 3}, + }, + { + ID: "Excelsa", + EmbeddedField64: []float64{4, 5, 6}, + EmbeddedField32: []float32{4, 5, 6}, + Float32s: []float32{4, 5, 6}, + }, + { + ID: "Arabica", + EmbeddedField64: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit + EmbeddedField32: []float32{100, 200, 300}, + Float32s: []float32{100, 200, 300}, + }, + + { + ID: "Liberica", + EmbeddedField64: []float64{1, 2}, // Not enough dimensions as compared to query vector. + EmbeddedField32: []float32{1, 2}, + Float32s: []float32{1, 2}, + }, + } + h := testHelper{t} + coll := integrationColl(t) + ctx := context.Background() + var docRefs []*DocumentRef + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + // create documents with vector field + for i := 0; i < len(beans); i++ { + doc := coll.NewDoc() + docRefs = append(docRefs, doc) + h.mustCreate(doc, beans[i]) + } + + // Query documents with a vector field + vectorQuery := collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil) + + iter := vectorQuery.Documents(ctx) + gotDocs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %+v", err) + } + + if len(gotDocs) != 2 { + t.Fatalf("Expected 2 results, got %d", len(gotDocs)) + } + + for i, doc := range gotDocs { + gotBean := coffeeBean{} + err := doc.DataTo(&gotBean) + if err != nil { + t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err) + } + if beans[i].ID != gotBean.ID { + t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID) + } + } +} diff --git a/firestore/query.go b/firestore/query.go index 73ef4b4cca01..4a1254d27306 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -61,6 +61,8 @@ type Query struct { // readOptions specifies constraints for reading results from the query // e.g. read time readSettings *readSettings + + findNearest *pb.StructuredQuery_FindNearest } // DocumentID is the special field name representing the ID of a document @@ -364,6 +366,106 @@ func (q Query) Deserialize(bytes []byte) (Query, error) { return q.fromProto(&runQueryRequest) } +// DistanceMeasure is the distance measure to use when comparing vectors with [Query.FindNearest] or [Query.FindNearestPath]. +type DistanceMeasure int32 + +const ( + // DistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. See + // [Euclidean] to learn more + // + // [Euclidean]: https://en.wikipedia.org/wiki/Euclidean_distance + DistanceMeasureEuclidean DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_EUCLIDEAN) + + // DistanceMeasureCosine compares vectors based on the angle between them, which allows you to + // measure similarity that isn't based on the vectors magnitude. + // We recommend using dot product with unit normalized vectors instead of + // cosine distance, which is mathematically equivalent with better + // performance. See [Cosine Similarity] to learn more. + // + // [Cosine Similarity]: https://en.wikipedia.org/wiki/Cosine_similarity + DistanceMeasureCosine DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_COSINE) + + // DistanceMeasureDotProduct is similar to cosine but is affected by the magnitude of the vectors. See + // [Dot Product] to learn more. + // + // [Dot Product]: https://en.wikipedia.org/wiki/Dot_product + DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) +) + +// FindNearestOptions are options for a FindNearest vector query. +type FindNearestOptions struct { +} + +// VectorQuery represents a vector query +type VectorQuery struct { + q Query +} + +// FindNearest returns a query that can perform vector distance (similarity) search with given parameters. +// +// The returned query, when executed, performs a distance (similarity) search on the specified +// vectorField against the given queryVector and returns the top documents that are closest +// to the queryVector;. +// +// Only documents whose vectorField field is a Vector of the same dimension as queryVector +// participate in the query, all other documents are ignored. +// +// The vectorField argument can be a single field or a dot-separated sequence of +// fields, and must not contain any of the runes "˜*/[]". +func (q Query) FindNearest(vectorField string, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { + // Validate field path + fieldPath, err := parseDotSeparatedString(vectorField) + if err != nil { + q.err = err + return VectorQuery{ + q: q, + } + } + return q.FindNearestPath(fieldPath, queryVector, limit, measure, options) +} + +// Documents returns an iterator over the vector query's resulting documents. +func (vq VectorQuery) Documents(ctx context.Context) *DocumentIterator { + return vq.q.Documents(ctx) +} + +// FindNearestPath is similar to FindNearest but it accepts a [FieldPath]. +func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit int, measure DistanceMeasure, options *FindNearestOptions) VectorQuery { + vq := VectorQuery{ + q: q, + } + + // Convert field path to field reference + vectorFieldRef, err := fref(vectorFieldPath) + if err != nil { + vq.q.err = err + return vq + } + + var fnvq *pb.Value + switch v := queryVector.(type) { + case Vector32: + fnvq = vectorToProtoValue([]float32(v)) + case []float32: + fnvq = vectorToProtoValue(v) + case Vector64: + fnvq = vectorToProtoValue([]float64(v)) + case []float64: + fnvq = vectorToProtoValue(v) + default: + vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64") + return vq + } + + vq.q.findNearest = &pb.StructuredQuery_FindNearest{ + VectorField: vectorFieldRef, + QueryVector: fnvq, + Limit: &wrapperspb.Int32Value{Value: trunc32(limit)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(measure), + } + return vq +} + // NewAggregationQuery returns an AggregationQuery with this query as its // base query. func (q *Query) NewAggregationQuery() *AggregationQuery { @@ -475,6 +577,8 @@ func (q Query) fromProto(pbQuery *pb.RunQueryRequest) (Query, error) { q.limit = limit } + q.findNearest = pbq.GetFindNearest() + // NOTE: limit to last isn't part of the proto, this is a client-side concept // limitToLast bool return q, q.err @@ -556,6 +660,7 @@ func (q Query) toProto() (*pb.StructuredQuery, error) { return nil, err } p.EndAt = cursor + p.FindNearest = q.findNearest return p, nil } diff --git a/firestore/query_test.go b/firestore/query_test.go index 228fb8ef9e29..106a1bbe15bd 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -650,6 +650,72 @@ func createTestScenarios(t *testing.T) []toProtoScenario { }, }, }, + { + desc: `q.Where("a", ">", 5).FindNearest float64 vector`, + in: q.Where("a", ">", 5). + FindNearest("embeddedField", []float64{100, 200, 300}, 2, DistanceMeasureEuclidean, nil).q, + want: &pb.StructuredQuery{ + Where: filtr([]string{"a"}, ">", 5), + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: fref1("embeddedField"), + QueryVector: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 100}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 200}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 300}}, + }, + }, + }, + }, + }, + }, + }, + }, + Limit: &wrapperspb.Int32Value{Value: trunc32(2)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + }, + }, + }, + { + desc: `q.Where("a", ">", 5).FindNearest float32 vector`, + in: q.Where("a", ">", 5). + FindNearest("embeddedField", []float32{100, 200, 300}, 2, DistanceMeasureEuclidean, nil).q, + want: &pb.StructuredQuery{ + Where: filtr([]string{"a"}, ">", 5), + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: fref1("embeddedField"), + QueryVector: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 100}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 200}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 300}}, + }, + }, + }, + }, + }, + }, + }, + }, + Limit: &wrapperspb.Int32Value{Value: trunc32(2)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + }, + }, + }, } } @@ -676,7 +742,6 @@ func TestQueryFromProtoRoundTrip(t *testing.T) { if err != nil { t.Fatalf("%s: %v", test.desc, err) } - fmt.Printf("proto: %v\n", proto) got, err := Query{c: c}.Deserialize(proto) if err != nil { t.Fatalf("%s: %v", test.desc, err) @@ -1359,3 +1424,78 @@ func TestWithAvgPath(t *testing.T) { } } } + +func TestFindNearest(t *testing.T) { + ctx := context.Background() + c, srv, cleanup := newMock(t) + defer cleanup() + + const dbPath = "projects/projectID/databases/(default)" + mapFields := map[string]*pb.Value{ + typeKey: {ValueType: &pb.Value_StringValue{StringValue: typeValVector}}, + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + {ValueType: &pb.Value_DoubleValue{DoubleValue: 1}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 2}}, + {ValueType: &pb.Value_DoubleValue{DoubleValue: 2}}, + }, + }, + }, + }, + } + wantPBDocs := []*pb.Document{ + { + Name: dbPath + "/documents/C/a", + CreateTime: aTimestamp, + UpdateTime: aTimestamp, + Fields: map[string]*pb.Value{"EmbeddedField": mapval(mapFields)}, + }, + } + + testcases := []struct { + desc string + path string + queryVector interface{} + wantErr bool + }{ + { + desc: "Invalid path", + path: "path*", + wantErr: true, + }, + { + desc: "Valid path", + path: "path", + queryVector: []float64{5, 6, 7}, + wantErr: false, + }, + { + desc: "Invalid vector type", + path: "path", + queryVector: "abcd", + wantErr: true, + }, + { + desc: "Valid vector type", + path: "path", + queryVector: []float32{5, 6, 7}, + wantErr: false, + }, + } + for _, tc := range testcases { + srv.reset() + srv.addRPC(nil, []interface{}{ + &pb.RunQueryResponse{Document: wantPBDocs[0]}, + }) + vQuery := c.Collection("C").FindNearest(tc.path, tc.queryVector, 2, DistanceMeasureEuclidean, nil) + + _, err := vQuery.Documents(ctx).GetAll() + if err == nil && tc.wantErr { + t.Fatalf("%s: got nil wanted error", tc.desc) + } else if err != nil && !tc.wantErr { + t.Fatalf("%s: got %v, want nil", tc.desc, err) + } + } +} diff --git a/firestore/to_value.go b/firestore/to_value.go index 0921ef9e6c51..9976b16e7251 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -34,6 +34,8 @@ var ( typeOfLatLng = reflect.TypeOf((*latlng.LatLng)(nil)) typeOfDocumentRef = reflect.TypeOf((*DocumentRef)(nil)) typeOfProtoTimestamp = reflect.TypeOf((*ts.Timestamp)(nil)) + typeOfVector64 = reflect.TypeOf(Vector64{}) + typeOfVector32 = reflect.TypeOf(Vector32{}) ) // toProtoValue converts a Go value to a Firestore Value protobuf. @@ -69,6 +71,10 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return nullValue, false, nil } return &pb.Value{ValueType: &pb.Value_TimestampValue{TimestampValue: x}}, false, nil + case Vector32: + return vectorToProtoValue(x), false, nil + case Vector64: + return vectorToProtoValue(x), false, nil case *latlng.LatLng: if x == nil { // gRPC doesn't like nil oneofs. Use NullValue. @@ -95,9 +101,9 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) case reflect.Uint8, reflect.Uint16, reflect.Uint32: return &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(v.Uint())}}, false, nil case reflect.Float32, reflect.Float64: - return &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: v.Float()}}, false, nil + return floatToProtoValue(v.Float()), false, nil case reflect.String: - return &pb.Value{ValueType: &pb.Value_StringValue{StringValue: v.String()}}, false, nil + return stringToProtoValue(v.String()), false, nil case reflect.Array: return arrayToProtoValue(v) case reflect.Slice: @@ -122,6 +128,14 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) } } +func stringToProtoValue(s string) *pb.Value { + return &pb.Value{ValueType: &pb.Value_StringValue{StringValue: s}} +} + +func floatToProtoValue(f float64) *pb.Value { + return &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: f}} +} + // arrayToProtoValue converts a array to a Firestore Value protobuf and reports // whether a transform was encountered. func arrayToProtoValue(v reflect.Value) (*pb.Value, bool, error) { diff --git a/firestore/vector.go b/firestore/vector.go new file mode 100644 index 000000000000..3b89d2772573 --- /dev/null +++ b/firestore/vector.go @@ -0,0 +1,146 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +const ( + typeKey = "__type__" + typeValVector = "__vector__" + valueKey = "value" +) + +// Vector64 is an embedding vector of float64s. +type Vector64 []float64 + +// Vector32 is an embedding vector of float32s. +type Vector32 []float32 + +// vectorToProtoValue returns a Firestore [pb.Value] representing the Vector. +// The calling function should check for type safety +func vectorToProtoValue[vType float32 | float64](v []vType) *pb.Value { + if v == nil { + return nullValue + } + pbVals := make([]*pb.Value, len(v)) + for i, val := range v { + pbVals[i] = floatToProtoValue(float64(val)) + } + + return &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: pbVals}, + }, + }, + }, + }, + }, + } +} + +func vectorFromProtoValue(v *pb.Value) (interface{}, error) { + return vector64FromProtoValue(v) +} + +func vector32FromProtoValue(v *pb.Value) (Vector32, error) { + pbArrVals, err := pbValToVectorVals(v) + if err != nil { + return nil, err + } + + floats := make([]float32, len(pbArrVals)) + for i, fval := range pbArrVals { + dv, ok := fval.ValueType.(*pb.Value_DoubleValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) + } + floats[i] = float32(dv.DoubleValue) + } + return floats, nil +} + +func vector64FromProtoValue(v *pb.Value) (Vector64, error) { + pbArrVals, err := pbValToVectorVals(v) + if err != nil { + return nil, err + } + + floats := make([]float64, len(pbArrVals)) + for i, fval := range pbArrVals { + dv, ok := fval.ValueType.(*pb.Value_DoubleValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_DoubleValue", fval.ValueType) + } + floats[i] = dv.DoubleValue + } + return floats, nil +} + +func pbValToVectorVals(v *pb.Value) ([]*pb.Value, error) { + /* + Vector is stored as: + { + "__type__": "__vector__", + "value": []float64{}, + } + */ + if v == nil { + return nil, nil + } + pbMap, ok := v.ValueType.(*pb.Value_MapValue) + if !ok { + return nil, fmt.Errorf("firestore: cannot convert %v to *pb.Value_MapValue", v.ValueType) + } + m := pbMap.MapValue.Fields + var typeVal string + typeVal, err := stringFromProtoValue(m[typeKey]) + if err != nil { + return nil, err + } + if typeVal != typeValVector { + return nil, fmt.Errorf("firestore: value of %v : %v is not %v", typeKey, typeVal, typeValVector) + } + pbVal, ok := m[valueKey] + if !ok { + return nil, fmt.Errorf("firestore: %v not present in %v", valueKey, m) + } + + pbArr, ok := pbVal.ValueType.(*pb.Value_ArrayValue) + if !ok { + return nil, fmt.Errorf("firestore: failed to convert %v to *pb.Value_ArrayValue", pbVal.ValueType) + } + + return pbArr.ArrayValue.Values, nil +} + +func stringFromProtoValue(v *pb.Value) (string, error) { + if v == nil { + return "", fmt.Errorf("firestore: failed to convert %v to string", v) + } + sv, ok := v.ValueType.(*pb.Value_StringValue) + if !ok { + return "", fmt.Errorf("firestore: failed to convert %v to *pb.Value_StringValue", v.ValueType) + } + return sv.StringValue, nil +} diff --git a/firestore/vector_test.go b/firestore/vector_test.go new file mode 100644 index 000000000000..9e1497b7ba0e --- /dev/null +++ b/firestore/vector_test.go @@ -0,0 +1,261 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +func TestVectorToProtoValue(t *testing.T) { + tests := []struct { + name string + v Vector64 + want *pb.Value + }{ + { + name: "nil vector", + v: nil, + want: nullValue, + }, + { + name: "empty vector", + v: Vector64{}, + want: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{}}, + }, + }, + }, + }, + }, + }, + }, + { + name: "multiple element vector", + v: Vector64{1.0, 2.0, 3.0}, + want: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := vectorToProtoValue(tt.v) + if !testEqual(got, tt.want) { + t.Errorf("vectorToProtoValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestVectorFromProtoValue(t *testing.T) { + tests := []struct { + name string + v *pb.Value + want Vector64 + wantErr bool + }{ + { + name: "nil value", + v: nil, + want: nil, + }, + { + name: "empty vector", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{}}, + }, + }, + }, + }, + }, + }, + want: Vector64{}, + }, + { + name: "multiple element vector", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + want: Vector64{1.0, 2.0, 3.0}, + }, + { + name: "invalid type", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue("invalid_type"), + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "missing type", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + valueKey: { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{Values: []*pb.Value{floatToProtoValue(1.0), floatToProtoValue(2.0), floatToProtoValue(3.0)}}, + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "missing value", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "invalid value", + v: &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + typeKey: stringToProtoValue(typeValVector), + valueKey: { + ValueType: &pb.Value_StringValue{ + StringValue: "invalid_value", + }, + }, + }, + }, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := vectorFromProtoValue(tt.v) + if (err != nil) != tt.wantErr { + t.Errorf("vectorFromProtoValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if !cmp.Equal(got, tt.want, cmpopts.EquateEmpty()) { + t.Errorf("vectorFromProtoValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStringFromProtoValue(t *testing.T) { + tests := []struct { + name string + v *pb.Value + want string + wantErr bool + }{ + { + name: "nil value", + v: nil, + wantErr: true, + }, + { + name: "string value", + v: &pb.Value{ + ValueType: &pb.Value_StringValue{ + StringValue: "test_string", + }, + }, + want: "test_string", + }, + { + name: "invalid value", + v: &pb.Value{ + ValueType: &pb.Value_IntegerValue{ + IntegerValue: 123, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := stringFromProtoValue(tt.v) + if (err != nil) != tt.wantErr { + t.Errorf("stringFromProtoValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if got != tt.want { + t.Errorf("stringFromProtoValue() = %v, want %v", got, tt.want) + } + }) + } +}