PyG 2.4.0: Model compilation, on-disk datasets, hierarchical sampling
We are excited to announce the release of PyG 2.4 πππ
PyG 2.4 is the culmination of work from 62 contributors who have worked on features and bug-fixes for a total of over 500 commits since torch-geometric==2.3.1
.
Highlights
PyTorch 2.1 and torch.compile(dynamic=True)
support
The long wait has an end! With the release of PyTorch 2.1, PyG 2.4 now brings full support for torch.compile
to graphs of varying size via the dynamic=True
option, which is especially useful for use-cases that involve the usage of DataLoader
or NeighborLoader
. Examples and tutorials have been updated to reflect this support accordingly (#8134), and models and layers in torch_geometric.nn
have been tested to produce zero graph breaks:
import torch_geometric
model = torch_geometric.compile(model, dynamic=True)
When enabling the dynamic=True
option, PyTorch will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches changes. As such, you should only ever not specify dynamic=True
when graph sizes are guaranteed to never change. Note that dynamic=True
requires PyTorch >= 2.1.0 to be installed.
PyG 2.4 is fully compatible with PyTorch 2.1, and supports the following combinations:
PyTorch 2.1 | cpu |
cu118 |
cu121 |
---|---|---|---|
Linux | β | β | β |
macOS | β | ||
Windows | β | β | β |
You can still install PyG 2.4 on older PyTorch releases up to PyTorch 1.11 in case you are not eager to update your PyTorch version.
OnDiskDataset
Interface
We added the OnDiskDataset
base class for creating large graph datasets (e.g., molecular databases with billions of graphs), which do not easily fit into CPU memory at once (#8028, #8044, #8046, #8051, #8052, #8054, #8057, #8058, #8066, #8088, #8092, #8106). OnDiskDataset
leverages our newly introduced Database
backend (sqlite3
by default) for on-disk storage and access of graphs, supports DataLoader
out-of-the-box, and is optimized for maximum performance.
OnDiskDataset
utilizes a user-specified schema to store data as efficient as possible (instead of Python pickling). The schema can take int
, float
str
, object
or a dictionary with dtype
and size
keys (for specifying tensor data) as input, and can be nested as a dictionary. For example,
dataset = OnDiskDataset(root, schema={
'x': dict(dtype=torch.float, size=(-1, 16)),
'edge_index': dict(dtype=torch.long, size=(2, -1)),
'y': float,
})
creates a database with three columns, where x
and edge_index
are stored as binary data, and y
is stored as a float.
Afterwards, you can append data to the OnDiskDataset
and retrieve data from it via dataset.append()
/dataset.extend()
, and dataset.get()
/dataset.multi_get()
, respectively. We added a fully working example on how to set up your own OnDiskDataset
here (#8102). You can also convert in-memory dataset instances to an OnDiskDataset
instance by running InMemoryDataset.to_on_disk_dataset()
(#8116).
Neighbor Sampling Improvements
Hierarchical Sampling
One drawback of NeighborLoader
is that it computes a representations for all sampled nodes at all depths of the network. However, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation. NeighborLoader
will be marginally slower since we are computing node embeddings for nodes we no longer need. This is a trade-off we have made to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine.
With PyG 2.4, we introduced the option to eliminate this overhead and speed-up training and inference in mini-batch GNNs further, which we call "Hierarchical Neighborhood Sampling" (see here for the full tutorial) (#6661, #7089, #7244, #7425, #7594, #7942). Its main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer, and works seamlessly across several models, both in the homogeneous and heterogeneous graph setting. To support this trimming and implement it effectively, the NeighborLoader
implementation in PyG and in pyg-lib
additionally return the number of nodes and edges sampled in each hop, which are then used on a per-layer basis to trim the adjacency matrix and the various feature matrices to only maintain the required amount (see the trim_to_layer
method):
class GNN(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, num_layers: int):
super().__init__()
self.convs = ModuleList([SAGEConv(in_channels, 64)])
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.lin = Linear(hidden_channels, out_channels)
def forward(
self,
x: Tensor,
edge_index: Tensor,
num_sampled_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Tensor:
for i, conv in enumerate(self.convs):
# Trim edge and node information to the current layer `i`.
x, edge_index, _ = trim_to_layer(
i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
x, edge_index)
x = conv(x, edge_index).relu()
return self.lin(x)
Corresponding examples can be found here and here.
Biased Sampling
Additionally, we added support for weighted/biased sampling in NeighborLoader
/LinkNeighborLoader
scenarios. For this, simply specify your edge_weight
attribute during NeighborLoader
initialization, and PyG will pick up these weights to perform weighted/biased sampling (#8038):
data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)
loader = NeighborLoader(
data,
num_neighbors=[10, 10],
weight_attr='edge_weight',
)
batch = next(iter(loader))
New models, datasets, examples & tutorials
As part of our algorithm and documentation sprints (#7892), we have added:
- Model components:
MixHopConv
: βMixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified Neighborhood Mixingβ (examples/mixhop.py
) (#8025)LCMAggregation
: βLearnable Commutative Monoids for Graph Neural Networksβ (examples/lcm_aggr_2nd_min.py
) (#7976, #8020, #8023, #8026, #8075)DirGNNConv
: βEdge Directionality Improves Learning on Heterophilic Graphsβ (examples/dir_gnn.py
) (#7458)- Support for
Performer
inGPSConv
: βRecipe for a General, Powerful, Scalable Graph Transformerβ (examples/graph_gps.py
) (#7465) PMLP
: βGraph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPsβ (examples/pmlp.py
) (#7470, #7543)RotateE
: βRotatE: Knowledge Graph Embedding by Relational Rotation in Complex Spaceβ (examples/kge_fb15k_237.py
) (#7026)NeuralFingerprint
: βConvolutional Networks on Graphs for Learning Molecular Fingerprintsβ (#7919)
- Datasets:
HM
(#7515),BrcaTcga
(#7994),MyketDataset
(#7959),Wikidata5M
(#7864),OSE_GVCS
(#7811),MovieLens1M
(#7479),AmazonBook
(#7483),GDELTLite
(#7442),IGMCDataset
(#7441),MovieLens100K
(#7398),EllipticBitcoinTemporalDataset
(#7011),NeuroGraphDataset
(#8112),PCQM4Mv2
(#8102) - Tutorials:
- Examples:
- Heterogeneous link-level GNN explanations via
CaptumExplainer
(examples/captum_explainer_hetero_link.py
) (#7096) - Training
LightGCN
onAmazonBook
for recommendation (examples/lightgcn.py
) (#7603) - Using the KΓΉzu remote backend as
FeatureStore
(examples/kuzu
) (#7298) - Multi-GPU training on
ogbn-papers100M
(examples/papers100m_multigpu.py
) (#7921) - The
OGC
model onCora
(examples/ogc.py
) (#8168) - Distributed training via
graphlearn-for-pytorch
(examples/distributed/graphlearn_for_pytorch
) (#7402)
- Heterogeneous link-level GNN explanations via
Join our Slack here if you're interested in joining community sprints in the future!
Breaking Changes
Data.keys()
is now a method instead of a property (#7629):<=2.3 2.4 data = Data(x=x, edge_index=edge_index) print(data.keys) # ['x', 'edge_index']
data = Data(x=x, edge_index=edge_index) print(data.keys()) # ['x', 'edge_index']
- Dropped Python 3.7 support (#7939)
- RemovedΒ
FastHGTConv
in favor ofΒHGTConv
Β (#7117) - Removed the
layer_type
argument fromGraphMaskExplainer
(#7445) - Renamed
dest
argument todst
inutils.geodesic_distance
(#7708)
Deprecations
- DeprecatedΒ
contrib.explain.GraphMaskExplainer
in favor ofΒexplain.algorithm.GraphMaskExplainer
Β (#7779)
Features
Data
and HeteroData
improvements
- Added a warning for isolated/non-existing node types in
HeteroData.validate()
(#7995) - Added
HeteroData
support into_networkx
(#7713) - Added
Data.sort()
andHeteroData.sort()
(#7649) - Added padding capabilities to
HeteroData.to_homogeneous()
in case feature dimensionalities do not match (#7374) - Added
torch.nested_tensor
support inData
andBatch
(#7643, #7647) - Added
keep_inter_cluster_edges
option toClusterData
to support inter-subgraph edge connections when doing graph partitioning (#7326)
Data-loading improvements
- Added support for floating-point slicing inΒ
Dataset
,Β e.g.,Βdataset[:0.9]
Β (#7915) - Added
save
andload
methods toInMemoryDataset
(#7250, #7413) - Beta: AddedΒ
IBMBNodeLoader
Β andΒIBMBBatchLoader
Β data loaders (#6230) - Beta: AddedΒ
HyperGraphData
Β to support hypergraphs (#7611) - AddedΒ
CachedLoader
(#7896, #7897) - Allowed GPU tensors as input toΒ
NodeLoader
Β andΒLinkLoader
Β (#7572) - AddedΒ
PrefetchLoader
Β capabilities (#7376, #7378, #7383) - Added manual sampling interface toΒ
NodeLoader
Β andΒLinkLoader
Β (#7197)
Better support for sparse tensors
- AddedΒ
SparseTensor
Β support toΒWLConvContinuous
,ΒGeneralConv
,ΒPDNConv
Β andΒARMAConv
Β (#8013) - ChangeΒ
torch_sparse.SparseTensor
Β logic to utilizeΒtorch.sparse_csr
Β instead (#7041) - Added support forΒ
torch.sparse.Tensor
Β inΒDataLoader
Β (#7252) - Added support forΒ
torch.jit.script
Β withinΒMessagePassing
Β layers withoutΒtorch_sparse
Β being installed (#7061, #7062) - Added unbatching logic forΒ
torch.sparse.Tensor
(#7037) - Added support forΒ
Data.num_edges
Β for nativeΒtorch.sparse.Tensor
Β adjacency matrices (#7104) - Accelerated sparse tensor conversion routines (#7042, #7043)
- Added a sparseΒ
cross_entropy
Β implementation (#7447, #7466)
Integration with 3rd-party libraries
- AddedΒ
FlopsCount
Β support viaΒfvcore
Β (#7693) - AddedΒ
to_dgl
Β andΒfrom_dgl
Β conversion functions (#7053)
torch_geometric.transforms
- All transforms are now immutable, i.e. they perform a shallow-copy of the data and therefore do not longer modify data in-place (#7429)
- Added the
HalfHop
graph upsampling augmentation (#7827) - Added interval argument to
Cartesian
,LocalCartesian
andDistance
transformations (#7533, #7614, #7700) - Added an optional
add_pad_mask
argument to thePad
transform (#7339) - Added
NodePropertySplit
transformation for creating node-level splits using structural node properties (#6894) - Added a
AddRemainingSelfLoops
transformation (#7192)
Bugfixes
- Fixed
HeteroConv
for layers that have a non-default argument order, e.g.,GCN2Conv
(#8166) - Handle reserved keywords as keys in
ModuleDict
andParameterDict
(#8163) - Fixed
DynamicBatchSampler.__len__
to raise an error in casenum_steps
is undefined (#8137) - Enabled pickling of
DimeNet
models (#8019) - Fixed a bug in which
batch.e_id
was not correctly computed on unsorted graph inputs (#7953) - Fixed
from_networkx
conversion fromnx.stochastic_block_model
graphs (#7941) - Fixed the usage of
bias_initializer
inHeteroLinear
(#7923) - Fixed broken URLs in
HGBDataset
(#7907) - Fixed an issue where
SetTransformerAggregation
produced NaN values for isolates nodes (#7902) - Fixed
summary
on modules with uninitialized parameters (#7884) - Fixed tracing of
add_self_loops
for a dynamic number of nodes (#7330) - Fixed device issue in
PNAConv.get_degree_histogram
(#7830) - Fixed the shape of
edge_label_time
when using temporal sampling on homogeneous graphs (#7807) - Fixed
edge_label_index
computation inLinkNeighborLoader
for the homogeneous+disjoint mode (#7791) - Fixed
CaptumExplainer
for binary classification tasks (#7787) - Raise error when collecting non-existing attributes in
HeteroData
(#7714) - Fixed
get_mesh_laplacian
fornormalization="sym"
(#7544) - Use
dim_size
to initialize output size of theEquilibriumAggregation
layer (#7530) - Fixed empty edge indices handling in
SparseTensor
(#7519) - Move the
scaler
tensor inGeneralConv
to the correct device (#7484) - Fixed
HeteroLinear
bug when used via mixed precision (#7473) - Fixed gradient computation of edge weights in
utils.spmm
(#7428) - Fixed an index-out-of-range bug in
QuantileAggregation
whendim_size
is passed (#7407) - Fixed a bug in
LightGCN.recommendation_loss()
to only use the embeddings of the nodes involved in the current mini-batch (#7384) - Fixed a bug in which inputs where modified in-place in
to_hetero_with_bases
(#7363) - Do not load
node_default
andedge_default
attributes infrom_networkx
(#7348) - Fixed
HGTConv
utility function_construct_src_node_feat
(#7194) - Fixed
subgraph
on unordered inputs (#7187) - Allow missing node types in
HeteroDictLinear
(#7185) - Fix
numpy
incompatiblity when reading files forPlanetoid
datasets (#7141) - Fixed crash of heterogeneous data loaders if node or edge types are missing (#7060, #7087)
- Allowed
CaptumExplainer
to be called multiple times in a row (#7391)
Changes
- Enabled dense eigenvalue computation in
AddLaplacianEigenvectorPE
for small-scale graphs (#8143) - Accelerated and simplified
top_k
computation inTopKPooling
(#7737) - Updated
GIN
implementation in benchmarks to apply sequential batch normalization (#7955) - Updated
QM9
data pre-processing to include the SMILES string (#7867) - Warn user when using the
training
flag into_hetero
modules (#7772) - Changed
add_random_edge
to only add true negative edges (#7654) - Allowed the usage of
BasicGNN
models inDeepGraphInfomax
(#7648) - Added a
num_edges
parameter to the forward method ofHypergraphConv
(#7560) - Added a
max_num_elements
parameter to the forward method ofGraphMultisetTransformer
,GRUAggregation
,LSTMAggregation
,SetTransformerAggregation
andSortAggregation
(#7529, #7367) - Re-factored
ClusterLoader
to integratepyg-lib
METIS routine (#7416) - The
filter_per_worker
option will not get automatically inferred by default based on the device of the underlying data (#7399) - Added the option to pass
fill_value
as atorch.tensor
toutils.to_dense_batch
(#7367) - Updated examples to use
NeighborLoader
instead ofNeighborSampler
(#7152) - Extend dataset summary to create stats for each node/edge type (#7203)
- Added an optional
batch_size
argument toavg_pool_x
andmax_pool_x
(#7216) - Optimized
from_networkx
memory footprint by reducing unnecessary copies (#7119) - Added an optional
batch_size
argument toLayerNorm
,GraphNorm
,InstanceNorm
,GraphSizeNorm
andPairNorm
(#7135) - Accelerated attention-based
MultiAggregation
(#7077) - Edges in
HeterophilousGraphDataset
are now undirected by default (#7065) - Added an optional
batch_size
andmax_num_nodes
arguments toMemPooling
layer (#7239)
Full Changelog
Full Changelog: 2.3.0...2.4.0
New Contributors
- @zoryzhang made their first contribution in #7027
- @DomInvivo made their first contribution in #7037
- @OlegPlatonov made their first contribution in #7065
- @hbenedek made their first contribution in #7053
- @rishiagarwal2000 made their first contribution in #7011
- @sisaman made their first contribution in #7104
- @amorehead made their first contribution in #7110
- @EulerPascal404 made their first contribution in #7093
- @Looong01 made their first contribution in #7143
- @kamil-andrzejewski made their first contribution in #7135
- @andreazanetti made their first contribution in #7089
- @akihironitta made their first contribution in #7195
- @kjkozlowski made their first contribution in #7216
- @vstenby made their first contribution in #7221
- @piotrchmiel made their first contribution in #7239
- @vedal made their first contribution in #7272
- @gvbazhenov made their first contribution in #6894
- @Saydemr made their first contribution in #7313
- @HaoyuLu1022 made their first contribution in #7325
- @Vuenc made their first contribution in #7330
- @mewim made their first contribution in #7298
- @volltin made their first contribution in #7355
- @kasper-piskorski made their first contribution in #7377
- @happykygo made their first contribution in #7384
- @ThomasKLY made their first contribution in #7398
- @sky-2002 made their first contribution in #7421
- @denadai2 made their first contribution in #7456
- @chrisgo-gc made their first contribution in #7484
- @furkanakkurt1335 made their first contribution in #7507
- @mzamini92 made their first contribution in #7497
- @n-patricia made their first contribution in #7543
- @SalvishGoomanee made their first contribution in #7573
- @emalgorithm made their first contribution in #7458
- @marshka made their first contribution in #7595
- @djm93dev made their first contribution in #7598
- @NripeshN made their first contribution in #7770
- @ATheCoder made their first contribution in #7774
- @ebrahimpichka made their first contribution in #7775
- @kaidic made their first contribution in #7814
- @Wesxdz made their first contribution in #7811
- @daviddavo made their first contribution in #7888
- @frinkleko made their first contribution in #7907
- @chendiqian made their first contribution in #7917
- @rajveer43 made their first contribution in #7885
- @erfanloghmani made their first contribution in #7959
- @xnuohz made their first contribution in #7937
- @Favourj-bit made their first contribution in #7905
- @apfelsinecode made their first contribution in #7996
- @ArchieGertsman made their first contribution in #7976
- @bkmi made their first contribution in #8019
- @harshit5674 made their first contribution in #7919
- @erikhuck made their first contribution in #8024
- @jay-bhambhani made their first contribution in #8028
- @Barcavin made their first contribution in #8049
- @royvelich made their first contribution in #8048
- @CodeTal made their first contribution in #7611
- @filipekstrm made their first contribution in #8117
- @Anwar-Said made their first contribution in #8122
- @xYix made their first contribution in #8168