Skip to content

Commit

Permalink
fixed: hypergcn model
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhihenpidehou committed Dec 10, 2024
1 parent 06c504f commit 6c133cf
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 17 deletions.
22 changes: 7 additions & 15 deletions easygraph/datasets/hypergraph/cat_edge_Cooking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self, data_root=None):
self.edge_labels_path = "https://gitlab.com/easy-graph/easygraph-data-cat-edge-cooking/-/raw/main/hyperedge-labels.txt?ref_type=heads&inline=false"
self.node_names_path = "https://gitlab.com/easy-graph/easygraph-data-cat-edge-cooking/-/raw/main/main/node-labels.txt?ref_type=heads&inline=false"
self.label_names_path = "https://gitlab.com/easy-graph/easygraph-data-cat-edge-cooking/-/raw/main/hyperedge-label-identities.txt?ref_type=heads&inline=false"
self.hyperedges_path = []
self.edge_labels_path = []
self.node_names_path = []
self.label_names_path = []
# self.hyperedges_path = []
# self.edge_labels_path = []
# self.node_names_path = []
# self.label_names_path = []
self.generate_hypergraph(
hyperedges_path=self.hyperedges_path,
edge_labels_path=self.edge_labels_path,
Expand Down Expand Up @@ -89,27 +89,19 @@ def fun(data):
self._hyperedges.append(tuple(hyperedge))
# print(self.hyperedges)

edge_labels_info = request_text_from_url(edge_labels_path)
edge_labels_info = request_text_from_url(self.edge_labels_path)
process_node_labels_info = self.process_label_txt(
node_labels_info, transform_fun=fun
)
self._edge_labels = process_edge_labels_info
# print("process_node_labels_info:", process_node_labels_info)
self._edge_labels = process_edge_labels_info()

node_names_info = request_text_from_url(node_names_path)
process_node_names_info = self.process_label_txt(node_names_info)
self._node_names = process_node_names_info

# print("process_node_names_info:", process_node_names_info)
label_names_info = request_text_from_url(label_names_path)
process_label_names_info = self.process_label_txt(label_names_info)
self._label_names = process_label_names_info
# print("process_label_names_info:", process_label_names_info)


#
# if __name__ == "__main__":
# a = House_Committees()
# print(a.node_labels)
# print(a.label_names)
# print(a.node_names)

4 changes: 2 additions & 2 deletions easygraph/model/hypergraphs/hypergcn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from easygraph.classes import Graph
from easygraph.classes import Hypergraph
from easygraph.nn import HyperGCNConv


Expand Down Expand Up @@ -56,7 +56,7 @@ def forward(self, X: torch.Tensor, hg: "eg.Hypergraph") -> torch.Tensor:
"""
if self.fast:
if self.cached_g is None:
self.cached_g = Graph.from_hypergraph_hypergcn(
self.cached_g = Hypergraph.from_hypergraph_hypergcn(
hg, X, self.with_mediator
)
for layer in self.layers:
Expand Down

0 comments on commit 6c133cf

Please sign in to comment.