Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for LargeGraphIndexer from_triples ordering issue #9952

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- fixed unique ordering in `LargeGraphIndexer.from_triplets` ([#9952](https://github.com/pyg-team/pytorch_geometric/pull/9952))
- Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756))
- Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766))
- Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665))
Expand Down
11 changes: 10 additions & 1 deletion test/data/test_large_graph_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

# create possible nodes and edges for graph
strkeys = string.ascii_letters + string.digits

NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)})
# insert a duplicate node to test the unique ordering of from_triples
NODE_POOL[1] = NODE_POOL[0]
EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)})


Expand Down Expand Up @@ -64,9 +67,15 @@ def test_basic_collate():
assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes)
assert len(set(big_indexer._edges.values())) == len(big_indexer._edges)

for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()):
for i, node in enumerate(indexer_0._nodes.keys()
| indexer_1._nodes.keys()):
assert big_indexer.node_attr[NODE_PID][
big_indexer._nodes[node]] == node
if i == 1:
node_0 = list(indexer_0._nodes.keys())[0]
assert big_indexer.node_attr[NODE_PID][
big_indexer._nodes[node_0]] != big_indexer.node_attr[NODE_PID][
big_indexer._nodes[node]]


def test_large_graph_index():
Expand Down
16 changes: 10 additions & 6 deletions torch_geometric/data/large_graph_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@


def ordered_set(values: Iterable[str]) -> List[str]:
# dicts mantain order while keeping unique keys
return list(dict.fromkeys(values))


Expand Down Expand Up @@ -172,8 +173,9 @@ def from_triplets(
LargeGraphIndexer: Index of unique nodes and edges.
"""
# NOTE: Right now assumes that all trips can be loaded into memory
nodes = set()
edges = set()
# initialize as a list and then remove duplicates
nodes = []
edges = []

if pre_transform is not None:

Expand All @@ -182,16 +184,18 @@ def apply_transform(
for trip in trips:
yield pre_transform(trip)

triplets = apply_transform(triplets)
triplets = list(apply_transform(triplets))

for h, r, t in triplets:

for node in (h, t):
nodes.add(node)
nodes.append(node)

edge_idx = (h, r, t)
edges.add(edge_idx)

edges.append(edge_idx)
# removes duplicates while mantaining order
nodes = ordered_set(nodes)
edges = ordered_set(edges)
return cls(list(nodes), list(edges))

@classmethod
Expand Down
Loading