Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support "valid_mask" for sampled edges. #118

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 127 additions & 51 deletions examples/in_memory/int_arithmetic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

The entry point is method `make_sampled_subgraphs_dataset()`, which accepts as
input, an in-memory graph dataset (from dataset.py) and `SamplingSpec`, and
outputs tf.data.Dataset that generates subgraphs according to `SamplingSpec`.
returns tf.data.Dataset that generates subgraphs according to `SamplingSpec`.

Specifically, `tf.data.Dataset` made by `make_sampled_subgraphs_dataset` wraps
a generator that yields `GraphTensor`, consisting of sub-graphs, rooted at
Expand All @@ -32,9 +32,10 @@
inmem_ds = datasets.get_dataset(dataset_name)

# Craft sampling specification.
sample_size = sample_size1 = 5
graph_schema = dataset_wrapper.export_graph_schema()
sampling_spec = (tfgnn.SamplingSpecBuilder(graph_schema)
.seed().sample([3, 3]).to_sampling_spec())
.seed().sample([sample_size, sample_size1]).to_sampling_spec())

train_data = make_sampled_subgraphs_dataset(inmem_ds, sampling_spec)

Expand All @@ -46,20 +47,6 @@
# composed of `tfgnn.keras.layers`.
```

# Note

This particular sampler expects that there are *no orphan nodes*. In particular,
if sampling specification samples from edge-set with name "E", then every node
must have *at least one* outgoing edge in edge-set "E". This "feature" can be
fixed, e.g., by allowing zero-degree nodes to jump to special node, then get
filtered upon output. However, we delay such completeness until we compare other
sampling implementations e.g. ones that uses RaggedTensors to naturally
accomodate variable-length neighborhoods.

Nonetheless, if each node has at least one-edge, then sampling will be correct.
If some node has less neighbors than required samples, then selection will
contain repeatitions.

# Algorithm & Implementation

`make_sampled_subgraphs_dataset(ds)` returns a generator over object
Expand Down Expand Up @@ -99,11 +86,11 @@ class exposes function `random_walk_tree`, which describe below.

Both of which are stored as tf.Tensor.

After initialization, function `random_walk_tree` accepts(*) seed nodes
After initialization, function `random_walk_tree` accepts seed nodes
`[n1, n2, n3, ..., nB]`, i.e. with batch size `B`.


NOTE: (*) generator `make_sampled_subgraphs_dataset` yield `GraphTensor`
NOTE: generator `make_sampled_subgraphs_dataset` yield `GraphTensor`
instances, each instance contain subgraphs rooted at a batch of nodes, which
cycle from `ds.node_split().train`.

Expand All @@ -112,20 +99,31 @@ class exposes function `random_walk_tree`, which describe below.
```
sample(f1, 'cites')
paper --------------------------> paper
\
\ sample(f2, 'rev_writes') sample(f3, 'affiliated_with')
---------------------------> author ------------------> institution
V1 \ V2
\ sample(f2, 'rev_writes') sample(f3, 'affiliated_with')
---------------------------> author ------------------> institution
V3 V4
```

Instance nodes of `TypedWalkTree` (above) have attribute `nodes` with shapes:
(B), (B, f1), (B, f2), (B, f2, f3) -- (left-to-right). All are `tf.Tensor`s
with dtype `tf.int{32, 64}`, matching the dtype of its input argument.
Instance nodes of `TypedWalkTree` (above) have attribute `nodes`, which is
`tf.Tensor`, depicted as V1, V2, V3, V4 with shapes, respectively (B), (B, f1),
(B, f2), (B, f2, f3). All are with dtype `tf.int{32, 64}`, matching the dtype of
input argument to function `random_walk_tree`. For some node position (i), then
node `V1[i]` has sampled edges pointing to nodes `V2[i, 0], V2[i, 1], ...`. The
(`int`) `B` corresponds to batch size and (`int`s) `f1, f2, ...` correspond to
`sample_size` that can be configured in `SamplingSpec` proto (below).

Further, if `sampling` strategy is one of `EdgeSampling.W_REPLACEMENT_W_ORPHANS`
or `EdgeSampling.WO_REPLACEMENT_WO_REPEAT`, then each `TypedWalkTree` node will
also contain attribute `valid` (tf.Tensor with dtype tf.bool) with same shape as
`nodes`, which marks positions in `nodes` that correspond to valid edges.


## Building SamplingSpec
Function `random_walk_tree` also requires argument `sampling_spec`, which
controls the subgraph size, sampled around seed nodes. For the above example,
`sampling_spec` instance can be built as, e.g.,:


```
f2 = f1 = 5
f3 = 3 # somewhat arbitrary.
Expand All @@ -143,7 +141,7 @@ class exposes function `random_walk_tree`, which describe below.
import collections
import enum
import functools
from typing import Any, Tuple, Callable, Mapping, Optional, MutableMapping, List
from typing import Any, Tuple, Callable, Mapping, Optional, MutableMapping, List, Union

