diff --git a/rust/test/nested.proto b/rust/test/nested.proto index e6d4f29a7ae9..59aee229db27 100644 --- a/rust/test/nested.proto +++ b/rust/test/nested.proto @@ -11,6 +11,10 @@ package nest; message Outer { message Inner { + message InnerSubMsg { + optional bool flag = 1; + } + optional double double = 1; optional float float = 2; optional int32 int32 = 3; @@ -26,6 +30,7 @@ message Outer { optional bool bool = 13; optional string string = 14; optional bytes bytes = 15; + optional InnerSubMsg innersubmsg = 16; message SuperInner { message DuperInner { @@ -40,4 +45,15 @@ message Outer { optional Inner inner = 1; optional .nest.Outer.Inner.SuperInner.DuperInner.EvenMoreInner .CantBelieveItsSoInner deep = 2; + + optional NotInside notinside = 3; +} + +message NotInside { + optional int32 num = 1; +} + +message Recursive { + optional Recursive rec = 1; + optional int32 num = 2; } diff --git a/rust/test/shared/simple_nested_test.rs b/rust/test/shared/simple_nested_test.rs index cd877996a704..a13d4b53f2c6 100644 --- a/rust/test/shared/simple_nested_test.rs +++ b/rust/test/shared/simple_nested_test.rs @@ -39,6 +39,7 @@ fn test_nested_views() { assert_that!(inner_msg.bool(), eq(false)); assert_that!(*inner_msg.string().as_bytes(), empty()); assert_that!(*inner_msg.bytes(), empty()); + assert_that!(inner_msg.innersubmsg().flag(), eq(false)); } #[test] @@ -96,3 +97,20 @@ fn test_nested_muts() { ); // TODO: add mutation tests for strings and bytes } + +#[test] +fn test_msg_from_outside() { + // let's make sure that we're not just working for messages nested inside + // messages, messages from without and within should work + let outer = nested_proto::nest::Outer::new(); + assert_that!(outer.notinside().num(), eq(0)); +} + +#[test] +fn test_recursive_msg() { + let rec = nested_proto::nest::Recursive::new(); + assert_that!(rec.num(), eq(0)); + assert_that!(rec.rec().num(), eq(0)); + assert_that!(rec.rec().rec().num(), eq(0)); // turtles all the way down... + assert_that!(rec.rec().rec().rec().num(), eq(0)); // ... ad infinitum +} diff --git a/src/google/protobuf/compiler/rust/message.cc b/src/google/protobuf/compiler/rust/message.cc index 6bc142fc2468..9cc35239c59e 100644 --- a/src/google/protobuf/compiler/rust/message.cc +++ b/src/google/protobuf/compiler/rust/message.cc @@ -9,7 +9,7 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" -#include "absl/strings/str_cat.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "google/protobuf/compiler/cpp/helpers.h" #include "google/protobuf/compiler/cpp/names.h" @@ -173,8 +173,50 @@ void GetterForViewOrMut(Context field, bool is_mut) { // If we're dealing with a Mut, the getter must be supplied // self.inner.msg() whereas a View has to be supplied self.msg auto self = is_mut ? "self.inner.msg()" : "self.msg"; - auto rsType = PrimitiveRsTypeName(field.desc()); + if (fieldType == FieldDescriptor::TYPE_MESSAGE) { + Context d = field.WithDesc(field.desc().message_type()); + auto prefix = "crate::" + GetCrateRelativeQualifiedPath(d); + // TODO: investigate imports breaking submsg accessors + if (absl::StrContains(prefix, "import")) { + return; + } + field.Emit( + { + {"prefix", prefix}, + {"field", fieldName}, + {"self", self}, + {"getter_thunk", getter_thunk}, + // TODO: dedupe with singular_message.cc + { + "view_body", + [&] { + if (field.is_upb()) { + field.Emit({}, R"rs( + let submsg = unsafe { $getter_thunk$($self$) }; + match submsg { + None => $prefix$View::new($pbi$::Private, + $pbr$::ScratchSpace::zeroed_block($pbi$::Private)), + Some(field) => $prefix$View::new($pbi$::Private, field), + } + )rs"); + } else { + field.Emit({}, R"rs( + let submsg = unsafe { $getter_thunk$($self$) }; + $prefix$View::new($pbi$::Private, submsg) + )rs"); + } + }, + }, + }, + R"rs( + pub fn r#$field$(&self) -> $prefix$View { + $view_body$ + } + )rs"); + return; + } + auto rsType = PrimitiveRsTypeName(field.desc()); if (fieldType == FieldDescriptor::TYPE_STRING) { field.Emit( { @@ -250,8 +292,7 @@ void AccessorsForViewOrMut(Context msg, bool is_mut) { // TODO - add cord support if (field.desc().options().has_ctype()) continue; // TODO - if (field.desc().type() == FieldDescriptor::TYPE_MESSAGE || - field.desc().type() == FieldDescriptor::TYPE_ENUM || + if (field.desc().type() == FieldDescriptor::TYPE_ENUM || field.desc().type() == FieldDescriptor::TYPE_GROUP) continue; GetterForViewOrMut(field, is_mut);