Skip to content

Commit

Permalink
impl SettableValue for MsgView
Browse files Browse the repository at this point in the history
For the cpp runtime, call the `Message::CopyFrom` method.

For the upb runtime, expose the message `MiniTable` and call `upb_Message_DeepCopy`.

PiperOrigin-RevId: 591250339
  • Loading branch information
kcking authored and copybara-github committed Dec 19, 2023
1 parent f73985a commit cd745eb
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ endif (MSVC)
include_directories(
${ZLIB_INCLUDE_DIRECTORIES}
${protobuf_BINARY_DIR}
# Support #include-ing other top-level directories, i.e. upb_generator.
${protobuf_SOURCE_DIR}
${protobuf_SOURCE_DIR}/src)

set(protobuf_ABSL_PROVIDER "module" CACHE STRING "Provider of absl library")
Expand Down
14 changes: 13 additions & 1 deletion rust/test/shared/accessors_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use googletest::prelude::*;
use matchers::{is_set, is_unset};
use protobuf::Optional;
use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_};
use unittest_proto::proto2_unittest::{NestedTestAllTypes, TestAllTypes, TestAllTypes_};

#[test]
fn test_default_accessors() {
Expand Down Expand Up @@ -735,3 +735,15 @@ fn test_oneof_mut_accessors() {
msg.oneof_bytes_mut().set(b"123");
assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_)));
}

#[test]
fn test_set_message_from_view() {
use protobuf::MutProxy;
let mut m1 = NestedTestAllTypes::new();

let mut m2 = NestedTestAllTypes::new();
m2.payload_mut().optional_int32_mut().set(1);

m1.payload_mut().set(m2.payload());
assert_that!(m1.payload().optional_int32(), eq(1));
}
22 changes: 22 additions & 0 deletions rust/upb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ impl<'msg> MutatorMessageRef<'msg> {
pub fn msg(&self) -> RawMessage {
self.msg
}

pub fn arena(&self) -> RawArena {
self.arena.raw()
}
}

pub fn copy_bytes_in_arena_if_needed_by_runtime<'msg>(
Expand All @@ -297,6 +301,24 @@ pub fn copy_bytes_in_arena_if_needed_by_runtime<'msg>(
}
}

/// Opaque struct containing a upb_MiniTable.
///
/// This wrapper is a workaround until stabilization of extern C types.
#[repr(C)]
pub struct RawMiniTable {
_data: [u8; 0],
_marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>,
}

extern "C" {
pub fn upb_Message_DeepCopy(
dst: RawMessage,
src: RawMessage,
mini_table: *const RawMiniTable,
arena: RawArena,
);
}

/// The raw type-erased pointer version of `RepeatedMut`.
///
/// Contains a `upb_Array*` as well as `RawArena`, most likely that of the
Expand Down
1 change: 1 addition & 0 deletions rust/upb_kernel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ cc_library(
deps = [
"//upb:mem",
"//upb:message",
"//upb:message_copy", # buildcleaner: keep
],
)
3 changes: 2 additions & 1 deletion rust/upb_kernel/upb_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@

#include "upb/mem/arena.h" // IWYU pragma: keep
#include "upb/message/array.h" // IWYU pragma: keep
#include "upb/message/map.h" // IWYU pragma: keep
#include "upb/message/copy.h" // IWYU pragma: keep
#include "upb/message/map.h" // IWYU pragma: keep
1 change: 1 addition & 0 deletions src/google/protobuf/compiler/rust/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_library(
":oneof",
"//src/google/protobuf:protobuf_nowkt",
"//src/google/protobuf/compiler/cpp:names",
"//upb_generator:mangle",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
],
Expand Down
59 changes: 51 additions & 8 deletions src/google/protobuf/compiler/rust/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "google/protobuf/compiler/rust/message.h"

#include <string>

#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/string_view.h"
Expand All @@ -17,13 +19,18 @@
#include "google/protobuf/compiler/rust/naming.h"
#include "google/protobuf/compiler/rust/oneof.h"
#include "google/protobuf/descriptor.h"
#include "upb_generator/mangle.h"

