Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #640

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorflow_gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@
is_ragged_tensor = tensor_utils.is_ragged_tensor
is_graph_tensor = graph_tensor_ops.is_graph_tensor

# Global state flags that controls GraphPieces checks.
enable_graph_tensor_inputs_validation = graph_constants.enable_graph_tensor_inputs_validation
disable_graph_tensor_inputs_validation = graph_constants.disable_graph_tensor_inputs_validation

# Prune imported module symbols so they're not accessible implicitly,
# except those meant to be used as subpackages, like tfgnn.keras.*.
# Please use the same order as for the import statements at the top.
Expand All @@ -237,4 +241,3 @@
del tag_utils
del tensor_utils
del graph_schema

2 changes: 1 addition & 1 deletion tensorflow_gnn/graph/batching_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DynamicBatchTest(tu.GraphTensorTestBase):
'f1': as_tensor([1.]),
'f2': as_tensor([[1., 2.]]),
'i3': as_tensor([[[1, 2], [3, 4]]]),
'r1': as_ragged([[], ['a', 'b']]),
'r1': as_ragged([[[], ['a', 'b']]]),
}),
])
def testFeaturesBatching(self, target_num_components: int,
Expand Down
16 changes: 16 additions & 0 deletions tensorflow_gnn/graph/graph_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@
# the wider type.
allow_indices_auto_casting = True

# If set, validates `GraphTensor` and its pieces and raises exception on an
# attempt to construct invalid graph tensors.
validate_graph_tensor_inputs = True

# The default choice for `indices_dtype`.
# Can be either tf.int32 or tf.int64.
#
Expand All @@ -118,3 +122,15 @@

# An older name used before tensorflow_gnn 0.2.
DEFAULT_STATE_NAME = HIDDEN_STATE


def disable_graph_tensor_inputs_validation():
"""Disables validation of `GraphTensor` inputs."""
global validate_graph_tensor_inputs
validate_graph_tensor_inputs = False


def enable_graph_tensor_inputs_validation():
"""Enables validation for `GraphTensor` inputs."""
global validate_graph_tensor_inputs
validate_graph_tensor_inputs = True
80 changes: 77 additions & 3 deletions tensorflow_gnn/graph/graph_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def num_components(self) -> tf.Tensor:
tf.debugging.assert_equal(
tf.size(result),
tf.size(result_dense),
message='`sizes` shape is not compatible with the piece shape'))
message='`sizes` shape is incompatible with the piece shape'))

