Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed support for v6d GraphScope #116

Merged
merged 24 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ detect_cuda_*
.idea/
cmake-build-release/
cmake-build-debug/
/graphlearn_torch/python/*.so
40 changes: 30 additions & 10 deletions graphlearn_torch/python/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import logging
from multiprocessing.reduction import ForkingPickler
from typing import Dict, List, Optional, Union, Literal, Tuple
from enum import Enum
from typing import Dict, List, Optional, Union, Literal, Tuple, Callable
from collections.abc import Sequence

import torch

Expand Down Expand Up @@ -125,6 +125,7 @@ def random_node_split(
self,
num_val: Union[float, int],
num_test: Union[float, int],
id_filter: Callable = None
LiSu marked this conversation as resolved.
Show resolved Hide resolved
):
r"""Performs a node-level random split by adding :obj:`train_idx`,
:obj:`val_idx` and :obj:`test_idx` attributes to the
Expand Down Expand Up @@ -161,6 +162,7 @@ def load_vineyard(
node_features: Dict[NodeType, List[str]] = None,
edge_features: Dict[EdgeType, List[str]] = None,
node_labels: Dict[NodeType, str] = None,
id2idx: Dict[NodeType, Sequence] = None,
):
# TODO(hongyi): GPU support
is_homo = len(edges) == 1 and edges[0][0] == edges[0][2]
Expand Down Expand Up @@ -198,7 +200,7 @@ def load_vineyard(
load_vertex_feature_from_vineyard(vineyard_socket, vineyard_id, property_names, ntype)
if is_homo:
node_feature_data = node_feature_data[edges[0][0]]
self.init_node_features(node_feature_data=node_feature_data, with_gpu=False)
self.init_node_features(node_feature_data=node_feature_data, id2idx=id2idx, with_gpu=False)

# load edge features
if edge_features:
Expand All @@ -221,12 +223,13 @@ def load_vineyard(

if is_homo:
node_label_data = node_label_data[edges[0][0]]
self.init_node_labels(node_label_data=node_label_data)
self.init_node_labels(node_label_data=node_label_data, id2idx=id2idx, with_gpu=False)

def init_node_features(
self,
node_feature_data: Union[TensorDataType, Dict[NodeType, TensorDataType]] = None,
id2idx: Union[TensorDataType, Dict[NodeType, TensorDataType]] = None,
id2idx: Union[TensorDataType, Dict[NodeType, TensorDataType],
Sequence, Dict[NodeType, Sequence]] = None,
sort_func = None,
split_ratio: Union[float, Dict[NodeType, float]] = 0.0,
device_group_list: Optional[List[DeviceGroup]] = None,
Expand Down Expand Up @@ -331,7 +334,10 @@ def init_edge_features(

def init_node_labels(
self,
node_label_data: Union[TensorDataType, Dict[NodeType, TensorDataType]] = None
node_label_data: Union[TensorDataType, Dict[NodeType, TensorDataType]] = None,
id2idx: Union[TensorDataType, Dict[NodeType, TensorDataType],
Sequence, Dict[NodeType, Sequence]] = None,
with_gpu: bool = False,
):
r""" Initialize the node label storage.

Expand All @@ -341,7 +347,12 @@ def init_node_labels(
(default: ``None``)
"""
if node_label_data is not None:
self.node_labels = squeeze(convert_to_tensor(node_label_data))
self.node_labels = convert_to_tensor(node_label_data, dtype=torch.int64)
id2idx = convert_to_tensor(id2idx)
if id2idx is not None:
self.node_labels = _build_features(
self.node_labels, id2idx, 0.0, None, None, with_gpu, None
)

def init_node_split(
self,
Expand Down Expand Up @@ -413,11 +424,20 @@ def get_edge_feature(self, etype: Optional[EdgeType] = None):
return None

def get_node_label(self, ntype: Optional[NodeType] = None):
if isinstance(self.node_labels, torch.Tensor):
if isinstance(self.node_labels, dict) and ntype is not None:
if isinstance(self.node_labels[ntype], torch.Tensor):
self.node_labels[ntype] = Feature(self.node_labels[ntype].reshape(-1,1),
dtype=self.node_labels[ntype].dtype)
return self.node_labels.get(ntype, None)
if isinstance(self.node_labels, Feature):
return self.node_labels
if isinstance(self.node_labels, torch.Tensor):
return Feature(self.node_labels.reshape(-1,1), dtype=self.node_labels.dtype)
if isinstance(self.node_labels, dict):
assert ntype is not None
return self.node_labels.get(ntype, None)
for ntype, labels in self.node_labels.items():
if isinstance(labels, torch.Tensor):
self.node_labels[ntype] = Feature(labels.reshape(-1,1), dtype=labels.dtype)
return self.node_labels
return None

def __getitem__(self, key):
Expand Down
7 changes: 4 additions & 3 deletions graphlearn_torch/python/data/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import threading
from multiprocessing.reduction import ForkingPickler
from typing import List, Optional
from typing import List, Optional, Union
from collections.abc import Sequence

import torch

Expand Down Expand Up @@ -100,7 +101,7 @@ class Feature(object):
"""
def __init__(self,
feature_tensor: TensorDataType,
id2index: Optional[torch.Tensor] = None,
id2index: Optional[Union[torch.Tensor, Sequence]] = None,
split_ratio: float = 0.0,
device_group_list: Optional[List[DeviceGroup]] = None,
device: Optional[int] = None,
Expand Down Expand Up @@ -210,7 +211,7 @@ def share_ipc(self):
if self._ipc_handle is not None:
return self._ipc_handle

if self.id2index is not None:
if self.id2index is not None and isinstance(self.id2index, torch.Tensor):
self.id2index = self.id2index.cpu()
self.id2index.share_memory_()

Expand Down
92 changes: 91 additions & 1 deletion graphlearn_torch/python/data/vineyard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@
# ==============================================================================

try:
from .. import py_graphlearn_torch_vineyard as pywrap
import torch
from typing import Dict
from collections.abc import Sequence

from .. import py_graphlearn_torch_vineyard as pywrap


except ImportError:
pass

from ..partition import PartitionBook


def vineyard_to_csr(sock, fid, v_label_name, e_label_name, edge_dir, haseid=0):
'''
Wrap to_csr function to read graph from vineyard
Expand All @@ -42,3 +51,84 @@ def load_edge_feature_from_vineyard(sock, fid, ecols, e_label_name):
return edge_feature(torch.Tensor)
'''
return pywrap.load_edge_feature_from_vineyard(sock, fid, e_label_name, ecols)


def get_fid_from_gid(gid):
'''
Wrap get_fid_from_gid function to get fid from gid
'''
return pywrap.get_fid_from_gid(gid)


def get_frag_vertex_offset(sock, fid, v_label_name):
'''
Wrap GetFragVertexOffset function to get vertex offset of a fragment.
'''
return pywrap.get_frag_vertex_offset(sock, fid, v_label_name)


def get_frag_vertex_num(sock, fid, v_label_name):
'''
Wrap GetFragVertexNum function to get vertex number of a fragment.
'''
return pywrap.get_frag_vertex_num(sock, fid, v_label_name)


class VineyardPartitionBook(PartitionBook):
def __init__(self, sock, obj_id, v_label_name, fid2pid: Dict=None):
self._sock = sock
self._obj_id = obj_id
self._v_label_name = v_label_name
self._frag = None
self._offset = get_frag_vertex_offset(sock, obj_id, v_label_name)
# TODO: optimise this query process if too slow
self._fid2pid = fid2pid

def __getitem__(self, gids) -> torch.Tensor:
fids = self.gid2fid(gids)
if self._fid2pid is not None:
pids = torch.tensor([self._fid2pid[fid] for fid in fids])
return pids.to(torch.int32)
return fids.to(torch.int32)

@property
def device(self):
return torch.device('cpu')

@property
def offset(self):
return self._offset

def gid2fid(self, gids):
'''
Parse gid to get fid
'''
if self._frag is None:
self._frag = pywrap.VineyardFragHandle(self._sock, self._obj_id)

fids = self._frag.get_fid_from_gid(gids.tolist())

return fids


class VineyardGid2Lid(Sequence):
def __init__(self, sock, fid, v_label_name):
self._offset = get_frag_vertex_offset(sock, fid, v_label_name)
self._vnum = get_frag_vertex_num(sock, fid, v_label_name)

def __getitem__(self, gids):
return gids - self._offset

def __len__(self):
return self._vnum


def v6d_id_select(srcs, p_mask, node_pb: PartitionBook):
gids = torch.masked_select(srcs, p_mask)
offsets = gids - node_pb.offset
return offsets

def v6d_id_filter(node_pb: VineyardPartitionBook, partition_idx):
frag = pywrap.VineyardFragHandle(node_pb._sock, node_pb._obj_id)
inner_vertices = frag.get_inner_vertices(node_pb._v_label_name)
return inner_vertices
55 changes: 25 additions & 30 deletions graphlearn_torch/python/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@
# ==============================================================================

from multiprocessing.reduction import ForkingPickler
from typing import Dict, List, Optional, Union, Literal, Tuple
from typing import Dict, List, Optional, Union, Literal, Tuple, Callable
from collections.abc import Sequence

import torch

from ..data import Dataset, Graph, Feature, DeviceGroup
from ..partition import load_partition, cat_feature_cache
from ..partition import (
load_partition, cat_feature_cache,
PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)
from ..typing import (
NodeType, EdgeType, TensorDataType, NodeLabel, NodeIndex,
PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)
from ..utils import share_memory

from ..utils import share_memory, default_id_filter


class DistDataset(Dataset):
Expand Down Expand Up @@ -175,6 +179,7 @@ def random_node_split(
self,
num_val: Union[float, int],
num_test: Union[float, int],
id_filter: Callable = default_id_filter,
):
r"""Performs a node-level random split by adding :obj:`train_idx`,
:obj:`val_idx` and :obj:`test_idx` attributes to the
Expand All @@ -196,10 +201,10 @@ def random_node_split(
test_idx = {}

for node_type, _ in self.node_labels.items():
indices = torch.where(self.node_pb[node_type] == self.partition_idx)[0]
indices = id_filter(self.node_pb[node_type], self.partition_idx)
train_idx[node_type], val_idx[node_type], test_idx[node_type] = random_split(indices, num_val, num_test)
else:
indices = torch.where(self.node_pb == self.partition_idx)[0]
indices = id_filter(self.node_pb, self.partition_idx)
train_idx, val_idx, test_idx = random_split(indices, num_val, num_test)
self.init_node_split((train_idx, val_idx, test_idx))

Expand All @@ -212,43 +217,33 @@ def load_vineyard(
node_features: Dict[NodeType, List[str]] = None,
edge_features: Dict[EdgeType, List[str]] = None,
node_labels: Dict[NodeType, str] = None,
id2idx: Dict[NodeType, Sequence] = None,
):
# TODO(hongyi): to support more than one partitions
super().load_vineyard(vineyard_id=vineyard_id, vineyard_socket=vineyard_socket,
edges=edges, edge_weights=edge_weights, node_features=node_features,
edge_features=edge_features, node_labels=node_labels,)
edge_features=edge_features, node_labels=node_labels, id2idx=id2idx)
if isinstance(self.graph, dict):
# hetero
self.node_pb = {}
self.edge_pb = {}
for etype, graph in self.graph.items():
self.node_pb[etype[0]] = torch.zeros(graph.row_count)
self.edge_pb[etype] = torch.zeros(graph.edge_count)

self._node_feat_pb = {}
if node_features:
for ntype, nfeat in self.node_features.items():
self._node_feat_pb[ntype] = torch.zeros(nfeat.shape[0])

self._edge_feat_pb = {}
if edge_features:
for etype, efeat in self.edge_features.items():
self._edge_feat_pb[etype] = torch.zeros(efeat.shape[0])
for ntype, _ in self.node_features.items():
if self.node_pb is not None:
self._node_feat_pb[ntype] = self.node_pb[ntype]
else:
self._node_feat_pb[ntype] = None
else:
# homo
self.node_pb = torch.zeros(self.graph.row_count)
self.edge_pb = torch.zeros(self.graph.edge_count)
if node_features:
self._node_feat_pb = torch.zeros(self.node_features.shape[0])
if edge_features:
self._edge_feat_pb = torch.zeros(self.edge_features.shape[0])
self._node_feat_pb = self.node_pb

def share_ipc(self):
super().share_ipc()
self.node_pb = share_memory(self.node_pb)
self.edge_pb = share_memory(self.edge_pb)
self._node_feat_pb = share_memory(self._node_feat_pb)
self._edge_feat_pb = share_memory(self._edge_feat_pb)
if isinstance(self.node_pb, torch.Tensor):
self.node_pb = share_memory(self.node_pb)
self.edge_pb = share_memory(self.edge_pb)
self._node_feat_pb = share_memory(self._node_feat_pb)
self._edge_feat_pb = share_memory(self._edge_feat_pb)
ipc_hanlde = (
self.num_partitions, self.partition_idx,
self.graph, self.node_features, self.edge_features, self.node_labels,
Expand Down Expand Up @@ -315,4 +310,4 @@ def random_split(
val_idx = indices[perm[:num_val]].clone()
test_idx = indices[perm[num_val:num_val + num_test]].clone()
train_idx = indices[perm[num_val + num_test:]].clone()
return train_idx, val_idx, test_idx
return train_idx, val_idx, test_idx
14 changes: 12 additions & 2 deletions graphlearn_torch/python/distributed/dist_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
from ..data import Feature
from ..typing import (
EdgeType, NodeType,
PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)
from ..partition import (
PartitionBook, GLTPartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)

from ..utils import get_available_device, ensure_device

from .rpc import (
Expand Down Expand Up @@ -87,7 +90,8 @@ def __init__(self,
if isinstance(self.local_feature, dict):
self.data_cls = 'hetero'
for _, feat in self.local_feature.items():
feat.lazy_init_with_ipc_handle()
if isinstance(feat, Feature):
feat.lazy_init_with_ipc_handle()
elif isinstance(self.local_feature, Feature):
self.data_cls = 'homo'
self.local_feature.lazy_init_with_ipc_handle()
Expand All @@ -97,8 +101,14 @@ def __init__(self,
self.feature_pb = feature_pb
if isinstance(self.feature_pb, dict):
assert self.data_cls == 'hetero'
for key, feat in self.feature_pb.items():
if not isinstance(feat, PartitionBook):
self.feature_pb[key] = GLTPartitionBook(feat)
elif isinstance(self.feature_pb, PartitionBook):
assert self.data_cls == 'homo'
elif isinstance(self.feature_pb, torch.Tensor):
self.feature_pb = GLTPartitionBook(self.feature_pb)
assert self.data_cls == 'homo'
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"patition book type '{type(self.feature_pb)}'")
Expand Down
Loading