diff --git a/tgx/classes/graph.py b/tgx/classes/graph.py index 663de41..cd79fe8 100644 --- a/tgx/classes/graph.py +++ b/tgx/classes/graph.py @@ -1,9 +1,10 @@ # import networkx as nx +import copy +import csv +import numpy as np from typing import Optional, Union from tgx.utils.graph_utils import discretize_edges, frequency_count, subsampling from tgx.io.read import read_csv -import copy -import csv #TODO should contain a new property tracking the number of timestamps#TODO should contain a new property tracking the number of timestamps class Graph(object): @@ -33,6 +34,31 @@ def __init__(self, self.freq_data = None self.id_map = None #a map from original node id to new node id based on their order of appearance + #TODO support edge features, edge weights, node features and more, currently supports, timestamp, source, destination + def export_full_data(self): + """ + convert self.data inot a dictionary of numpy arrays similar to TGB LinkPropPredDataset + """ + num_edge = self.number_of_edges() + sources = np.zeros(num_edge, dtype=np.int64) + destinations = np.zeros(num_edge, dtype=np.int64) + timestamps = np.zeros(num_edge, dtype=np.int64) + idx = 0 + edgelist = self.data + + for ts, edge_data in edgelist.items(): + for u,v in edge_data.keys(): + sources[idx] = u + destinations[idx] = v + timestamps[idx] = ts + idx += 1 + full_data = { + "sources": sources, + "destinations": destinations, + "timestamps": timestamps, + } + return full_data + def shift_time_to_zero(self) -> None: r""" shift all edges in the dataset to start with timestamp 0