import numpy as np
import scipy.sparse as ssp
Expand Down Expand Up @@ -225,21 +223,45 @@ class TypedWalkTree:
`TypedWalkTree`) with node features & labels, into `GraphTensor` instances.
"""

def __init__(self, nodes, owner=None):
def __init__(self, nodes: tf.Tensor, owner: Optional['GraphSampler'] = None,
valid_mask: Optional[tf.Tensor] = None):
self._nodes = nodes
self._next_steps = []
self._owner = owner
if valid_mask is None:
self._valid_mask = tf.ones(shape=nodes.shape, dtype=tf.bool)
else:
self._valid_mask = valid_mask

@property
def nodes(self) -> tf.Tensor:
"""int tf.Tensor with shape `[b, s1, s2, ..., sH]` where `b` is batch size.

`H` is number of hops (until this sampling step). Each int `si` indicates
number of nodes sampled at step `i`.
"""
return self._nodes

@property
def valid_mask(self) -> Optional[tf.Tensor]:
"""bool tf.Tensor with same shape of `nodes` marking "correct" samples.

If entry `valid_mask[i, j, k]` is True, then `nodes[i, j, k]` corresponds to
a node that is indeed a sampled neighbor of `previous_step.nodes[i, j]`.
"""
return self._valid_mask

@property
def next_steps(self) -> List[Tuple[tfgnn.EdgeSetName, 'TypedWalkTree']]:
return self._next_steps

def add_step(self, edge_set_name: tfgnn.EdgeSetName, nodes: tf.Tensor):
child_tree = TypedWalkTree(nodes, owner=self._owner)
def add_step(self, edge_set_name: tfgnn.EdgeSetName, nodes: tf.Tensor,
valid_mask: Optional[tf.Tensor] = None,
propagate_validation: bool = True) -> 'TypedWalkTree':
if propagate_validation and valid_mask is not None:
valid_mask = tf.logical_and(tf.expand_dims(self.valid_mask, -1),
valid_mask)
child_tree = TypedWalkTree(nodes, owner=self._owner, valid_mask=valid_mask)
self._next_steps.append((edge_set_name, child_tree))
return child_tree

Expand Down Expand Up @@ -342,8 +364,30 @@ def to_graph_tensor(


class EdgeSampling(enum.Enum):
WITH_REPLACEMENT = 'with_replacement'
WITHOUT_REPLACEMENT = 'without_replacement'
"""Enum for randomized strategies for sampling neighbors."""
# Samples each neighbor independently. It assumes that *every node* has at
# least one outgoing neighbor, for all sampled edge-sets.
W_REPLACEMENT = 'w_replacement'

# Samples each neighbor independently. It assumes that some nodes might have
# zero outgoing edges. This option causes `sample_one_hop()` to also return
# `valid_mask` (boolean tf.Tensor) marking positions corresponding to an
# actual edge, which will be False iff sampling from orphan nodes.
W_REPLACEMENT_W_ORPHANS = 'w_replacement_w_orphans'

# Samples neighbors without replacement. However, if (int) `S` neighbors were
# requested, and there are only `s` neighbors (with `s < S`), then the samples
# will be repeated. You *must* ensure that each node has at least one outgoing
# neighbor. If your graph has orphan nodes, use `WO_REPLACEMENT_WO_REPEAT` or
# `W_REPLACEMENT_W_ORPHANS`.
WO_REPLACEMENT = 'wo_replacement'

# Like the above. In cases if some nodes have very few neighbors (less than
# `sample_size`), then nodes will only be sampled once. This option also works
# when some nodes have zero outgoing edges.
# This option causes `sample_one_hop()` to also return `valid_mask` (boolean
# tf.Tensor) marking positions corresponding to an actual edge.
WO_REPLACEMENT_WO_REPEAT = 'wo_replacement_wo_repeat'


class GraphSampler:
Expand All @@ -359,7 +403,7 @@ def __init__(self,
make_undirected: bool = False,
ensure_self_loops: bool = False,
reduce_memory_footprint: bool = True,
sampling: EdgeSampling = EdgeSampling.WITHOUT_REPLACEMENT):
sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT):
self.dataset = dataset
self.sampling = sampling
self.edge_types = {} # edge set name -> (src node set name, dst *).
Expand Down Expand Up @@ -416,44 +460,65 @@ def make_sample_layer(self, edge_set_name, sample_size=3, sampling=None):

def sample_one_hop(
self, source_nodes: tf.Tensor, edge_set_name: tfgnn.EdgeSetName,
sample_size: int, sampling: Optional[EdgeSampling] = None) -> tf.Tensor:
sample_size: int, sampling: Optional[EdgeSampling] = None,
) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
"""Samples one-hop from source-nodes using edge `edge_set_name`."""
if sampling is None:
sampling = EdgeSampling.WITH_REPLACEMENT
sampling = EdgeSampling.WO_REPLACEMENT

all_degrees = self.degrees[edge_set_name]
node_degrees = tf.gather(all_degrees, source_nodes)

offsets = self.degrees_cumsum[edge_set_name]

if sampling == EdgeSampling.WITH_REPLACEMENT:
next_nodes = valid_mask = None # Answer, to be populated, below.

if sampling in (EdgeSampling.W_REPLACEMENT,
EdgeSampling.W_REPLACEMENT_W_ORPHANS):
sample_indices = tf.random.uniform(
shape=source_nodes.shape + [sample_size], minval=0, maxval=1,
dtype=tf.float32)

sample_indices = sample_indices * tf.cast(
tf.expand_dims(node_degrees, -1), tf.float32)
node_degrees_expanded = tf.expand_dims(node_degrees, -1)
sample_indices = sample_indices * tf.cast(node_degrees_expanded,
tf.float32)

# According to https://www.pcg-random.org/posts/bounded-rands.html, this
# sample is biased. NOTE: we plan to adopt one of the linked alternatives.
sample_indices = tf.cast(tf.math.floor(sample_indices), tf.int64)

if sampling == EdgeSampling.W_REPLACEMENT_W_ORPHANS:
valid_mask = sample_indices < node_degrees_expanded

# Shape: (sample_size, nodes_reshaped.shape[0])
sample_indices += tf.expand_dims(tf.gather(offsets, source_nodes), -1)
nonzero_cols = self.edge_lists[edge_set_name][1]
if sampling == EdgeSampling.W_REPLACEMENT_W_ORPHANS:
sample_indices = tf.where(
valid_mask, sample_indices, tf.zeros_like(sample_indices))
next_nodes = tf.gather(nonzero_cols, sample_indices)
elif sampling == EdgeSampling.WITHOUT_REPLACEMENT:
elif sampling in (EdgeSampling.WO_REPLACEMENT,
EdgeSampling.WO_REPLACEMENT_WO_REPEAT):
# shape=(total_input_nodes).
nodes_reshaped = tf.reshape(source_nodes, [-1])
# shape=(total_input_nodes).
reshaped_node_degrees = tf.reshape(node_degrees, [-1])
reshaped_node_degrees_or_1 = tf.maximum(
reshaped_node_degrees, tf.ones_like(reshaped_node_degrees))
# shape=(sample_size, total_input_nodes).
sample_upto = tf.stack([reshaped_node_degrees] * sample_size, axis=0)

# [[0, 1, 2, ..., f], <repeated>].T
subtract_mod = tf.stack(
[tf.range(sample_size, dtype=tf.int64)] * nodes_reshaped.shape[0],
axis=-1)
subtract_mod = subtract_mod % sample_upto
if sampling == EdgeSampling.WO_REPLACEMENT_WO_REPEAT:
valid_mask = subtract_mod < reshaped_node_degrees
valid_mask = tf.reshape(
tf.transpose(valid_mask), source_nodes.shape + [sample_size])

subtract_mod = subtract_mod % tf.maximum(
sample_upto, tf.ones_like(sample_upto))

# [[d, d-1, d-2, ... 1, d, d-1, ...]].T
# where 'd' is degree of node in row corresponding to nodes_reshaped.
Expand All @@ -475,7 +540,7 @@ def sample_one_hop(

for i in range(1, sample_size):
already_sampled = tf.where(
i % reshaped_node_degrees == 0,
i % reshaped_node_degrees_or_1 == 0,
tf.ones_like(already_sampled) * max_degree, already_sampled)
next_sample = sample_indices[i]
for j in range(i):
Expand All @@ -493,24 +558,30 @@ def sample_one_hop(

sample_indices += tf.expand_dims(tf.gather(offsets, nodes_reshaped), 0)
sample_indices = tf.reshape(tf.transpose(sample_indices),
[source_nodes.shape[0], -1])
source_nodes.shape + [sample_size])
nonzero_cols = self.edge_lists[edge_set_name][1]
if sampling == EdgeSampling.WO_REPLACEMENT_WO_REPEAT:
sample_indices = tf.where(
valid_mask, sample_indices, tf.zeros_like(sample_indices))

next_nodes = tf.gather(nonzero_cols, sample_indices)
next_nodes = tf.reshape(next_nodes, source_nodes.shape + [sample_size])
else:
raise ValueError('Unknown sampling ' + str(sampling))

if next_nodes.dtype != source_nodes.dtype:
# It could happen, e.g., if edge-list is int32 and input seed is int64.
next_nodes = tf.cast(next_nodes, source_nodes.dtype)

return next_nodes
if valid_mask is None:
return next_nodes
else:
return next_nodes, valid_mask

def generate_subgraphs(
self, batch_size: int,
sampling_spec: sampling_spec_pb2.SamplingSpec,
split: str = 'train',
sampling=EdgeSampling.WITH_REPLACEMENT):
sampling=EdgeSampling.WO_REPLACEMENT):
"""Infinitely yields random subgraphs each rooted on node in train set."""
if isinstance(split, bytes):
split = split.decode()
Expand All @@ -532,7 +603,7 @@ def generate_subgraphs(

def random_walk_tree(
self, node_idx: tf.Tensor, sampling_spec: sampling_spec_pb2.SamplingSpec,
sampling: EdgeSampling = EdgeSampling.WITH_REPLACEMENT) -> TypedWalkTree:
sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT) -> TypedWalkTree:
"""Returns `TypedWalkTree` where `nodes` are seed root-nodes.