with tf.control_dependencies(check_ops):
result = tf.identity(result_dense)
Expand Down Expand Up @@ -163,13 +163,22 @@ def _from_features_and_sizes(cls, features: Fields, sizes: Field,
raise NotImplementedError

# Note that this graph piece does not use any metadata fields.
return cls._from_data(
result = cls._from_data(
data=data,
shape=sizes.shape[:-1],
indices_dtype=indices_dtype,
row_splits_dtype=row_splits_dtype,
)

if const.validate_graph_tensor_inputs:
# NOTE: The batch dimensions are already validated by the
# `GraphPieceBase._from_data()`. At this point we are checking only
# invariants specific to the `_GraphPieceWithFeatures`.
_static_check_sizes(result.sizes, result.shape.rank)
_static_check_items_dim(result.features, result.shape.rank)

return result

def get_features_dict(self) -> Dict[FieldName, Field]:
"""Returns features copy as a dictionary."""
return dict(self._get_features_ref)
Expand Down Expand Up @@ -231,6 +240,7 @@ def _from_feature_and_size_specs(
gp.check_indices_dtype(sizes_spec.dtype, what='`sizes_spec`')

features_spec = features_spec.copy()

data_spec = {
_NodeOrEdgeSet._DATAKEY_FEATURES: features_spec,
_NodeOrEdgeSet._DATAKEY_SIZES: sizes_spec,
Expand All @@ -251,13 +261,19 @@ def _from_feature_and_size_specs(
raise NotImplementedError

# Note that this graph piece does not use any metadata fields.
return cls._from_data_spec(
result = cls._from_data_spec(
data_spec,
shape=sizes_spec.shape[:-1],
indices_dtype=indices_dtype,
row_splits_dtype=row_splits_dtype,
)

if const.validate_graph_tensor_inputs:
_static_check_sizes(result.sizes_spec, result.rank)
_static_check_items_dim(result.features_spec, result.rank)

return result

@classmethod
def _data_spec_with_indices_dtype(
cls, data_spec: gp.DataSpec, dtype: tf.dtypes.DType
Expand Down Expand Up @@ -1792,3 +1808,61 @@ def check_homogeneous_graph_tensor(
"""Raises ValueError when tfgnn.get_homogeneous_node_and_edge_set_name() does.
"""
_ = get_homogeneous_node_and_edge_set_name(graph, name=name)


def _static_check_sizes(
sizes: Union[Field, FieldSpec], graph_rank: int
) -> None:
"""Checks graph component sizes rank.

Args:
sizes: graph piece sizes, as `[*graph_shape, num_components]`.
graph_rank: The number of batch dimensions, as `graph_shape.rank`.

Raises:
ValueError: if `sizes.shape.rank != graph_rank + 1`.
"""
expected_rank = graph_rank + 1
if sizes.shape.rank != expected_rank:
raise ValueError(
f'`sizes` must be of rank {expected_rank} (number of batch dimensions'
f' plus 1), got {sizes.shape.rank}.'
)


def _static_check_items_dim(
features: Union[Fields, FieldsSpec], graph_rank: int
) -> None:
"""Checks items dimension in features shape.

NOTE: here we check the subset of graph tensor shape rules and allow a mix of
fully defined and undefined item dimensions, as long as all fully defined
dimensions have the same sizes. The reason for this is that unknown dimensions
can originate from the imperfect static shape inference in Tensorflow or from
ragged tensors that have uniform inner dimensions but were constructed using
ragged row partitions. We delegate pedantic validation of the statically
unknown dimensions to the dynamic checks.

Args:
features: graph pieces features or their type specs. Must have shapes
`[*graph_shape, num_items, *feature_shape]`.
graph_rank: the number of batch dimensions as `graph_shape.rank`.

Raises:
ValueError: if some fully defined item dimensions do not match.
"""

if not features:
return

num_items = {
fvalue.shape[graph_rank]: fname
for fname, fvalue in features.items()
if fvalue.shape[graph_rank] is not None
}
if len(num_items) >= 2:
(da, fa), (db, fb) = list(num_items.items())[:2]
raise ValueError(
f'Features "{fa}" and "{fb}" have shapes with'
f' incompatible items dimension (dim={graph_rank}): {da} != {db}.'
)
18 changes: 9 additions & 9 deletions tensorflow_gnn/graph/graph_tensor_io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,29 @@ class TfExampleParsingFromSpecTest(TfExampleParsingTestBase):
context_spec=gt.ContextSpec.from_field_specs(features_spec={
'v': tf.TensorSpec(shape=(2,), dtype=tf.int16),
'm': tf.TensorSpec(shape=(2, 3), dtype=tf.int32),
't': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.int64),
't': tf.TensorSpec(shape=(2, 1, 2), dtype=tf.int64),
})),
examples=[
r"""
features {
feature {key: "context/v" value {int64_list {value: [1, 2]} } }
feature {key: "context/m" value {int64_list {value: [1, 2, 3, 4, 5, 6]} } }
feature {key: "context/t" value {int64_list {value: [1, 2] } } }
feature {key: "context/t" value {int64_list {value: [1, 2, 3, 4] } } }
}""", r"""
features {
feature {key: "context/v" value {int64_list {value: [9, 8]} } }
feature {key: "context/m" value {int64_list {value: [9, 8, 7, 6, 5, 4]} } }
feature {key: "context/t" value {int64_list {value: [9, 8]} } }
feature {key: "context/t" value {int64_list {value: [9, 8, 7, 6]} } }
}"""
],
expected_values=[{
'context/v': as_tensor([1, 2]),
'context/m': as_tensor([[1, 2, 3], [4, 5, 6]]),
'context/t': as_tensor([[[1, 2]]])
'context/t': as_tensor([[[1, 2]], [[3, 4]]])
}, {
'context/v': as_tensor([9, 8]),
'context/m': as_tensor([[9, 8, 7], [6, 5, 4]]),
'context/t': as_tensor([[[9, 8]]])
'context/t': as_tensor([[[9, 8]], [[7, 6]]])
}]),
dict(
description='context ragged features parsing',
Expand Down Expand Up @@ -211,26 +211,26 @@ def testSingleExampleParsing(
context_spec=gt.ContextSpec.from_field_specs(features_spec={
'v': tf.TensorSpec(shape=(2,), dtype=tf.int16),
'm': tf.TensorSpec(shape=(2, 3), dtype=tf.int32),
't': tf.TensorSpec(shape=(1, 1, 2), dtype=tf.int64),
't': tf.TensorSpec(shape=(2, 1, 2), dtype=tf.int64),
})),
examples=[
r"""
features {
feature {key: "context/v" value {int64_list {value: [1, 2]} } }
feature {key: "context/m" value {int64_list {value: [1, 2, 3, 4, 5, 6]} } }
feature {key: "context/t" value {int64_list {value: [1, 2] } } }
feature {key: "context/t" value {int64_list {value: [1, 2, 3, 4] } } }
}""", r"""
features {
feature {key: "context/v" value {int64_list {value: [9, 8]} } }
feature {key: "context/m" value {int64_list {value: [9, 8, 7, 6, 5, 4]} } }
feature {key: "context/t" value {int64_list {value: [9, 8]} } }
feature {key: "context/t" value {int64_list {value: [9, 8, 7, 6]} } }
}"""
],
expected={
'context/v': as_tensor([[1, 2], [9, 8]]),
'context/m': as_tensor([[[1, 2, 3], [4, 5, 6]],
[[9, 8, 7], [6, 5, 4]]]),
'context/t': as_tensor([[[[1, 2]]], [[[9, 8]]]])
'context/t': as_tensor([[[[1, 2]], [[3, 4]]], [[[9, 8]], [[7, 6]]]])
},
prefix=None,
validate=True)
Expand Down
76 changes: 71 additions & 5 deletions tensorflow_gnn/graph/graph_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ def testContext(self, features, shape):
self.assertAllEqual(context.spec['a'],
type_spec.type_spec_from_value(features['a']))

def testValidationOnOff(self):
invalid_features = {'a': [1, 2], 'b': [1, 2, 3]}
const.disable_graph_tensor_inputs_validation()
_ = gt.Context.from_fields(features=invalid_features)

const.enable_graph_tensor_inputs_validation()
with self.assertRaises(ValueError):
gt.Context.from_fields(features=invalid_features)

def testRaisesOnInvalidContextInit(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
'Features "a" and "b" have shapes with incompatible items dimension'
' (dim=0): 2 != 3.',
):
gt.Context.from_fields(features={'a': [1, 2], 'b': [1, 2, 3]})

def testCreationChain(self):
source = gt.Context.from_fields(features={'x': as_tensor([1.])})
copy1 = gt.Context.from_fields(features=source.features)
Expand Down Expand Up @@ -122,6 +139,30 @@ def testNodeSet(self, features, sizes, expected_shape):
self.assertAllEqual(node_set.spec['a'],
type_spec.type_spec_from_value(features['a']))

def testRaisesOnInvalidNodeSetInit(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
'Field rank must be greater than the batch rank: field shape=(1,),'
' batch_rank=1',
):
gt.NodeSet.from_fields(sizes=[[3]], features={'a': [1]})

with self.assertRaisesWithLiteralMatch(
ValueError,
'Fields batch dimensions do not match: batch_rank=1, 1st field shape:'
' (1, 2), 2nd field shape: (2, 1)',
):
gt.NodeSet.from_fields(sizes=[[3], [2]], features={'x': [[1, 2]]})

with self.assertRaisesWithLiteralMatch(
ValueError,
'Features "a" and "b" have shapes with incompatible items'
' dimension (dim=0): 1 != 2.',
):
gt.NodeSet.from_fields(
sizes=[3], features={'a': [1], 'b': as_ragged([[1], [2, 3]])}
)

@parameterized.parameters([
dict(
features={},
Expand Down Expand Up @@ -159,6 +200,21 @@ def testEdgeSet(self, features, sizes, adjacency, expected_shape):
self.assertAllEqual(edge_set.adjacency[const.SOURCE],
adjacency[const.SOURCE])

def testRaisesOnInvalidEdgeSetInit(self):
with self.assertRaisesWithLiteralMatch(
ValueError,
'Field rank must be greater than the batch rank: field shape=(),'
' batch_rank=0',
):
gt.EdgeSet.from_fields(
sizes=3,
adjacency=adj.Adjacency.from_indices(
('node', [0]),
('node', [0]),
),
features={},
)

def testEmptyGraphTensor(self):
result = gt.GraphTensor.from_pieces()

Expand Down Expand Up @@ -1010,12 +1066,15 @@ def testVarSizeBatching(self, row_splits_dtype: tf.DType):

@tf.function
def generate(num_nodes):
ones = tf.ones(tf.stack([num_nodes], 0), dtype=row_splits_dtype)
zeros = tf.convert_to_tensor([0], dtype=row_splits_dtype)
row_lengths = tf.concat([zeros, ones, zeros], axis=0)
return gt.Context.from_fields(
features={
'x': tf.range(num_nodes),
'r': tf.RaggedTensor.from_row_lengths(
tf.ones(tf.stack([num_nodes], 0), dtype=tf.float32),
tf.cast(tf.stack([0, num_nodes, 0], 0), row_splits_dtype),
row_lengths,
),
}
)
Expand All @@ -1038,10 +1097,17 @@ def generate(num_nodes):
self.assertAllEqual(
element['r'],
as_ragged([
[[[], [], []], [[], [1], []], [[], [1, 1], []]],
[[[], [1, 1, 1], []], [[], [1, 1, 1, 1], []],
[[], [1, 1, 1, 1, 1], []]],
]))
[
[[], []],
[[], [1.0], []],
[[], [1.0], [1.0], []]],
[
[[], [1.0], [1.0], [1.0], []],
[[], [1.0], [1.0], [1.0], [1.0], []],
[[], [1.0], [1.0], [1.0], [1.0], [1.0], []],
],
]),
)

self.assertAllEqual(
type_spec.type_spec_from_value(element['x']),
Expand Down
8 changes: 4 additions & 4 deletions tensorflow_gnn/graph/readout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test(self, remove_input_feature):
"zeros": tf.zeros([2])}),
"unrelated": gt.NodeSet.from_fields(
sizes=tf.constant([1]),
features={"labels": tf.constant([9, 9]),
features={"labels": tf.constant([9]),
"stuff": tf.constant([[3.14, 2.71]])}),
"_readout": gt.NodeSet.from_fields(
sizes=tf.constant([2]),
Expand Down Expand Up @@ -378,7 +378,7 @@ def test(self, remove_input_feature):
"unrelated": gt.NodeSet.from_fields(
sizes=tf.constant([1]),
features={"labels": tf.constant([9, 9]),
"stuff": tf.constant([[3.14, 2.71]])})})
"stuff": tf.constant([[3.14], [2.71]])})})

graph = readout.context_readout_into_feature(
test_graph, feature_name="labels", new_feature_name="target",
Expand All @@ -404,9 +404,9 @@ def testExistingReadout(self):
features={"labels": tf.constant([11, 22, 33, 44, 55]),
"ones": tf.ones([5, 1])}),
"unrelated": gt.NodeSet.from_fields(
sizes=tf.constant([1]),
sizes=tf.constant([1, 1]),
features={"labels": tf.constant([9, 9]),
"stuff": tf.constant([[3.14, 2.71]])})})
"stuff": tf.constant([3.14, 2.71])})})
# Put a readout node set like a multi-task training pipeline with
# context and seed node features would.
test_graph = readout.add_readout_from_first_node(test_graph, "seed",
Expand Down
Loading