-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Cross Network #3650
base: master
Are you sure you want to change the base?
Graph Cross Network #3650
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3650 +/- ##
==========================================
- Coverage 81.58% 80.82% -0.77%
==========================================
Files 295 299 +4
Lines 14861 15279 +418
==========================================
+ Hits 12125 12349 +224
- Misses 2736 2930 +194
Continue to review full report at Codecov.
|
self.augment_self_loops() | ||
self.degree_as_feature() | ||
|
||
def degree_as_feature(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can replace that with torch_geometric.transforms.OneHotDegree
:
dataset = TUDataset(root, name, transform=T.OneHotDegree(...))
else: | ||
g.x = node_feat | ||
|
||
def augment_self_loops(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we need this. Can you clarify?
def forward(self, graph): | ||
embed, logits1, logits2 = self.gxn(graph) | ||
logits = F.relu(self.lin1(embed)) | ||
if self.dropout > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed. F.dropout
is a no-op for p=0.0
.
# classification loss | ||
classify_loss = F.nll_loss(cls_logits, labels.to(device)) | ||
|
||
# loss for vertex infomax pooling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we integrate the computation of the VIX loss into GraphCrossNet
model, e.g., GraphCrossNet.pooling_loss(...)
?
from typing import List, Optional, Union | ||
|
||
|
||
class GCNConvWithDropout(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not create two much modules here :)
For example, GCNConvWithDropout
, DenseLayer
can be safely dropped, IMO.
|
||
|
||
class GraphCrossModule(torch.nn.Module): | ||
r""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow the docstring conventions of our other modules to ensure that documentation is rendered correctly, see here.
num_cross_layers: int = 2): | ||
super(GraphCrossModule, self).__init__() | ||
if isinstance(pool_ratios, float): | ||
pool_ratios = (pool_ratios, pool_ratios) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pool_ratios = (pool_ratios, pool_ratios) | |
pool_ratios = [pool_ratios] * num_cross_layers |
Default: :obj:`1.0` | ||
fuse_weight : float, optional | ||
The weight parameter used at the end of GXN for channel fusion. | ||
Default: :obj:`1.0` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_cross_layers
is missing.
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
for more information, see https://pre-commit.ci
This PR contains the PyG implementation of the paper Graph Cross Networks with Vertex Infomax Pooling (NeurIPS 2020). It has been adapted from DGL implementation here, and the author's code is here. I have also included an example to show end-to-end training on graph classification task.
Submitted as part of Stanford CS224W: Machine Learning with Graphs, Autumn 2021-2022.