Skip to content

Commit

Permalink
fixed: setgnn; feat: load_line_expansion_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhihenpidehou committed Apr 27, 2024
1 parent 9bbe8cc commit b1ed256
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,5 @@ Icon
Network Trash Folder
Temporary Items
.apdisk
allset_test.py

74 changes: 74 additions & 0 deletions easygraph/datasets/hypergraph/loadDeepSetDatasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import os.path as osp
import numpy as np
import scipy.sparse as sp
from torch_geometric.data import Data
from torch_sparse import coalesce

__all__ = ["load_line_expansion_dataset"]
def load_line_expansion_dataset(path=None, dataset="cocitation-cora", train_percent = 0.5):
# load edges, features, and labels.
print('Loading {} dataset...'.format(dataset))

file_name = f'{dataset}.content'
p2idx_features_labels = osp.join(path, dataset, file_name)
idx_features_labels = np.genfromtxt(p2idx_features_labels,
dtype=np.dtype(str))
# features = np.array(idx_features_labels[:, 1:-1])
features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
# labels = encode_onehot(idx_features_labels[:, -1])
labels = torch.LongTensor(idx_features_labels[:, -1].astype(float))

print('load features')

# build graph
idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}

file_name = f'{dataset}.edges'
p2edges_unordered = osp.join(path, dataset, file_name)
edges_unordered = np.genfromtxt(p2edges_unordered,
dtype=np.int32)

edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)

print('load edges')


# From adjacency matrix to edge_list
edge_index = edges.T
# ipdb.set_trace()
assert edge_index[0].max() == edge_index[1].min() - 1

# check if values in edge_index is consecutive. i.e. no missing value for node_id/he_id.
assert len(np.unique(edge_index)) == edge_index.max() + 1

num_nodes = edge_index[0].max() + 1
num_he = edge_index[1].max() - num_nodes + 1
edge_index = np.hstack((edge_index, edge_index[::-1, :]))

# build torch data class
data = Data(
x=torch.FloatTensor(np.array(features[:num_nodes].todense())),
edge_index=torch.LongTensor(edge_index),
y=labels[:num_nodes])


# used user function to override the default function.
# the following will also sort the edge_index and remove duplicates.
total_num_node_id_he_id = len(np.unique(edge_index))
data.edge_index, data.edge_attr = coalesce(data.edge_index,
None,
total_num_node_id_he_id,
total_num_node_id_he_id)
n_x = num_nodes
# n_x = n_expanded
num_class = len(np.unique(labels[:num_nodes].numpy()))
data.n_x = n_x
# add parameters to attribute

data.train_percent = train_percent
data.num_hyperedges = num_he

return data
52 changes: 47 additions & 5 deletions easygraph/model/hypergraphs/setgnn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(
GPR=False,
LearnMask=False,
norm=None,
self_loop=True,
):
super(SetGNN, self).__init__()
"""
Expand All @@ -76,7 +79,8 @@ def __init__(
self.E2VConvs = nn.ModuleList()
self.bnV2Es = nn.ModuleList()
self.bnE2Vs = nn.ModuleList()

self.edge_index = None
self.self_loop = self_loop
if self.LearnMask:
self.Importance = nn.Parameter(torch.ones(norm.size()))

Expand Down Expand Up @@ -180,6 +184,43 @@ def __init__(
InputNorm=False,
)

def generate_edge_index(self, dataset, self_loop=False):
edge_list = dataset["edge_list"]
e_ind = 0
edge_index = [[], []]
for e in edge_list:
for n in e:
edge_index[0].append(n)
edge_index[1].append(e_ind)
e_ind += 1
edge_index = torch.tensor(edge_index).type(torch.LongTensor)
if self_loop:
hyperedge_appear_fre = Counter(edge_index[1].numpy())
skip_node_lst = []
for edge in hyperedge_appear_fre:
if hyperedge_appear_fre[edge] == 1:
skip_node = edge_index[0][torch.where(edge_index[1] == edge)[0]]
skip_node_lst.append(skip_node)
num_nodes = dataset["num_vertices"]
new_edge_idx = len(edge_index[1]) + 1
new_edges = torch.zeros(
(2, num_nodes - len(skip_node_lst)), dtype=edge_index.dtype
)
tmp_count = 0
for i in range(num_nodes):
if i not in skip_node_lst:
new_edges[0][tmp_count] = i
new_edges[1][tmp_count] = new_edge_idx
new_edge_idx += 1
tmp_count += 1

edge_index = torch.Tensor(edge_index).type(torch.LongTensor)
edge_index = torch.cat((edge_index, new_edges), dim=1)
_, sorted_idx = torch.sort(edge_index[0])
edge_index = torch.Tensor(edge_index[:, sorted_idx]).type(torch.LongTensor)

return edge_index

def reset_parameters(self):
for layer in self.V2EConvs:
layer.reset_parameters()
Expand All @@ -206,17 +247,18 @@ def forward(self, data):
data.norm: The weight for edges in bipartite graphs, correspond to data.edge_index
!!! Note that we output final node representation. Loss should be defined outside.
"""

x, edge_index= data["features"], data["edge_index"]
if self.edge_index is None:
self.edge_index = self.generate_edge_index(data, self.self_loop)
# print("generate_edge_index:", self.edge_index.shape)
x, edge_index = data["features"], self.edge_index
if data["weight"] == None:
norm = torch.ones(edge_index.size()[1])
else:
norm = data["weight"]

if self.LearnMask:
norm = self.Importance * norm
cidx = min(edge_index[1])
# edge_index[1] -= cidx # make sure we do not waste memory

reversed_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
if self.GPR:
xs = []
Expand Down
6 changes: 3 additions & 3 deletions easygraph/nn/convs/hypergraphs/halfnlh_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -95,7 +96,7 @@ def forward(self, x, edge_index, norm, aggr="add"):
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j

def aggregate(self, inputs, index, dim_size=None, aggr="add"):
def aggregate(self, inputs, index, dim_size=None, aggr="sum"):
r"""Aggregates messages from neighbors as
:math:`\square_{j \in \mathcal{N}(i)}`.
Expand All @@ -107,6 +108,5 @@ def aggregate(self, inputs, index, dim_size=None, aggr="add"):
:meth:`__init__` by the :obj:`aggr` argument.
"""
# ipdb.set_trace()
if aggr is None:
raise ValueError("aggr was not passed!")

return scatter(inputs, index, dim=self.node_dim, reduce=aggr)

0 comments on commit b1ed256

Please sign in to comment.