diff --git a/scitsr/data/loader.py b/scitsr/data/loader.py index a88c6d1..38f4e25 100644 --- a/scitsr/data/loader.py +++ b/scitsr/data/loader.py @@ -221,7 +221,7 @@ def get_vertexes(self, chunks): def get_vertex_features(self, vertexes): vertex_features = [] for vertex in vertexes: - features = [v for v in vertex.get_features().values()] + features = list(vertex.features.values()) vertex_features.append(features) return vertex_features @@ -238,7 +238,7 @@ def get_edges(self, relations, vertexes): edge_features = [] for i, j, _ in relations: edge = Edge(vertexes[i], vertexes[j]) - features = [v for v in edge.get_features().values()] + features = list(edge.features.values()) edge_features.append(features) return edge_features @@ -310,4 +310,4 @@ def to_tensors(self, nodes, edges, adj, incidence): edges = torch.tensor(edges, dtype=torch.float) adj = torch.tensor(adj, dtype=torch.long) incidence = torch.tensor(incidence, dtype=torch.long) - return nodes, edges, adj, incidence \ No newline at end of file + return nodes, edges, adj, incidence