Skip to content

Commit

Permalink
Merge pull request #59 from Xtra-Computing/layerwise
Browse files Browse the repository at this point in the history
Layerwise
  • Loading branch information
QinbinLi authored Jan 26, 2023
2 parents 6a9a873 + 82da2f3 commit 3701d11
Show file tree
Hide file tree
Showing 11 changed files with 482 additions and 69 deletions.
2 changes: 2 additions & 0 deletions include/FedTree/Tree/function_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class FunctionBuilder {
public:
virtual vector<Tree> build_approximate(const SyncArray<GHPair> &gradients, bool update_y_predict = true) = 0;

virtual vector<Tree> build_a_subtree_approximate(const SyncArray<GHPair> &gradients, int n_layer) = 0;

virtual Tree get_tree()= 0;

virtual void set_tree(Tree tree) = 0;
Expand Down
4 changes: 4 additions & 0 deletions include/FedTree/Tree/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ class GBDT {

void train(GBDTParam &param, DataSet &dataset);

void train_a_subtree(GBDTParam &param, DataSet &dataset, int n_layer, int *id_list, int *nins_list, float *gradient_g_list, float *gradient_h_list, int *n_node, int *nodeid_list, float *input_gradient_g, float *input_gradient_h);

vector<float_type> predict(const GBDTParam &model_param, const DataSet &dataSet);

vector<float_type> predict(const GBDTParam &model_param, const vector<DataSet> &dataSet);

void predict_raw(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict);

void predict_leaf(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, int *ins2leaf);

void predict_raw_vertical(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, std::map<int, vector<int>> &batch_idxs);

void predict_raw_vertical(const GBDTParam &model_param, const vector<DataSet> &dataSet, SyncArray<float_type> &y_predict);
Expand Down
2 changes: 2 additions & 0 deletions include/FedTree/Tree/tree_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class TreeBuilder : public FunctionBuilder{

vector<Tree> build_approximate(const SyncArray<GHPair> &gradients, bool update_y_predict = true) override;

vector<Tree> build_a_subtree_approximate(const SyncArray<GHPair> &gradients, int n_layer) override;

void build_tree_by_predefined_structure(const SyncArray<GHPair> &gradients, vector<Tree> &trees);

void build_init(const GHPair sum_gh, int k) override;
Expand Down
3 changes: 3 additions & 0 deletions include/FedTree/booster.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Booster {

void boost(vector<vector<Tree>> &boosted_model);

void boost_a_subtree(vector<vector<Tree>> &trees, int n_layer, int *id_list, int *nins_list, float *gradient_g_list,
float *gradient_h_list, int *n_node, int *nodeid_list, float *input_gradient_g, float *input_gradient_h);

void boost_without_prediction(vector<vector<Tree>> &boosted_model);

GBDTParam param;
Expand Down
117 changes: 116 additions & 1 deletion python/fedtree/fedtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def fit(self, X, y, groups=None):
self.model = None
sparse = sp.issparse(X)
if sparse is False:
# potential bug: csr_matrix ignores all zero values in X
X = sp.csr_matrix(X)
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')

Expand Down Expand Up @@ -182,6 +183,52 @@ def predict(self, X, groups=None):
predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]
self.predict_label = np.asarray(predict_label)
return self.predict_label

def predict_leaf(self, X, groups=None):
if self.model is None:
print("Please train the model first or load model from file!")
raise ValueError
sparse = sp.isspmatrix(X)
if sparse is False:
X = sp.csr_matrix(X)
X.data = np.asarray(X.data, dtype=np.float32, order='C')
X.sort_indices()
data = X.data.ctypes.data_as(POINTER(c_float))
indices = X.indices.ctypes.data_as(POINTER(c_int32))
indptr = X.indptr.ctypes.data_as(POINTER(c_int32))
if(self.objective != 'multi:softprob'):
self.predict_label_ptr = (c_float * X.shape[0])()
else:
temp_size = X.shape[0] * self.num_class
self.predict_label_ptr = (c_float * temp_size)()
if self.group_label is not None:
group_label = (c_float * len(self.group_label))()
group_label[:] = self.group_label
else:
group_label = None
in_groups, num_groups = self._construct_groups(groups)
ins2leaf_c = (c_int32 * (X.shape[0] * self.n_trees))()
fedtree.predict_leaf(
X.shape[0],
data,
indptr,
indices,
self.predict_label_ptr,
byref(self.model),
self.n_trees,
self.tree_per_iter,
self.objective.encode('utf-8'),
self.num_class,
c_float(self.learning_rate),
group_label,
in_groups,
ins2leaf_c,
num_groups, self.verbose, self.bagging,
)
self.ins2leaf = np.array([ins2leaf_c[i] for i in range(X.shape[0] * self.n_trees)])
# predict_label = [self.predict_label_ptr[index] for index in range(0, X.shape[0])]
# self.predict_label = np.asarray(predict_label)
return self.ins2leaf

def predict_proba(self, X, groups=None):
if self.model is None:
Expand Down Expand Up @@ -235,7 +282,6 @@ def predict_proba(self, X, groups=None):
return self.predict_proba



def save_model(self, model_path):
if self.model is None:
print("Please train the model first or load model from file!")
Expand Down Expand Up @@ -350,6 +396,75 @@ def cv(self, X, y, folds=None, nfold=5, shuffle=True, seed=0):
print("mean test RMSE:%.6f+%.6f" %(statistics.mean(test_score_list), statistics.stdev(test_score_list)))
return self.eval_res

def centralize_train_a_subtree(self, X, y, n_layer, input_gradient_g = None, input_gradient_h = None, groups=None):
n_ins = len(X)
if self.model is not None:
fedtree.model_free(byref(self.model))
self.model = None
sparse = sp.issparse(X)
if sparse is False:
# potential bug: csr_matrix ignores all zero values in X
X = sp.csr_matrix(X)
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')

X.data = np.asarray(X.data, dtype=np.float32, order='C')
X.sort_indices()
data = X.data.ctypes.data_as(POINTER(c_float))
indices = X.indices.ctypes.data_as(POINTER(c_int32))
indptr = X.indptr.ctypes.data_as(POINTER(c_int32))
y = np.asarray(y, dtype=np.float32, order='C')
label = y.ctypes.data_as(POINTER(c_float))
in_groups, num_groups = self._construct_groups(groups)
group_label = (c_float * len(set(y)))()
n_class = (c_int * 1)()
n_class[0] = self.num_class
tree_per_iter_ptr = (c_int * 1)()
self.model = (c_long * 1)()
n_max_node = pow(2, n_layer)
# needs to represent instance ID as int
insid_list = (c_int * n_ins)()
n_ins_list = (c_int * n_max_node)()
gradient_g_list = (c_float * n_ins)()
gradient_h_list = (c_float * n_ins)()
n_node = (c_int * 1)()
nodeid_list = (c_int * n_max_node)()
input_gradient_g = np.asarray(input_gradient_g, dtype=np.float32, order='C')
input_g = input_gradient_g.ctypes.data_as(POINTER(c_float))
input_gradient_h = np.asarray(input_gradient_h, dtype=np.float32, order='C')
input_h = input_gradient_h.ctypes.data_as(POINTER(c_float))
fedtree.centralize_train_a_subtree(c_float(self.variance), c_float(self.privacy_budget),
self.max_depth, self.n_trees, c_float(self.min_child_weight), c_float(self.lambda_ft), c_float(self.gamma), c_float(self.column_sampling_rate),
self.verbose, self.bagging, self.n_parallel_trees, c_float(self.learning_rate), self.objective.encode('utf-8'), n_class, self.n_device, self.max_num_bin,
self.seed, c_float(self.ins_bagging_fraction), self.reorder_label, c_float(self.constant_h),
X.shape[0], data, indptr, indices, label, self.tree_method, byref(self.model), tree_per_iter_ptr, group_label,
in_groups, num_groups, n_layer, insid_list, n_ins_list, gradient_g_list, gradient_h_list, n_node, nodeid_list, input_g, input_h)
self.num_class = n_class[0]
self.tree_per_iter = tree_per_iter_ptr[0]
self.group_label = [group_label[idx] for idx in range(len(set(y)))]

self.insid_list = [insid_list[i] for i in range(n_ins)]
self.n_ins_list = [n_ins_list[i] for i in range(n_node[0])]
self.gradient_g_list = [gradient_g_list[i] for i in range(n_ins)]
self.gradient_h_list = [gradient_h_list[i] for i in range(n_ins)]
self.n_node = n_node[0]
self.nodeid_list = [nodeid_list[i] for i in range(n_node[0])]
if self.model is None:
print("The model returned is empty!")
exit()

return self

def update_a_layer_cpp(self, X, ins, nins, gradient_g, gradient_h, n_node, lamb):
c_x = np.asarray(X, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
c_ins = np.asarray(ins, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
c_nins = np.asarray(nins, dtype=np.int32).data.ctypes.data_as(POINTER(c_int32))
c_gradient_g = np.asarray(gradient_g, dtype=np.float32).data.ctypes.data_as(POINTER(c_float))
c_gradient_h = np.asarray(gradient_h, dtype=np.float32).data.ctypes.data_as(POINTER(c_float))
leaf_val = (c_float * (n_node*2))()
fedtree.update_a_layer_with_flag(c_x, c_ins, c_nins, c_gradient_g, c_gradient_h, n_node, leaf_val)
self.leaf_val = [leaf_val[i] for i in range(len(n_node*2))]


class FLClassifier(FLModel, fedtreeClassifierBase):
_impl = 'classifier'

Expand Down
5 changes: 4 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
if not path.exists(path.join(dirname, "fedtree", path.basename(lib_path))):
copyfile(lib_path, path.join(dirname, "fedtree", path.basename(lib_path)))

# lib_path = "./fedtree/libFedTree.so"


setuptools.setup(name="fedtree",
version="1.0.4",
version="1.0.5",
packages=["fedtree"],
package_dir={"python": "fedtree"},
description="A federated learning library for trees",
Expand Down
132 changes: 129 additions & 3 deletions src/FedTree/Tree/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,50 @@ void GBDT::train(GBDTParam &param, DataSet &dataset) {
// float_type score = predict_score(param, dataset);
// LOG(INFO) << score;

auto stop = timer.now();
std::chrono::duration<float> training_time = stop - start;
LOG(INFO) << "training time = " << training_time.count();
return;
}

void GBDT::train_a_subtree(GBDTParam &param, DataSet &dataset, int n_layer, int *id_list, int *nins_list, float *gradient_g_list,
float *gradient_h_list, int *n_node, int *node_id_list, float *input_gradient_g, float *input_gradient_h) {
if (param.tree_method == "auto")
param.tree_method = "hist";
else if (param.tree_method != "hist") {
std::cout << "FedTree only supports histogram-based training yet";
exit(1);
}

if (param.objective.find("multi:") != std::string::npos || param.objective.find("binary:") != std::string::npos) {
int num_class = dataset.label.size();
if (param.num_class != num_class) {
LOG(INFO) << "updating number of classes from " << param.num_class << " to " << num_class;
param.num_class = num_class;
}
if (param.num_class > 2)
param.tree_per_round = param.num_class;
} else if (param.objective.find("reg:") != std::string::npos) {
param.num_class = 1;
}

Booster booster;
booster.init(dataset, param);
std::chrono::high_resolution_clock timer;
auto start = timer.now();
std::cout<<"start boost a subtree"<<std::endl;
booster.boost_a_subtree(trees, n_layer, id_list, nins_list, gradient_g_list, gradient_h_list, n_node, node_id_list, input_gradient_g, input_gradient_h);
//booster.boost(trees);
// float_type score = predict_score(param, dataset);
// LOG(INFO) << score;

auto stop = timer.now();
std::chrono::duration<float> training_time = stop - start;
LOG(INFO) << "training time = " << training_time.count();
return;
}


vector<float_type> GBDT::predict(const GBDTParam &model_param, const DataSet &dataSet) {
SyncArray<float_type> y_predict;
predict_raw(model_param, dataSet, y_predict);
Expand Down Expand Up @@ -157,9 +194,7 @@ void GBDT::predict_raw(const GBDTParam &model_param, const DataSet &dataSet, Syn
int num_node = trees[0][0].nodes.size();

int total_num_node = num_iter * num_class * num_node;
//TODO: reduce the output size for binary classification
y_predict.resize(n_instances * num_class);

SyncArray<Tree::TreeNode> model(total_num_node);
auto model_data = model.host_data();
int tree_cnt = 0;
Expand Down Expand Up @@ -462,4 +497,95 @@ void GBDT::predict_raw_vertical(const GBDTParam &model_param, const vector<DataS
predict_data_class[iid] += sum;
}//end all tree prediction
}
}
}

void GBDT::predict_leaf(const GBDTParam &model_param, const DataSet &dataSet, SyncArray<float_type> &y_predict, int *ins2leaf) {
TIMED_SCOPE(timerObj, "predict");
int n_instances = dataSet.n_instances();
// int n_features = dataSet.n_features();

//the whole model to an array
int num_iter = trees.size();
int num_class = trees.front().size();
int num_node = trees[0][0].nodes.size();

int total_num_node = num_iter * num_class * num_node;
// y_predict.resize(n_instances * num_class);
std::cout<<"num_class in predict_raw:"<<num_class<<std::endl;
SyncArray<Tree::TreeNode> model(total_num_node);
auto model_data = model.host_data();
int tree_cnt = 0;
for (auto &vtree:trees) {
for (auto &t:vtree) {
memcpy(model_data + num_node * tree_cnt, t.nodes.host_data(), sizeof(Tree::TreeNode) * num_node);
tree_cnt++;
}
}

PERFORMANCE_CHECKPOINT_WITH_ID(timerObj, "init trees");

//do prediction
auto model_host_data = model.host_data();
// auto predict_data = y_predict.host_data();
auto csr_col_idx_data = dataSet.csr_col_idx.data();
auto csr_val_data = dataSet.csr_val.data();
auto csr_row_ptr_data = dataSet.csr_row_ptr.data();
auto lr = model_param.learning_rate;
PERFORMANCE_CHECKPOINT_WITH_ID(timerObj, "copy data");

#pragma omp parallel for
for (int iid = 0; iid < n_instances; iid++) {
auto get_next_child = [&](Tree::TreeNode node, float_type feaValue) {
//return feaValue < node.split_value ? node.lch_index : node.rch_index;
return (feaValue - node.split_value) >= -1e-6 ? node.rch_index : node.lch_index;
};
auto get_val = [&](const int *row_idx, const float_type *row_val, int row_len, int idx,
bool *is_missing) -> float_type {
//binary search to get feature value
const int *left = row_idx;
const int *right = row_idx + row_len;

while (left != right) {
const int *mid = left + (right - left) / 2;
if (*mid == idx) {
*is_missing = false;
return row_val[mid - row_idx];
}
if (*mid > idx)
right = mid;
else left = mid + 1;
}
*is_missing = true;
return 0;
};
const int *col_idx = csr_col_idx_data + csr_row_ptr_data[iid];
const float_type *row_val = csr_val_data + csr_row_ptr_data[iid];
int row_len = csr_row_ptr_data[iid + 1] - csr_row_ptr_data[iid];
for (int t = 0; t < num_class; t++) {
// auto predict_data_class = predict_data + t * n_instances;
// float_type sum = 0;
for (int iter = 0; iter < num_iter; iter++) {
const Tree::TreeNode *node_data = model_host_data + iter * num_class * num_node + t * num_node;
Tree::TreeNode curNode = node_data[0];
int cur_nid = 0; //node id
while (!curNode.is_leaf) {
int fid = curNode.split_feature_id;
bool is_missing;
float_type fval = get_val(col_idx, row_val, row_len, fid, &is_missing);
if (!is_missing)
cur_nid = get_next_child(curNode, fval);
else if (curNode.default_right)
cur_nid = curNode.rch_index;
else
cur_nid = curNode.lch_index;

curNode = node_data[cur_nid];
}
ins2leaf[iter * n_instances + iid] = cur_nid;
// sum += lr * curNode.base_weight;
}
// if (model_param.bagging)
// sum /= num_iter;
}//end all tree prediction
}
}
Loading

0 comments on commit 3701d11

Please sign in to comment.