namespace google {
namespace protobuf {
namespace compiler {
namespace rust {
namespace {

static std::string MinitableName(const Descriptor& msg) {
return upb::generator::MessageInit(msg.full_name());
}

void MessageNew(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
Expand Down Expand Up @@ -124,12 +131,14 @@ void MessageExterns(Context& ctx, const Descriptor& msg) {
{"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
{"copy_from_thunk", Thunk(ctx, msg, "copy_from")},
},
R"rs(
fn $new_thunk$() -> $pbi$::RawMessage;
fn $delete_thunk$(raw_msg: $pbi$::RawMessage);
fn $serialize_thunk$(raw_msg: $pbi$::RawMessage) -> $pbr$::SerializedData;
fn $deserialize_thunk$(raw_msg: $pbi$::RawMessage, data: $pbr$::SerializedData) -> bool;
fn $copy_from_thunk$(dst: $pbi$::RawMessage, src: $pbi$::RawMessage);
)rs");
return;

Expand All @@ -139,11 +148,13 @@ void MessageExterns(Context& ctx, const Descriptor& msg) {
{"new_thunk", Thunk(ctx, msg, "new")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "parse")},
{"minitable", MinitableName(msg)},
},
R"rs(
fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage;
fn $serialize_thunk$(msg: $pbi$::RawMessage, arena: $pbi$::RawArena, len: &mut usize) -> $NonNull$<u8>;
fn $deserialize_thunk$(data: *const u8, size: usize, arena: $pbi$::RawArena) -> Option<$pbi$::RawMessage>;
static $minitable$: $pbr$::RawMiniTable;
)rs");
return;
}
Expand All @@ -163,6 +174,38 @@ void MessageDrop(Context& ctx, const Descriptor& msg) {
)rs");
}

void MessageSettableValue(Context& ctx, const Descriptor& msg) {
switch (ctx.opts().kernel) {
case Kernel::kCpp:
ctx.Emit({{"copy_from_thunk", Thunk(ctx, msg, "copy_from")}},
R"rs(
impl<'msg> $pb$::SettableValue<$Msg$> for $Msg$View<'msg> {
fn set_on<'dst>(
self, _private: $pb$::__internal::Private, mutator: $pb$::Mut<'dst, $Msg$>)
where $Msg$: 'dst {
unsafe { $copy_from_thunk$(mutator.inner.msg(), self.msg) };
}
}
)rs");
return;

case Kernel::kUpb:
ctx.Emit({{"minitable", MinitableName(msg)}}, R"rs(
impl<'msg> $pb$::SettableValue<$Msg$> for $Msg$View<'msg> {
fn set_on<'dst>(
self, _private: $pb$::__internal::Private, mutator: $pb$::Mut<'dst, $Msg$>)
where $Msg$: 'dst {
unsafe { $pbr$::upb_Message_DeepCopy(
mutator.inner.msg(), self.msg, &$minitable$, mutator.inner.arena()) };
}
}
)rs");
return;
}

ABSL_LOG(FATAL) << "unreachable";
}

void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field,
bool is_mut) {
auto fieldName = field.name();
Expand Down Expand Up @@ -397,7 +440,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
{"accessor_fns_for_views",
[&] { AccessorsForViewOrMut(ctx, msg, false); }},
{"accessor_fns_for_muts",
[&] { AccessorsForViewOrMut(ctx, msg, true); }}},
[&] { AccessorsForViewOrMut(ctx, msg, true); }},
{"settable_impl", [&] { MessageSettableValue(ctx, msg); }}},
R"rs(
#[allow(non_camel_case_types)]
// TODO: Implement support for debug redaction
Expand Down Expand Up @@ -453,13 +497,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) {
}
}
impl<'a> $pb$::SettableValue<$Msg$> for $Msg$View<'a> {
fn set_on<'b>(self, _private: $pb$::__internal::Private, _mutator: $pb$::Mut<'b, $Msg$>)
where
$Msg$: 'b {
todo!()
}
}
$settable_impl$
#[derive(Debug)]
#[allow(dead_code)]
Expand Down Expand Up @@ -572,6 +610,7 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
{"delete_thunk", Thunk(ctx, msg, "delete")},
{"serialize_thunk", Thunk(ctx, msg, "serialize")},
{"deserialize_thunk", Thunk(ctx, msg, "deserialize")},
{"copy_from_thunk", Thunk(ctx, msg, "copy_from")},
{"nested_msg_thunks",
[&] {
for (int i = 0; i < msg.nested_type_count(); ++i) {
Expand Down Expand Up @@ -606,6 +645,10 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) {
return msg->ParseFromArray(data.data, data.len);
}
void $copy_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) {
dst->CopyFrom(*src);
}
$accessor_thunks$
$oneof_thunks$
Expand Down
1 change: 1 addition & 0 deletions src/libprotoc.map
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
extern "C++" {
*google*;
pb::*;
upb::*;
};
scc_info_*;
descriptor_table_*;
Expand Down

0 comments on commit cd745eb

Please sign in to comment.