Skip to content

Commit

Permalink
Reparse all extensions, not just unrecognized ones (#3344)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Sep 26, 2024
1 parent ee73df5 commit dacb299
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 15 deletions.
2 changes: 1 addition & 1 deletion private/bufpkg/bufimage/bufimage.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ type newImageForProtoOptions struct {
}

func reparseImageProto(protoImage *imagev1.Image, resolver protoencoding.Resolver, computeUnusedImports bool) error {
if err := protoencoding.ReparseUnrecognized(resolver, protoImage.ProtoReflect()); err != nil {
if err := protoencoding.ReparseExtensions(resolver, protoImage.ProtoReflect()); err != nil {
return fmt.Errorf("could not reparse image: %v", err)
}
if computeUnusedImports {
Expand Down
2 changes: 1 addition & 1 deletion private/pkg/protoencoding/json_marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newJSONMarshaler(resolver Resolver, options ...JSONMarshalerOption) Marshal
}

func (m *jsonMarshaler) Marshal(message proto.Message) ([]byte, error) {
if err := ReparseUnrecognized(m.resolver, message.ProtoReflect()); err != nil {
if err := ReparseExtensions(m.resolver, message.ProtoReflect()); err != nil {
return nil, err
}
options := protojson.MarshalOptions{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,62 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

// ReparseUnrecognized uses the given resolver to parse any unrecognized fields in the
// given reflectMessage. It does so recursively, resolving any unrecognized fields in
// nested messages.
func ReparseUnrecognized(resolver Resolver, reflectMessage protoreflect.Message) error {
// ReparseExtensions uses the given resolver to parse any unrecognized fields in the
// given reflectMessage as well as re-parse any extensions.
func ReparseExtensions(resolver Resolver, reflectMessage protoreflect.Message) error {
if resolver == nil {
return nil
}
unknown := reflectMessage.GetUnknown()
if len(unknown) > 0 {
reparseBytes := reflectMessage.GetUnknown()

if reflectMessage.Descriptor().ExtensionRanges().Len() > 0 {
// Collect extensions into separate message so we can serialize
// *just* the extensions and then re-parse them below.
var msgExts protoreflect.Message
reflectMessage.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
if !field.IsExtension() {
return true
}
if msgExts == nil {
msgExts = reflectMessage.Type().New()
}
msgExts.Set(field, value)
reflectMessage.Clear(field)
return true
})
if msgExts != nil {
options := proto.MarshalOptions{AllowPartial: true}
var err error
reparseBytes, err = options.MarshalAppend(reparseBytes, msgExts.Interface())
if err != nil {
return err
}
}
}

if len(reparseBytes) > 0 {
reflectMessage.SetUnknown(nil)
options := proto.UnmarshalOptions{
Resolver: resolver,
Merge: true,
}
if err := options.Unmarshal(unknown, reflectMessage.Interface()); err != nil {
if err := options.Unmarshal(reparseBytes, reflectMessage.Interface()); err != nil {
return err
}
}
var err error
reflectMessage.Range(func(fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) bool {
err = reparseUnrecognizedInField(resolver, fieldDescriptor, value)
err = reparseInField(resolver, fieldDescriptor, value)
return err == nil
})
return err
}

func reparseUnrecognizedInField(resolver Resolver, fieldDescriptor protoreflect.FieldDescriptor, value protoreflect.Value) error {
func reparseInField(
resolver Resolver,
fieldDescriptor protoreflect.FieldDescriptor,
value protoreflect.Value,
) error {
if fieldDescriptor.IsMap() {
valDesc := fieldDescriptor.MapValue()
if valDesc.Kind() != protoreflect.MessageKind && valDesc.Kind() != protoreflect.GroupKind {
Expand All @@ -54,7 +83,7 @@ func reparseUnrecognizedInField(resolver Resolver, fieldDescriptor protoreflect.
}
var err error
value.Map().Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
err = ReparseUnrecognized(resolver, v.Message())
err = ReparseExtensions(resolver, v.Message())
return err == nil
})
return err
Expand All @@ -66,11 +95,11 @@ func reparseUnrecognizedInField(resolver Resolver, fieldDescriptor protoreflect.
if fieldDescriptor.IsList() {
list := value.List()
for i := 0; i < list.Len(); i++ {
if err := ReparseUnrecognized(resolver, list.Get(i).Message()); err != nil {
if err := ReparseExtensions(resolver, list.Get(i).Message()); err != nil {
return err
}
}
return nil
}
return ReparseUnrecognized(resolver, value.Message())
return ReparseExtensions(resolver, value.Message())
}
128 changes: 128 additions & 0 deletions private/pkg/protoencoding/reparse_extensions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2020-2024 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package protoencoding

import (
"math"
"testing"

"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/descriptorpb"
"google.golang.org/protobuf/types/dynamicpb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
)

func TestReparseExtensions(t *testing.T) {
t.Parallel()

descriptorFile := protodesc.ToFileDescriptorProto(descriptorpb.File_google_protobuf_descriptor_proto)
durationFile := protodesc.ToFileDescriptorProto(durationpb.File_google_protobuf_duration_proto)
timestampFile := protodesc.ToFileDescriptorProto(timestamppb.File_google_protobuf_timestamp_proto)
validateFile := protodesc.ToFileDescriptorProto(validate.File_buf_validate_validate_proto)

// The file will include one custom option with a known/generated type.
fieldOpts := &descriptorpb.FieldOptions{}
fieldConstraints := &validate.FieldConstraints{
Required: proto.Bool(true),
Type: &validate.FieldConstraints_Int32{
Int32: &validate.Int32Rules{
GreaterThan: &validate.Int32Rules_Gt{
Gt: 0,
},
},
},
}
proto.SetExtension(fieldOpts, validate.E_Field, fieldConstraints)
// The file will also contain an unrecognized custom option.
const customOptionNum = 54321
const customOptionVal = float32(3.14159)
var unknownOption []byte
unknownOption = protowire.AppendTag(unknownOption, customOptionNum, protowire.Fixed32Type)
unknownOption = protowire.AppendFixed32(unknownOption, math.Float32bits(customOptionVal))
fieldOpts.ProtoReflect().SetUnknown(unknownOption)

testFile := &descriptorpb.FileDescriptorProto{
Name: proto.String("test.proto"),
Syntax: proto.String("proto3"),
Package: proto.String("blah.blah"),
Dependency: []string{"buf/validate/validate.proto", "google/protobuf/descriptor.proto"},
MessageType: []*descriptorpb.DescriptorProto{
{
Name: proto.String("Foo"),
Field: []*descriptorpb.FieldDescriptorProto{
{
Name: proto.String("bar"),
Number: proto.Int32(1),
Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: descriptorpb.FieldDescriptorProto_TYPE_INT32.Enum(),
JsonName: proto.String("bar"),
Options: fieldOpts,
},
},
},
},
Extension: []*descriptorpb.FieldDescriptorProto{
{
Extendee: proto.String(".google.protobuf.FieldOptions"),
Name: proto.String("baz"),
Number: proto.Int32(customOptionNum),
Label: descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL.Enum(),
Type: descriptorpb.FieldDescriptorProto_TYPE_FLOAT.Enum(),
},
},
}

resolver, err := NewResolver(descriptorFile, durationFile, timestampFile, validateFile, testFile)
require.NoError(t, err)
err = ReparseExtensions(resolver, testFile.ProtoReflect())
require.NoError(t, err)

require.Empty(t, fieldOpts.ProtoReflect().GetUnknown())
var found int
fieldOpts.ProtoReflect().Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
switch field.Number() {
case customOptionNum:
found++
assert.Equal(t, customOptionVal, value.Interface())
case protoreflect.FieldNumber(validate.E_Field.Field):

Check failure on line 108 in private/pkg/protoencoding/reparse_extensions_test.go

View workflow job for this annotation

GitHub Actions / codeql

SA1019: validate.E_Field.Field is deprecated: Use the Descriptor().Number method instead. (staticcheck)
found++
msg := value.Message().Interface()
assert.NotSame(t, fieldConstraints, msg)
_, isGenType := msg.(*validate.FieldConstraints)
assert.False(t, isGenType)
_, isDynamicType := msg.(*dynamicpb.Message)
assert.True(t, isDynamicType)

// round-trip back to gen type to check for equality with original
data, err := proto.Marshal(msg)
require.NoError(t, err)
roundTrippedConstraints := &validate.FieldConstraints{}
err = proto.Unmarshal(data, roundTrippedConstraints)
require.NoError(t, err)
require.Empty(t, cmp.Diff(fieldConstraints, roundTrippedConstraints, protocmp.Transform()))
}
return true
})
assert.Equal(t, 2, found)
}
2 changes: 1 addition & 1 deletion private/pkg/protoencoding/yaml_marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newYAMLMarshaler(resolver Resolver, options ...YAMLMarshalerOption) Marshal
}

func (m *yamlMarshaler) Marshal(message proto.Message) ([]byte, error) {
if err := ReparseUnrecognized(m.resolver, message.ProtoReflect()); err != nil {
if err := ReparseExtensions(m.resolver, message.ProtoReflect()); err != nil {
return nil, err
}
options := protoyaml.MarshalOptions{
Expand Down

0 comments on commit dacb299

Please sign in to comment.