diff --git a/graphlearn_torch/python/distributed/dist_server.py b/graphlearn_torch/python/distributed/dist_server.py index 283893c4..264e650d 100644 --- a/graphlearn_torch/python/distributed/dist_server.py +++ b/graphlearn_torch/python/distributed/dist_server.py @@ -20,7 +20,7 @@ import warnings import torch -from graphscope.learning.graphlearn_torch.partition.base import PartitionBook +from ..partition import PartitionBook from ..channel import ShmChannel, QueueTimeoutError from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig, RemoteSamplerInput @@ -95,7 +95,7 @@ def get_node_partition_id(self, node_type, index): def get_node_feature(self, node_type, index): feature = self.dataset.get_node_feature(node_type) - return feature[index] + return feature[index].cpu() def get_tensor_size(self, node_type): feature = self.dataset.get_node_feature(node_type)