Skip to content

Commit

Permalink
Add kUpb_DecodeOption_AlwaysValidateUtf8 decode option, to force UTF-…
Browse files Browse the repository at this point in the history
…8 validation of proto2 strings.

Also slightly optimized _upb_Decoder_GetDelimitedOp, which should result in a modest speedup parsing protos with many bytes fields (or proto2 string fields), and submessages.

PiperOrigin-RevId: 584944649
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Dec 30, 2023
1 parent 499c748 commit 5e049b7
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 23 deletions.
21 changes: 20 additions & 1 deletion upb/message/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,12 @@ proto_library(
name = "utf8_test_proto",
testonly = 1,
srcs = ["utf8_test.proto"],
deps = ["//src/google/protobuf:test_messages_proto3_proto"],
)

proto_library(
name = "utf8_test_proto2_proto",
testonly = 1,
srcs = ["utf8_test_proto2.proto"],
)

upb_minitable_proto_library(
Expand All @@ -336,16 +341,30 @@ upb_minitable_proto_library(
deps = [":utf8_test_proto"],
)

upb_minitable_proto_library(
name = "utf8_test_proto2_upb_minitable_proto",
testonly = 1,
deps = [":utf8_test_proto2_proto"],
)

upb_c_proto_library(
name = "utf8_test_upb_proto",
testonly = 1,
deps = [":utf8_test_proto"],
)

upb_c_proto_library(
name = "utf8_test_proto2_upb_proto",
testonly = 1,
deps = [":utf8_test_proto2_proto"],
)

cc_test(
name = "utf8_test",
srcs = ["utf8_test.cc"],
deps = [
":utf8_test_proto2_upb_minitable_proto",
":utf8_test_proto2_upb_proto",
":utf8_test_upb_minitable_proto",
":utf8_test_upb_proto",
"//upb:base",
Expand Down
96 changes: 96 additions & 0 deletions upb/message/utf8_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "upb/mem/arena.hpp"
#include "upb/message/utf8_test.upb.h"
#include "upb/message/utf8_test.upb_minitable.h"
#include "upb/message/utf8_test_proto2.upb.h"
#include "upb/message/utf8_test_proto2.upb_minitable.h"
#include "upb/wire/decode.h"

namespace {
Expand Down Expand Up @@ -72,6 +74,100 @@ TEST(Utf8Test, RepeatedProto3FieldValidates) {
ASSERT_EQ(kUpb_DecodeStatus_BadUtf8, status);
}

TEST(Utf8Test, Proto2BytesValidates) {
upb::Arena arena;
size_t size;
char* data = GetBadUtf8Payload(arena.ptr(), &size);

upb_test_TestUtf8Proto2Bytes* msg =
upb_test_TestUtf8Proto2Bytes_new(arena.ptr());

upb_DecodeStatus status;
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8Proto2Bytes_msg_init, nullptr, 0,
arena.ptr());

// Parse succeeds, because proto2 bytes fields don't validate UTF-8.
ASSERT_EQ(kUpb_DecodeStatus_Ok, status);
}

TEST(Utf8Test, Proto2RepeatedBytesValidates) {
upb::Arena arena;
size_t size;
char* data = GetBadUtf8Payload(arena.ptr(), &size);

upb_test_TestUtf8RepeatedProto2Bytes* msg =
upb_test_TestUtf8RepeatedProto2Bytes_new(arena.ptr());

upb_DecodeStatus status;
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8RepeatedProto2Bytes_msg_init, nullptr,
0, arena.ptr());

// Parse succeeds, because proto2 bytes fields don't validate UTF-8.
ASSERT_EQ(kUpb_DecodeStatus_Ok, status);
}

TEST(Utf8Test, Proto2StringValidates) {
upb::Arena arena;
size_t size;
char* data = GetBadUtf8Payload(arena.ptr(), &size);

upb_test_TestUtf8Proto2String* msg =
upb_test_TestUtf8Proto2String_new(arena.ptr());

upb_DecodeStatus status;
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8Proto2String_msg_init, nullptr, 0,
arena.ptr());

// Parse succeeds, because proto2 string fields don't validate UTF-8.
ASSERT_EQ(kUpb_DecodeStatus_Ok, status);
}

TEST(Utf8Test, Proto2FieldFailsValidation) {
upb::Arena arena;
size_t size;
char* data = GetBadUtf8Payload(arena.ptr(), &size);

upb_test_TestUtf8Proto2String* msg =
upb_test_TestUtf8Proto2String_new(arena.ptr());

upb_DecodeStatus status;
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8Proto2String_msg_init, nullptr, 0,
arena.ptr());

// Parse fails, because we pass in kUpb_DecodeOption_AlwaysValidateUtf8 to
// force validation of proto2 string fields.
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8Proto2String_msg_init, nullptr,
kUpb_DecodeOption_AlwaysValidateUtf8, arena.ptr());
ASSERT_EQ(kUpb_DecodeStatus_BadUtf8, status);
}

