From 03ad2ffa1ba9fc7f8d2889b51c3819def37e4cae Mon Sep 17 00:00:00 2001 From: Oleksandr Ferludin Date: Wed, 25 Oct 2023 05:20:16 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 576497631 --- tensorflow_gnn/__init__.py | 5 +- tensorflow_gnn/graph/batching_utils_test.py | 2 +- tensorflow_gnn/graph/graph_constants.py | 16 ++++ tensorflow_gnn/graph/graph_tensor.py | 80 ++++++++++++++++++- tensorflow_gnn/graph/graph_tensor_io_test.py | 18 ++--- tensorflow_gnn/graph/graph_tensor_test.py | 76 ++++++++++++++++-- tensorflow_gnn/graph/readout_test.py | 8 +- .../runner/utils/attribution_test.py | 10 +-- 8 files changed, 187 insertions(+), 28 deletions(-) diff --git a/tensorflow_gnn/__init__.py b/tensorflow_gnn/__init__.py index 3421023c..155d34e7 100644 --- a/tensorflow_gnn/__init__.py +++ b/tensorflow_gnn/__init__.py @@ -213,6 +213,10 @@ is_ragged_tensor = tensor_utils.is_ragged_tensor is_graph_tensor = graph_tensor_ops.is_graph_tensor +# Global state flags that controls GraphPieces checks. +enable_graph_tensor_inputs_validation = graph_constants.enable_graph_tensor_inputs_validation +disable_graph_tensor_inputs_validation = graph_constants.disable_graph_tensor_inputs_validation + # Prune imported module symbols so they're not accessible implicitly, # except those meant to be used as subpackages, like tfgnn.keras.*. # Please use the same order as for the import statements at the top. @@ -237,4 +241,3 @@ del tag_utils del tensor_utils del graph_schema - diff --git a/tensorflow_gnn/graph/batching_utils_test.py b/tensorflow_gnn/graph/batching_utils_test.py index 815edd58..9b512614 100644 --- a/tensorflow_gnn/graph/batching_utils_test.py +++ b/tensorflow_gnn/graph/batching_utils_test.py @@ -42,7 +42,7 @@ class DynamicBatchTest(tu.GraphTensorTestBase): 'f1': as_tensor([1.]), 'f2': as_tensor([[1., 2.]]), 'i3': as_tensor([[[1, 2], [3, 4]]]), - 'r1': as_ragged([[], ['a', 'b']]), + 'r1': as_ragged([[[], ['a', 'b']]]), }), ]) def testFeaturesBatching(self, target_num_components: int, diff --git a/tensorflow_gnn/graph/graph_constants.py b/tensorflow_gnn/graph/graph_constants.py index e6994de3..a23d7d8f 100644 --- a/tensorflow_gnn/graph/graph_constants.py +++ b/tensorflow_gnn/graph/graph_constants.py @@ -98,6 +98,10 @@ # the wider type. allow_indices_auto_casting = True +# If set, validates `GraphTensor` and its pieces and raises exception on an +# attempt to construct invalid graph tensors. +validate_graph_tensor_inputs = True + # The default choice for `indices_dtype`. # Can be either tf.int32 or tf.int64. # @@ -118,3 +122,15 @@ # An older name used before tensorflow_gnn 0.2. DEFAULT_STATE_NAME = HIDDEN_STATE + + +def disable_graph_tensor_inputs_validation(): + """Disables validation of `GraphTensor` inputs.""" + global validate_graph_tensor_inputs + validate_graph_tensor_inputs = False + + +def enable_graph_tensor_inputs_validation(): + """Enables validation for `GraphTensor` inputs.""" + global validate_graph_tensor_inputs + validate_graph_tensor_inputs = True diff --git a/tensorflow_gnn/graph/graph_tensor.py b/tensorflow_gnn/graph/graph_tensor.py index e3913d2d..ecb38b7f 100644 --- a/tensorflow_gnn/graph/graph_tensor.py +++ b/tensorflow_gnn/graph/graph_tensor.py @@ -104,7 +104,7 @@ def num_components(self) -> tf.Tensor: tf.debugging.assert_equal( tf.size(result), tf.size(result_dense), - message='`sizes` shape is not compatible with the piece shape')) + message='`sizes` shape is incompatible with the piece shape')) with tf.control_dependencies(check_ops): result = tf.identity(result_dense) @@ -163,13 +163,22 @@ def _from_features_and_sizes(cls, features: Fields, sizes: Field, raise NotImplementedError # Note that this graph piece does not use any metadata fields. - return cls._from_data( + result = cls._from_data( data=data, shape=sizes.shape[:-1], indices_dtype=indices_dtype, row_splits_dtype=row_splits_dtype, ) + if const.validate_graph_tensor_inputs: + # NOTE: The batch dimensions are already validated by the + # `GraphPieceBase._from_data()`. At this point we are checking only + # invariants specific to the `_GraphPieceWithFeatures`. + _static_check_sizes(result.sizes, result.shape.rank) + _static_check_items_dim(result.features, result.shape.rank) + + return result + def get_features_dict(self) -> Dict[FieldName, Field]: """Returns features copy as a dictionary.""" return dict(self._get_features_ref) @@ -231,6 +240,7 @@ def _from_feature_and_size_specs( gp.check_indices_dtype(sizes_spec.dtype, what='`sizes_spec`') features_spec = features_spec.copy() + data_spec = { _NodeOrEdgeSet._DATAKEY_FEATURES: features_spec, _NodeOrEdgeSet._DATAKEY_SIZES: sizes_spec, @@ -251,13 +261,19 @@ def _from_feature_and_size_specs( raise NotImplementedError # Note that this graph piece does not use any metadata fields. - return cls._from_data_spec( + result = cls._from_data_spec( data_spec, shape=sizes_spec.shape[:-1], indices_dtype=indices_dtype, row_splits_dtype=row_splits_dtype, ) + if const.validate_graph_tensor_inputs: + _static_check_sizes(result.sizes_spec, result.rank) + _static_check_items_dim(result.features_spec, result.rank) + + return result + @classmethod def _data_spec_with_indices_dtype( cls, data_spec: gp.DataSpec, dtype: tf.dtypes.DType @@ -1792,3 +1808,61 @@ def check_homogeneous_graph_tensor( """Raises ValueError when tfgnn.get_homogeneous_node_and_edge_set_name() does. """ _ = get_homogeneous_node_and_edge_set_name(graph, name=name) + + +def _static_check_sizes( + sizes: Union[Field, FieldSpec], graph_rank: int +) -> None: + """Checks graph component sizes rank. + + Args: + sizes: graph piece sizes, as `[*graph_shape, num_components]`. + graph_rank: The number of batch dimensions, as `graph_shape.rank`. + + Raises: + ValueError: if `sizes.shape.rank != graph_rank + 1`. + """ + expected_rank = graph_rank + 1 + if sizes.shape.rank != expected_rank: + raise ValueError( + f'`sizes` must be of rank {expected_rank} (number of batch dimensions' + f' plus 1), got {sizes.shape.rank}.' + ) + + +def _static_check_items_dim( + features: Union[Fields, FieldsSpec], graph_rank: int +) -> None: + """Checks items dimension in features shape. + + NOTE: here we check the subset of graph tensor shape rules and allow a mix of + fully defined and undefined item dimensions, as long as all fully defined + dimensions have the same sizes. The reason for this is that unknown dimensions + can originate from the imperfect static shape inference in Tensorflow or from + ragged tensors that have uniform inner dimensions but were constructed using + ragged row partitions. We delegate pedantic validation of the statically + unknown dimensions to the dynamic checks. + + Args: + features: graph pieces features or their type specs. Must have shapes + `[*graph_shape, num_items, *feature_shape]`. + graph_rank: the number of batch dimensions as `graph_shape.rank`. + + Raises: + ValueError: if some fully defined item dimensions do not match. + """ + + if not features: + return + + num_items = { + fvalue.shape[graph_rank]: fname + for fname, fvalue in features.items() + if fvalue.shape[graph_rank] is not None + } + if len(num_items) >= 2: + (da, fa), (db, fb) = list(num_items.items())[:2] + raise ValueError( + f'Features "{fa}" and "{fb}" have shapes with' + f' incompatible items dimension (dim={graph_rank}): {da} != {db}.' + ) diff --git a/tensorflow_gnn/graph/graph_tensor_io_test.py b/tensorflow_gnn/graph/graph_tensor_io_test.py index eef32c1b..a8f88c34 100644 --- a/tensorflow_gnn/graph/graph_tensor_io_test.py +++ b/tensorflow_gnn/graph/graph_tensor_io_test.py @@ -68,29 +68,29 @@ class TfExampleParsingFromSpecTest(TfExampleParsingTestBase): context_spec=gt.ContextSpec.from_field_specs(features_spec={ 'v': tf.TensorSpec(shape=(2,), dtype=tf.int16), 'm': tf.TensorSpec(shape=(2, 3), dtype=tf.int32), - 't': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.int64), + 't': tf.TensorSpec(shape=(2, 1, 2), dtype=tf.int64), })), examples=[ r""" features { feature {key: "context/v" value {int64_list {value: [1, 2]} } } feature {key: "context/m" value {int64_list {value: [1, 2, 3, 4, 5, 6]} } } - feature {key: "context/t" value {int64_list {value: [1, 2] } } } + feature {key: "context/t" value {int64_list {value: [1, 2, 3, 4] } } } }""", r""" features { feature {key: "context/v" value {int64_list {value: [9, 8]} } } feature {key: "context/m" value {int64_list {value: [9, 8, 7, 6, 5, 4]} } } - feature {key: "context/t" value {int64_list {value: [9, 8]} } } + feature {key: "context/t" value {int64_list {value: [9, 8, 7, 6]} } } }""" ], expected_values=[{ 'context/v': as_tensor([1, 2]), 'context/m': as_tensor([[1, 2, 3], [4, 5, 6]]), - 'context/t': as_tensor([[[1, 2]]]) + 'context/t': as_tensor([[[1, 2]], [[3, 4]]]) }, { 'context/v': as_tensor([9, 8]), 'context/m': as_tensor([[9, 8, 7], [6, 5, 4]]), - 'context/t': as_tensor([[[9, 8]]]) + 'context/t': as_tensor([[[9, 8]], [[7, 6]]]) }]), dict( description='context ragged features parsing', @@ -211,26 +211,26 @@ def testSingleExampleParsing( context_spec=gt.ContextSpec.from_field_specs(features_spec={ 'v': tf.TensorSpec(shape=(2,), dtype=tf.int16), 'm': tf.TensorSpec(shape=(2, 3), dtype=tf.int32), - 't': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.int64), + 't': tf.TensorSpec(shape=(2, 1, 2), dtype=tf.int64), })), examples=[ r""" features { feature {key: "context/v" value {int64_list {value: [1, 2]} } } feature {key: "context/m" value {int64_list {value: [1, 2, 3, 4, 5, 6]} } } - feature {key: "context/t" value {int64_list {value: [1, 2] } } } + feature {key: "context/t" value {int64_list {value: [1, 2, 3, 4] } } } }""", r""" features { feature {key: "context/v" value {int64_list {value: [9, 8]} } } feature {key: "context/m" value {int64_list {value: [9, 8, 7, 6, 5, 4]} } } - feature {key: "context/t" value {int64_list {value: [9, 8]} } } + feature {key: "context/t" value {int64_list {value: [9, 8, 7, 6]} } } }""" ], expected={ 'context/v': as_tensor([[1, 2], [9, 8]]), 'context/m': as_tensor([[[1, 2, 3], [4, 5, 6]], [[9, 8, 7], [6, 5, 4]]]), - 'context/t': as_tensor([[[[1, 2]]], [[[9, 8]]]]) + 'context/t': as_tensor([[[[1, 2]], [[3, 4]]], [[[9, 8]], [[7, 6]]]]) }, prefix=None, validate=True) diff --git a/tensorflow_gnn/graph/graph_tensor_test.py b/tensorflow_gnn/graph/graph_tensor_test.py index 9dd0c5a6..0bc8ebac 100644 --- a/tensorflow_gnn/graph/graph_tensor_test.py +++ b/tensorflow_gnn/graph/graph_tensor_test.py @@ -70,6 +70,23 @@ def testContext(self, features, shape): self.assertAllEqual(context.spec['a'], type_spec.type_spec_from_value(features['a'])) + def testValidationOnOff(self): + invalid_features = {'a': [1, 2], 'b': [1, 2, 3]} + const.disable_graph_tensor_inputs_validation() + _ = gt.Context.from_fields(features=invalid_features) + + const.enable_graph_tensor_inputs_validation() + with self.assertRaises(ValueError): + gt.Context.from_fields(features=invalid_features) + + def testRaisesOnInvalidContextInit(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Features "a" and "b" have shapes with incompatible items dimension' + ' (dim=0): 2 != 3.', + ): + gt.Context.from_fields(features={'a': [1, 2], 'b': [1, 2, 3]}) + def testCreationChain(self): source = gt.Context.from_fields(features={'x': as_tensor([1.])}) copy1 = gt.Context.from_fields(features=source.features) @@ -122,6 +139,30 @@ def testNodeSet(self, features, sizes, expected_shape): self.assertAllEqual(node_set.spec['a'], type_spec.type_spec_from_value(features['a'])) + def testRaisesOnInvalidNodeSetInit(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Field rank must be greater than the batch rank: field shape=(1,),' + ' batch_rank=1', + ): + gt.NodeSet.from_fields(sizes=[[3]], features={'a': [1]}) + + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Fields batch dimensions do not match: batch_rank=1, 1st field shape:' + ' (1, 2), 2nd field shape: (2, 1)', + ): + gt.NodeSet.from_fields(sizes=[[3], [2]], features={'x': [[1, 2]]}) + + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Features "a" and "b" have shapes with incompatible items' + ' dimension (dim=0): 1 != 2.', + ): + gt.NodeSet.from_fields( + sizes=[3], features={'a': [1], 'b': as_ragged([[1], [2, 3]])} + ) + @parameterized.parameters([ dict( features={}, @@ -159,6 +200,21 @@ def testEdgeSet(self, features, sizes, adjacency, expected_shape): self.assertAllEqual(edge_set.adjacency[const.SOURCE], adjacency[const.SOURCE]) + def testRaisesOnInvalidEdgeSetInit(self): + with self.assertRaisesWithLiteralMatch( + ValueError, + 'Field rank must be greater than the batch rank: field shape=(),' + ' batch_rank=0', + ): + gt.EdgeSet.from_fields( + sizes=3, + adjacency=adj.Adjacency.from_indices( + ('node', [0]), + ('node', [0]), + ), + features={}, + ) + def testEmptyGraphTensor(self): result = gt.GraphTensor.from_pieces() @@ -1010,12 +1066,15 @@ def testVarSizeBatching(self, row_splits_dtype: tf.DType): @tf.function def generate(num_nodes): + ones = tf.ones(tf.stack([num_nodes], 0), dtype=row_splits_dtype) + zeros = tf.convert_to_tensor([0], dtype=row_splits_dtype) + row_lengths = tf.concat([zeros, ones, zeros], axis=0) return gt.Context.from_fields( features={ 'x': tf.range(num_nodes), 'r': tf.RaggedTensor.from_row_lengths( tf.ones(tf.stack([num_nodes], 0), dtype=tf.float32), - tf.cast(tf.stack([0, num_nodes, 0], 0), row_splits_dtype), + row_lengths, ), } ) @@ -1038,10 +1097,17 @@ def generate(num_nodes): self.assertAllEqual( element['r'], as_ragged([ - [[[], [], []], [[], [1], []], [[], [1, 1], []]], - [[[], [1, 1, 1], []], [[], [1, 1, 1, 1], []], - [[], [1, 1, 1, 1, 1], []]], - ])) + [ + [[], []], + [[], [1.0], []], + [[], [1.0], [1.0], []]], + [ + [[], [1.0], [1.0], [1.0], []], + [[], [1.0], [1.0], [1.0], [1.0], []], + [[], [1.0], [1.0], [1.0], [1.0], [1.0], []], + ], + ]), + ) self.assertAllEqual( type_spec.type_spec_from_value(element['x']), diff --git a/tensorflow_gnn/graph/readout_test.py b/tensorflow_gnn/graph/readout_test.py index eb5ebafe..0f1e6012 100644 --- a/tensorflow_gnn/graph/readout_test.py +++ b/tensorflow_gnn/graph/readout_test.py @@ -262,7 +262,7 @@ def test(self, remove_input_feature): "zeros": tf.zeros([2])}), "unrelated": gt.NodeSet.from_fields( sizes=tf.constant([1]), - features={"labels": tf.constant([9, 9]), + features={"labels": tf.constant([9]), "stuff": tf.constant([[3.14, 2.71]])}), "_readout": gt.NodeSet.from_fields( sizes=tf.constant([2]), @@ -378,7 +378,7 @@ def test(self, remove_input_feature): "unrelated": gt.NodeSet.from_fields( sizes=tf.constant([1]), features={"labels": tf.constant([9, 9]), - "stuff": tf.constant([[3.14, 2.71]])})}) + "stuff": tf.constant([[3.14], [2.71]])})}) graph = readout.context_readout_into_feature( test_graph, feature_name="labels", new_feature_name="target", @@ -404,9 +404,9 @@ def testExistingReadout(self): features={"labels": tf.constant([11, 22, 33, 44, 55]), "ones": tf.ones([5, 1])}), "unrelated": gt.NodeSet.from_fields( - sizes=tf.constant([1]), + sizes=tf.constant([1, 1]), features={"labels": tf.constant([9, 9]), - "stuff": tf.constant([[3.14, 2.71]])})}) + "stuff": tf.constant([3.14, 2.71])})}) # Put a readout node set like a multi-task training pipeline with # context and seed node features would. test_graph = readout.add_readout_from_first_node(test_graph, "seed", diff --git a/tensorflow_gnn/runner/utils/attribution_test.py b/tensorflow_gnn/runner/utils/attribution_test.py index b34c05ab..cb4d33d8 100644 --- a/tensorflow_gnn/runner/utils/attribution_test.py +++ b/tensorflow_gnn/runner/utils/attribution_test.py @@ -28,7 +28,7 @@ class AttributionTest(tf.test.TestCase): context=tfgnn.Context.from_fields(features={ "h": tf.convert_to_tensor((.514, .433)), # An integer feature with uniform values. - "labels": tf.convert_to_tensor((0,)), + "labels": tf.convert_to_tensor((0, 1)), }), node_sets={ "node": @@ -58,7 +58,7 @@ def test_counterfactual_random(self): tf.convert_to_tensor((0.49280962, 0.466383))) self.assertAllEqual( counterfactual.context.features["labels"], - tf.convert_to_tensor((0,))) + tf.convert_to_tensor((1, 0))) self.assertAllEqual( counterfactual.edge_sets["edge"].features["weight"], @@ -76,7 +76,7 @@ def test_counterfactual_zeros(self): tf.convert_to_tensor((0, 0))) self.assertAllEqual( counterfactual.context.features["labels"], - tf.convert_to_tensor((0,))) + tf.convert_to_tensor((0, 0))) self.assertAllEqual( counterfactual.edge_sets["edge"].features["weight"], @@ -92,7 +92,7 @@ def test_subtract_graph_features(self): self.gt.replace_features( context={ "h": tf.convert_to_tensor((.4, .8)), - "labels": tf.convert_to_tensor((1,)) + "labels": tf.convert_to_tensor((1, 1)) }, node_sets={ "node": { @@ -110,7 +110,7 @@ def test_subtract_graph_features(self): tf.convert_to_tensor((.514 - .4, .433 - .8))) self.assertAllClose( deltas.context.features["labels"], - tf.convert_to_tensor((0 - 1,))) + tf.convert_to_tensor((0 - 1, 1 - 1))) self.assertAllClose( deltas.edge_sets["edge"].features["weight"],