Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Align gcn_norm #3711

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 28 additions & 37 deletions torch_geometric/nn/conv/gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@


@torch.jit._overload
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
# type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, add_self_loops=True,
dtype=None, improved=False):
# type: (Tensor, OptTensor, Optional[int], bool, Optional[int], bool) -> PairTensor # noqa
pass


@torch.jit._overload
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
# type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, add_self_loops=True,
dtype=None, improved=False):
# type: (SparseTensor, OptTensor, Optional[int], bool, Optional[int], bool) -> Tuple[SparseTensor, OptTensor] # noqa
pass


def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, dtype=None):
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, add_self_loops=True,
dtype=None, improved=False):

fill_value = 2. if improved else 1.

Expand All @@ -43,7 +43,7 @@ def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t
return adj_t, None

else:
num_nodes = maybe_num_nodes(edge_index, num_nodes)
Expand Down Expand Up @@ -115,8 +115,8 @@ class GCNConv(MessagePassing):
:class:`torch_geometric.nn.conv.MessagePassing`.
"""

_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
_cached_adj_t: Optional[SparseTensor]
_edge_cache: Optional[Tuple[Tensor, Tensor]]
_adjt_cache: Optional[Tuple[SparseTensor, OptTensor]]

def __init__(self, in_channels: int, out_channels: int,
improved: bool = False, cached: bool = False,
Expand All @@ -133,8 +133,7 @@ def __init__(self, in_channels: int, out_channels: int,
self.add_self_loops = add_self_loops
self.normalize = normalize

self._cached_edge_index = None
self._cached_adj_t = None
self._edge_cache = self._adjt_cache = None

self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
Expand All @@ -149,35 +148,27 @@ def __init__(self, in_channels: int, out_channels: int,
def reset_parameters(self):
self.lin.reset_parameters()
zeros(self.bias)
self._cached_edge_index = None
self._cached_adj_t = None
self._edge_cache = self._adjt_cache = None

def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
""""""

if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]

elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
edge_cache, adjt_cache = self._edge_cache, self._adjt_cache
if isinstance(edge_index, Tensor) and edge_cache is not None:
edge_index, edge_weight = edge_cache[0], edge_cache[1]
elif isinstance(edge_index, SparseTensor) and adjt_cache is not None:
edge_index, edge_weight = adjt_cache[0], adjt_cache[1]

elif self.normalize:
out = gcn_norm(edge_index, edge_weight, x.size(self.node_dim),
self.add_self_loops, x.dtype, self.improved)
edge_index, edge_weight = out

if self.cached and isinstance(edge_index, Tensor):
self._edge_cache = (edge_index, edge_weight)
if self.cached and isinstance(edge_index, SparseTensor):
self._adjt_cache = (edge_index, edge_weight)

x = self.lin(x)

Expand Down
8 changes: 3 additions & 5 deletions torch_geometric/nn/models/label_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ def forward(
out = torch.zeros_like(y)
out[mask] = y[mask]

if isinstance(edge_index, SparseTensor) and not edge_index.has_value():
edge_index = gcn_norm(edge_index, add_self_loops=False)
elif isinstance(edge_index, Tensor) and edge_weight is None:
edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),
add_self_loops=False)
edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
num_nodes=y.size(0),
add_self_loops=False, dtype=y.dtype)

res = (1 - self.alpha) * out
for _ in range(self.num_layers):
Expand Down