Args:
Expand Down Expand Up @@ -566,8 +637,13 @@ def process_sampling_op(sampling_op: sampling_spec_pb2.SamplingOp):
next_nodes = self.sample_one_hop(
parent_nodes, sampling_op.edge_set_name,
sample_size=sampling_op.sample_size, sampling=sampling)
child_tree = parent_trees[0].add_step(
sampling_op.edge_set_name, next_nodes)
if isinstance(next_nodes, tuple):
next_nodes, valid_mask = next_nodes
child_tree = parent_trees[0].add_step(
sampling_op.edge_set_name, next_nodes, valid_mask=valid_mask)
else:
child_tree = parent_trees[0].add_step(
sampling_op.edge_set_name, next_nodes)

op_name_to_tree[sampling_op.op_name] = child_tree

Expand All @@ -581,17 +657,17 @@ def process_sampling_op(sampling_op: sampling_spec_pb2.SamplingOp):

def sample_sub_graph_tensor(
self, node_idx: tf.Tensor, sampling_spec: sampling_spec_pb2.SamplingSpec,
sampling: EdgeSampling = EdgeSampling.WITH_REPLACEMENT
sampling: EdgeSampling = EdgeSampling.WO_REPLACEMENT
) -> tfgnn.GraphTensor:
"""Samples GraphTensor starting from seed nodes `node_idx`.

Args:
node_idx: (int) tf.Tensor of node indices to seed random-walk trees.
sampling_spec: Specifies the hops (edge set names) to be sampled, and the
number of sampled edges per hop.
sampling: If `== EdgeSampling.WITH_REPLACEMENT`, then neighbors for a node
sampling: If `== EdgeSampling.W_REPLACEMENT`, then neighbors for a node
will be sampled uniformly and indepedently. If
`== EdgeSampling.WITHOUT_REPLACEMENT`, then a node's neighbors will be
`== EdgeSampling.WO_REPLACEMENT`, then a node's neighbors will be
chosen in (random) round-robin order. If more samples are requested are
larger than neighbors, then the samples will be repeated (each time, in
a different random order), such that, all neighbors appears exactly the
Expand Down Expand Up @@ -621,7 +697,7 @@ def make_sampled_subgraphs_dataset(
batch_size: int = 64,
split='train',
make_undirected: bool = False,
sampling=EdgeSampling.WITH_REPLACEMENT
sampling=EdgeSampling.WO_REPLACEMENT
) -> Tuple[tf.TensorSpec, tf.data.Dataset]:
"""Infinite tf.data.Dataset wrapping generate_subgraphs."""
subgraph_generator = GraphSampler(dataset, make_undirected=make_undirected)
Expand Down
Loading