Skip to content

Commit

Permalink
feat(firestore): Adding vector search (googleapis#10548)
Browse files Browse the repository at this point in the history
* feat(firestore): Adding vector search

* feat(firestore): refactoring code

* feat(firestore): Resolving vet failures

* feat(firestore): Adding unit and integration tests

* feat(firestore): Fixing tests and refactoring code

* feat(firestore): Resolving vet failures

* feat(firestore): Refactoring code

* feat(firestore): Resolving review comments
  • Loading branch information
bhshkh authored Jul 22, 2024
1 parent 6b32871 commit 5c0d6df
Show file tree
Hide file tree
Showing 8 changed files with 964 additions and 122 deletions.
1 change: 1 addition & 0 deletions firestore/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func (d *DocumentSnapshot) Data() map[string]interface{} {
// Slices are resized to the incoming value's size, while arrays that are too
// long have excess elements filled with zero values. If the array is too short,
// excess incoming values will be dropped.
// - Vectors convert to []float64
// - Maps convert to map[string]interface{}. When setting a struct field,
// maps of key type string and any value type are permitted, and are populated
// recursively.
Expand Down
148 changes: 85 additions & 63 deletions firestore/from_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,221 +32,236 @@ func setFromProtoValue(x interface{}, vproto *pb.Value, c *Client) error {
return setReflectFromProtoValue(v.Elem(), vproto, c)
}

// setReflectFromProtoValue sets v from a Firestore Value.
// v must be a settable value.
func setReflectFromProtoValue(v reflect.Value, vproto *pb.Value, c *Client) error {
// setReflectFromProtoValue sets vDest from a Firestore Value.
// vDest must be a settable value.
func setReflectFromProtoValue(vDest reflect.Value, vprotoSrc *pb.Value, c *Client) error {
typeErr := func() error {
return fmt.Errorf("firestore: cannot set type %s to %s", v.Type(), typeString(vproto))
return fmt.Errorf("firestore: cannot set type %s to %s", vDest.Type(), typeString(vprotoSrc))
}

val := vproto.ValueType
valTypeSrc := vprotoSrc.ValueType
// A Null value sets anything nullable to nil, and has no effect
// on anything else.
if _, ok := val.(*pb.Value_NullValue); ok {
switch v.Kind() {
if _, ok := valTypeSrc.(*pb.Value_NullValue); ok {
switch vDest.Kind() {
case reflect.Interface, reflect.Ptr, reflect.Map, reflect.Slice:
v.Set(reflect.Zero(v.Type()))
vDest.Set(reflect.Zero(vDest.Type()))
}
return nil
}

// Handle special types first.
switch v.Type() {
switch vDest.Type() {
case typeOfByteSlice:
x, ok := val.(*pb.Value_BytesValue)
x, ok := valTypeSrc.(*pb.Value_BytesValue)
if !ok {
return typeErr()
}
v.SetBytes(x.BytesValue)
vDest.SetBytes(x.BytesValue)
return nil

case typeOfGoTime:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
if err := x.TimestampValue.CheckValid(); err != nil {
return err
}
v.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
vDest.Set(reflect.ValueOf(x.TimestampValue.AsTime()))
return nil

case typeOfProtoTimestamp:
x, ok := val.(*pb.Value_TimestampValue)
x, ok := valTypeSrc.(*pb.Value_TimestampValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.TimestampValue))
vDest.Set(reflect.ValueOf(x.TimestampValue))
return nil

case typeOfLatLng:
x, ok := val.(*pb.Value_GeoPointValue)
x, ok := valTypeSrc.(*pb.Value_GeoPointValue)
if !ok {
return typeErr()
}
v.Set(reflect.ValueOf(x.GeoPointValue))
vDest.Set(reflect.ValueOf(x.GeoPointValue))
return nil

case typeOfDocumentRef:
x, ok := val.(*pb.Value_ReferenceValue)
x, ok := valTypeSrc.(*pb.Value_ReferenceValue)
if !ok {
return typeErr()
}
dr, err := pathToDoc(x.ReferenceValue, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(dr))
vDest.Set(reflect.ValueOf(dr))
return nil

case typeOfVector32:
val, err := vector32FromProtoValue(vprotoSrc)
if err != nil {
return err
}
vDest.Set(reflect.ValueOf(val))
return nil
case typeOfVector64:
val, err := vector64FromProtoValue(vprotoSrc)
if err != nil {
return err
}
vDest.Set(reflect.ValueOf(val))
return nil
}

switch v.Kind() {
switch vDest.Kind() {
case reflect.Bool:
x, ok := val.(*pb.Value_BooleanValue)
x, ok := valTypeSrc.(*pb.Value_BooleanValue)
if !ok {
return typeErr()
}
v.SetBool(x.BooleanValue)
vDest.SetBool(x.BooleanValue)

case reflect.String:
x, ok := val.(*pb.Value_StringValue)
x, ok := valTypeSrc.(*pb.Value_StringValue)
if !ok {
return typeErr()
}
v.SetString(x.StringValue)
vDest.SetString(x.StringValue)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
i = x.IntegerValue
case *pb.Value_DoubleValue:
f := x.DoubleValue
i = int64(f)
if float64(i) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowInt(i) {
return overflowErr(v, i)
if vDest.OverflowInt(i) {
return overflowErr(vDest, i)
}
v.SetInt(i)
vDest.SetInt(i)

case reflect.Uint8, reflect.Uint16, reflect.Uint32:
var u uint64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_IntegerValue:
u = uint64(x.IntegerValue)
case *pb.Value_DoubleValue:
f := x.DoubleValue
u = uint64(f)
if float64(u) != f {
return fmt.Errorf("firestore: float %f does not fit into %s", f, v.Type())
return fmt.Errorf("firestore: float %f does not fit into %s", f, vDest.Type())
}
default:
return typeErr()
}
if v.OverflowUint(u) {
return overflowErr(v, u)
if vDest.OverflowUint(u) {
return overflowErr(vDest, u)
}
v.SetUint(u)
vDest.SetUint(u)

case reflect.Float32, reflect.Float64:
var f float64
switch x := val.(type) {
switch x := valTypeSrc.(type) {
case *pb.Value_DoubleValue:
f = x.DoubleValue
case *pb.Value_IntegerValue:
f = float64(x.IntegerValue)
if int64(f) != x.IntegerValue {
return overflowErr(v, x.IntegerValue)
return overflowErr(vDest, x.IntegerValue)
}
default:
return typeErr()
}
if v.OverflowFloat(f) {
return overflowErr(v, f)
if vDest.OverflowFloat(f) {
return overflowErr(vDest, f)
}
v.SetFloat(f)
vDest.SetFloat(f)

case reflect.Slice:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
vlen := v.Len()
vlen := vDest.Len()
xlen := len(vals)
// Make a slice of the right size, avoiding allocation if possible.
switch {
case vlen < xlen:
v.Set(reflect.MakeSlice(v.Type(), xlen, xlen))
vDest.Set(reflect.MakeSlice(vDest.Type(), xlen, xlen))
case vlen > xlen:
v.SetLen(xlen)
vDest.SetLen(xlen)
}
return populateRepeated(v, vals, xlen, c)
return populateRepeated(vDest, vals, xlen, c)

case reflect.Array:
x, ok := val.(*pb.Value_ArrayValue)
x, ok := valTypeSrc.(*pb.Value_ArrayValue)
if !ok {
return typeErr()
}
vals := x.ArrayValue.Values
xlen := len(vals)
vlen := v.Len()
vlen := vDest.Len()
minlen := vlen
// Set extra elements to their zero value.
if vlen > xlen {
z := reflect.Zero(v.Type().Elem())
z := reflect.Zero(vDest.Type().Elem())
for i := xlen; i < vlen; i++ {
v.Index(i).Set(z)
vDest.Index(i).Set(z)
}
minlen = xlen
}
return populateRepeated(v, vals, minlen, c)
return populateRepeated(vDest, vals, minlen, c)

case reflect.Map:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateMap(v, x.MapValue.Fields, c)
return populateMap(vDest, x.MapValue.Fields, c)

case reflect.Ptr:
// If the pointer is nil, set it to a zero value.
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
if vDest.IsNil() {
vDest.Set(reflect.New(vDest.Type().Elem()))
}
return setReflectFromProtoValue(v.Elem(), vproto, c)
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)

case reflect.Struct:
x, ok := val.(*pb.Value_MapValue)
x, ok := valTypeSrc.(*pb.Value_MapValue)
if !ok {
return typeErr()
}
return populateStruct(v, x.MapValue.Fields, c)
return populateStruct(vDest, x.MapValue.Fields, c)

case reflect.Interface:
if v.NumMethod() == 0 { // empty interface
if vDest.NumMethod() == 0 { // empty interface
// If v holds a pointer, set the pointer.
if !v.IsNil() && v.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(v.Elem(), vproto, c)
if !vDest.IsNil() && vDest.Elem().Kind() == reflect.Ptr {
return setReflectFromProtoValue(vDest.Elem(), vprotoSrc, c)
}
// Otherwise, create a fresh value.
x, err := createFromProtoValue(vproto, c)
x, err := createFromProtoValue(vprotoSrc, c)
if err != nil {
return err
}
v.Set(reflect.ValueOf(x))
vDest.Set(reflect.ValueOf(x))
return nil
}
// Any other kind of interface is an error.
fallthrough

default:
return fmt.Errorf("firestore: cannot set type %s", v.Type())
return fmt.Errorf("firestore: cannot set type %s", vDest.Type())
}
return nil
}
Expand Down Expand Up @@ -389,8 +404,15 @@ func createFromProtoValue(vproto *pb.Value, c *Client) (interface{}, error) {
}
ret[k] = r
}
return ret, nil

typeVal, ok := ret[typeKey]
if !ok || typeVal != typeValVector {
// Map is not a vector. Return the map
return ret, nil
}

// Special handling for vector
return vectorFromProtoValue(vproto)
default:
return nil, fmt.Errorf("firestore: unknown value type %T", v)
}
Expand Down
Loading

0 comments on commit 5c0d6df

Please sign in to comment.