Skip to content

Commit

Permalink
feat:try to add reverse_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
jessicawwen committed Jan 1, 2025
1 parent 3097f32 commit d77eb80
Show file tree
Hide file tree
Showing 7 changed files with 599 additions and 271 deletions.
266 changes: 266 additions & 0 deletions cpp_easygraph/classes/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Graph::Graph() {
this->dirty_adj = true;
this->linkgraph_dirty = true;
this->csr_graph = nullptr;
this->in_csr_graph = nullptr;
this->node_to_id = py::dict();
this->id_to_node = py::dict();
this->graph = py::dict();
Expand Down Expand Up @@ -715,6 +716,81 @@ void Graph::drop_cache() {
csr_graph = nullptr;
}

std::shared_ptr<CSRGraph> Graph::gen_CSR_fast(const std::string& weight_key) {
// Step 2: 收集所有节点 ID 并排序
std::vector<node_t> node_vec;
for (const auto& item : this->node_to_id) {
node_vec.push_back(item.second.cast<node_t>());
}
std::sort(node_vec.begin(), node_vec.end());

// Step 3: 遍历邻接表,构造 edges
int num_nodes = node_vec.size();
std::vector<int> rowPtrOut(num_nodes + 1, 0);
std::vector<int> rowPtrIn(num_nodes + 1, 0);
std::vector<std::tuple<int, int, double>> edges;

for (const auto& adj_item : this->adj) {
node_t u_id = adj_item.first;
for (const auto& neighbor : adj_item.second) {
node_t v_id = neighbor.first;
const auto& edge_attrs = neighbor.second;

double weight = 1.0; // 默认权重
if (edge_attrs.find(weight_key) != edge_attrs.end()) {
weight = edge_attrs.at(weight_key);
}

edges.emplace_back(u_id, v_id, weight);
rowPtrOut[u_id + 1]++;
rowPtrIn[v_id + 1]++;
}
}

// Step 4: 计算 rowPtr 的前缀和
for (int i = 1; i <= num_nodes; i++) {
rowPtrOut[i] += rowPtrOut[i - 1];
rowPtrIn[i] += rowPtrIn[i - 1];
}

// Step 5: 填充 colIdx 和 val
std::vector<int> colIdxOut(edges.size());
std::vector<double> valOut(edges.size());
std::vector<int> colIdxIn(edges.size());
std::vector<double> valIn(edges.size());
std::vector<int> offsetOut(num_nodes, 0);
std::vector<int> offsetIn(num_nodes, 0);

for (const auto& [u_id, v_id, weight] : edges) {
int posOut = rowPtrOut[u_id] + offsetOut[u_id];
colIdxOut[posOut] = v_id;
valOut[posOut] = weight;
offsetOut[u_id]++;

int posIn = rowPtrIn[v_id] + offsetIn[v_id];
colIdxIn[posIn] = u_id;
valIn[posIn] = weight;
offsetIn[v_id]++;
}

// Step 6: 保存到 CSRGraph 对象
this->csr_graph = std::make_shared<CSRGraph>();
auto& csr = *this->csr_graph;
csr.nodes = node_vec;
csr.V = rowPtrOut;
csr.E = colIdxOut;
csr.unweighted_W = valOut;

this->in_csr_graph = std::make_shared<CSRGraph>();
auto& in_csr = *this->in_csr_graph;
in_csr.nodes = node_vec;
in_csr.V = rowPtrIn;
in_csr.E = colIdxIn;
in_csr.unweighted_W = valIn;

return this->csr_graph;
}

std::shared_ptr<CSRGraph> Graph::gen_CSR(const std::string& weight) {
if (csr_graph != nullptr) {
if (csr_graph->W_map.find(weight) == csr_graph->W_map.end()) {
Expand Down Expand Up @@ -849,6 +925,196 @@ std::shared_ptr<std::vector<int>> Graph::gen_CSR_sources(const py::object& py_so
return sources;
}

std::shared_ptr<CSRGraph> Graph::gen_reverse_CSR(const std::string& weight) {
// 如果 in_csr_graph 已经存在,检查有没有对应 weight 的权重
if (in_csr_graph != nullptr) {
// 如果没有该字段的权重,则需要补一份
if (in_csr_graph->W_map.find(weight) == in_csr_graph->W_map.end()) {
// 构造一个新的权重向量
auto W = std::make_shared<std::vector<double>>();
W->reserve(in_csr_graph->E.size());

// 反向 CSR 的节点顺序、V、E 在 in_csr_graph 中已经确定
// 这里只需要根据 E 来补上权重
// E[i] 表示从第 u = in_csr_graph->idx2node[...] 个节点
// 指向 E[i] 对应的节点,这里是反向边,所以在原图中是 E[i] -> u
//
// 我们可以通过 V 来找出每个节点对应的边区间
const auto& V = in_csr_graph->V;
const auto& E = in_csr_graph->E;
const auto& idx2node = in_csr_graph->nodes; // 排过序的节点列表
const auto& node2idx = in_csr_graph->node2idx; // 节点到索引

// 遍历每个节点在 in_csr_graph 的所有反向边
for (size_t u_idx = 0; u_idx < idx2node.size(); ++u_idx) {
int start = V[u_idx];
int end = V[u_idx + 1];
node_t u_node = idx2node[u_idx]; // 反向图中的“目标”节点

for (int edge_pos = start; edge_pos < end; ++edge_pos) {
int v_idx = E[edge_pos];
// v_idx 对应的是原图中的“源”节点
node_t v_node = idx2node[v_idx];

// 在原图中 v_node -> u_node
// 查找它的 edge_attr
const auto& v_adjs = adj.find(v_node)->second;
const auto& edge_attr = v_adjs.find(u_node)->second;

auto edge_it = edge_attr.find(weight);
double w = (edge_it != edge_attr.end()) ? edge_it->second : 1.0;

W->push_back(w);
}
}

// 存到 in_csr_graph->W_map
in_csr_graph->W_map[weight] = W;
}
}
else {
// 如果从未构造过 in_csr_graph 或者图数据修改过,就需要重新构造
in_csr_graph = std::make_shared<CSRGraph>();

// 1. 收集并排序所有节点
auto& nodes = in_csr_graph->nodes; // 用于存储排好序的节点
nodes.reserve(node.size());
for (auto it = node.begin(); it != node.end(); ++it) {
nodes.push_back(it->first);
}
std::sort(nodes.begin(), nodes.end());

// 2. 建立 node2idx
auto& node2idx = in_csr_graph->node2idx;
for (int i = 0; i < (int)nodes.size(); ++i) {
node2idx[nodes[i]] = i;
}

// 3. 构造【反向】邻接表(临时结构),以便后续统一写入 V、E、W
// 这里我们需要把 adj 中的 (src -> dst) 翻转成 (dst -> src)
std::unordered_map<node_t, std::unordered_map<node_t, edge_attr_dict_factory>> rev_adj;
for (auto& kv : adj) {
node_t src = kv.first;
for (const auto& kv : adj) {
node_t src = kv.first;
for (const auto& kv2 : kv.second) {
node_t dst = kv2.first;
// kv2.second 是 edge_attr_dict_factory 类型
rev_adj[dst][src] = kv2.second;
}
}
}

// 4. 分配 V、E、W 并写入
auto& V = in_csr_graph->V;
auto& E = in_csr_graph->E;
auto W = std::make_shared<std::vector<double>>();

V.reserve(nodes.size() + 1);

// 对每个节点(按排序后顺序)填充 V, E, W
for (int i = 0; i < (int)nodes.size(); ++i) {
node_t cur_node = nodes[i];
// 记录当前 E 的起始位置
V.push_back((int)E.size());

// 查看反向邻接 rev_adj[cur_node],即有哪些节点指向 cur_node
auto rev_it = rev_adj.find(cur_node);
if (rev_it != rev_adj.end()) {
const auto& neighbors = rev_it->second; // map<node_t, edge_attr_dict_factory>
for (auto& nb_kv : neighbors) {
node_t src_node = nb_kv.first;
const auto& edge_attr = nb_kv.second;

// src_node -> cur_node 在原图中
// 在反向 CSR 中,我们在“cur_node”这行添加一个从 cur_node -> src_node 的记录
E.push_back(node2idx[src_node]);

// 权重处理
auto edge_it = edge_attr.find(weight);
double w_val = (edge_it != edge_attr.end()) ? edge_it->second : 1.0;
W->push_back(w_val);
}
}
}

// 最后一个节点的“结束位置”,相当于 E.size()
V.push_back((int)E.size());

// 存储权重向量
in_csr_graph->W_map[weight] = W;
}

return in_csr_graph;
}

// 生成无权反向 CSR
std::shared_ptr<CSRGraph> Graph::gen_reverse_CSR() {
// 如果已经有 in_csr_graph,则只需要检查 unweighted_W 是否跟 E.size() 对齐
if (in_csr_graph != nullptr) {
if (in_csr_graph->unweighted_W.size() != in_csr_graph->E.size()) {
in_csr_graph->unweighted_W = std::vector<double>(in_csr_graph->E.size(), 1.0);
}
}
else {
// 和上面同理,重新构造
in_csr_graph = std::make_shared<CSRGraph>();

// 1. 收集并排序节点
auto& nodes = in_csr_graph->nodes;
nodes.reserve(node.size());
for (auto it = node.begin(); it != node.end(); ++it) {
nodes.push_back(it->first);
}
std::sort(nodes.begin(), nodes.end());

// 2. 建立 node2idx
auto& node2idx = in_csr_graph->node2idx;
for (int i = 0; i < (int)nodes.size(); ++i) {
node2idx[nodes[i]] = i;
}

// 3. 构造反向邻接表
std::unordered_map<node_t, std::unordered_map<node_t, edge_attr_dict_factory>> rev_adj;
for (auto& kv : adj) {
node_t src = kv.first;
for (const auto& kv : adj) {
node_t src = kv.first;
for (const auto& kv2 : kv.second) {
node_t dst = kv2.first;
// kv2.second 是 edge_attr_dict_factory 类型
rev_adj[dst][src] = kv2.second;
}
}
}

// 4. 写入 V、E
auto& V = in_csr_graph->V;
auto& E = in_csr_graph->E;
V.reserve(nodes.size() + 1);

for (int i = 0; i < (int)nodes.size(); ++i) {
node_t cur_node = nodes[i];
V.push_back((int)E.size());

auto rev_it = rev_adj.find(cur_node);
if (rev_it != rev_adj.end()) {
const auto& neighbors = rev_it->second;
for (auto& nb_kv : neighbors) {
node_t src_node = nb_kv.first;
E.push_back(node2idx[src_node]);
}
}
}
V.push_back((int)E.size());

// 无权时,统一填 1.0
in_csr_graph->unweighted_W = std::vector<double>(E.size(), 1.0);
}

return in_csr_graph;
}

std::shared_ptr<COOGraph> Graph::gen_COO() {
if (coo_graph != nullptr) {
if (coo_graph->unweighted_W.size() != coo_graph->row.size()) {
Expand Down
4 changes: 4 additions & 0 deletions cpp_easygraph/classes/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ struct Graph {
adj_dict_factory adj;
Graph_L linkgraph_structure;
std::shared_ptr<CSRGraph> csr_graph;
std::shared_ptr<CSRGraph> in_csr_graph;
py::kwargs node_to_id, id_to_node, graph;
node_t id;
bool dirty_nodes, dirty_adj, linkgraph_dirty;
Expand All @@ -32,10 +33,13 @@ struct Graph {

std::shared_ptr<CSRGraph> gen_CSR(const std::string& weight);
std::shared_ptr<CSRGraph> gen_CSR();
std::shared_ptr<CSRGraph> gen_reverse_CSR(const std::string& weight);
std::shared_ptr<CSRGraph> gen_reverse_CSR();
std::shared_ptr<std::vector<int>> gen_CSR_sources(const py::object& py_sources);
std::shared_ptr<COOGraph> gen_COO();
std::shared_ptr<COOGraph> gen_COO(const std::string& weight);
std::shared_ptr<COOGraph> transfer_csr_to_coo(const std::shared_ptr<CSRGraph>& csr_graph);
std::shared_ptr<CSRGraph> gen_CSR_fast(const std::string& weight_key);
};

py::object Graph__init__(py::args args, py::kwargs kwargs);
Expand Down
36 changes: 22 additions & 14 deletions cpp_easygraph/functions/structural_holes/evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,27 @@ static py::object invoke_gpu_constraint(py::object G, py::object nodes, py::obje
Graph& G_ = G.cast<Graph&>();
if (weight.is_none()) {
G_.gen_CSR();
G_.gen_reverse_CSR();
} else {
G_.gen_CSR(weight_to_string(weight));
}
auto csr_graph = G_.csr_graph;
auto coo_graph = G_.transfer_csr_to_coo(csr_graph);
std::vector<int>& V = csr_graph->V;
std::vector<int>& E = csr_graph->E;
std::vector<int>& row = coo_graph->row;
std::vector<int>& col = coo_graph->col;
std::vector<double> *W_p = weight.is_none() ? &(coo_graph->unweighted_W)
: coo_graph->W_map.find(weight_to_string(weight))->second.get();
std::unordered_map<node_t, int>& node2idx = coo_graph->node2idx;
int num_nodes = coo_graph->node2idx.size();
G_.gen_reverse_CSR(weight_to_string(weight));
}
// G_.gen_CSR_fast(weight_to_string(weight));
auto out_csr_graph = G_.csr_graph;
auto in_csr_graph = G_.in_csr_graph;
std::vector<int>& rowPtrOut = out_csr_graph->V;
std::vector<int>& colIdxOut = out_csr_graph->E;
std::vector<double> *valOut = weight.is_none() ? &(out_csr_graph->unweighted_W)
: out_csr_graph->W_map.find(weight_to_string(weight))->second.get();
std::vector<int>& rowPtrIn = in_csr_graph->V;

std::vector<int>& colIdxIn = in_csr_graph->E;
std::vector<double> *valIn = weight.is_none() ? &(in_csr_graph->unweighted_W)
: in_csr_graph->W_map.find(weight_to_string(weight))->second.get();
std::unordered_map<node_t, int>& node2idx = out_csr_graph->node2idx;
int num_nodes = out_csr_graph->node2idx.size();
bool is_directed = G.attr("is_directed")().cast<bool>();
std::vector<double> constraint_results(num_nodes, 0.0);

std::vector<int> node_mask(num_nodes, 0);
py::list nodes_list;
if (!nodes.is_none()) {
Expand All @@ -240,8 +245,11 @@ static py::object invoke_gpu_constraint(py::object G, py::object nodes, py::obje
nodes_list = py::list(G.attr("nodes"));
std::fill(node_mask.begin(), node_mask.end(), 1);
}

int gpu_r = gpu_easygraph::constraint(V, E, row, col, num_nodes, *W_p, is_directed, node_mask, constraint_results);

int gpu_r = gpu_easygraph::constraint(num_nodes,
rowPtrOut, colIdxOut, *valOut,
rowPtrIn, colIdxIn, *valIn,
is_directed, node_mask, constraint_results);
if (gpu_r != gpu_easygraph::EG_GPU_SUCC) {
py::pybind11_fail(gpu_easygraph::err_code_detail(gpu_r));
}
Expand Down
20 changes: 11 additions & 9 deletions gpu_easygraph/functions/structural_holes/constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ namespace gpu_easygraph {
using std::vector;

int constraint(
_IN_ const vector<int>& V,
_IN_ const vector<int>& E,
_IN_ const vector<int>& row,
_IN_ const vector<int>& col,
_IN_ int num_nodes,
_IN_ const vector<double>& W,
_IN_ const std::vector<int>& rowPtrOut,
_IN_ const std::vector<int>& colIdxOut,
_IN_ const std::vector<double>& valOut,
_IN_ const std::vector<int>& rowPtrIn,
_IN_ const std::vector<int>& colIdxIn,
_IN_ const std::vector<double>& valIn,
_IN_ bool is_directed,
_IN_ vector<int>& node_mask,
_OUT_ vector<double>& constraint
_OUT_ vector<double>& constraints
) {
int num_edges = row.size();
int len_rowPtrOut = rowPtrOut.size();
int len_colIdxOut = colIdxOut.size();

constraint = vector<double>(num_nodes);
int r = cuda_constraint(V.data(), E.data(), row.data(), col.data(), W.data(), num_nodes, num_edges, is_directed, node_mask.data(), constraint.data());
constraints = vector<double>(num_nodes);
int r = cuda_constraint(num_nodes, len_rowPtrOut, len_colIdxOut, rowPtrOut.data(), colIdxOut.data(), valOut.data(), rowPtrIn.data(), colIdxIn.data(), valIn.data(), is_directed, node_mask.data(), constraints.data());

return r;
}
Expand Down
Loading

0 comments on commit d77eb80

Please sign in to comment.