Skip to content

Commit

Permalink
feat(protojson): message json unmarshal/marshal functions
Browse files Browse the repository at this point in the history
Call UnmarshalJSON or MarshalJSON if they are present on proto messages.

Add an option to disable this behavior.

Signed-off-by: Christian Stewart <[email protected]>
  • Loading branch information
paralin committed Mar 22, 2024
1 parent d24fc1b commit b46c935
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
51 changes: 51 additions & 0 deletions encoding/protojson/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ type UnmarshalOptions struct {
// If DiscardUnknown is set, unknown fields and enum name values are ignored.
DiscardUnknown bool

// IgnoreUnmarshalJSON ignores UnmarshalJSON functions on messages.
// If unset: will call UnmarshalJSON on message types, if it exists.
IgnoreUnmarshalJSON bool

// Resolver is used for looking up types when unmarshaling
// google.protobuf.Any messages or extension fields.
// If nil, this defaults to using protoregistry.GlobalTypes.
Expand Down Expand Up @@ -120,6 +124,41 @@ func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
return errors.New(head+f, x...)
}

// unmarshalObjectToJSON unmarshals an entire object to raw json.
func (d decoder) unmarshalObjectToJSON() ([]byte, error) {
// advance to the end of this block of json
var objOpenDepth int
var startPos int
for {
tok, err := d.Read()
if err != nil {
return nil, err
}
tokKind := tok.Kind()
if objOpenDepth == 0 {
startPos = tok.Pos()
if tokKind == json.String {
// parse as string, single token.
break
}
if tokKind != json.ObjectOpen {
return nil, d.unexpectedTokenError(tok)
}
}
if tokKind == json.ObjectOpen {
objOpenDepth++
} else if tokKind == json.ObjectClose {
objOpenDepth--
}
if objOpenDepth == 0 {
break
}
}

endPos := d.CurrentPos()
return d.Original()[startPos:endPos], nil
}

// unmarshalMessage unmarshals a message into the given protoreflect.Message.
func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error {
d.opts.RecursionLimit--
Expand All @@ -130,6 +169,18 @@ func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) erro
return unmarshal(d, m)
}

if !d.opts.IgnoreUnmarshalJSON {
if jdec, jdecOk := m.Interface().(protoreflect.ProtoMessageWithJSONUnmarshaler); jdecOk {
// advance the decoder to extract the raw json for this entire message
messageData, err := d.unmarshalObjectToJSON()
if err != nil {
return err
}

return jdec.UnmarshalJSON(messageData)
}
}

tok, err := d.Read()
if err != nil {
return err
Expand Down
24 changes: 24 additions & 0 deletions encoding/protojson/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type MarshalOptions struct {
// Marshal will return error if there are any missing required fields.
AllowPartial bool

// IgnoreMarshalJSON ignores MarshalJSON functions on messages.
// If unset: will call MarshalJSON on message types, if it exists.
IgnoreMarshalJSON bool

// UseProtoNames uses proto field name instead of lowerCamelCase name in JSON
// field names.
UseProtoNames bool
Expand Down Expand Up @@ -239,6 +243,26 @@ func (e encoder) marshalMessage(m protoreflect.Message, typeURL string) error {
return marshal(e, m)
}

if !e.opts.IgnoreMarshalJSON {
// Marshaler is the interface implemented by types that
// can marshal themselves into valid JSON.
//
// Copied from encoding/json.Marshaler
type JSONMarshaler interface {
MarshalJSON() ([]byte, error)
}
if marshaler, ok := m.Interface().(JSONMarshaler); ok {
data, err := marshaler.MarshalJSON()
if err != nil {
return err
}
if len(data) != 0 {
e.Encoder.WriteRawObject(string(data))
}
return nil
}
}

e.StartObject()
defer e.EndObject()

Expand Down
10 changes: 10 additions & 0 deletions internal/encoding/json/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ func (d *Decoder) currPos() int {
return len(d.orig) - len(d.in)
}

// CurrentPos returns the current position.
func (d *Decoder) CurrentPos() int {
return d.currPos()
}

// Original returns the original buffer.
func (d *Decoder) Original() []byte {
return d.orig
}

// matchWithDelim matches s with the input b and verifies that the match
// terminates with a delimiter of some form (e.g., r"[^-+_.a-zA-Z0-9]").
// As a special case, EOF is considered a delimiter. It returns the length of s
Expand Down
8 changes: 8 additions & 0 deletions internal/encoding/json/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ func (e *Encoder) Bytes() []byte {
return e.out
}

// WriteRawObject writes a raw object {...} to the marshaler.
// Assumes that data starts with { and ends with }.
func (e *Encoder) WriteRawObject(data string) {
e.prepareNext(objectOpen)
e.out = append(e.out, data...)
e.lastKind = objectClose
}

// WriteNull writes out the null value.
func (e *Encoder) WriteNull() {
e.prepareNext(scalar)
Expand Down
6 changes: 6 additions & 0 deletions reflect/protoreflect/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ type doNotImplement pragma.DoNotImplement
// use the [google.golang.org/protobuf/proto.Message] type instead, which aliases this type.
type ProtoMessage interface{ ProtoReflect() Message }

// ProtoMessageWithJSONUnmarshaler is a ProtoMessage with custom JSON unmarshal logic.
type ProtoMessageWithJSONUnmarshaler interface {
ProtoMessage
UnmarshalJSON([]byte) error
}

// Syntax is the language version of the proto file.
type Syntax syntax

Expand Down

0 comments on commit b46c935

Please sign in to comment.