Skip to content

Commit

Permalink
[gbdt] enhance error handling for forced splits file loading
Browse files Browse the repository at this point in the history
  • Loading branch information
KYash03 committed Feb 17, 2025
1 parent d02a01a commit 05430e5
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,19 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
// load forced_splits file
if (!config->forcedsplits_filename.empty()) {
std::ifstream forced_splits_file(config->forcedsplits_filename.c_str());
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), &err);
if (!forced_splits_file.good()) {
Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.",
config->forcedsplits_filename.c_str());
} else {
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), &err);
if (!err.empty()) {
Log::Fatal("Failed to parse forced splits file '%s': %s",
config->forcedsplits_filename.c_str(), err.c_str());
}
}
}

objective_function_ = objective_function;
Expand Down Expand Up @@ -823,13 +832,23 @@ void GBDT::ResetConfig(const Config* config) {
if (config_.get() != nullptr && config_->forcedsplits_filename != new_config->forcedsplits_filename) {
// load forced_splits file
if (!new_config->forcedsplits_filename.empty()) {
std::ifstream forced_splits_file(
new_config->forcedsplits_filename.c_str());
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), &err);
tree_learner_->SetForcedSplit(&forced_splits_json_);
std::ifstream forced_splits_file(new_config->forcedsplits_filename.c_str());
if (!forced_splits_file.good()) {
Log::Warning("Forced splits file '%s' does not exist. Forced splits will be ignored.",
new_config->forcedsplits_filename.c_str());
forced_splits_json_ = Json();
tree_learner_->SetForcedSplit(nullptr);
} else {
std::stringstream buffer;
buffer << forced_splits_file.rdbuf();
std::string err;
forced_splits_json_ = Json::parse(buffer.str(), &err);
if (!err.empty()) {
Log::Fatal("Failed to parse forced splits file '%s': %s",
new_config->forcedsplits_filename.c_str(), err.c_str());
}
tree_learner_->SetForcedSplit(&forced_splits_json_);
}
} else {
forced_splits_json_ = Json();
tree_learner_->SetForcedSplit(nullptr);
Expand Down

0 comments on commit 05430e5

Please sign in to comment.