TEST(Utf8Test, Proto2RepeatedFieldFailsValidation) {
upb::Arena arena;
size_t size;
char* data = GetBadUtf8Payload(arena.ptr(), &size);

upb_test_TestUtf8RepeatedProto2String* msg =
upb_test_TestUtf8RepeatedProto2String_new(arena.ptr());

upb_DecodeStatus status;
status = upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8RepeatedProto2String_msg_init,
nullptr, 0, arena.ptr());

// Parse fails, because we pass in kUpb_DecodeOption_AlwaysValidateUtf8 to
// force validation of proto2 string fields.
status =
upb_Decode(data, size, UPB_UPCAST(msg),
&upb_0test__TestUtf8RepeatedProto2String_msg_init, nullptr,
kUpb_DecodeOption_AlwaysValidateUtf8, arena.ptr());
ASSERT_EQ(kUpb_DecodeStatus_BadUtf8, status);
}

// begin:google_only
// TEST(Utf8Test, Proto3MixedFieldValidates) {
// upb::Arena arena;
Expand Down
2 changes: 1 addition & 1 deletion upb/message/utf8_test.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ message TestUtf8Proto3StringEnforceUtf8False {
}

message TestUtf8RepeatedProto3StringEnforceUtf8False {
optional string data = 1;
repeated string data = 1;
}

