From 8727154bea4648a3bc1fd9af7e0fe5f861a0bd34 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Thu, 7 Mar 2024 11:09:34 -0500 Subject: [PATCH] Add function for stripping source retention options from a descriptor (#250) This adds a helper function to the `options` sub-package that can strip "source retention" options from a descriptor. These are options that should only be retained in the descriptor in source form -- like when manipulated by a compiler or code generator -- and should not be present at runtime. Stripping these options results in a descriptor that could safely be embedded in generated code. --- options/source_retention_options.go | 385 +++++++++++++++++++ options/source_retention_options_test.go | 463 +++++++++++++++++++++++ 2 files changed, 848 insertions(+) create mode 100644 options/source_retention_options.go create mode 100644 options/source_retention_options_test.go diff --git a/options/source_retention_options.go b/options/source_retention_options.go new file mode 100644 index 00000000..1d196483 --- /dev/null +++ b/options/source_retention_options.go @@ -0,0 +1,385 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +import ( + "fmt" + + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/descriptorpb" +) + +// StripSourceRetentionOptionsFromFile returns a file descriptor proto that omits any +// options in file that are defined to be retained only in source. If file has no +// such options, then it is returned as is. If it does have such options, a copy is +// made; the given file will not be mutated. +// +// Even when a copy is returned, it is not a deep copy: it may share data with the +// original file. So callers should not mutate the returned file unless mutating the +// input file is also safe. +func StripSourceRetentionOptionsFromFile(file *descriptorpb.FileDescriptorProto) (*descriptorpb.FileDescriptorProto, error) { + var dirty bool + newOpts, err := stripSourceRetentionOptions(file.GetOptions()) + if err != nil { + return nil, err + } + if newOpts != file.GetOptions() { + dirty = true + } + newMsgs, changed, err := updateAll(file.GetMessageType(), stripSourceRetentionOptionsFromMessage) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newEnums, changed, err := updateAll(file.GetEnumType(), stripSourceRetentionOptionsFromEnum) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newExts, changed, err := updateAll(file.GetExtension(), stripSourceRetentionOptionsFromField) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newSvcs, changed, err := updateAll(file.GetService(), stripSourceRetentionOptionsFromService) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return file, nil + } + + newFile, err := shallowCopy(file) + if err != nil { + return nil, err + } + newFile.Options = newOpts + newFile.MessageType = newMsgs + newFile.EnumType = newEnums + newFile.Extension = newExts + newFile.Service = newSvcs + return newFile, nil +} + +func stripSourceRetentionOptions[M proto.Message](options M) (M, error) { + optionsRef := options.ProtoReflect() + // See if there are any options to strip. + var found bool + var err error + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) + if !ok { + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) + return false + } + if fieldOpts.GetRetention() == descriptorpb.FieldOptions_RETENTION_SOURCE { + found = true + return false + } + return true + }) + var zero M + if err != nil { + return zero, err + } + if !found { + return options, nil + } + + // There is at least one. So we need to make a copy that does not have those options. + newOptions := optionsRef.New() + ret, ok := newOptions.Interface().(M) + if !ok { + return zero, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", newOptions.Interface(), zero) + } + optionsRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + fieldOpts, ok := field.Options().(*descriptorpb.FieldOptions) + if !ok { + err = fmt.Errorf("field options is unexpected type: got %T, want %T", field.Options(), fieldOpts) + return false + } + if fieldOpts.GetRetention() != descriptorpb.FieldOptions_RETENTION_SOURCE { + newOptions.Set(field, val) + } + return true + }) + if err != nil { + return zero, err + } + return ret, nil +} + +func stripSourceRetentionOptionsFromMessage(msg *descriptorpb.DescriptorProto) (*descriptorpb.DescriptorProto, error) { + var dirty bool + newOpts, err := stripSourceRetentionOptions(msg.Options) + if err != nil { + return nil, err + } + if newOpts != msg.Options { + dirty = true + } + newFields, changed, err := updateAll(msg.Field, stripSourceRetentionOptionsFromField) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newOneofs, changed, err := updateAll(msg.OneofDecl, stripSourceRetentionOptionsFromOneof) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newExtRanges, changed, err := updateAll(msg.ExtensionRange, stripSourceRetentionOptionsFromExtensionRange) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newMsgs, changed, err := updateAll(msg.NestedType, stripSourceRetentionOptionsFromMessage) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newEnums, changed, err := updateAll(msg.EnumType, stripSourceRetentionOptionsFromEnum) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + newExts, changed, err := updateAll(msg.Extension, stripSourceRetentionOptionsFromField) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return msg, nil + } + + newMsg, err := shallowCopy(msg) + if err != nil { + return nil, err + } + newMsg.Options = newOpts + newMsg.Field = newFields + newMsg.OneofDecl = newOneofs + newMsg.ExtensionRange = newExtRanges + newMsg.NestedType = newMsgs + newMsg.EnumType = newEnums + newMsg.Extension = newExts + return newMsg, nil +} + +func stripSourceRetentionOptionsFromField(field *descriptorpb.FieldDescriptorProto) (*descriptorpb.FieldDescriptorProto, error) { + newOpts, err := stripSourceRetentionOptions(field.Options) + if err != nil { + return nil, err + } + if newOpts == field.Options { + return field, nil + } + newField, err := shallowCopy(field) + if err != nil { + return nil, err + } + newField.Options = newOpts + return newField, nil +} + +func stripSourceRetentionOptionsFromOneof(oneof *descriptorpb.OneofDescriptorProto) (*descriptorpb.OneofDescriptorProto, error) { + newOpts, err := stripSourceRetentionOptions(oneof.Options) + if err != nil { + return nil, err + } + if newOpts == oneof.Options { + return oneof, nil + } + newOneof, err := shallowCopy(oneof) + if err != nil { + return nil, err + } + newOneof.Options = newOpts + return newOneof, nil +} + +func stripSourceRetentionOptionsFromExtensionRange(extRange *descriptorpb.DescriptorProto_ExtensionRange) (*descriptorpb.DescriptorProto_ExtensionRange, error) { + newOpts, err := stripSourceRetentionOptions(extRange.Options) + if err != nil { + return nil, err + } + if newOpts == extRange.Options { + return extRange, nil + } + newExtRange, err := shallowCopy(extRange) + if err != nil { + return nil, err + } + newExtRange.Options = newOpts + return newExtRange, nil +} + +func stripSourceRetentionOptionsFromEnum(enum *descriptorpb.EnumDescriptorProto) (*descriptorpb.EnumDescriptorProto, error) { + var dirty bool + newOpts, err := stripSourceRetentionOptions(enum.Options) + if err != nil { + return nil, err + } + if newOpts != enum.Options { + dirty = true + } + newVals, changed, err := updateAll(enum.Value, stripSourceRetentionOptionsFromEnumValue) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return enum, nil + } + + newEnum, err := shallowCopy(enum) + if err != nil { + return nil, err + } + newEnum.Options = newOpts + newEnum.Value = newVals + return newEnum, nil +} + +func stripSourceRetentionOptionsFromEnumValue(enumVal *descriptorpb.EnumValueDescriptorProto) (*descriptorpb.EnumValueDescriptorProto, error) { + newOpts, err := stripSourceRetentionOptions(enumVal.Options) + if err != nil { + return nil, err + } + if newOpts == enumVal.Options { + return enumVal, nil + } + newEnumVal, err := shallowCopy(enumVal) + if err != nil { + return nil, err + } + newEnumVal.Options = newOpts + return newEnumVal, nil +} + +func stripSourceRetentionOptionsFromService(svc *descriptorpb.ServiceDescriptorProto) (*descriptorpb.ServiceDescriptorProto, error) { + var dirty bool + newOpts, err := stripSourceRetentionOptions(svc.Options) + if err != nil { + return nil, err + } + if newOpts != svc.Options { + dirty = true + } + newMethods, changed, err := updateAll(svc.Method, stripSourceRetentionOptionsFromMethod) + if err != nil { + return nil, err + } + if changed { + dirty = true + } + + if !dirty { + return svc, nil + } + + newSvc, err := shallowCopy(svc) + if err != nil { + return nil, err + } + newSvc.Options = newOpts + newSvc.Method = newMethods + return newSvc, nil +} + +func stripSourceRetentionOptionsFromMethod(method *descriptorpb.MethodDescriptorProto) (*descriptorpb.MethodDescriptorProto, error) { + newOpts, err := stripSourceRetentionOptions(method.Options) + if err != nil { + return nil, err + } + if newOpts == method.Options { + return method, nil + } + newMethod, err := shallowCopy(method) + if err != nil { + return nil, err + } + newMethod.Options = newOpts + return newMethod, nil +} + +func shallowCopy[M proto.Message](msg M) (M, error) { + msgRef := msg.ProtoReflect() + other := msgRef.New() + ret, ok := other.Interface().(M) + if !ok { + return ret, fmt.Errorf("creating new message of same type resulted in unexpected type; got %T, want %T", other.Interface(), ret) + } + msgRef.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { + other.Set(field, val) + return true + }) + return ret, nil +} + +// updateAll applies the given function to each element in the given slice. It +// returns the new slice and a bool indicating whether anything was actually +// changed. If the second value is false, then the returned slice is the same +// slice as the input slice. Usually, T is a pointer type, in which case the +// given updateFunc should NOT mutate the input value. Instead, it should return +// the input value if only if there is no update needed. If a mutation is needed, +// it should return a new value. +func updateAll[T comparable](slice []T, updateFunc func(T) (T, error)) ([]T, bool, error) { + var updated []T // initialized lazily, only when/if a copy is needed + for i, item := range slice { + newItem, err := updateFunc(item) + if err != nil { + return nil, false, err + } + if updated != nil { + updated[i] = newItem + } else if newItem != item { + updated = make([]T, len(slice)) + copy(updated[:i], slice) + updated[i] = newItem + } + } + if updated != nil { + return updated, true, nil + } + return slice, false, nil +} diff --git a/options/source_retention_options_test.go b/options/source_retention_options_test.go new file mode 100644 index 00000000..38c7d0fe --- /dev/null +++ b/options/source_retention_options_test.go @@ -0,0 +1,463 @@ +// Copyright 2020-2024 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/descriptorpb" + "google.golang.org/protobuf/types/dynamicpb" +) + +func TestStripSourceOnlyOptions(t *testing.T) { + t.Parallel() + optsFileProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("opts.proto"), + Package: proto.String("foo.bar"), + Dependency: []string{"google/protobuf/descriptor.proto"}, + Extension: []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("no_retention"), + Number: proto.Int32(10000), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + // No option + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("unknown_retention"), + Number: proto.Int32(10001), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_UNKNOWN.Enum(), + }, + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("runtime_retention"), + Number: proto.Int32(10002), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BYTES.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_RUNTIME.Enum(), + }, + }, + { + Extendee: proto.String(".google.protobuf.FileOptions"), + Name: proto.String("source_retention"), + Number: proto.Int32(10003), + Label: descriptorpb.FieldDescriptorProto_LABEL_REPEATED.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_SOURCE.Enum(), + }, + }, + }, + } + optsFile, err := protodesc.NewFile(optsFileProto, protoregistry.GlobalFiles) + require.NoError(t, err) + extNoRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("no_retention")) + extUnknownRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("unknown_retention")) + extRuntimeRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("runtime_retention")) + extSourceRetention := dynamicpb.NewExtensionType(optsFile.Extensions().ByName("source_retention")) + + // Create a message with these options. + options := (&descriptorpb.FileOptions{}).ProtoReflect() + options.Set(extNoRetention.TypeDescriptor(), protoreflect.ValueOfString("abc")) + listVal := extUnknownRetention.New().List() + listVal.Append(protoreflect.ValueOfString("foo")) + listVal.Append(protoreflect.ValueOfString("bar")) + options.Set(extUnknownRetention.TypeDescriptor(), protoreflect.ValueOfList(listVal)) + options.Set(extRuntimeRetention.TypeDescriptor(), protoreflect.ValueOfBytes([]byte("xyz"))) + // The above will be retained, so create a copy now to serve as the expected result. + optionsAfterStrip := proto.Clone(options.Interface()) + // The below option will get stripped because it's retention policy is source. + listVal = extSourceRetention.New().List() + listVal.Append(protoreflect.ValueOfInt32(123)) + listVal.Append(protoreflect.ValueOfInt32(-456)) + options.Set(extSourceRetention.TypeDescriptor(), protoreflect.ValueOfList(listVal)) + + optionsMsg := options.Interface() + actualOptionsAfterStrip, err := stripSourceRetentionOptions(optionsMsg) + require.NoError(t, err) + + require.NotSame(t, actualOptionsAfterStrip, optionsMsg) + require.Empty(t, cmp.Diff(optionsAfterStrip, actualOptionsAfterStrip, protocmp.Transform())) + + // If we do it again, there are no changes to made (since source-only options were + // already stripped). So we should get back unmodified value. + optionsMsg = actualOptionsAfterStrip + actualOptionsAfterStrip, err = stripSourceRetentionOptions(optionsMsg) + require.NoError(t, err) + + require.Same(t, actualOptionsAfterStrip, optionsMsg) + require.Empty(t, cmp.Diff(optionsAfterStrip, actualOptionsAfterStrip, protocmp.Transform())) +} + +func TestStripSourceOnlyOptionsFromFile(t *testing.T) { + t.Parallel() + makeCustomOptionSet := func(startTag int32, extendee string, prefix string, label descriptorpb.FieldDescriptorProto_Label) []*descriptorpb.FieldDescriptorProto { + return []*descriptorpb.FieldDescriptorProto{ + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "no_retention"), + Number: proto.Int32(startTag), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + // No option + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "unknown_retention"), + Number: proto.Int32(startTag + 1), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BOOL.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_UNKNOWN.Enum(), + }, + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "runtime_retention"), + Number: proto.Int32(startTag + 2), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_BYTES.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_RUNTIME.Enum(), + }, + }, + { + Extendee: proto.String(extendee), + Name: proto.String(prefix + "source_retention"), + Number: proto.Int32(startTag + 3), + Label: label.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(), + Options: &descriptorpb.FieldOptions{ + Retention: descriptorpb.FieldOptions_RETENTION_SOURCE.Enum(), + }, + }, + } + } + makeCustomOptions := func(extendee string, prefix string) []*descriptorpb.FieldDescriptorProto { + return append( + makeCustomOptionSet(10000, extendee, prefix, descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL), + makeCustomOptionSet(20000, extendee, prefix+"rep_", descriptorpb.FieldDescriptorProto_LABEL_REPEATED)..., + ) + } + combineAll := func(exts ...[]*descriptorpb.FieldDescriptorProto) []*descriptorpb.FieldDescriptorProto { + result := exts[0] + for _, exts := range exts[1:] { + result = append(result, exts...) + } + return result + } + optsFileProto := &descriptorpb.FileDescriptorProto{ + Name: proto.String("opts.proto"), + Package: proto.String("foo.bar"), + Dependency: []string{"google/protobuf/descriptor.proto"}, + Extension: combineAll( + makeCustomOptions(".google.protobuf.FileOptions", "file_"), + makeCustomOptions(".google.protobuf.MessageOptions", "msg_"), + makeCustomOptions(".google.protobuf.FieldOptions", "field_"), + makeCustomOptions(".google.protobuf.OneofOptions", "oneof_"), + makeCustomOptions(".google.protobuf.ExtensionRangeOptions", "extrange_"), + makeCustomOptions(".google.protobuf.EnumOptions", "enum_"), + makeCustomOptions(".google.protobuf.EnumValueOptions", "enumval_"), + makeCustomOptions(".google.protobuf.ServiceOptions", "svc_"), + makeCustomOptions(".google.protobuf.MethodOptions", "method_"), + ), + } + optsFile, err := protodesc.NewFile(optsFileProto, protoregistry.GlobalFiles) + require.NoError(t, err) + + applyCustomOptionSet := func(all, retained protoreflect.Message, prefix protoreflect.Name, isList bool, file protoreflect.FileDescriptor) { + extType := dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "no_retention")) + var val protoreflect.Value + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfString("foo")) + listVal.Append(protoreflect.ValueOfString("bar")) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfString("abc") + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "unknown_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfBool(false)) + listVal.Append(protoreflect.ValueOfBool(true)) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfBool(true) + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "runtime_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfBytes([]byte{0, 1, 2, 3})) + listVal.Append(protoreflect.ValueOfBytes([]byte{4, 5, 6, 7})) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfBytes([]byte{0, 1, 2, 3}) + } + all.Set(extType.TypeDescriptor(), val) + retained.Set(extType.TypeDescriptor(), val) + + extType = dynamicpb.NewExtensionType(file.Extensions().ByName(prefix + "source_retention")) + if isList { + listVal := extType.New().List() + listVal.Append(protoreflect.ValueOfInt32(123)) + listVal.Append(protoreflect.ValueOfInt32(-456)) + val = protoreflect.ValueOfList(listVal) + } else { + val = protoreflect.ValueOfInt32(123) + } + all.Set(extType.TypeDescriptor(), val) + // don't set retained because this is a source-only option (won't be retained) + } + applyCustomOptions := func(message proto.Message, prefix protoreflect.Name, file protoreflect.FileDescriptor) (all, retained proto.Message) { + allRef := message.ProtoReflect() + strippedRef := proto.Clone(message).ProtoReflect() + applyCustomOptionSet(allRef, strippedRef, prefix, false, file) + applyCustomOptionSet(allRef, strippedRef, prefix+"rep_", true, file) + return allRef.Interface(), strippedRef.Interface() + } + + fileOpts, fileOptsStripped := applyCustomOptions(&descriptorpb.FileOptions{}, "file_", optsFile) + msgOpts, msgOptsStripped := applyCustomOptions(&descriptorpb.MessageOptions{}, "msg_", optsFile) + fieldOpts, fieldOptsStripped := applyCustomOptions(&descriptorpb.FieldOptions{}, "field_", optsFile) + oneofOpts, oneofOptsStripped := applyCustomOptions(&descriptorpb.OneofOptions{}, "oneof_", optsFile) + extRangeOpts, extRangeOptsStripped := applyCustomOptions(&descriptorpb.ExtensionRangeOptions{}, "extrange_", optsFile) + enumOpts, enumOptsStripped := applyCustomOptions(&descriptorpb.EnumOptions{}, "enum_", optsFile) + enumValOpts, enumValOptsStripped := applyCustomOptions(&descriptorpb.EnumValueOptions{}, "enumval_", optsFile) + svcOpts, svcOptsStripped := applyCustomOptions(&descriptorpb.ServiceOptions{}, "svc_", optsFile) + methodOpts, methodOptsStripped := applyCustomOptions(&descriptorpb.MethodOptions{}, "method_", optsFile) + + beforeFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + Options: fileOpts.(*descriptorpb.FileOptions), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Message"), + Options: msgOpts.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOpts.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + { + Name: proto.String("oo"), + Options: oneofOpts.(*descriptorpb.OneofOptions), + }, + }, + ExtensionRange: []*descriptorpb.DescriptorProto_ExtensionRange{ + { + Start: proto.Int32(100), + End: proto.Int32(200), + Options: extRangeOpts.(*descriptorpb.ExtensionRangeOptions), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Enum"), + Options: enumOpts.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOpts.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Service"), + Options: svcOpts.(*descriptorpb.ServiceOptions), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Do"), + InputType: proto.String(".Message"), + OutputType: proto.String(".Message"), + Options: methodOpts.(*descriptorpb.MethodOptions), + }, + }, + }, + }, + } + + // This one is the same as above, but uses the stripped option messages + afterFile := &descriptorpb.FileDescriptorProto{ + Name: proto.String("foo.proto"), + Options: fileOptsStripped.(*descriptorpb.FileOptions), + MessageType: []*descriptorpb.DescriptorProto{ + { + Name: proto.String("Message"), + Options: msgOptsStripped.(*descriptorpb.MessageOptions), + Field: []*descriptorpb.FieldDescriptorProto{ + { + Name: proto.String("field"), + Number: proto.Int32(1), + Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(), + Type: descriptorpb.FieldDescriptorProto_TYPE_STRING.Enum(), + JsonName: proto.String("field"), + Options: fieldOptsStripped.(*descriptorpb.FieldOptions), + OneofIndex: proto.Int32(0), + }, + }, + OneofDecl: []*descriptorpb.OneofDescriptorProto{ + { + Name: proto.String("oo"), + Options: oneofOptsStripped.(*descriptorpb.OneofOptions), + }, + }, + ExtensionRange: []*descriptorpb.DescriptorProto_ExtensionRange{ + { + Start: proto.Int32(100), + End: proto.Int32(200), + Options: extRangeOptsStripped.(*descriptorpb.ExtensionRangeOptions), + }, + }, + }, + }, + EnumType: []*descriptorpb.EnumDescriptorProto{ + { + Name: proto.String("Enum"), + Options: enumOptsStripped.(*descriptorpb.EnumOptions), + Value: []*descriptorpb.EnumValueDescriptorProto{ + { + Name: proto.String("ZERO"), + Number: proto.Int32(0), + Options: enumValOptsStripped.(*descriptorpb.EnumValueOptions), + }, + }, + }, + }, + Service: []*descriptorpb.ServiceDescriptorProto{ + { + Name: proto.String("Service"), + Options: svcOptsStripped.(*descriptorpb.ServiceOptions), + Method: []*descriptorpb.MethodDescriptorProto{ + { + Name: proto.String("Do"), + InputType: proto.String(".Message"), + OutputType: proto.String(".Message"), + Options: methodOptsStripped.(*descriptorpb.MethodOptions), + }, + }, + }, + }, + } + + actualStrippedFile, err := StripSourceRetentionOptionsFromFile(beforeFile) + require.NoError(t, err) + require.NotSame(t, actualStrippedFile, beforeFile) + require.Empty(t, cmp.Diff(afterFile, actualStrippedFile, protocmp.Transform())) + + // If we repeat the operation, we get back the same descriptor unchanged because + // it doesn't have any source-only options. + doubleStrippedFile, err := StripSourceRetentionOptionsFromFile(actualStrippedFile) + require.NoError(t, err) + require.Same(t, doubleStrippedFile, actualStrippedFile) + require.Empty(t, cmp.Diff(afterFile, doubleStrippedFile, protocmp.Transform())) +} + +func TestUpdateAll(t *testing.T) { + t.Parallel() + + errInvalid := errors.New("invalid value") + updateFunc := func(i *int32) (*int32, error) { + if i == nil { + return proto.Int32(-1), nil + } + if *i <= -100 { + return nil, errInvalid + } + if *i > 5 { + return proto.Int32(*i * 2), nil + } + return i, nil + } + + vals := []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + proto.Int32(6), proto.Int32(7), proto.Int32(8), + } + newVals, changed, err := updateAll(vals, updateFunc) + require.NoError(t, err) + require.True(t, changed) + expected := []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + proto.Int32(12), proto.Int32(14), proto.Int32(16), + } + require.Equal(t, expected, newVals) + + vals = []*int32{ + nil, proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + newVals, changed, err = updateAll(vals, updateFunc) + require.NoError(t, err) + require.True(t, changed) + expected = []*int32{ + proto.Int32(-1), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + require.Equal(t, expected, newVals) + + // No changes + vals = []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(4), proto.Int32(5), + } + newVals, changed, err = updateAll(vals, updateFunc) + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, vals, newVals) + + // Propagate error + vals = []*int32{ + proto.Int32(0), proto.Int32(1), proto.Int32(2), + proto.Int32(3), proto.Int32(-101), proto.Int32(5), + } + _, _, err = updateAll(vals, updateFunc) + require.ErrorIs(t, err, errInvalid) +}