diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/AnyUtils.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/AnyUtils.java index 34e1e3e73..454f50d1a 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/AnyUtils.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/AnyUtils.java @@ -69,8 +69,9 @@ static ExtensionRegistry defaultExtensionRegistry() { return DEFAULT_EXTENSION_REGISTRY; } - /** Unpack an `Any` proto using the TypeRegistry and ExtensionRegistry on `config`. */ - static Optional unpack(Message any, FluentEqualityConfig config) { + /** Unpack an `Any` proto using the given TypeRegistry and ExtensionRegistry. */ + static Optional unpack( + Message any, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) { Preconditions.checkArgument( any.getDescriptorForType().equals(Any.getDescriptor()), "Expected type google.protobuf.Any, but was: %s", @@ -80,13 +81,12 @@ static Optional unpack(Message any, FluentEqualityConfig config) { ByteString value = (ByteString) any.getField(valueFieldDescriptor()); try { - Descriptor descriptor = config.useTypeRegistry().getDescriptorForTypeUrl(typeUrl); + Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(typeUrl); if (descriptor == null) { return Optional.absent(); } - Message defaultMessage = - DynamicMessage.parseFrom(descriptor, value, config.useExtensionRegistry()); + Message defaultMessage = DynamicMessage.parseFrom(descriptor, value, extensionRegistry); return Optional.of(defaultMessage); } catch (InvalidProtocolBufferException e) { return Optional.absent(); diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/DiffResult.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/DiffResult.java index e6c0af64a..990d350d1 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/DiffResult.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/DiffResult.java @@ -557,6 +557,9 @@ default void printFieldValue(SubScopeId subScopeId, Object o, StringBuilder sb) case UNKNOWN_FIELD_DESCRIPTOR: printFieldValue(subScopeId.unknownFieldDescriptor(), o, sb); return; + case UNPACKED_ANY_VALUE_TYPE: + printFieldValue(AnyUtils.valueFieldDescriptor(), o, sb); + return; } throw new AssertionError(subScopeId.kind()); } diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java index d33e2fee1..698b9d948 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java @@ -16,9 +16,12 @@ package com.google.common.truth.extensions.proto; +import com.google.common.base.Optional; import com.google.common.collect.Maps; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; +import com.google.protobuf.TypeRegistry; import com.google.protobuf.UnknownFieldSet; import java.util.List; import java.util.Map; @@ -62,7 +65,8 @@ boolean hasChild(SubScopeId subScopeId) { return children.containsKey(subScopeId); } - static FieldNumberTree fromMessage(Message message) { + static FieldNumberTree fromMessage( + Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) { FieldNumberTree tree = new FieldNumberTree(); // Known fields. @@ -72,15 +76,25 @@ static FieldNumberTree fromMessage(Message message) { FieldNumberTree childTree = new FieldNumberTree(); tree.children.put(subScopeId, childTree); - Object fieldValue = knownFieldValues.get(field); - if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { - if (field.isRepeated()) { - List valueList = (List) fieldValue; - for (Object value : valueList) { - childTree.merge(fromMessage((Message) value)); + if (field.equals(AnyUtils.valueFieldDescriptor())) { + // Handle Any protos specially. + Optional unpackedAny = AnyUtils.unpack(message, typeRegistry, extensionRegistry); + if (unpackedAny.isPresent()) { + tree.children.put( + SubScopeId.ofUnpackedAnyValueType(unpackedAny.get().getDescriptorForType()), + fromMessage(unpackedAny.get(), typeRegistry, extensionRegistry)); + } + } else { + Object fieldValue = knownFieldValues.get(field); + if (field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + if (field.isRepeated()) { + List valueList = (List) fieldValue; + for (Object value : valueList) { + childTree.merge(fromMessage((Message) value, typeRegistry, extensionRegistry)); + } + } else { + childTree.merge(fromMessage((Message) fieldValue, typeRegistry, extensionRegistry)); } - } else { - childTree.merge(fromMessage((Message) fieldValue)); } } } @@ -91,11 +105,14 @@ static FieldNumberTree fromMessage(Message message) { return tree; } - static FieldNumberTree fromMessages(Iterable messages) { + static FieldNumberTree fromMessages( + Iterable messages, + TypeRegistry typeRegistry, + ExtensionRegistry extensionRegistry) { FieldNumberTree tree = new FieldNumberTree(); for (Message message : messages) { if (message != null) { - tree.merge(fromMessage(message)); + tree.merge(fromMessage(message, typeRegistry, extensionRegistry)); } } return tree; diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java index 4acf9916b..0eadd855c 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java @@ -28,7 +28,9 @@ import com.google.common.collect.Lists; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; +import com.google.protobuf.TypeRegistry; import java.util.List; /** @@ -62,13 +64,17 @@ private static FieldScope create( // Instantiation methods. ////////////////////////////////////////////////////////////////////////////////////////////////// - static FieldScope createFromSetFields(Message message) { + static FieldScope createFromSetFields( + Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) { return create( - FieldScopeLogic.partialScope(message), + FieldScopeLogic.partialScope(message, typeRegistry, extensionRegistry), Functions.constant(String.format("FieldScopes.fromSetFields({%s})", message.toString()))); } - static FieldScope createFromSetFields(Iterable messages) { + static FieldScope createFromSetFields( + Iterable messages, + TypeRegistry typeRegistry, + ExtensionRegistry extensionRegistry) { if (emptyOrAllNull(messages)) { return create( FieldScopeLogic.none(), @@ -82,7 +88,8 @@ static FieldScope createFromSetFields(Iterable messages) { getDescriptors(messages)); return create( - FieldScopeLogic.partialScope(messages, optDescriptor.get()), + FieldScopeLogic.partialScope( + messages, optDescriptor.get(), typeRegistry, extensionRegistry), Functions.constant(String.format("FieldScopes.fromSetFields(%s)", formatList(messages)))); } diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java index 31fd0563b..dfca1f805 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java @@ -28,7 +28,9 @@ import com.google.errorprone.annotations.ForOverride; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; +import com.google.protobuf.TypeRegistry; import java.util.List; /** @@ -267,14 +269,21 @@ public String toString() { } } - static FieldScopeLogic partialScope(Message message) { + static FieldScopeLogic partialScope( + Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) { return new RootPartialScopeLogic( - FieldNumberTree.fromMessage(message), message.toString(), message.getDescriptorForType()); + FieldNumberTree.fromMessage(message, typeRegistry, extensionRegistry), + message.toString(), + message.getDescriptorForType()); } - static FieldScopeLogic partialScope(Iterable messages, Descriptor descriptor) { + static FieldScopeLogic partialScope( + Iterable messages, + Descriptor descriptor, + TypeRegistry typeRegistry, + ExtensionRegistry extensionRegistry) { return new RootPartialScopeLogic( - FieldNumberTree.fromMessages(messages), + FieldNumberTree.fromMessages(messages, typeRegistry, extensionRegistry), Joiner.on(", ").useForNull("null").join(messages), descriptor); } @@ -304,11 +313,18 @@ protected FieldMatcherLogicBase(boolean isRecursive) { @Override final FieldScopeResult policyFor(Descriptor rootDescriptor, SubScopeId subScopeId) { - if (subScopeId.kind() == SubScopeId.Kind.UNKNOWN_FIELD_DESCRIPTOR) { - return FieldScopeResult.EXCLUDED_RECURSIVELY; + FieldDescriptor fieldDescriptor = null; + switch (subScopeId.kind()) { + case FIELD_DESCRIPTOR: + fieldDescriptor = subScopeId.fieldDescriptor(); + break; + case UNPACKED_ANY_VALUE_TYPE: + fieldDescriptor = AnyUtils.valueFieldDescriptor(); + break; + case UNKNOWN_FIELD_DESCRIPTOR: + return FieldScopeResult.EXCLUDED_RECURSIVELY; } - FieldDescriptor fieldDescriptor = subScopeId.fieldDescriptor(); if (matchesFieldDescriptor(rootDescriptor, fieldDescriptor)) { return FieldScopeResult.of(/* included = */ true, isRecursive); } diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopes.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopes.java index 9b709e550..0ba6b4402 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopes.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopes.java @@ -19,7 +19,9 @@ import static com.google.common.truth.extensions.proto.FieldScopeUtil.asList; import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.ExtensionRegistry; import com.google.protobuf.Message; +import com.google.protobuf.TypeRegistry; /** Factory class for {@link FieldScope} instances. */ public final class FieldScopes { @@ -66,7 +68,58 @@ public final class FieldScopes { // Alternatively II, add Scope.PARTIAL support to ProtoFluentEquals, but with a different name and // explicit documentation that it may cause issues with Proto 3. public static FieldScope fromSetFields(Message message) { - return FieldScopeImpl.createFromSetFields(message); + return fromSetFields( + message, AnyUtils.defaultTypeRegistry(), AnyUtils.defaultExtensionRegistry()); + } + + /** + * Returns a {@link FieldScope} which is constrained to precisely those specific field paths that + * are explicitly set in the message. Note that, for version 3 protobufs, such a {@link + * FieldScope} will omit fields in the provided message which are set to default values. + * + *

This can be used limit the scope of a comparison to a complex set of fields in a very brief + * statement. Often, {@code message} is the expected half of a comparison about to be performed. + * + *

Example usage: + * + *

{@code
+   * Foo actual = Foo.newBuilder().setBar(3).setBaz(4).build();
+   * Foo expected = Foo.newBuilder().setBar(3).setBaz(5).build();
+   * // Fails, because actual.getBaz() != expected.getBaz().
+   * assertThat(actual).isEqualTo(expected);
+   *
+   * Foo scope = Foo.newBuilder().setBar(2).build();
+   * // Succeeds, because only the field 'bar' is compared.
+   * assertThat(actual).withPartialScope(FieldScopes.fromSetFields(scope)).isEqualTo(expected);
+   *
+   * }
+ * + *

The returned {@link FieldScope} does not respect repeated field indices nor map keys. For + * example, if the provided message sets different field values for different elements of a + * repeated field, like so: + * + *

{@code
+   * sub_message: {
+   *   foo: "foo"
+   * }
+   * sub_message: {
+   *   bar: "bar"
+   * }
+   * }
+ * + *

The {@link FieldScope} will contain {@code sub_message.foo} and {@code sub_message.bar} for + * *all* repeated {@code sub_messages}, including those beyond index 1. + * + *

If there are {@code google.protobuf.Any} protos anywhere within these messages, they will be + * unpacked using the provided {@link TypeRegistry} and {@link ExtensionRegistry} to determine + * which fields within them should be compared. + * + * @see ProtoFluentAssertion#unpackingAnyUsing + * @since 1.2 + */ + public static FieldScope fromSetFields( + Message message, TypeRegistry typeRegistry, ExtensionRegistry extensionRegistry) { + return FieldScopeImpl.createFromSetFields(message, typeRegistry, extensionRegistry); } /** @@ -89,7 +142,29 @@ public static FieldScope fromSetFields( * or the {@link FieldScope} for the merge of all the messages. These are equivalent. */ public static FieldScope fromSetFields(Iterable messages) { - return FieldScopeImpl.createFromSetFields(messages); + return fromSetFields( + messages, AnyUtils.defaultTypeRegistry(), AnyUtils.defaultExtensionRegistry()); + } + + /** + * Creates a {@link FieldScope} covering the fields set in every message in the provided list of + * messages, with the same semantics as in {@link #fromSetFields(Message)}. + * + *

This can be thought of as the union of the {@link FieldScope}s for each individual message, + * or the {@link FieldScope} for the merge of all the messages. These are equivalent. + * + *

If there are {@code google.protobuf.Any} protos anywhere within these messages, they will be + * unpacked using the provided {@link TypeRegistry} and {@link ExtensionRegistry} to determine + * which fields within them should be compared. + * + * @see ProtoFluentAssertion#unpackingAnyUsing + * @since 1.2 + */ + public static FieldScope fromSetFields( + Iterable messages, + TypeRegistry typeRegistry, + ExtensionRegistry extensionRegistry) { + return FieldScopeImpl.createFromSetFields(messages, typeRegistry, extensionRegistry); } /** diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FluentEqualityConfig.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FluentEqualityConfig.java index 25efe7212..1c425bd46 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FluentEqualityConfig.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FluentEqualityConfig.java @@ -273,7 +273,11 @@ final FluentEqualityConfig withExpectedMessages(Iterable mess Builder builder = toBuilder().setHasExpectedMessages(true); if (compareExpectedFieldsOnly()) { builder.setCompareFieldsScope( - FieldScopeLogic.and(compareFieldsScope(), FieldScopes.fromSetFields(messages).logic())); + FieldScopeLogic.and( + compareFieldsScope(), + FieldScopeImpl.createFromSetFields( + messages, useTypeRegistry(), useExtensionRegistry()) + .logic())); } return builder.build(); } diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/ProtoTruthMessageDifferencer.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/ProtoTruthMessageDifferencer.java index 5ccf6dc16..8b6ebf26e 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/ProtoTruthMessageDifferencer.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/ProtoTruthMessageDifferencer.java @@ -221,8 +221,10 @@ private DiffResult diffAnyMessages( if (shouldCompareValue == FieldScopeResult.EXCLUDED_RECURSIVELY) { valueDiffResult = SingularField.ignored(name(AnyUtils.valueFieldDescriptor())); } else { - Optional unpackedActual = AnyUtils.unpack(actual, config); - Optional unpackedExpected = AnyUtils.unpack(expected, config); + Optional unpackedActual = + AnyUtils.unpack(actual, config.useTypeRegistry(), config.useExtensionRegistry()); + Optional unpackedExpected = + AnyUtils.unpack(expected, config.useTypeRegistry(), config.useExtensionRegistry()); if (unpackedActual.isPresent() && unpackedExpected.isPresent() && descriptorsMatch(unpackedActual.get(), unpackedExpected.get())) { @@ -235,7 +237,10 @@ && descriptorsMatch(unpackedActual.get(), unpackedExpected.get())) { shouldCompareValue == FieldScopeResult.EXCLUDED_NONRECURSIVELY, AnyUtils.valueFieldDescriptor(), name(AnyUtils.valueFieldDescriptor()), - config.subScope(rootDescriptor, AnyUtils.valueSubScopeId())); + config.subScope( + rootDescriptor, + SubScopeId.ofUnpackedAnyValueType( + unpackedActual.get().getDescriptorForType()))); } else { valueDiffResult = compareSingularValue( diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/SubScopeId.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/SubScopeId.java index 4860969e7..925c1569f 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/SubScopeId.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/SubScopeId.java @@ -17,13 +17,15 @@ package com.google.common.truth.extensions.proto; import com.google.auto.value.AutoOneOf; +import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; @AutoOneOf(SubScopeId.Kind.class) abstract class SubScopeId { enum Kind { FIELD_DESCRIPTOR, - UNKNOWN_FIELD_DESCRIPTOR; + UNKNOWN_FIELD_DESCRIPTOR, + UNPACKED_ANY_VALUE_TYPE; } abstract Kind kind(); @@ -32,6 +34,8 @@ enum Kind { abstract UnknownFieldDescriptor unknownFieldDescriptor(); + abstract Descriptor unpackedAnyValueType(); + /** Returns a short, human-readable version of this identifier. */ final String shortName() { switch (kind()) { @@ -41,6 +45,8 @@ final String shortName() { : fieldDescriptor().getName(); case UNKNOWN_FIELD_DESCRIPTOR: return String.valueOf(unknownFieldDescriptor().fieldNumber()); + case UNPACKED_ANY_VALUE_TYPE: + return AnyUtils.valueFieldDescriptor().getName(); } throw new AssertionError(kind()); } @@ -52,4 +58,8 @@ static SubScopeId of(FieldDescriptor fieldDescriptor) { static SubScopeId of(UnknownFieldDescriptor unknownFieldDescriptor) { return AutoOneOf_SubScopeId.unknownFieldDescriptor(unknownFieldDescriptor); } + + static SubScopeId ofUnpackedAnyValueType(Descriptor unpackedAnyValueType) { + return AutoOneOf_SubScopeId.unpackedAnyValueType(unpackedAnyValueType); + } } diff --git a/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java b/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java index fb0a07dec..f99e7554c 100644 --- a/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java +++ b/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java @@ -384,6 +384,96 @@ public void testIgnoringFieldOfAnyMessage() throws Exception { .contains("modified: o_any_message.value.r_string[0]: \"foo\" -> \"bar\""); } + @Test + public void testAnyMessageComparingExpectedFieldsOnly() throws Exception { + + String typeUrl = + isProto3() + ? "type.googleapis.com/com.google.common.truth.extensions.proto.SubTestMessage3" + : "type.googleapis.com/com.google.common.truth.extensions.proto.SubTestMessage2"; + + Message message = parse("o_any_message { [" + typeUrl + "]: { o_int: 2 } }"); + Message eqMessage = + parse("o_any_message { [" + typeUrl + "]: { o_int: 2 r_string: \"foo\" } }"); + Message diffMessage = + parse("o_any_message { [" + typeUrl + "]: { o_int: 3 r_string: \"bar\" } }"); + + expectThat(eqMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .comparingExpectedFieldsOnly() + .isEqualTo(message); + expectThat(diffMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .comparingExpectedFieldsOnly() + .isNotEqualTo(message); + } + + @Test + public void testInvalidAnyMessageComparingExpectedFieldsOnly() throws Exception { + + Message message = parse("o_any_message { type_url: 'invalid-type' value: 'abc123' }"); + Message eqMessage = parse("o_any_message { type_url: 'invalid-type' value: 'abc123' }"); + Message diffMessage = parse("o_any_message { type_url: 'invalid-type' value: 'def456' }"); + + expectThat(eqMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .comparingExpectedFieldsOnly() + .isEqualTo(message); + expectThat(diffMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .comparingExpectedFieldsOnly() + .isNotEqualTo(message); + } + + @Test + public void testDifferentAnyMessagesComparingExpectedFieldsOnly() throws Exception { + + // 'o_int' and 'o_float' have the same field numbers in both messages. However, to compare + // accurately, we incorporate the unpacked Descriptor type into the FieldNumberTree as well to + // disambiguate. + String typeUrl1 = + isProto3() + ? "type.googleapis.com/com.google.common.truth.extensions.proto.SubTestMessage3" + : "type.googleapis.com/com.google.common.truth.extensions.proto.SubTestMessage2"; + String typeUrl2 = + isProto3() + ? "type.googleapis.com/com.google.common.truth.extensions.proto.SubSubTestMessage3" + : "type.googleapis.com/com.google.common.truth.extensions.proto.SubSubTestMessage2"; + + Message message = + parse( + "r_any_message { [" + + typeUrl1 + + "]: { o_int: 2 } } r_any_message { [" + + typeUrl2 + + "]: { o_float: 3.1 } }"); + Message eqMessage = + parse( + "r_any_message { [" + + typeUrl1 + + "]: { o_int: 2 o_float: 1.9 } } r_any_message { [" + + typeUrl2 + + "]: { o_int: 5 o_float: 3.1 } }"); + Message diffMessage = + parse( + "r_any_message { [" + + typeUrl1 + + "]: { o_int: 5 o_float: 3.1 } } r_any_message { [" + + typeUrl2 + + "]: { o_int: 2 o_float: 1.9 } }"); + + expectThat(eqMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .ignoringRepeatedFieldOrder() + .comparingExpectedFieldsOnly() + .isEqualTo(message); + expectThat(diffMessage) + .unpackingAnyUsing(getTypeRegistry(), getExtensionRegistry()) + .ignoringRepeatedFieldOrder() + .comparingExpectedFieldsOnly() + .isNotEqualTo(message); + } + @Test public void testIgnoringAllButOneFieldOfSubMessage() { // Consider all of TestMessage, but none of o_sub_test_message, except