diff --git a/fmutils.go b/fmutils.go index d73cdcd..2f5d101 100644 --- a/fmutils.go +++ b/fmutils.go @@ -170,30 +170,64 @@ func (mask NestedMask) Overwrite(src, dest proto.Message) { mask.overwrite(src.ProtoReflect(), dest.ProtoReflect()) } -func (mask NestedMask) overwrite(src, dest protoreflect.Message) { - for k, v := range mask { - srcFD := src.Descriptor().Fields().ByName(protoreflect.Name(k)) - destFD := dest.Descriptor().Fields().ByName(protoreflect.Name(k)) - if srcFD == nil || destFD == nil { - continue - } +func (mask NestedMask) overwrite(srcRft, destRft protoreflect.Message) { + for srcFDName, submask := range mask { + srcFD := srcRft.Descriptor().Fields().ByName(protoreflect.Name(srcFDName)) + srcVal := srcRft.Get(srcFD) + if len(submask) == 0 { + if isValid(srcFD, srcVal) { + destRft.Set(srcFD, srcVal) + } else { + destRft.Clear(srcFD) + } + } else if srcFD.IsMap() && srcFD.Kind() == protoreflect.MessageKind { + srcMap := srcRft.Get(srcFD).Map() + destMap := destRft.Get(srcFD).Map() + if !destMap.IsValid() { + destRft.Set(srcFD, protoreflect.ValueOf(srcMap)) + destMap = destRft.Get(srcFD).Map() + } + srcMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool { + if mi, ok := submask[mk.String()]; ok { + if i, ok := mv.Interface().(protoreflect.Message); ok && len(mi) > 0 { + newVal := protoreflect.ValueOf(i.New()) + destMap.Set(mk, newVal) + mi.overwrite(mv.Message(), newVal.Message()) + } else { - // Leaf mask -> copy value from src to dest - if len(v) == 0 { - if srcFD.Kind() == destFD.Kind() { // TODO: Full type equality check - val := src.Get(srcFD) - if isValid(srcFD, val) { - dest.Set(destFD, val) + destMap.Set(mk, mv) + } } else { - dest.Clear(destFD) + destMap.Clear(mk) } + return true + }) + } else if srcFD.IsList() && srcFD.Kind() == protoreflect.MessageKind { + srcList := srcRft.Get(srcFD).List() + destList := destRft.Mutable(srcFD).List() + // Truncate anything in dest that exceeds the length of src + if srcList.Len() < destList.Len() { + destList.Truncate(srcList.Len()) } + for i := 0; i < srcList.Len(); i++ { + srcListItem := srcList.Get(i) + var destListItem protoreflect.Message + if destList.Len() > i { + // Overwrite existing items. + destListItem = destList.Get(i).Message() + } else { + // Append new items to overwrite. + destListItem = destList.AppendMutable().Message() + } + submask.overwrite(srcListItem.Message(), destListItem) + } + } else if srcFD.Kind() == protoreflect.MessageKind { - // If dest field is nil - if !dest.Get(destFD).Message().IsValid() { - dest.Set(destFD, protoreflect.ValueOf(dest.Get(destFD).Message().New())) + // If the dest field is nil + if !destRft.Get(srcFD).Message().IsValid() { + destRft.Set(srcFD, protoreflect.ValueOf(destRft.Get(srcFD).Message().New())) } - v.overwrite(src.Get(srcFD).Message(), dest.Get(destFD).Message()) + submask.overwrite(srcRft.Get(srcFD).Message(), destRft.Get(srcFD).Message()) } } } diff --git a/fmutils_test.go b/fmutils_test.go index d3b4add..3953d5f 100644 --- a/fmutils_test.go +++ b/fmutils_test.go @@ -883,6 +883,217 @@ func TestOverwrite(t *testing.T) { }, }, }, + { + name: "overwrite map with message values", + paths: []string{"attributes.src1.tags.key1", "attributes.src2"}, + src: &testproto.Profile{ + User: nil, + Attributes: map[string]*testproto.Attribute{ + "src1": { + Tags: map[string]string{"key1": "value1", "key2": "value2"}, + }, + "src2": { + Tags: map[string]string{"key3": "value3"}, + }, + }, + }, + dest: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Attributes: map[string]*testproto.Attribute{ + "dest1": { + Tags: map[string]string{"key4": "value4"}, + }, + }, + }, + want: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Attributes: map[string]*testproto.Attribute{ + "src1": { + Tags: map[string]string{"key1": "value1"}, + }, + "src2": { + Tags: map[string]string{"key3": "value3"}, + }, + "dest1": { + Tags: map[string]string{"key4": "value4"}, + }, + }, + }, + }, + { + name: "overwrite repeated message fields", + paths: []string{"gallery.path"}, + src: &testproto.Profile{ + User: &testproto.User{ + UserId: 567, + Name: "different-name", + }, + Photo: &testproto.Photo{ + Path: "photo-path", + }, + LoginTimestamps: []int64{1, 2, 3}, + Attributes: map[string]*testproto.Attribute{ + "src": {}, + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + dest: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-7", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-6", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-5", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + { + PhotoId: 345, + Path: "test-path-4", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + want: &testproto.Profile{ + User: &testproto.User{ + Name: "name", + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + }, + { + name: "overwrite repeated message fields to empty list", + paths: []string{"gallery.path"}, + src: &testproto.Profile{ + User: &testproto.User{ + UserId: 567, + Name: "different-name", + }, + Photo: &testproto.Photo{ + Path: "photo-path", + }, + LoginTimestamps: []int64{1, 2, 3}, + Attributes: map[string]*testproto.Attribute{ + "src": {}, + }, + Gallery: []*testproto.Photo{ + { + PhotoId: 123, + Path: "test-path-1", + Dimensions: &testproto.Dimensions{ + Width: 345, + Height: 456, + }, + }, + { + PhotoId: 234, + Path: "test-path-2", + Dimensions: &testproto.Dimensions{ + Width: 3456, + Height: 4567, + }, + }, + { + PhotoId: 345, + Path: "test-path-3", + Dimensions: &testproto.Dimensions{ + Width: 34567, + Height: 45678, + }, + }, + }, + }, + dest: &testproto.Profile{}, + want: &testproto.Profile{ + Gallery: []*testproto.Photo{ + { + Path: "test-path-1", + }, + { + Path: "test-path-2", + }, + { + Path: "test-path-3", + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {