Skip to content

Commit

Permalink
Merge branch 'master' into tagdataset/add-llm-exp-pred
Browse files Browse the repository at this point in the history
  • Loading branch information
xnuohz authored Feb 12, 2025
2 parents 9670e2b + 1ab3993 commit 486f880
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Added `Personalization` metric for link prediction ([#10015](https://github.com/pyg-team/pytorch_geometric/pull/10015))
- Added `HitRatio` metric for link prediction ([#10013](https://github.com/pyg-team/pytorch_geometric/pull/10013))
- Added Data Splitting Tutorial ([#8366](https://github.com/pyg-team/pytorch_geometric/pull/8366))
- Added `Diversity` metric for link prediction ([#10009](https://github.com/pyg-team/pytorch_geometric/pull/10009))
Expand Down
24 changes: 24 additions & 0 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
LinkPredMetricCollection,
LinkPredMRR,
LinkPredNDCG,
LinkPredPersonalization,
LinkPredPrecision,
LinkPredRecall,
)
from torch_geometric.testing import withCUDA


@pytest.mark.parametrize('num_src_nodes', [100])
Expand Down Expand Up @@ -235,6 +237,28 @@ def test_diversity():
assert pytest.approx(float(result)) == (1 + 2 / 3) / 2


@withCUDA
def test_personalization(device):
pred_index_mat = torch.tensor([[0, 1, 2, 3], [2, 1, 0, 4], [1, 0, 2, 5]],
device=device)
edge_label_index = torch.empty(2, 0, dtype=torch.long, device=device)

metric = LinkPredPersonalization(k=4).to(device)
assert str(metric) == 'LinkPredPersonalization(k=4)'
metric.update(pred_index_mat, edge_label_index)
result = metric.compute()
assert result.device == device
assert float(result) == 0.25
metric.reset()
assert metric.preds == []

metric.update(pred_index_mat[:0], edge_label_index)
result = metric.compute()
assert result.device == device
assert float(result) == 0.0
metric.reset()


@pytest.mark.parametrize('num_src_nodes', [10])
@pytest.mark.parametrize('num_dst_nodes', [50])
@pytest.mark.parametrize('num_edges', [200])
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
LinkPredHitRatio,
LinkPredCoverage,
LinkPredDiversity,
LinkPredPersonalization,
)

link_pred_metrics = [
Expand All @@ -26,6 +27,7 @@
'LinkPredHitRatio',
'LinkPredCoverage',
'LinkPredDiversity',
'LinkPredPersonalization',
]

__all__ = link_pred_metrics
83 changes: 83 additions & 0 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,86 @@ def compute(self) -> Tensor:
def _reset(self) -> None:
self.accum.zero_()
self.total.zero_()


class LinkPredPersonalization(_LinkPredMetric):
r"""A link prediction metric to compute the Personalization @ :math:`k`,
*i.e.* the dissimilarity of recommendations across different users.
Higher personalization suggests that the model tailors recommendations to
individual user preferences rather than providing generic results.
Dissimilarity is defined by the average inverse cosine similarity between
users' lists of recommendations.
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
batch_size (int, optional): The batch size to determine how many pairs
of user recommendations should be processed at once.
(default: :obj:`2**16`)
"""
def __init__(self, k: int, batch_size: int = 2**16) -> None:
super().__init__(k)
self.batch_size = batch_size

if WITH_TORCHMETRICS:
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state('dev_tensor', torch.empty(0), dist_reduce_fx='sum')
else:
self.preds: List[Tensor] = []
self.register_buffer('dev_tensor', torch.empty(0))

def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:
# NOTE Move to CPU to avoid memory blowup.
self.preds.append(pred_index_mat[:, :self.k].cpu())

def compute(self) -> Tensor:
device = self.dev_tensor.device
score = torch.tensor(0.0, device=device)
total = torch.tensor(0, device=device)

if len(self.preds) == 0:
return score

pred = torch.cat(self.preds, dim=0)

if pred.size(0) == 0:
return score

# Calculate all pairs of nodes (e.g., triu_indices with offset=1).
# NOTE We do this in chunks to avoid memory blow-up, which leads to a
# more efficient but trickier implementation.
num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2
offset = torch.arange(pred.size(0) - 1, 0, -1, device=device)
rowptr = cumsum(offset)
for start in range(0, num_pairs, self.batch_size):
end = min(start + self.batch_size, num_pairs)
idx = torch.arange(start, end, device=device)

# Find the corresponding row:
row = torch.searchsorted(rowptr, idx, right=True) - 1
# Find the corresponding column:
col = idx - rowptr[row] + (pred.size(0) - offset[row])

left = pred[row.cpu()].to(device)
right = pred[col.cpu()].to(device)

# Use offset to work around applying `isin` along a specific dim:
i = max(left.max(), right.max()) + 1 # type: ignore
i = torch.arange(0, i * row.size(0), i, device=device).view(-1, 1)
isin = torch.isin(left + i, right + i)

# Compute personalization via average inverse cosine similarity:
cos = isin.sum(dim=-1) / pred.size(1)
score += (1 - cos).sum()
total += cos.numel()

return score / total

def _reset(self) -> None:
self.preds = []

0 comments on commit 486f880

Please sign in to comment.