Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Adds TransD KGE, and Bernoulli corruption strategy for all KGE #9864

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
eef03e5
Initial non-interactive implementation for GATConv
mattjhayes3 Nov 6, 2024
4d1ac24
initial copy for diffing
mattjhayes3 Dec 13, 2024
c55de3e
Check in the TransD and bern impls
mattjhayes3 Dec 14, 2024
1082557
full training set bernoulli
mattjhayes3 Dec 14, 2024
c450f41
back to batchwise bern
mattjhayes3 Dec 14, 2024
6755ec8
addl tests, add bern to other variants
mattjhayes3 Dec 14, 2024
2354e7d
fix last commit
mattjhayes3 Dec 14, 2024
9a3e3a8
Ensure *_idx are fully distributed across nodes/devices (#9753)
Kh4L Nov 13, 2024
c98e204
Fix `utils.group_cat` concatenating dimension (#9766)
tzuhanchang Nov 13, 2024
e62a075
Add an error message for batching when `num_nodes` is unknown (#9743)
wzm2256 Nov 13, 2024
150ec97
Improve `__inc__` error message and add tests (#9778)
rusty1s Nov 13, 2024
4884c35
Fix arrow direction in `_visualize_graph_via_networkx` (#9773)
darabos Nov 13, 2024
c29c2b8
Added functionality of `FaceToEdge` transform to work for 3D tetrahed…
Aiik Nov 13, 2024
65d3e2c
Upgrade CI to PyTorch 2.5 (#9779)
rusty1s Nov 13, 2024
6b17856
Cancel intermediate CI builds (#9781)
rusty1s Nov 13, 2024
dd67b98
Add support for `torch_delaunay` package in `Delaunay` transformation…
ybubnov Nov 13, 2024
5b7325f
Fixing `edge_mask` handling for directed graphs in `k_hop_subgraph` (…
ryoji-kubo Nov 13, 2024
7119785
Added PyTorch 2.5 support (#9780)
rusty1s Nov 13, 2024
197eac5
Drop `TensorAttr.fully_specify` (#9782)
akihironitta Nov 13, 2024
23c4844
solve issue 9755 (fix typo) (#9790)
goelzva Nov 18, 2024
b8c4278
add GLEM model, TAGDataset and example of GLEM (#9662)
ECMGit Nov 19, 2024
8a501a8
Add MoleculeGPT (#9710)
xnuohz Nov 20, 2024
aa61d21
Add comment in `g_retriever.py` pointing to `Neo4j` Graph DB integrat…
puririshi98 Nov 20, 2024
ec09b86
Add GIT-Mol (#9730)
xnuohz Nov 25, 2024
073aa52
Run `GitMolDataset` tests only in full test mode (#9804)
rusty1s Nov 25, 2024
625561b
fix for cugraph (#9803)
puririshi98 Nov 25, 2024
72cea9a
G-retriever API updates (NVTX, Remote Backend, Large Graph Indexer, E…
zaristei Nov 26, 2024
d163641
Check that custom edge types actually exist in `NumNeighbors` definit…
rusty1s Nov 26, 2024
25d755f
Fix typo in Dataset docstring (#9813)
abertics Nov 28, 2024
1e1142a
updated Dockerfile based on NGC PyG 24.09 image (#9794)
sbhavani Dec 6, 2024
45bcbf2
Revert "Initial non-interactive implementation for GATConv"
mattjhayes3 Dec 14, 2024
8cf1629
Fix Docstring Typos for LargeGraphIndexer (#9837)
zaristei Dec 10, 2024
575825f
feat: store reverse mapping within `EdgeTypeStr` (#9844)
mananshah99 Dec 11, 2024
78426fa
minor cleanups
mattjhayes3 Dec 15, 2024
e20f018
checkin gat opt
mattjhayes3 Dec 14, 2024
1c4e33d
Merge branch 'master' into kge
mattjhayes3 Dec 15, 2024
f38aa94
Revert "checkin gat opt"
mattjhayes3 Dec 15, 2024
d33dd69
changelog
mattjhayes3 Dec 15, 2024
a65235c
documentation fixes
mattjhayes3 Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add TransD KGE and Bernoulli corruption ([#9864](https://github.com/pyg-team/pytorch_geometric/pull/9864))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
18 changes: 15 additions & 3 deletions examples/kge_fb15k_237.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import argparse
import os.path as osp
import time

import torch
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransD, TransE

model_map = {
'transe': TransE,
'complex': ComplEx,
'distmult': DistMult,
'rotate': RotatE,
'transd': TransD,
}

parser = argparse.ArgumentParser()
parser.add_argument('--model', choices=model_map.keys(), type=str.lower,
required=True)
parser.add_argument('--bern', action='store_true')
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -26,11 +29,17 @@
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

model_arg_map = {'rotate': {'margin': 9.0}}
model_arg_map = {model: {'hidden_channels': 50} for model in model_map.keys()}
model_arg_map['rotate']['margin'] = 9.0
model_arg_map['transd'] = {
'hidden_channels_node': 50,
'hidden_channels_rel': 50,
'bern': args.bern,
}

model = model_map[args.model](
num_nodes=train_data.num_nodes,
num_relations=train_data.num_edge_types,
hidden_channels=50,
**model_arg_map.get(args.model, {}),
).to(device)

Expand All @@ -44,6 +53,7 @@

optimizer_map = {
'transe': optim.Adam(model.parameters(), lr=0.01),
'transd': optim.Adam(model.parameters(), lr=0.01),
'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
'rotate': optim.Adam(model.parameters(), lr=1e-3),
Expand Down Expand Up @@ -76,11 +86,13 @@ def test(data):
)


start = time.time()
for epoch in range(1, 501):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, mrr, hits = test(val_data)
print(f"Time: {(time.time() - start) / epoch}")
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')

Expand Down
7 changes: 5 additions & 2 deletions test/nn/kge/test_complex.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch

from torch_geometric.nn import ComplEx
Expand Down Expand Up @@ -37,8 +38,10 @@ def test_complex_scoring():
assert score.tolist() == [58., 8.]


def test_complex():
model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32)
@pytest.mark.parametrize('bern', [False, True])
def test_complex(bern):
model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32,
bern=bern)
assert str(model) == 'ComplEx(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
Expand Down
7 changes: 5 additions & 2 deletions test/nn/kge/test_distmult.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest
import torch

from torch_geometric.nn import DistMult


def test_distmult():
model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32)
@pytest.mark.parametrize('bern', [False, True])
def test_distmult(bern):
model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32,
bern=bern)
assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
Expand Down
7 changes: 5 additions & 2 deletions test/nn/kge/test_rotate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest
import torch

from torch_geometric.nn import RotatE


def test_rotate():
model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32)
@pytest.mark.parametrize('bern', [False, True])
def test_rotate(bern):
model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32,
bern=bern)
assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
Expand Down
33 changes: 33 additions & 0 deletions test/nn/kge/test_transd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch

from torch_geometric.nn import TransD


@pytest.mark.parametrize('channels_node_rel', [(16, 32), {32, 16}])
@pytest.mark.parametrize('bern', [False, True])
def test_transd(channels_node_rel, bern):
channels_node, channels_rel = channels_node_rel
model = TransD(num_nodes=10, num_relations=5,
hidden_channels_node=channels_node,
hidden_channels_rel=channels_rel, bern=bern)
assert str(model) == ('TransD(10, num_relations=5,'
f' hidden_channels_node={channels_node},'
f' hidden_channels_rel={channels_rel})')

head_index = torch.tensor([0, 2, 4, 6, 8])
rel_type = torch.tensor([0, 1, 2, 3, 4])
tail_index = torch.tensor([1, 3, 5, 7, 9])

loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
for h, r, t in loader:
out = model(h, r, t)
assert out.size() == (5, )

loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
7 changes: 5 additions & 2 deletions test/nn/kge/test_transe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest
import torch

from torch_geometric.nn import TransE


def test_transe():
model = TransE(num_nodes=10, num_relations=5, hidden_channels=32)
@pytest.mark.parametrize('bern', [False, True])
def test_transe(bern):
model = TransE(num_nodes=10, num_relations=5, hidden_channels=32,
bern=bern)
assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/kge/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
r"""Knowledge Graph Embedding (KGE) package."""

from .base import KGEModel
from .transd import TransD
from .transe import TransE
from .complex import ComplEx
from .distmult import DistMult
Expand All @@ -9,6 +10,7 @@
__all__ = classes = [
'KGEModel',
'TransE',
'TransD',
'ComplEx',
'DistMult',
'RotatE',
Expand Down
42 changes: 37 additions & 5 deletions torch_geometric/nn/kge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
from tqdm import tqdm

from torch_geometric.nn.kge.loader import KGTripletLoader
from torch_geometric.utils import scatter


def _avg_count_per_r(x_idx, r_idx):
# Assume no duplicate triples, so e.g. each occurence of a tail index
# represents a different head to count for that tail.

# Map the tuple (x_idx, r_idx) to unique indices in a new combined index.
num_x = x_idx.max() + 1
rx_idx = r_idx * num_x + x_idx
# Get counts of each unique (x_idx, r_idx) pair.
unique_rx, rx_counts = torch.unique(rx_idx, return_counts=True)
# Average those counts grouped by r_idx.
r_idx = unique_rx // num_x
return scatter(rx_counts, r_idx, reduce='mean')


class KGEModel(torch.nn.Module):
Expand All @@ -24,12 +39,14 @@ def __init__(
num_relations: int,
hidden_channels: int,
sparse: bool = False,
bern: bool = False,
):
super().__init__()

self.num_nodes = num_nodes
self.num_relations = num_relations
self.hidden_channels = hidden_channels
self.bern = bern

self.node_emb = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb = Embedding(num_relations, hidden_channels, sparse=sparse)
Expand Down Expand Up @@ -150,16 +167,31 @@ def random_sample(
rel_type (torch.Tensor): The relation type.
tail_index (torch.Tensor): The tail indices.
"""
# Random sample either `head_index` or `tail_index` (but not both):
num_negatives = head_index.numel() // 2
rnd_index = torch.randint(self.num_nodes, head_index.size(),
device=head_index.device)

head_index = head_index.clone()
head_index[:num_negatives] = rnd_index[:num_negatives]
tail_index = tail_index.clone()
tail_index[num_negatives:] = rnd_index[num_negatives:]

if not self.bern:
# Random sample either `head_index` or `tail_index` (but not both):
num_negatives = head_index.numel() // 2

head_index[:num_negatives] = rnd_index[:num_negatives]
tail_index[num_negatives:] = rnd_index[num_negatives:]

return head_index, rel_type, tail_index

# Bernoulli: decide whether to corrupt the head or tail proportional to
# the number of heads per tail and tails per head for each relation.
# I.e. if there are more tails per head than heads per tail, we should
# corrupt the head more often to get fewer false negatives.
hpt = _avg_count_per_r(tail_index, rel_type)
tph = _avg_count_per_r(head_index, rel_type)
berns = tph / (tph + hpt)
head_mask = berns[rel_type].bernoulli().type(torch.bool)
tail_mask = ~head_mask
head_index[head_mask] = rnd_index[head_mask]
tail_index[tail_mask] = rnd_index[tail_mask]
return head_index, rel_type, tail_index

def __repr__(self) -> str:
Expand Down
21 changes: 13 additions & 8 deletions torch_geometric/nn/kge/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ class ComplEx(KGEModel):
.. math::
d(h, r, t) = Re(< \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t>)

This score is optimized with the :obj:`margin_ranking_loss` by creating
corrupted triplets. By default either the head or the tail of is corrupted
uniformly at random. When :obj:`bern=True`, the head or tail is chosen
proportional to the average number of heads per tail and tails per head for
the relation, as described in the `"Knowledge Graph Embedding by
Translating on Hyperplanes" <https://cdn.aaai.org/ojs/8870/
8870-13-12398-1-2-20201228.pdf>`_ paper.

.. note::

For an example of using the :class:`ComplEx` model, see
Expand All @@ -32,14 +40,11 @@ class ComplEx(KGEModel):
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)
def __init__(self, num_nodes: int, num_relations: int,
hidden_channels: int, sparse: bool = False,
bern: bool = False):
super().__init__(num_nodes, num_relations, hidden_channels, sparse,
bern)

self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb_im = Embedding(num_relations, hidden_channels,
Expand Down
12 changes: 11 additions & 1 deletion torch_geometric/nn/kge/distmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ class DistMult(KGEModel):
.. math::
d(h, r, t) = < \mathbf{e}_h, \mathbf{e}_r, \mathbf{e}_t >

This score is optimized with the :obj:`margin_ranking_loss` by creating
corrupted triplets. By default either the head or the tail of is corrupted
uniformly at random. When :obj:`bern=True`, the head or tail is chosen
proportional to the average number of heads per tail and tails per head for
the relation, as described in the `"Knowledge Graph Embedding by
Translating on Hyperplanes" <https://cdn.aaai.org/ojs/8870/
8870-13-12398-1-2-20201228.pdf>`_ paper.

.. note::

For an example of using the :class:`DistMult` model, see
Expand All @@ -40,8 +48,10 @@ def __init__(
hidden_channels: int,
margin: float = 1.0,
sparse: bool = False,
bern: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)
super().__init__(num_nodes, num_relations, hidden_channels, sparse,
bern)

self.margin = margin

Expand Down
22 changes: 13 additions & 9 deletions torch_geometric/nn/kge/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class RotatE(KGEModel):
.. math::
d(h, r, t) = - {\| \mathbf{e}_h \circ \mathbf{e}_r - \mathbf{e}_t \|}_p

This score is optimized with the :obj:`margin_ranking_loss` by creating
corrupted triplets. By default either the head or the tail of is corrupted
uniformly at random. When :obj:`bern=True`, the head or tail is chosen
proportional to the average number of heads per tail and tails per head for
the relation, as described in the `"Knowledge Graph Embedding by
Translating on Hyperplanes" <https://cdn.aaai.org/ojs/8870/
8870-13-12398-1-2-20201228.pdf>`_ paper.

.. note::

For an example of using the :class:`RotatE` model, see
Expand All @@ -39,15 +47,11 @@ class RotatE(KGEModel):
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to
the embedding matrices will be sparse. (default: :obj:`False`)
"""
def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
margin: float = 1.0,
sparse: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)
def __init__(self, num_nodes: int, num_relations: int,
hidden_channels: int, margin: float = 1.0,
sparse: bool = False, bern: bool = False):
super().__init__(num_nodes, num_relations, hidden_channels, sparse,
bern)

self.margin = margin
self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse)
Expand Down
Loading
Loading