From 662190d807ab5840ff67d86f27a3886d00c55be3 Mon Sep 17 00:00:00 2001 From: Kevin King Date: Fri, 15 Dec 2023 07:41:49 -0800 Subject: [PATCH] impl SettableValue for MsgView For the cpp runtime, call the `Message::CopyFrom` method. For the upb runtime, expose the message `MiniTable` and call `upb_Message_DeepCopy`. PiperOrigin-RevId: 591250339 --- CMakeLists.txt | 2 + rust/test/shared/accessors_test.rs | 14 +++- rust/upb.rs | 24 +++++++ rust/upb_kernel/BUILD | 1 + rust/upb_kernel/upb_api.c | 3 +- src/google/protobuf/compiler/rust/BUILD.bazel | 1 + src/google/protobuf/compiler/rust/message.cc | 64 ++++++++++++++++--- src/libprotoc.map | 1 + 8 files changed, 100 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f78fccc27a5b..4fae55bb804c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -279,6 +279,8 @@ endif (MSVC) include_directories( ${ZLIB_INCLUDE_DIRECTORIES} ${protobuf_BINARY_DIR} + # Support #include-ing other top-level directories, i.e. upb_generator. + ${protobuf_SOURCE_DIR} ${protobuf_SOURCE_DIR}/src) set(protobuf_ABSL_PROVIDER "module" CACHE STRING "Provider of absl library") diff --git a/rust/test/shared/accessors_test.rs b/rust/test/shared/accessors_test.rs index 15758a1f2436..289f3765afd6 100644 --- a/rust/test/shared/accessors_test.rs +++ b/rust/test/shared/accessors_test.rs @@ -10,7 +10,7 @@ use googletest::prelude::*; use matchers::{is_set, is_unset}; use protobuf::Optional; -use unittest_proto::proto2_unittest::{TestAllTypes, TestAllTypes_}; +use unittest_proto::proto2_unittest::{NestedTestAllTypes, TestAllTypes, TestAllTypes_}; #[test] fn test_default_accessors() { @@ -735,3 +735,15 @@ fn test_oneof_mut_accessors() { msg.oneof_bytes_mut().set(b"123"); assert_that!(msg.oneof_field_mut(), matches_pattern!(OneofBytes(_))); } + +#[test] +fn test_set_message_from_view() { + use protobuf::MutProxy; + let mut m1 = NestedTestAllTypes::new(); + + let mut m2 = NestedTestAllTypes::new(); + m2.payload_mut().optional_int32_mut().set(1); + + m1.payload_mut().set(m2.payload()); + assert_that!(m1.payload().optional_int32(), eq(1)); +} diff --git a/rust/upb.rs b/rust/upb.rs index cef685c3a98f..9e4bdc15d002 100644 --- a/rust/upb.rs +++ b/rust/upb.rs @@ -277,6 +277,10 @@ impl<'msg> MutatorMessageRef<'msg> { pub fn msg(&self) -> RawMessage { self.msg } + + pub fn raw_arena(&self, _private: Private) -> RawArena { + self.arena.raw() + } } pub fn copy_bytes_in_arena_if_needed_by_runtime<'msg>( @@ -297,6 +301,26 @@ pub fn copy_bytes_in_arena_if_needed_by_runtime<'msg>( } } +/// Opaque struct containing a upb_MiniTable. +/// +/// This wrapper is a workaround until stabilization of extern C types. +#[repr(C)] +pub struct OpaqueMiniTable { + // TODO: consider importing a minitable struct declared in + // google3/third_party/upb/bits. + _data: [u8; 0], + _marker: std::marker::PhantomData<(*mut u8, ::std::marker::PhantomPinned)>, +} + +extern "C" { + pub fn upb_Message_DeepCopy( + dst: RawMessage, + src: RawMessage, + mini_table: *const OpaqueMiniTable, + arena: RawArena, + ); +} + /// The raw type-erased pointer version of `RepeatedMut`. /// /// Contains a `upb_Array*` as well as `RawArena`, most likely that of the diff --git a/rust/upb_kernel/BUILD b/rust/upb_kernel/BUILD index 3c373323e03a..ba92e612c9eb 100644 --- a/rust/upb_kernel/BUILD +++ b/rust/upb_kernel/BUILD @@ -10,5 +10,6 @@ cc_library( deps = [ "//upb:mem", "//upb:message", + "//upb:message_copy", ], ) diff --git a/rust/upb_kernel/upb_api.c b/rust/upb_kernel/upb_api.c index 7de8d0bd4da7..e168b02e69a0 100644 --- a/rust/upb_kernel/upb_api.c +++ b/rust/upb_kernel/upb_api.c @@ -10,4 +10,5 @@ #include "upb/mem/arena.h" // IWYU pragma: keep #include "upb/message/array.h" // IWYU pragma: keep -#include "upb/message/map.h" // IWYU pragma: keep +#include "upb/message/copy.h" // IWYU pragma: keep +#include "upb/message/map.h" // IWYU pragma: keep \ No newline at end of file diff --git a/src/google/protobuf/compiler/rust/BUILD.bazel b/src/google/protobuf/compiler/rust/BUILD.bazel index e9639700a1db..46d0723577ff 100644 --- a/src/google/protobuf/compiler/rust/BUILD.bazel +++ b/src/google/protobuf/compiler/rust/BUILD.bazel @@ -42,6 +42,7 @@ cc_library( ":oneof", "//src/google/protobuf:protobuf_nowkt", "//src/google/protobuf/compiler/cpp:names", + "//upb_generator:mangle", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", ], diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 6f5454d1a5de..28326536ac71 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -7,6 +7,8 @@ #include "google/protobuf/compiler/rust/message.h" +#include + #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/string_view.h" @@ -17,6 +19,7 @@ #include "google/protobuf/compiler/rust/naming.h" #include "google/protobuf/compiler/rust/oneof.h" #include "google/protobuf/descriptor.h" +#include "upb_generator/mangle.h" namespace google { namespace protobuf { @@ -24,6 +27,10 @@ namespace compiler { namespace rust { namespace { +std::string UpbMinitableName(const Descriptor& msg) { + return upb::generator::MessageInit(msg.full_name()); +} + void MessageNew(Context& ctx, const Descriptor& msg) { switch (ctx.opts().kernel) { case Kernel::kCpp: @@ -124,12 +131,14 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { {"delete_thunk", Thunk(ctx, msg, "delete")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")}, {"deserialize_thunk", Thunk(ctx, msg, "deserialize")}, + {"copy_from_thunk", Thunk(ctx, msg, "copy_from")}, }, R"rs( fn $new_thunk$() -> $pbi$::RawMessage; fn $delete_thunk$(raw_msg: $pbi$::RawMessage); fn $serialize_thunk$(raw_msg: $pbi$::RawMessage) -> $pbr$::SerializedData; fn $deserialize_thunk$(raw_msg: $pbi$::RawMessage, data: $pbr$::SerializedData) -> bool; + fn $copy_from_thunk$(dst: $pbi$::RawMessage, src: $pbi$::RawMessage); )rs"); return; @@ -139,11 +148,15 @@ void MessageExterns(Context& ctx, const Descriptor& msg) { {"new_thunk", Thunk(ctx, msg, "new")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")}, {"deserialize_thunk", Thunk(ctx, msg, "parse")}, + {"minitable", UpbMinitableName(msg)}, }, R"rs( fn $new_thunk$(arena: $pbi$::RawArena) -> $pbi$::RawMessage; fn $serialize_thunk$(msg: $pbi$::RawMessage, arena: $pbi$::RawArena, len: &mut usize) -> $NonNull$; fn $deserialize_thunk$(data: *const u8, size: usize, arena: $pbi$::RawArena) -> Option<$pbi$::RawMessage>; + /// Opaque wrapper for this message's MiniTable. The only valid way to + /// reference this static is with `std::ptr::addr_of!(..)`. + static $minitable$: $pbr$::OpaqueMiniTable; )rs"); return; } @@ -163,6 +176,41 @@ void MessageDrop(Context& ctx, const Descriptor& msg) { )rs"); } +void MessageSettableValue(Context& ctx, const Descriptor& msg) { + switch (ctx.opts().kernel) { + case Kernel::kCpp: + ctx.Emit({{"copy_from_thunk", Thunk(ctx, msg, "copy_from")}}, R"rs( + impl<'msg> $pb$::SettableValue<$Msg$> for $Msg$View<'msg> { + fn set_on<'dst>( + self, _private: $pbi$::Private, mutator: $pb$::Mut<'dst, $Msg$>) + where $Msg$: 'dst { + unsafe { $copy_from_thunk$(mutator.inner.msg(), self.msg) }; + } + } + )rs"); + return; + + case Kernel::kUpb: + ctx.Emit({{"minitable", UpbMinitableName(msg)}}, R"rs( + impl<'msg> $pb$::SettableValue<$Msg$> for $Msg$View<'msg> { + fn set_on<'dst>( + self, _private: $pbi$::Private, mutator: $pb$::Mut<'dst, $Msg$>) + where $Msg$: 'dst { + unsafe { $pbr$::upb_Message_DeepCopy( + mutator.inner.msg(), + self.msg, + $std$::ptr::addr_of!($minitable$), + mutator.inner.raw_arena($pbi$::Private), + ) }; + } + } + )rs"); + return; + } + + ABSL_LOG(FATAL) << "unreachable"; +} + void GetterForViewOrMut(Context& ctx, const FieldDescriptor& field, bool is_mut) { auto fieldName = field.name(); @@ -397,7 +445,8 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { {"accessor_fns_for_views", [&] { AccessorsForViewOrMut(ctx, msg, false); }}, {"accessor_fns_for_muts", - [&] { AccessorsForViewOrMut(ctx, msg, true); }}}, + [&] { AccessorsForViewOrMut(ctx, msg, true); }}, + {"settable_impl", [&] { MessageSettableValue(ctx, msg); }}}, R"rs( #[allow(non_camel_case_types)] // TODO: Implement support for debug redaction @@ -453,13 +502,7 @@ void GenerateRs(Context& ctx, const Descriptor& msg) { } } - impl<'a> $pb$::SettableValue<$Msg$> for $Msg$View<'a> { - fn set_on<'b>(self, _private: $pb$::__internal::Private, _mutator: $pb$::Mut<'b, $Msg$>) - where - $Msg$: 'b { - todo!() - } - } + $settable_impl$ #[derive(Debug)] #[allow(dead_code)] @@ -572,6 +615,7 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { {"delete_thunk", Thunk(ctx, msg, "delete")}, {"serialize_thunk", Thunk(ctx, msg, "serialize")}, {"deserialize_thunk", Thunk(ctx, msg, "deserialize")}, + {"copy_from_thunk", Thunk(ctx, msg, "copy_from")}, {"nested_msg_thunks", [&] { for (int i = 0; i < msg.nested_type_count(); ++i) { @@ -606,6 +650,10 @@ void GenerateThunksCc(Context& ctx, const Descriptor& msg) { return msg->ParseFromArray(data.data, data.len); } + void $copy_from_thunk$($QualifiedMsg$* dst, const $QualifiedMsg$* src) { + dst->CopyFrom(*src); + } + $accessor_thunks$ $oneof_thunks$ diff --git a/src/libprotoc.map b/src/libprotoc.map index 6f3a36e481e1..24a5f76a5593 100644 --- a/src/libprotoc.map +++ b/src/libprotoc.map @@ -3,6 +3,7 @@ extern "C++" { *google*; pb::*; + upb::*; }; scc_info_*; descriptor_table_*;