Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph Cross Network #3650

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

paridhimaheshwari2708
Copy link

@paridhimaheshwari2708 paridhimaheshwari2708 commented Dec 8, 2021

This PR contains the PyG implementation of the paper Graph Cross Networks with Vertex Infomax Pooling (NeurIPS 2020). It has been adapted from DGL implementation here, and the author's code is here. I have also included an example to show end-to-end training on graph classification task.

Submitted as part of Stanford CS224W: Machine Learning with Graphs, Autumn 2021-2022.

@codecov-commenter
Copy link

codecov-commenter commented Dec 8, 2021

Codecov Report

Merging #3650 (2ac96ee) into master (560df33) will decrease coverage by 0.76%.
The diff coverage is 16.29%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3650      +/-   ##
==========================================
- Coverage   81.58%   80.82%   -0.77%     
==========================================
  Files         295      299       +4     
  Lines       14861    15279     +418     
==========================================
+ Hits        12125    12349     +224     
- Misses       2736     2930     +194     
Impacted Files Coverage Δ
torch_geometric/nn/models/graph_cross_network.py 15.92% <15.92%> (ø)
torch_geometric/nn/models/__init__.py 100.00% <100.00%> (ø)
torch_geometric/loader/utils.py 81.42% <0.00%> (-5.34%) ⬇️
torch_geometric/nn/conv/gatv2_conv.py 91.96% <0.00%> (-3.37%) ⬇️
torch_geometric/graphgym/models/encoder.py 31.81% <0.00%> (-1.52%) ⬇️
torch_geometric/graphgym/model_builder.py 92.85% <0.00%> (-0.48%) ⬇️
torch_geometric/transforms/random_link_split.py 93.61% <0.00%> (-0.05%) ⬇️
torch_geometric/data/storage.py 80.96% <0.00%> (-0.01%) ⬇️
torch_geometric/nn/pool/pool.py 100.00% <0.00%> (ø)
torch_geometric/nn/conv/appnp.py 78.33% <0.00%> (ø)
... and 39 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 560df33...2ac96ee. Read the comment docs.

self.augment_self_loops()
self.degree_as_feature()

def degree_as_feature(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace that with torch_geometric.transforms.OneHotDegree:

dataset = TUDataset(root, name, transform=T.OneHotDegree(...))

else:
g.x = node_feat

def augment_self_loops(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need this. Can you clarify?

examples/graph_cross_network.py Outdated Show resolved Hide resolved
def forward(self, graph):
embed, logits1, logits2 = self.gxn(graph)
logits = F.relu(self.lin1(embed))
if self.dropout > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed. F.dropout is a no-op for p=0.0.

# classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device))

# loss for vertex infomax pooling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we integrate the computation of the VIX loss into GraphCrossNet model, e.g., GraphCrossNet.pooling_loss(...)?

from typing import List, Optional, Union


class GCNConvWithDropout(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not create two much modules here :)

For example, GCNConvWithDropout, DenseLayer can be safely dropped, IMO.



class GraphCrossModule(torch.nn.Module):
r"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the docstring conventions of our other modules to ensure that documentation is rendered correctly, see here.

torch_geometric/nn/models/graph_cross_network.py Outdated Show resolved Hide resolved
num_cross_layers: int = 2):
super(GraphCrossModule, self).__init__()
if isinstance(pool_ratios, float):
pool_ratios = (pool_ratios, pool_ratios)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pool_ratios = (pool_ratios, pool_ratios)
pool_ratios = [pool_ratios] * num_cross_layers

Default: :obj:`1.0`
fuse_weight : float, optional
The weight parameter used at the end of GXN for channel fusion.
Default: :obj:`1.0`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_cross_layers is missing.

@rusty1s rusty1s changed the title Adding code for Graph Cross Network Graph Cross Network Feb 7, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants