Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
Signed-off-by: kaixuan <[email protected]>
  • Loading branch information
kaixuanliu committed Mar 15, 2024
1 parent 8a4f02f commit 3e94be1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/igbh/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ def partition_dataset(src_path: str,

if use_graph_caching:
base_path = osp.join(dst_path, f'{dataset_size}-partitions', 'graph')
convert_graph_layout(base_path, compress_edge_dict, layout, use_graph_caching)
convert_graph_layout(base_path, compress_edge_dict, layout)

else:
for pidx in range(num_partitions):
base_path = osp.join(dst_path, f'{dataset_size}-partitions', f'part{pidx}', 'graph')
convert_graph_layout(base_path, compress_edge_dict, layout, use_graph_caching)
convert_graph_layout(base_path, compress_edge_dict, layout)

if __name__ == '__main__':
root = osp.join(osp.dirname(osp.dirname(osp.dirname(osp.realpath(__file__)))), 'data', 'igbh')
Expand Down
42 changes: 27 additions & 15 deletions graphlearn_torch/python/partition/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def save_edge_pb(
def save_graph_cache(
output_dir: str,
graph_partition_list: List[GraphPartitionData],
etype: Optional[EdgeType] = None
etype: Optional[EdgeType] = None,
with_edge_feat: bool = False
):
r""" Save full graph topology into the output directory.
"""
Expand All @@ -116,6 +117,9 @@ def save_graph_cache(
weights = torch.cat([graph_partition.weights for graph_partition in graph_partition_list])
torch.save(rows, os.path.join(subdir, 'rows.pt'))
torch.save(cols, os.path.join(subdir, 'cols.pt'))
if with_edge_feat:
edge_ids = torch.cat([graph_partition.eids for graph_partition in graph_partition_list])
torch.save(edge_ids, os.path.join(subdir, 'eids.pt'))
if weights is not None:
torch.save(weights, os.path.join(subdir, 'weights.pt'))

Expand Down Expand Up @@ -545,14 +549,18 @@ def partition(self, with_feature=True, graph_caching=False):

for etype in self.edge_types:
graph_list, edge_pb = self._partition_graph(node_pb_dict, etype)
edge_feat = self.get_edge_feat(etype)
with_edge_feat = (edge_feat != None)
if graph_caching:
save_graph_cache(self.output_dir, graph_list, etype)
if with_edge_feat:
save_edge_pb(self.output_dir, edge_pb, etype)
save_graph_cache(self.output_dir, graph_list, etype, with_edge_feat)
else:
save_edge_pb(self.output_dir, edge_pb, etype)
for pidx in range(self.num_parts):
save_graph_partition(self.output_dir, pidx, graph_list[pidx], etype)
if with_feature:
self._partition_and_save_edge_feat(graph_list, etype)
if with_feature:
self._partition_and_save_edge_feat(graph_list, etype)

else:
node_ids_list, node_pb = self._partition_node()
Expand All @@ -561,14 +569,19 @@ def partition(self, with_feature=True, graph_caching=False):
self._partition_and_save_node_feat(node_ids_list)

graph_list, edge_pb = self._partition_graph(node_pb)
save_edge_pb(self.output_dir, edge_pb)
edge_feat = self.get_edge_feat()
with_edge_feat = (edge_feat != None)

if graph_caching:
save_graph_cache(self.output_dir, graph_list)
if with_edge_feat:
save_edge_pb(self.output_dir, edge_pb)
save_graph_cache(self.output_dir, graph_list, with_edge_feat)
else:
save_edge_pb(self.output_dir, edge_pb)
for pidx in range(self.num_parts):
save_graph_partition(self.output_dir, pidx, graph_list[pidx])
if with_feature:
self._partition_and_save_edge_feat(graph_list)
if with_feature:
self._partition_and_save_edge_feat(graph_list)

# save meta.
save_meta(self.output_dir, self.num_parts, self.data_cls,
Expand Down Expand Up @@ -843,13 +856,12 @@ def load_partition(
os.path.join(node_pb_dir, f'{as_str(ntype)}.pt'), map_location=device)

edge_pb_dict = {}
if not graph_caching:
edge_pb_dir = os.path.join(root_dir, 'edge_pb')
for etype in meta['edge_types']:
edge_pb_file = os.path.join(edge_pb_dir, f'{as_str(etype)}.pt')
if os.path.exists(edge_pb_file):
edge_pb_dict[etype] = torch.load(
edge_pb_file, map_location=device)
edge_pb_dir = os.path.join(root_dir, 'edge_pb')
for etype in meta['edge_types']:
edge_pb_file = os.path.join(edge_pb_dir, f'{as_str(etype)}.pt')
if os.path.exists(edge_pb_file):
edge_pb_dict[etype] = torch.load(
edge_pb_file, map_location=device)

return (
num_partitions, partition_idx,
Expand Down

0 comments on commit 3e94be1

Please sign in to comment.