diff --git a/yggdrasil_decision_forests/learner/decision_tree/BUILD b/yggdrasil_decision_forests/learner/decision_tree/BUILD index 05e0255d..e3e5dbf1 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/BUILD +++ b/yggdrasil_decision_forests/learner/decision_tree/BUILD @@ -251,6 +251,7 @@ cc_test( "//yggdrasil_decision_forests/utils:test", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", ], diff --git a/yggdrasil_decision_forests/learner/decision_tree/decision_tree_test.cc b/yggdrasil_decision_forests/learner/decision_tree/decision_tree_test.cc index 1edc8a5d..e0a05f4d 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/decision_tree_test.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/decision_tree_test.cc @@ -31,6 +31,7 @@ #include "gtest/gtest.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "yggdrasil_decision_forests/dataset/data_spec.h" #include "yggdrasil_decision_forests/dataset/data_spec.pb.h" @@ -81,20 +82,17 @@ struct FakeLabelStats : LabelStats {}; // A fake consumer that persistently fails to find a valid attribute. SplitterWorkResponse FakeFindBestConditionConcurrentConsumerAlwaysInvalid( SplitterWorkRequest request) { - return SplitterWorkResponse{ - .manager_data = request.manager_data, - .status = SplitSearchResult::kInvalidAttribute, - }; + return SplitterWorkResponse(request.manager_data, + SplitSearchResult::kInvalidAttribute, {}); } // A fake consumer that sets the split score to 10 times the request index. SplitterWorkResponse FakeFindBestConditionConcurrentConsumerMultiplicative( SplitterWorkRequest request) { - SplitterWorkResponse response{ - .manager_data = request.manager_data, - .status = SplitSearchResult::kBetterSplitFound, - }; - request.condition->set_split_score(request.attribute_idx * 10.f); + SplitterWorkResponse response(request.manager_data, + SplitSearchResult::kBetterSplitFound, + absl::make_unique()); + response.condition->set_split_score(request.attribute_idx * 10.f); return response; } @@ -102,14 +100,13 @@ SplitterWorkResponse FakeFindBestConditionConcurrentConsumerMultiplicative( // 10 times the attribute_idx otherwise. SplitterWorkResponse FakeFindBestConditionConcurrentConsumerAlternate( SplitterWorkRequest request) { - auto response = SplitterWorkResponse{ - .manager_data = request.manager_data, - .status = SplitSearchResult::kBetterSplitFound, - }; + auto response = SplitterWorkResponse( + request.manager_data, SplitSearchResult::kBetterSplitFound, + absl::make_unique()); if (request.attribute_idx % 2 == 0) { response.status = SplitSearchResult::kInvalidAttribute; } - request.condition->set_split_score(request.attribute_idx * 10.f); + response.condition->set_split_score(request.attribute_idx * 10.f); return response; } @@ -2128,7 +2125,6 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_AlwaysInvalid) { EXPECT_EQ(cache.splitter_cache_list.size(), 2); EXPECT_EQ(cache.durable_response_list.size(), 20); - EXPECT_EQ(cache.condition_list.size(), 4); EXPECT_FALSE(result); } @@ -2264,7 +2260,6 @@ TEST(DecisionTree, FindBestConditionConcurrentManagerScaled) { EXPECT_EQ(cache.splitter_cache_list.size(), 10); EXPECT_EQ(cache.durable_response_list.size(), 100); - EXPECT_EQ(cache.condition_list.size(), 20); EXPECT_FALSE(result); random.seed(4321); diff --git a/yggdrasil_decision_forests/learner/decision_tree/training.cc b/yggdrasil_decision_forests/learner/decision_tree/training.cc index 55552e43..47df1ea5 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/training.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/training.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include "absl/base/optimization.h" #include "absl/log/log.h" +#include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -388,7 +390,7 @@ proto::DecisionTreeTrainingConfig::Internal::SortingStrategy EffectiveStrategy( case proto::DecisionTreeTrainingConfig::Internal::AUTO: CHECK(false); // The AUTO strategy should have been resolved before. - [[fallthrough]]; + break; case proto::DecisionTreeTrainingConfig::Internal::PRESORTED: { DCHECK(internal_config.preprocessing); const auto num_total_examples = @@ -405,7 +407,7 @@ proto::DecisionTreeTrainingConfig::Internal::SortingStrategy EffectiveStrategy( } // namespace // Specialization in the case of classification. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionClassification( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -550,7 +552,7 @@ SplitSearchResult FindBestCondition( return result; } -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionRegressionHessianGain( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -750,7 +752,7 @@ SplitSearchResult FindBestCondition( } // Specialization in the case of regression. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionRegression( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -967,7 +969,7 @@ SplitSearchResult FindBestCondition( } // Specialization in the case of uplift with categorical outcome. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionUpliftCategorical( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -1038,7 +1040,7 @@ SplitSearchResult FindBestCondition( } // Specialization in the case of uplift with numerical outcome. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionUpliftNumerical( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -1118,16 +1120,19 @@ SplitterWorkResponse FindBestConditionFromSplitterWorkRequest( response.manager_data = request.manager_data; request.splitter_cache->random.seed(request.seed); - if (request.num_oblique_projections_to_run.has_value()) { + response.condition = absl::make_unique(); + response.condition->set_split_score(request.best_score); + + if (request.num_oblique_projections_to_run != -1) { DCHECK_EQ(request.attribute_idx, -1); const auto found_oblique_condition = FindBestConditionOblique( request.common->train_dataset, request.common->selected_examples, weights, config, config_link, dt_config, request.common->parent, internal_config, request.common->label_stats, - request.num_oblique_projections_to_run.value(), - request.common->constraints, request.condition, - &request.splitter_cache->random, request.splitter_cache) + request.num_oblique_projections_to_run, request.common->constraints, + response.condition.get(), &request.splitter_cache->random, + request.splitter_cache) .value(); // An oblique split cannot be invalid. @@ -1143,11 +1148,11 @@ SplitterWorkResponse FindBestConditionFromSplitterWorkRequest( utils::down_cast( request.common->label_stats); - response.status = FindBestCondition( + response.status = FindBestConditionClassification( request.common->train_dataset, request.common->selected_examples, weights, config, config_link, dt_config, request.common->parent, internal_config, label_stats, request.attribute_idx, - request.common->constraints, request.condition, + request.common->constraints, response.condition.get(), &request.splitter_cache->random, request.splitter_cache); } break; case model::proto::Task::REGRESSION: @@ -1156,22 +1161,22 @@ SplitterWorkResponse FindBestConditionFromSplitterWorkRequest( utils::down_cast( request.common->label_stats); - response.status = FindBestCondition( + response.status = FindBestConditionRegressionHessianGain( request.common->train_dataset, request.common->selected_examples, weights, config, config_link, dt_config, request.common->parent, internal_config, label_stats, request.attribute_idx, - request.common->constraints, request.condition, + request.common->constraints, response.condition.get(), &request.splitter_cache->random, request.splitter_cache); } else { const auto& label_stats = utils::down_cast( request.common->label_stats); - response.status = FindBestCondition( + response.status = FindBestConditionRegression( request.common->train_dataset, request.common->selected_examples, weights, config, config_link, dt_config, request.common->parent, internal_config, label_stats, request.attribute_idx, - request.common->constraints, request.condition, + request.common->constraints, response.condition.get(), &request.splitter_cache->random, request.splitter_cache); } break; @@ -1285,53 +1290,53 @@ absl::StatusOr FindBestConditionSingleThreadManager( const auto& class_label_stats = utils::down_cast(label_stats); - result = FindBestCondition(train_dataset, selected_examples, weights, - config, config_link, dt_config, parent, - internal_config, class_label_stats, - attribute_idx, constraints, best_condition, - random, &cache->splitter_cache_list[0]); + result = FindBestConditionClassification( + train_dataset, selected_examples, weights, config, config_link, + dt_config, parent, internal_config, class_label_stats, + attribute_idx, constraints, best_condition, random, + &cache->splitter_cache_list[0]); } break; case model::proto::Task::REGRESSION: if (internal_config.hessian_score) { const auto& reg_label_stats = utils::down_cast(label_stats); - result = FindBestCondition(train_dataset, selected_examples, weights, - config, config_link, dt_config, parent, - internal_config, reg_label_stats, - attribute_idx, constraints, best_condition, - random, &cache->splitter_cache_list[0]); + result = FindBestConditionRegressionHessianGain( + train_dataset, selected_examples, weights, config, config_link, + dt_config, parent, internal_config, reg_label_stats, + attribute_idx, constraints, best_condition, random, + &cache->splitter_cache_list[0]); } else { const auto& reg_label_stats = utils::down_cast(label_stats); - result = FindBestCondition(train_dataset, selected_examples, weights, - config, config_link, dt_config, parent, - internal_config, reg_label_stats, - attribute_idx, constraints, best_condition, - random, &cache->splitter_cache_list[0]); + result = FindBestConditionRegression( + train_dataset, selected_examples, weights, config, config_link, + dt_config, parent, internal_config, reg_label_stats, + attribute_idx, constraints, best_condition, random, + &cache->splitter_cache_list[0]); } break; case model::proto::Task::CATEGORICAL_UPLIFT: { const auto& uplift_label_stats = utils::down_cast(label_stats); - result = FindBestCondition(train_dataset, selected_examples, weights, - config, config_link, dt_config, parent, - internal_config, uplift_label_stats, - attribute_idx, constraints, best_condition, - random, &cache->splitter_cache_list[0]); + result = FindBestConditionUpliftCategorical( + train_dataset, selected_examples, weights, config, config_link, + dt_config, parent, internal_config, uplift_label_stats, + attribute_idx, constraints, best_condition, random, + &cache->splitter_cache_list[0]); } break; case model::proto::Task::NUMERICAL_UPLIFT: { const auto& uplift_label_stats = utils::down_cast(label_stats); - result = FindBestCondition(train_dataset, selected_examples, weights, - config, config_link, dt_config, parent, - internal_config, uplift_label_stats, - attribute_idx, constraints, best_condition, - random, &cache->splitter_cache_list[0]); + result = FindBestConditionUpliftNumerical( + train_dataset, selected_examples, weights, config, config_link, + dt_config, parent, internal_config, uplift_label_stats, + attribute_idx, constraints, best_condition, random, + &cache->splitter_cache_list[0]); } break; default: @@ -1390,13 +1395,13 @@ absl::StatusOr FindBestConditionConcurrentManager( // // Note that next_job_to_process < next_job_to_schedule always holds. - int num_threads = splitter_concurrency_setup.num_threads; - int num_features = config_link.features().size(); + const int num_threads = splitter_concurrency_setup.num_threads; - if (num_features == 0) { + if (config_link.features().empty()) { return false; } + // Constant and static part of the requests. SplitterWorkRequestCommon common{ .train_dataset = train_dataset, .selected_examples = selected_examples, @@ -1407,58 +1412,72 @@ absl::StatusOr FindBestConditionConcurrentManager( // Computes the number of oblique projections to evaluate and how to group // them into requests. - bool oblique = false; int num_oblique_jobs = 0; - int num_oblique_projections = 0; - if (dt_config.split_axis_case() == - proto::DecisionTreeTrainingConfig::kSparseObliqueSplit) { - num_oblique_projections = - GetNumProjections(dt_config, config_link.numerical_features_size()); - - // Arbitrary minimum number of oblique projections to test in each job. - // Because oblique jobs are expensive (more than non oblique jobs), it is - // not efficient to create a request with too little work to do. - // - // In most real cases, this parameter does not matter as the limit is - // effectively constraint by the number of threads. - const int min_projections_per_request = 10; - - DCHECK_GE(num_threads, 1); - num_oblique_jobs = std::min(num_threads, (num_oblique_projections + - min_projections_per_request - 1) / - min_projections_per_request); - oblique = num_oblique_projections > 0; - } else if (config_link.numerical_features_size() > 0 && - dt_config.split_axis_case() == - proto::DecisionTreeTrainingConfig::kMhldObliqueSplit) { - num_oblique_projections = 1; - num_oblique_jobs = 1; - oblique = true; + int num_oblique_projections; + int num_oblique_projections_per_oblique_job; + + if (config_link.numerical_features_size() > 0) { + if (dt_config.split_axis_case() == + proto::DecisionTreeTrainingConfig::kSparseObliqueSplit) { + num_oblique_projections = + GetNumProjections(dt_config, config_link.numerical_features_size()); + + if (num_oblique_projections > 0) { + // Arbitrary minimum number of oblique projections to test in each job. + // Because oblique jobs are expensive (more than non oblique jobs), it + // is not efficient to create a request with too little work to do. + // + // In most real cases, this parameter does not matter as the limit is + // effectively constraint by the number of threads. + const int min_projections_per_request = 10; + + DCHECK_GE(num_threads, 1); + num_oblique_jobs = std::min( + num_threads, + (num_oblique_projections + min_projections_per_request - 1) / + min_projections_per_request); + num_oblique_projections_per_oblique_job = + (num_oblique_projections + num_oblique_jobs - 1) / num_oblique_jobs; + } + } else if (dt_config.split_axis_case() == + proto::DecisionTreeTrainingConfig::kMhldObliqueSplit) { + num_oblique_projections = 1; + num_oblique_projections_per_oblique_job = 1; + num_oblique_jobs = 1; + } } // Prepare caches. cache->splitter_cache_list.resize(num_threads); - cache->condition_list.resize(num_threads * kConditionPoolGrowthFactor); - const int num_jobs = num_features + num_oblique_jobs; - cache->durable_response_list.resize(num_jobs); // Get the ordered indices of the attributes to test. int min_num_jobs_to_test; std::vector& candidate_attributes = cache->candidate_attributes; GetCandidateAttributes(config, config_link, dt_config, &min_num_jobs_to_test, &candidate_attributes, random); - // All the oblique requests need to be tested. + + const int num_jobs = candidate_attributes.size() + num_oblique_jobs; + // All the oblique jobs need to be done. + // Note: When do look for oblique splits, we also run the classical numerical + // splitter. min_num_jobs_to_test += num_oblique_jobs; - // Marks all the caches and conditions as "available". - cache->available_cache_idxs.fill_iota(cache->splitter_cache_list.size(), 0); - cache->available_condition_idxs.fill_iota(cache->condition_list.size(), 0); + cache->durable_response_list.resize(num_jobs); + + // Marks all the caches "available". + cache->available_cache_idxs.resize(cache->splitter_cache_list.size()); + std::iota(cache->available_cache_idxs.begin(), + cache->available_cache_idxs.end(), 0); // Marks all the duration responses as "non set". for (auto& s : cache->durable_response_list) { s.set = false; } + // Score and value of the best found condition. + std::atomic best_split_score = best_condition->split_score(); + std::unique_ptr best_condition_ptr; + // Get Channel readers and writers. auto& processor = *splitter_concurrency_setup.split_finder_processor; @@ -1469,78 +1488,64 @@ absl::StatusOr FindBestConditionConcurrentManager( // If attribute_idx is == -1 and num_oblique_projections_to_run != -1, create // a request for an oblique split. // - auto produce = - [&](const int job_idx, const float best_score, const int attribute_idx, + auto build_request = + [&](const int job_idx, const int attribute_idx, const int num_oblique_projections_to_run) -> SplitterWorkRequest { - // Get a cache and a condition. + DCHECK_NE(attribute_idx != -1, num_oblique_projections_to_run != -1); DCHECK(!cache->available_cache_idxs.empty()); - DCHECK(!cache->available_condition_idxs.empty()); - int32_t cache_idx = cache->available_cache_idxs.back(); + const int32_t cache_idx = cache->available_cache_idxs.back(); cache->available_cache_idxs.pop_back(); - int32_t condition_idx = cache->available_condition_idxs.back(); - DCHECK_GE(condition_idx, -1); - cache->available_condition_idxs.pop_back(); - - SplitterWorkRequest request; - request.manager_data.condition_idx = condition_idx; - request.manager_data.cache_idx = cache_idx; - request.manager_data.job_idx = job_idx; - DCHECK((attribute_idx == -1) != (num_oblique_projections_to_run == -1)); - if (attribute_idx != -1) { - request.attribute_idx = attribute_idx; - } else { - request.num_oblique_projections_to_run = num_oblique_projections_to_run; - request.attribute_idx = -1; - } - request.condition = &cache->condition_list[condition_idx]; - request.splitter_cache = &cache->splitter_cache_list[cache_idx]; - request.condition->set_split_score(best_score); // Best score so far. - request.common = &common; - request.seed = (*random)(); // Create a new seed. - - return request; + return SplitterWorkRequest( + /*manager_data=*/ + { + .cache_idx = cache_idx, + .job_idx = job_idx, + }, + /*best_score=*/best_split_score, + /*attribute_idx=*/attribute_idx, + /*splitter_cache=*/&cache->splitter_cache_list[cache_idx], + /*common=*/&common, + /*seed=*/(*random)(), + /*num_oblique_projections_to_run=*/num_oblique_projections_to_run); }; // Schedule all the oblique jobs. int next_job_to_schedule = 0; - if (oblique) { - const int num_oblique_projections_per_job = - (num_oblique_projections + num_oblique_jobs - 1) / num_oblique_jobs; - - for (int oblique_job_idx = 0; oblique_job_idx < num_oblique_jobs; - oblique_job_idx++) { - const int num_projections_in_request = - std::min((oblique_job_idx + 1) * num_oblique_projections_per_job, - num_oblique_projections) - - oblique_job_idx * num_oblique_projections_per_job; - processor.Submit(produce( - next_job_to_schedule++, best_condition->split_score(), - /*attribute_idx=*/-1, - /*num_oblique_projections_to_run=*/num_projections_in_request)); + for (int oblique_job_idx = 0; oblique_job_idx < num_oblique_jobs; + oblique_job_idx++) { + int num_projections_in_request; + if (oblique_job_idx == num_oblique_jobs - 1) { + num_projections_in_request = + num_oblique_projections - + oblique_job_idx * num_oblique_projections_per_oblique_job; + } else { + num_projections_in_request = num_oblique_projections_per_oblique_job; } + + processor.Submit(build_request( + next_job_to_schedule++, + /*attribute_idx=*/-1, + /*num_oblique_projections_to_run=*/num_projections_in_request)); } - // Schedule some non-oblique jobs. + // Schedule some non-oblique jobs if threads are still available. while (next_job_to_schedule < std::min(num_threads, num_jobs) && - !cache->available_condition_idxs.empty() && !cache->available_cache_idxs.empty()) { + DCHECK_GE(next_job_to_schedule, num_oblique_jobs); const int attribute_idx = candidate_attributes[next_job_to_schedule - num_oblique_jobs]; - processor.Submit(produce(next_job_to_schedule++, - best_condition->split_score(), - /*attribute_idx=*/attribute_idx, - /*num_oblique_projections_to_run=*/-1)); + + processor.Submit(build_request(next_job_to_schedule, + /*attribute_idx=*/attribute_idx, + /*num_oblique_projections_to_run=*/-1)); + next_job_to_schedule++; } int num_valid_job_tested = 0; int next_job_to_process = 0; - // Index of the best condition. -1 if not better condition was found. - int best_condition_idx = -1; - // Score of the best found condition, or minimum condition score to look for. - float best_split_score = best_condition->split_score(); - while (true) { + // Get a new result from a worker splitter. auto maybe_response = processor.GetResult(); if (!maybe_response.has_value()) { break; @@ -1551,73 +1556,36 @@ absl::StatusOr FindBestConditionConcurrentManager( SplitterWorkResponse& response = maybe_response.value(); // Release the cache immediately to be reused by other workers. - cache->available_cache_idxs.push_front(response.manager_data.cache_idx); + cache->available_cache_idxs.push_back(response.manager_data.cache_idx); + // Record response for further processing. auto& durable_response = cache->durable_response_list[response.manager_data.job_idx]; durable_response.status = response.status; durable_response.set = true; if (response.status == SplitSearchResult::kBetterSplitFound) { - // The worker found a better solution compared from when the worker - // started working. - - const float new_split_score = - cache->condition_list[response.manager_data.condition_idx] - .split_score(); - - if ((new_split_score > best_split_score) || - (new_split_score == best_split_score && - response.manager_data.condition_idx < - durable_response.condition_idx)) { - // This is the best condition so far. Keep it for processing. - durable_response.condition_idx = response.manager_data.condition_idx; - } else { - // Acctually, a better condition was found by another worker and - // processed in the mean time. No need to keep the condition. - cache->available_condition_idxs.push_front( - response.manager_data.condition_idx); - durable_response.condition_idx = -1; - durable_response.status = SplitSearchResult::kNoBetterSplitFound; - } - } else { - // Return the condition to the condition pool. - cache->available_condition_idxs.push_front( - response.manager_data.condition_idx); - durable_response.condition_idx = -1; + // The worker found a potentially better solution. + durable_response.condition = std::move(response.condition); } } - // Process all the responses that can be processed. - // Simulate a deterministic sequential processing of the responses. + // Process new responses that can be processed. while (next_job_to_process < next_job_to_schedule && num_valid_job_tested < min_num_jobs_to_test && cache->durable_response_list[next_job_to_process].set) { + // Something to process. auto durable_response = &cache->durable_response_list[next_job_to_process]; next_job_to_process++; - if (durable_response->status == SplitSearchResult::kNoBetterSplitFound) { - // Even if no better split was found, this is still a valid job. + if (durable_response->status != SplitSearchResult::kInvalidAttribute) { num_valid_job_tested++; - } else if (durable_response->status == - SplitSearchResult::kBetterSplitFound) { - num_valid_job_tested++; - DCHECK_NE(durable_response->condition_idx, -1); - - const float process_split_score = - cache->condition_list[durable_response->condition_idx] - .split_score(); - - if (process_split_score > best_split_score) { - if (best_condition_idx != -1) { - cache->available_condition_idxs.push_front(best_condition_idx); - } - best_condition_idx = durable_response->condition_idx; - best_split_score = process_split_score; - } else { - // Return the condition to the condition pool. - cache->available_condition_idxs.push_front( - durable_response->condition_idx); + } + if (durable_response->status == SplitSearchResult::kBetterSplitFound) { + const float split_score = durable_response->condition->split_score(); + if (split_score > best_split_score) { + best_condition_ptr = std::move(durable_response->condition); + best_split_score = split_score; } } } @@ -1627,40 +1595,38 @@ absl::StatusOr FindBestConditionConcurrentManager( break; } - // Schedule the testing of more conditions. - - while (!cache->available_condition_idxs.empty() && - !cache->available_cache_idxs.empty() && - next_job_to_schedule < num_jobs) { - const int attribute_idx = - candidate_attributes[next_job_to_schedule - num_oblique_jobs]; - processor.Submit(produce(next_job_to_schedule++, best_split_score, - /*attribute_idx=*/attribute_idx, - /*num_oblique_projections_to_run=*/-1)); + if (next_job_to_process >= num_jobs) { + // We have processed all the jobs. + break; } - // The following condition means that no work is in the pipeline and no more - // work will be generated. - if (cache->available_cache_idxs.full()) { - break; + // Schedule the testing of more conditions. + while (!cache->available_cache_idxs.empty() && + next_job_to_schedule < num_jobs) { + processor.Submit(build_request( + next_job_to_schedule, + /*attribute_idx=*/ + candidate_attributes[next_job_to_schedule - num_oblique_jobs], + /*num_oblique_projections_to_run=*/-1)); + next_job_to_schedule++; } } // Drain the response channel. - while (!cache->available_cache_idxs.full()) { + while (cache->available_cache_idxs.size() < num_threads) { auto maybe_response = processor.GetResult(); if (!maybe_response.has_value()) { break; } SplitterWorkResponse& response = maybe_response.value(); - cache->available_cache_idxs.push_front(response.manager_data.cache_idx); + cache->available_cache_idxs.push_back(response.manager_data.cache_idx); } - // Move the random generator state to facilitate deterministic behavior. + // Move the random generator state to make the behavior deterministic. random->discard(num_jobs - next_job_to_schedule); - if (best_condition_idx != -1) { - *best_condition = cache->condition_list[best_condition_idx]; + if (best_condition_ptr) { + *best_condition = std::move(*best_condition_ptr); return true; } return false; @@ -1679,19 +1645,20 @@ absl::StatusOr FindBestConditionManager( proto::NodeCondition* best_condition, utils::RandomEngine* random, PerThreadCache* cache) { if (splitter_concurrency_setup.concurrent_execution) { + // Multi-thread. return FindBestConditionConcurrentManager( train_dataset, selected_examples, weights, config, config_link, dt_config, splitter_concurrency_setup, parent, internal_config, label_stats, constraints, best_condition, random, cache); } + + // Single thread. return FindBestConditionSingleThreadManager( train_dataset, selected_examples, weights, config, config_link, dt_config, parent, internal_config, label_stats, constraints, best_condition, random, cache); } -// This is the entry point when searching for a condition. -// All other "FindBestCondition*" functions are called by this one. absl::StatusOr FindBestCondition( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, @@ -1706,13 +1673,11 @@ absl::StatusOr FindBestCondition( switch (config.task()) { case model::proto::Task::CLASSIFICATION: { STATUS_CHECK(!internal_config.hessian_score); - ClassificationLabelStats label_stat( - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::CategoricalColumn>( - config_link.label()) - .value() - ->values()); + ASSIGN_OR_RETURN(const auto labels, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::CategoricalColumn>( + config_link.label())); + ClassificationLabelStats label_stat(labels->values()); const auto& label_column_spec = train_dataset.data_spec().columns(config_link.label()); @@ -1737,25 +1702,23 @@ absl::StatusOr FindBestCondition( case model::proto::Task::REGRESSION: { if (internal_config.hessian_score) { - DCHECK_NE(internal_config.gradient_col_idx, -1); - DCHECK_NE(internal_config.hessian_col_idx, -1); - - DCHECK_EQ(internal_config.gradient_col_idx, config_link.label()); - RegressionHessianLabelStats label_stat( - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::NumericalColumn>( - internal_config.gradient_col_idx) - .value() - ->values(), - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::NumericalColumn>( - internal_config.hessian_col_idx) - .value() - ->values()); - - DCHECK(parent.regressor().has_sum_gradients()); + STATUS_CHECK_NE(internal_config.gradient_col_idx, -1); + STATUS_CHECK_NE(internal_config.hessian_col_idx, -1); + STATUS_CHECK_EQ(internal_config.gradient_col_idx, config_link.label()); + + ASSIGN_OR_RETURN(const auto gradients, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::NumericalColumn>( + internal_config.gradient_col_idx)); + ASSIGN_OR_RETURN(const auto hessians, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::NumericalColumn>( + internal_config.hessian_col_idx)); + + RegressionHessianLabelStats label_stat(gradients->values(), + hessians->values()); + + STATUS_CHECK(parent.regressor().has_sum_gradients()); label_stat.sum_gradient = parent.regressor().sum_gradients(); label_stat.sum_hessian = parent.regressor().sum_hessians(); label_stat.sum_weights = parent.regressor().sum_weights(); @@ -1765,15 +1728,13 @@ absl::StatusOr FindBestCondition( dt_config, splitter_concurrency_setup, parent, internal_config, label_stat, constraints, best_condition, random, cache); } else { - RegressionLabelStats label_stat( - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::NumericalColumn>( - config_link.label()) - .value() - ->values()); - - DCHECK(parent.regressor().has_distribution()); + ASSIGN_OR_RETURN(const auto labels, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::NumericalColumn>( + config_link.label())); + RegressionLabelStats label_stat(labels->values()); + + STATUS_CHECK(parent.regressor().has_distribution()); label_stat.label_distribution.Load(parent.regressor().distribution()); return FindBestConditionManager( @@ -1790,20 +1751,20 @@ absl::StatusOr FindBestCondition( const auto& treatment_spec = train_dataset.data_spec().columns(config_link.uplift_treatment()); + ASSIGN_OR_RETURN(const auto labels, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::CategoricalColumn>( + config_link.label())); + + ASSIGN_OR_RETURN(const auto treatments, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::CategoricalColumn>( + config_link.uplift_treatment())); + CategoricalUpliftLabelStats label_stat( - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::CategoricalColumn>( - config_link.label()) - .value() - ->values(), + labels->values(), outcome_spec.categorical().number_of_unique_values(), - train_dataset - .ColumnWithCastWithStatus< - dataset::VerticalDataset::CategoricalColumn>( - config_link.uplift_treatment()) - .value() - ->values(), + treatments->values(), treatment_spec.categorical().number_of_unique_values()); UpliftLeafToLabelDist(parent.uplift(), &label_stat.label_distribution); @@ -1819,15 +1780,18 @@ absl::StatusOr FindBestCondition( const auto& treatment_spec = train_dataset.data_spec().columns(config_link.uplift_treatment()); + ASSIGN_OR_RETURN( + const auto labels, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::NumericalColumn>(config_link.label())); + + ASSIGN_OR_RETURN(const auto treatments, + train_dataset.ColumnWithCastWithStatus< + dataset::VerticalDataset::CategoricalColumn>( + config_link.uplift_treatment())); + NumericalUpliftLabelStats label_stat( - train_dataset - .ColumnWithCast( - config_link.label()) - ->values(), - train_dataset - .ColumnWithCast( - config_link.uplift_treatment()) - ->values(), + labels->values(), treatments->values(), treatment_spec.categorical().number_of_unique_values()); UpliftLeafToLabelDist(parent.uplift(), &label_stat.label_distribution); @@ -4148,15 +4112,9 @@ absl::Status DecisionTreeTrain( splitter_concurrency_setup.num_threads = internal_config.num_threads; } - splitter_concurrency_setup.split_finder_processor = - std::make_unique( - "SplitFinder", internal_config.num_threads, - [&](SplitterWorkRequest request) -> SplitterWorkResponse { - return FindBestConditionFromSplitterWorkRequest( - weights, config, config_link, dt_config, - splitter_concurrency_setup, internal_config, request); - }); - splitter_concurrency_setup.split_finder_processor->StartWorkers(); + RETURN_IF_ERROR(FindBestConditionStartWorkers(config, config_link, dt_config, + internal_config, weights, + &splitter_concurrency_setup)); return DecisionTreeCoreTrain(train_dataset, *effective_selected_examples, leaf_examples, config, config_link, dt_config, @@ -4164,6 +4122,27 @@ absl::Status DecisionTreeTrain( random, internal_config, dt); } +absl::Status FindBestConditionStartWorkers( + const model::proto::TrainingConfig& config, + const model::proto::TrainingConfigLinking& config_link, + const proto::DecisionTreeTrainingConfig& dt_config, + const InternalTrainConfig& internal_config, + const std::vector& weights, + SplitterConcurrencySetup* splitter_concurrency_setup) { + auto find_condition = + [&](SplitterWorkRequest request) -> SplitterWorkResponse { + return FindBestConditionFromSplitterWorkRequest( + weights, config, config_link, dt_config, *splitter_concurrency_setup, + internal_config, request); + }; + splitter_concurrency_setup->split_finder_processor = + std::make_unique( + "SplitFinder", splitter_concurrency_setup->num_threads, + find_condition); + splitter_concurrency_setup->split_finder_processor->StartWorkers(); + return absl::OkStatus(); +} + absl::Status DecisionTreeCoreTrain( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, diff --git a/yggdrasil_decision_forests/learner/decision_tree/training.h b/yggdrasil_decision_forests/learner/decision_tree/training.h index a4415c4d..fc394999 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/training.h +++ b/yggdrasil_decision_forests/learner/decision_tree/training.h @@ -16,6 +16,7 @@ #ifndef YGGDRASIL_DECISION_FORESTS_LEARNER_DECISION_TREE_TRAINING_H_ #define YGGDRASIL_DECISION_FORESTS_LEARNER_DECISION_TREE_TRAINING_H_ +#include #include #include #include @@ -38,7 +39,6 @@ #include "yggdrasil_decision_forests/learner/decision_tree/utils.h" #include "yggdrasil_decision_forests/model/decision_tree/decision_tree.h" #include "yggdrasil_decision_forests/model/decision_tree/decision_tree.pb.h" -#include "yggdrasil_decision_forests/utils/circular_buffer.h" #include "yggdrasil_decision_forests/utils/concurrency_streamprocessor.h" #include "yggdrasil_decision_forests/utils/distribution.h" #include "yggdrasil_decision_forests/utils/random.h" @@ -104,8 +104,6 @@ struct SplitterWorkRequestCommon { // Data packed with the work request that can be used by the manager to pass // information to itself. struct SplitterWorkManagerData { - // Index of the condition in the condition pool. - int condition_idx; // Index of the condition in the cache pool. int cache_idx; // Index of the job. @@ -117,10 +115,11 @@ struct SplitterWorkManagerData { struct SplitterWorkRequest { SplitterWorkManagerData manager_data; + std::atomic& best_score; + // The attribute index to pass onto splitters. int attribute_idx; - // Non-owning pointer to a "condition" in PerThreadCache.condition_list. - proto::NodeCondition* condition; + // Non-owning pointer to an entry in PerThreadCache.splitter_cache_list. SplitterPerThreadCache* splitter_cache; @@ -128,9 +127,28 @@ struct SplitterWorkRequest { SplitterWorkRequestCommon* common; // Seed used to initialize the random generator. utils::RandomEngine::result_type seed; - // If set, search for oblique split. In this case "attribute_idx" should be + // If not -1, search for oblique split. In this case "attribute_idx" should be // -1. - std::optional num_oblique_projections_to_run; + int num_oblique_projections_to_run; + + // Copy is not allowed. + SplitterWorkRequest(SplitterWorkManagerData manager_data, + std::atomic& best_score, int attribute_idx, + SplitterPerThreadCache* splitter_cache, + SplitterWorkRequestCommon* common, + utils::RandomEngine::result_type seed, + int num_oblique_projections_to_run) + : manager_data(manager_data), + best_score(best_score), + attribute_idx(attribute_idx), + splitter_cache(splitter_cache), + common(common), + seed(seed), + num_oblique_projections_to_run(num_oblique_projections_to_run) {} + SplitterWorkRequest(const SplitterWorkRequest&) = delete; + SplitterWorkRequest& operator=(const SplitterWorkRequest&) = delete; + SplitterWorkRequest(SplitterWorkRequest&&) = default; + SplitterWorkRequest& operator=(SplitterWorkRequest&&) = default; }; // Contains the result of a splitter. @@ -139,6 +157,21 @@ struct SplitterWorkResponse { // The status returned by a splitter. SplitSearchResult status; + + std::unique_ptr condition; + + // Copy is not allowed. + SplitterWorkResponse() = default; + SplitterWorkResponse(SplitterWorkManagerData manager_data, + SplitSearchResult status, + std::unique_ptr condition) + : manager_data(manager_data), + status(status), + condition(std::move(condition)) {} + SplitterWorkResponse(const SplitterWorkResponse&) = delete; + SplitterWorkResponse& operator=(const SplitterWorkResponse&) = delete; + SplitterWorkResponse(SplitterWorkResponse&&) = default; + SplitterWorkResponse& operator=(SplitterWorkResponse&&) = default; }; using SplitterFinderStreamProcessor = @@ -149,8 +182,7 @@ using SplitterFinderStreamProcessor = // Part of the worker response (SplitterWorkResponse) that need to be kept in // order to simulate sequential feature splitting. struct SplitterWorkDurableResponse { - // Index of the condition if status==kBetterSplitFound. - int condition_idx; + std::unique_ptr condition; // The status returned by a splitter. SplitSearchResult status; @@ -189,12 +221,9 @@ struct PerThreadCache { // A set of objects that are used by FindBestCondition. std::vector splitter_cache_list; std::vector durable_response_list; - std::vector condition_list; // List of available indices into splitter_cache_list. - utils::CircularBuffer available_cache_idxs; - // List of available indices into condition_list. - utils::CircularBuffer available_condition_idxs; + std::vector available_cache_idxs; }; // In a concurrent setup, this structure encapsulates all the objects that are @@ -215,10 +244,12 @@ typedef std::function SetLeafValueFromLabelStatsFunctor; -// Find the best condition for this node. Return true iff a good condition has -// been found. -// This is the entry point when searching for a condition. -// All other "FindBestCondition*" functions are called by this one. +// Find the best condition for a leaf node. Return true if a condition better +// than the one initially in `best_condition` was found. If `best_condition` is +// a newly created object, return true if a condition was found (since +// `best_condition` does not yet define a condition). +// +// This is the entry point / main function to call to find a condition. absl::StatusOr FindBestCondition( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, @@ -231,8 +262,11 @@ absl::StatusOr FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, PerThreadCache* cache); -// Contains logic to switch between a single-threaded splitter and a concurrent -// implementation. +// Following are the method to handle multithreading in FindBestCondition. +// ============================================================================= + +// Dispatches the condition search to either single thread or multithread +// computation. absl::StatusOr FindBestConditionManager( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, @@ -246,8 +280,7 @@ absl::StatusOr FindBestConditionManager( proto::NodeCondition* best_condition, utils::RandomEngine* random, PerThreadCache* cache); -// This is an implementation of FindBestConditionManager that is optimized for -// execution in a single thread. +// Single thread search for conditions. absl::StatusOr FindBestConditionSingleThreadManager( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, @@ -260,7 +293,7 @@ absl::StatusOr FindBestConditionSingleThreadManager( proto::NodeCondition* best_condition, utils::RandomEngine* random, PerThreadCache* cache); -// This is a concurrent implementation of FindBestConditionManager. +// Multi-thread search for conditions. absl::StatusOr FindBestConditionConcurrentManager( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, @@ -274,6 +307,15 @@ absl::StatusOr FindBestConditionConcurrentManager( proto::NodeCondition* best_condition, utils::RandomEngine* random, PerThreadCache* cache); +// Starts the worker threads needed for "FindBestConditionConcurrentManager". +absl::Status FindBestConditionStartWorkers( + const model::proto::TrainingConfig& config, + const model::proto::TrainingConfigLinking& config_link, + const proto::DecisionTreeTrainingConfig& dt_config, + const InternalTrainConfig& internal_config, + const std::vector& weights, + SplitterConcurrencySetup* splitter_concurrency_setup); + // A worker that receives splitter work requests and dispatches those to the // right specialized splitter function. // @@ -287,8 +329,10 @@ SplitterWorkResponse FindBestConditionFromSplitterWorkRequest( const InternalTrainConfig& internal_config, const SplitterWorkRequest& request); -// Specialization in the case of classification. -SplitSearchResult FindBestCondition( +// Following are the "FindBestCondition" specialized for specific tasks. +// ============================================================================= + +SplitSearchResult FindBestConditionClassification( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -300,8 +344,7 @@ SplitSearchResult FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); -// Specialization in the case of regression. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionRegression( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -313,8 +356,7 @@ SplitSearchResult FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); -// Specialization in the case of regression with hessian gain. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionRegressionHessianGain( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -326,8 +368,7 @@ SplitSearchResult FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); -// Specialization in the case of uplift with categorical outcome. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionUpliftCategorical( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -339,8 +380,7 @@ SplitSearchResult FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); -// Specialization in the case of uplift with numerical outcome. -SplitSearchResult FindBestCondition( +SplitSearchResult FindBestConditionUpliftNumerical( const dataset::VerticalDataset& train_dataset, const std::vector& selected_examples, const std::vector& weights, @@ -352,8 +392,13 @@ SplitSearchResult FindBestCondition( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); -// Following are the split finder functions. Their name follow the patter: +// Following are the "FindBestCondition" specialized for both a task (i.e. label +// semantic) and feature semantic. The function names follow the pattern: // FindSplitLabel{label_type}Feature{feature_type}{algorithm_name}. +// +// Some splitters are only specialized on the feature, but not one the label +// typee (e.g. "FindBestConditionOblique"; +// ============================================================================= // Search for the best split of the type "Attribute is NA" (i.e. "Attribute is // missing") for classification. @@ -362,10 +407,10 @@ SplitSearchResult FindSplitLabelClassificationFeatureNA( const std::vector& weights, const dataset::VerticalDataset::AbstractColumn* attributes, const std::vector& labels, const int32_t num_label_classes, - const UnsignedExampleIdx min_num_obs, + UnsignedExampleIdx min_num_obs, const proto::DecisionTreeTrainingConfig& dt_config, const utils::IntegerDistributionDouble& label_distribution, - const int32_t attribute_idx, proto::NodeCondition* condition, + int32_t attribute_idx, proto::NodeCondition* condition, SplitterPerThreadCache* cache); // Search for the best split of the type "Attribute is NA" (i.e. "Attribute is @@ -763,6 +808,9 @@ absl::StatusOr FindBestConditionOblique( const NodeConstraints& constraints, proto::NodeCondition* best_condition, utils::RandomEngine* random, SplitterPerThreadCache* cache); +// End of the FindBestCondition specialization. +// ============================================================================= + // Returns the number of attributes to test ("num_attributes_to_test") and a // list of candidate attributes to test in order ("candidate_attributes"). // "candidate_attributes" is guaranteed to have at least diff --git a/yggdrasil_decision_forests/model/decision_tree/decision_tree.h b/yggdrasil_decision_forests/model/decision_tree/decision_tree.h index 79dc46d2..1be843fc 100644 --- a/yggdrasil_decision_forests/model/decision_tree/decision_tree.h +++ b/yggdrasil_decision_forests/model/decision_tree/decision_tree.h @@ -47,13 +47,6 @@ namespace decision_tree { using row_t = dataset::VerticalDataset::row_t; -// The total number of "conditions" allocated in PerThreadCache.condition_list -// is equal to the number of threads times this factor. -// -// A larger value might increase the memory usage while a lower value might slow -// down the training." -constexpr int32_t kConditionPoolGrowthFactor = 2; - // Variable importance names to be used for all decision tree based model. static constexpr char kVariableImportanceNumberOfNodes[] = "NUM_NODES"; static constexpr char kVariableImportanceNumberOfTimesAsRoot[] = "NUM_AS_ROOT";