Skip to content

Commit

Permalink
Add LinkPredCoverage metric (#10006)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 7, 2025
1 parent 2f1e4f2 commit b432363
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `Coverage` metric for link prediction ([#10006](https://github.com/pyg-team/pytorch_geometric/pull/10006))
- Added Graph Transformer Tutorial ([#8144](https://github.com/pyg-team/pytorch_geometric/pull/8144))
- Consolidate Cugraph examples into ogbn_train_cugraph.py and ogbn_train_cugraph_multigpu.py for ogbn-arxiv, ogbn-products and ogbn-papers100M ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
Expand Down
24 changes: 24 additions & 0 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch_geometric.metrics import (
LinkPredCoverage,
LinkPredF1,
LinkPredMAP,
LinkPredMetricCollection,
Expand Down Expand Up @@ -178,6 +179,29 @@ def test_mrr():
metric.reset()


def test_coverage():
pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])

metric = LinkPredCoverage(k=2, num_dst_nodes=3)
assert str(metric) == 'LinkPredCoverage(k=2, num_dst_nodes=3)'
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()
metric.reset()
assert metric.mask.sum() == 0

assert float(result) == 1.0

metric = LinkPredCoverage(k=1, num_dst_nodes=4)
assert str(metric) == 'LinkPredCoverage(k=1, num_dst_nodes=4)'
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()
metric.reset()
assert metric.mask.sum() == 0

assert float(result) == 2 / 4


@pytest.mark.parametrize('num_src_nodes', [10])
@pytest.mark.parametrize('num_dst_nodes', [50])
@pytest.mark.parametrize('num_edges', [200])
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
LinkPredMAP,
LinkPredNDCG,
LinkPredMRR,
LinkPredCoverage,
)

link_pred_metrics = [
Expand All @@ -20,6 +21,7 @@
'LinkPredMAP',
'LinkPredNDCG',
'LinkPredMRR',
'LinkPredCoverage',
]

__all__ = link_pred_metrics
49 changes: 49 additions & 0 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,52 @@ def _compute(self, data: LinkPredMetricData) -> Tensor:
device = pred_rel_mat.device
arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)
return (pred_rel_mat / arange).max(dim=-1)[0]


class LinkPredCoverage(BaseMetric):
r"""A link prediction metric to compute the Coverage @ :math:`k`.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
is_differentiable: bool = False
full_state_update: bool = True
higher_is_better: bool = True
weighted: bool = False

def __init__(self, k: int, num_dst_nodes: int) -> None:
super().__init__()

if k <= 0:
raise ValueError(f"'k' needs to be a positive integer in "
f"'{self.__class__.__name__}' (got {k})")

self.k = k
self.num_dst_nodes = num_dst_nodes

mask = torch.zeros(num_dst_nodes, dtype=torch.bool)
if WITH_TORCHMETRICS:
self.add_state('mask', mask, dist_reduce_fx='max')
else:
self.register_buffer('mask', mask)

def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
self.mask[pred_index_mat[:, :self.k].view(-1)] = True

def compute(self) -> Tensor:
return self.mask.to(torch.get_default_dtype()).mean()

def reset(self) -> None:
if WITH_TORCHMETRICS:
super().reset()
else:
self.mask.zero_()

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(k={self.k}, '
f'num_dst_nodes={self.num_dst_nodes})')

0 comments on commit b432363

Please sign in to comment.