message TestUtf8Proto3StringEnforceUtf8FalseMixed {
Expand Down
26 changes: 26 additions & 0 deletions upb/message/utf8_test_proto2.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Protocol Buffers - Google's data interchange format
// Copyright 2023 Google LLC. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd

syntax = "proto2";

package upb_test;

message TestUtf8Proto2Bytes {
optional bytes data = 1;
}

message TestUtf8RepeatedProto2Bytes {
optional bytes data = 1;
}

message TestUtf8Proto2String {
optional string data = 1;
}

message TestUtf8RepeatedProto2String {
repeated string data = 1;
}
62 changes: 41 additions & 21 deletions upb/wire/decode.c
Original file line number Diff line number Diff line change
Expand Up @@ -1005,18 +1005,24 @@ int _upb_Decoder_GetVarintOp(const upb_MiniTableField* field) {
return kVarintOps[field->UPB_PRIVATE(descriptortype)];
}

UPB_FORCEINLINE
static void _upb_Decoder_CheckUnlinked(upb_Decoder* d, const upb_MiniTable* mt,
const upb_MiniTableField* field,
int* op) {
// If sub-message is not linked, treat as unknown.
if (field->UPB_PRIVATE(mode) & kUpb_LabelFlags_IsExtension) return;
// This returns a DecodeOp enum (kUpb_DecodeOp_SubMessage,
// kUpb_DecodeOp_UnknownField, etc).
static int _upb_Decoder_CheckUnlinked(upb_Decoder* d, const upb_MiniTable* mt,
const upb_MiniTableField* field) {
if (field->UPB_PRIVATE(mode) & kUpb_LabelFlags_IsExtension) {
return kUpb_DecodeOp_SubMessage;
}

const upb_MiniTable* mt_sub =
_upb_MiniTableSubs_MessageByField(mt->UPB_PRIVATE(subs), field);
if ((d->options & kUpb_DecodeOption_ExperimentalAllowUnlinked) ||
!UPB_PRIVATE(_upb_MiniTable_IsEmpty)(mt_sub)) {
return;
if (!UPB_PRIVATE(_upb_MiniTable_IsEmpty)(mt_sub)) {
return kUpb_DecodeOp_SubMessage;
}

if (d->options & kUpb_DecodeOption_ExperimentalAllowUnlinked) {
return kUpb_DecodeOp_SubMessage;
}

#ifndef NDEBUG
const upb_MiniTableField* oneof = upb_MiniTable_GetOneof(mt, field);
if (oneof) {
Expand All @@ -1030,13 +1036,26 @@ static void _upb_Decoder_CheckUnlinked(upb_Decoder* d, const upb_MiniTable* mt,
} while (upb_MiniTable_NextOneofField(mt, &oneof));
}
#endif // NDEBUG
*op = kUpb_DecodeOp_UnknownField;

return kUpb_DecodeOp_UnknownField;
}

static int _upb_Decoder_MaybeVerifyUtf8(upb_Decoder* d,
const upb_MiniTableField* field) {
if (!(field->UPB_ONLYBITS(mode) & kUpb_LabelFlags_IsAlternate) ||
!UPB_UNLIKELY(d->options & kUpb_DecodeOption_AlwaysValidateUtf8))
return kUpb_DecodeOp_Bytes;

return kUpb_DecodeOp_String;
}

int _upb_Decoder_GetDelimitedOp(upb_Decoder* d, const upb_MiniTable* mt,
const upb_MiniTableField* field) {
static int _upb_Decoder_GetDelimitedOp(upb_Decoder* d, const upb_MiniTable* mt,
const upb_MiniTableField* field) {
enum { kRepeatedBase = 19 };

// This table is used to map field types to decode ops. However, not all the
// field types here may be used, since specialized logic in this function may
// determine the decode op based on more than just the field type.
static const int8_t kDelimitedOps[] = {
/* For non-repeated field type. */
[kUpb_FakeFieldType_FieldNotFound] =
Expand Down Expand Up @@ -1083,15 +1102,17 @@ int _upb_Decoder_GetDelimitedOp(upb_Decoder* d, const upb_MiniTable* mt,
// repeated msgset type
};

int ndx = field->UPB_PRIVATE(descriptortype);
if (upb_MiniTableField_IsArray(field)) ndx += kRepeatedBase;
int op = kDelimitedOps[ndx];
const int ndx = field->UPB_PRIVATE(descriptortype);
const bool is_array = upb_MiniTableField_IsArray(field);

if (op == kUpb_DecodeOp_SubMessage) {
_upb_Decoder_CheckUnlinked(d, mt, field, &op);
if (ndx == kUpb_FieldType_Message ||
(ndx == kUpb_FieldType_Group && is_array)) {
return _upb_Decoder_CheckUnlinked(d, mt, field);
} else if (ndx == kUpb_FieldType_Bytes) {
return _upb_Decoder_MaybeVerifyUtf8(d, field);
} else {
return kDelimitedOps[ndx + (is_array ? kRepeatedBase : 0)];
}

return op;
}

UPB_FORCEINLINE
Expand Down Expand Up @@ -1133,8 +1154,7 @@ static const char* _upb_Decoder_DecodeWireValue(upb_Decoder* d, const char* ptr,
case kUpb_WireType_StartGroup:
val->uint32_val = field->UPB_PRIVATE(number);
if (field->UPB_PRIVATE(descriptortype) == kUpb_FieldType_Group) {
*op = kUpb_DecodeOp_SubMessage;
_upb_Decoder_CheckUnlinked(d, mt, field, op);
*op = _upb_Decoder_CheckUnlinked(d, mt, field);
} else if (field->UPB_PRIVATE(descriptortype) ==
kUpb_FakeFieldType_MessageSetItem) {
*op = kUpb_DecodeOp_MessageSetItem;
Expand Down
10 changes: 10 additions & 0 deletions upb/wire/decode.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ enum {
* be created by the parser or the message-copying logic in message/copy.h.
*/
kUpb_DecodeOption_ExperimentalAllowUnlinked = 4,

/* EXPERIMENTAL:
*
* If set, decoding will enforce UTF-8 validation for string fields, even for
* proto2 or fields with `features.utf8_validation = NONE`. Normally, only
* proto3 string fields will be validated for UTF-8. Decoding will return
* kUpb_DecodeStatus_BadUtf8 for non-UTF-8 strings, which is the same behavior
* as non-UTF-8 proto3 string fields.
*/
kUpb_DecodeOption_AlwaysValidateUtf8 = 8,
};

UPB_INLINE uint32_t upb_DecodeOptions_MaxDepth(uint16_t depth) {
Expand Down

0 comments on commit 5e049b7

Please sign in to comment.