diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD index 76365a8e..eb509cc3 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD @@ -213,6 +213,7 @@ cc_library_ydf( "//yggdrasil_decision_forests/utils:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], alwayslink = 1, ) diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.cc index 802119f3..7a18dfed 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.cc @@ -15,12 +15,15 @@ #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.h" +#include + #include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "yggdrasil_decision_forests/dataset/types.h" #include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/learner/abstract_learner.pb.h" @@ -29,7 +32,7 @@ #include "yggdrasil_decision_forests/metric/metric.h" #include "yggdrasil_decision_forests/model/abstract_model.pb.h" #include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep -#include "yggdrasil_decision_forests/utils/logging.h" +#include "yggdrasil_decision_forests/utils/logging.h" // IWYU pragma: keep #include "yggdrasil_decision_forests/utils/math.h" #include "yggdrasil_decision_forests/utils/random.h" #include "yggdrasil_decision_forests/utils/status_macros.h" @@ -37,6 +40,30 @@ namespace yggdrasil_decision_forests { namespace model { namespace gradient_boosted_trees { +namespace { +void UpdateGradientsSingleThread(const absl::Span labels, + const absl::Span predictions, + absl::Span gradient_data, + absl::Span hessian_data) { + DCHECK_EQ(labels.size(), predictions.size()); + DCHECK_EQ(labels.size(), gradient_data.size()); + DCHECK_EQ(labels.size(), hessian_data.size()); + + // We use "table" to avoid a branch in the following loop. + // This optimization was found to improve the code speed. This should be + // revisited as new compilers are likely to do this optimization + // automatically one day. + static float table[] = {-1.f, 1.f}; + + for (size_t example_idx = 0; example_idx < labels.size(); ++example_idx) { + const float label = labels[example_idx]; + const float prediction = predictions[example_idx]; + gradient_data[example_idx] = table[label >= prediction]; + hessian_data[example_idx] = 1.f; + } +} + +} // namespace absl::Status MeanAverageErrorLoss::Status() const { if (task_ != model::proto::Task::REGRESSION) { @@ -106,26 +133,33 @@ absl::Status MeanAverageErrorLoss::UpdateGradients( const RankingGroupsIndices* ranking_index, GradientDataRef* gradients, utils::RandomEngine* random, utils::concurrency::ThreadPool* thread_pool) const { - // TODO: b/303811729 - Use "thread_pool" if set. - STATUS_CHECK_EQ(gradients->size(), 1); - const UnsignedExampleIdx num_examples = labels.size(); std::vector& gradient_data = *(*gradients)[0].gradient; std::vector& hessian_data = *(*gradients)[0].hessian; STATUS_CHECK_EQ(gradient_data.size(), hessian_data.size()); - // We use "table" to avoid a branch in the following loop. - // This optimization was found to improve the code speed. This should be - // revisited as new compilers are likely to do this optimization - // automatically one day. - static float table[] = {-1.f, 1.f}; - for (UnsignedExampleIdx example_idx = 0; example_idx < num_examples; - example_idx++) { - const float label = labels[example_idx]; - const float prediction = predictions[example_idx]; - gradient_data[example_idx] = table[label >= prediction]; - hessian_data[example_idx] = 1.f; + if (thread_pool == nullptr) { + UpdateGradientsSingleThread(labels, predictions, + absl::Span(gradient_data), + absl::Span(hessian_data)); + } else { + utils::concurrency::ConcurrentForLoop( + thread_pool->num_threads(), thread_pool, labels.size(), + [&labels, &predictions, &gradient_data, &hessian_data]( + const size_t block_idx, const size_t begin_idx, + const size_t end_idx) -> void { + UpdateGradientsSingleThread( + absl::Span(labels).subspan(begin_idx, + end_idx - begin_idx), + absl::Span(predictions) + .subspan(begin_idx, end_idx - begin_idx), + absl::Span(gradient_data) + .subspan(begin_idx, end_idx - begin_idx), + absl::Span(hessian_data) + .subspan(begin_idx, end_idx - begin_idx)); + }); } + return absl::OkStatus(); } diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error_test.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error_test.cc index 3690ba99..25680438 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error_test.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error_test.cc @@ -16,6 +16,7 @@ #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.h" #include +#include #include #include @@ -26,9 +27,9 @@ #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h" #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h" #include "yggdrasil_decision_forests/model/abstract_model.pb.h" -#include "yggdrasil_decision_forests/utils/concurrency.h" +#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep #include "yggdrasil_decision_forests/utils/random.h" -#include "yggdrasil_decision_forests/utils/status_macros.h" +#include "yggdrasil_decision_forests/utils/status_macros.h" // IWYU pragma: keep #include "yggdrasil_decision_forests/utils/test.h" #include "yggdrasil_decision_forests/utils/testing_macros.h" @@ -36,7 +37,6 @@ namespace yggdrasil_decision_forests::model::gradient_boosted_trees { namespace { -using ::testing::Bool; using ::testing::Combine; using ::testing::ElementsAre; using ::testing::FloatNear; @@ -89,6 +89,19 @@ enum class UseMultithreading : bool { class MeanAverageErrorLossWeightAndThreadingTest : public testing::TestWithParam> { + protected: + void SetUp() override { + const bool threaded = std::get<1>(GetParam()) == UseMultithreading::kYes; + if (threaded) { + thread_pool_ = std::make_unique("", 4); + thread_pool_->StartWorkers(); + } + } + + void TearDown() override { thread_pool_.reset(); } + + // The thread pool is only set if "UseMultithreading=kYes". + std::unique_ptr thread_pool_; }; class MeanAverageErrorLossWeightTest @@ -125,10 +138,11 @@ TEST(MeanAverageErrorLossTestNonWeighted, InitialPredictionsOdd) { ElementsAre(2.f)); } -TEST_P(MeanAverageErrorLossWeightTest, UpdateGradients) { +TEST_P(MeanAverageErrorLossWeightAndThreadingTest, UpdateGradients) { ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset, CreateToyDataset()); - const bool weighted = GetParam() == UseWeights::kYes; + const bool weighted = std::get<0>(GetParam()) == UseWeights::kYes; + const std::vector weights = CreateToyWeights(weighted); dataset::VerticalDataset gradient_dataset; @@ -153,7 +167,7 @@ TEST_P(MeanAverageErrorLossWeightTest, UpdateGradients) { ASSERT_OK(loss_imp.UpdateGradients(gradient_dataset, /* label_col_idx= */ 0, predictions, /*ranking_index=*/nullptr, &gradients, - &random)); + &random, thread_pool_.get())); ASSERT_THAT(gradients, Not(IsEmpty())); if (weighted) { @@ -167,7 +181,6 @@ TEST_P(MeanAverageErrorLossWeightAndThreadingTest, ComputeLoss) { ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset, CreateToyDataset()); const bool weighted = std::get<0>(GetParam()) == UseWeights::kYes; - const bool threaded = std::get<1>(GetParam()) == UseMultithreading::kYes; const std::vector weights = CreateToyWeights(weighted); const std::vector predictions(4, 0.f); @@ -175,19 +188,11 @@ TEST_P(MeanAverageErrorLossWeightAndThreadingTest, ComputeLoss) { model::proto::Task::REGRESSION, dataset.data_spec().columns(0)); LossResults loss_results; - if (threaded) { - utils::concurrency::ThreadPool thread_pool("", 4); - thread_pool.StartWorkers(); - ASSERT_OK_AND_ASSIGN(loss_results, - loss_imp.Loss(dataset, - /* label_col_idx= */ 0, predictions, - weights, nullptr, &thread_pool)); - } else { - ASSERT_OK_AND_ASSIGN( - loss_results, - loss_imp.Loss(dataset, - /* label_col_idx= */ 0, predictions, weights, nullptr)); - } + ASSERT_OK_AND_ASSIGN( + loss_results, loss_imp.Loss(dataset, + /* label_col_idx= */ 0, predictions, weights, + nullptr, thread_pool_.get())); + if (weighted) { // MAE = \sum (abs(prediction_i - label_i) * weight_i) / \sum weight_i const float expected_mae = diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.cc index 3a37440e..fe3315df 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.cc @@ -42,7 +42,8 @@ absl::Status AbstractLoss::UpdateGradients( const dataset::VerticalDataset& dataset, int label_col_idx, const std::vector& predictions, const RankingGroupsIndices* ranking_index, - std::vector* gradients, utils::RandomEngine* random) const { + std::vector* gradients, utils::RandomEngine* random, + utils::concurrency::ThreadPool* thread_pool) const { GradientDataRef compact_gradient(gradients->size()); for (int i = 0; i < gradients->size(); i++) { compact_gradient[i] = {&(*gradients)[i].gradient, &(*gradients)[i].hessian}; @@ -53,7 +54,8 @@ absl::Status AbstractLoss::UpdateGradients( label_col_idx); if (categorical_labels) { return UpdateGradients(categorical_labels->values(), predictions, - ranking_index, &compact_gradient, random, nullptr); + ranking_index, &compact_gradient, random, + thread_pool); } const auto* numerical_labels = @@ -61,7 +63,8 @@ absl::Status AbstractLoss::UpdateGradients( label_col_idx); if (numerical_labels) { return UpdateGradients(numerical_labels->values(), predictions, - ranking_index, &compact_gradient, random, nullptr); + ranking_index, &compact_gradient, random, + thread_pool); } return absl::InternalError( diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h index 74f672d5..da5aa925 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h @@ -231,12 +231,12 @@ class AbstractLoss { // This method calls the UpdateGradients defined above depending on the type // of the label column in the VerticalDataset (currently, only support float // (Numerical) and int32 (Categorical)). - absl::Status UpdateGradients(const dataset::VerticalDataset& dataset, - int label_col_idx, - const std::vector& predictions, - const RankingGroupsIndices* ranking_index, - std::vector* gradients, - utils::RandomEngine* random) const; + absl::Status UpdateGradients( + const dataset::VerticalDataset& dataset, int label_col_idx, + const std::vector& predictions, + const RankingGroupsIndices* ranking_index, + std::vector* gradients, utils::RandomEngine* random, + utils::concurrency::ThreadPool* thread_pool = nullptr) const; // Gets the name of the metrics returned in "secondary_metric" of the "Loss" // method. diff --git a/yggdrasil_decision_forests/utils/BUILD b/yggdrasil_decision_forests/utils/BUILD index 7d2cba14..3ca0f7c2 100644 --- a/yggdrasil_decision_forests/utils/BUILD +++ b/yggdrasil_decision_forests/utils/BUILD @@ -973,6 +973,7 @@ cc_test( ":concurrency", # "@com_google_googletest//:gtest_main", # When fixed "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/types:optional", ], ) diff --git a/yggdrasil_decision_forests/utils/concurrency.cc b/yggdrasil_decision_forests/utils/concurrency.cc index 80065e4e..ef7b48af 100644 --- a/yggdrasil_decision_forests/utils/concurrency.cc +++ b/yggdrasil_decision_forests/utils/concurrency.cc @@ -13,31 +13,17 @@ * limitations under the License. */ -#ifndef YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_UTILS_H_ -#define YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_UTILS_H_ +#include "yggdrasil_decision_forests/utils/concurrency.h" + +#include +#include #include -#include "yggdrasil_decision_forests/utils/concurrency.h" +#include "yggdrasil_decision_forests/utils/logging.h" // IWYU pragma: keep -namespace yggdrasil_decision_forests { -namespace utils { -namespace concurrency { +namespace yggdrasil_decision_forests::utils::concurrency { -// Utility to apply a function over a range of elements using multi-threading. -// -// Given "num_items" elements divided into "num_blocks" contiguous blocks of -// the same size (except possibly for the last one). This method calls -// "function" on each block in parallel using the thread-pool. -// -// The method is blocking until all the "function" call have returned. -// -// For example, support num_items=10 and num_blocks=3 defines the following -// blocks: [0,4), [4,8), [8,10). Then, "function" will be called in parallel on: -// function(block_idx=0, begin_item_idx=0, end_item_idx=0) -// function(block_idx=1, begin_item_idx=4, end_item_idx=8) -// function(block_idx=2, begin_item_idx=8, end_item_idx=10) -// void ConcurrentForLoop( const size_t num_blocks, ThreadPool* thread_pool, const size_t num_items, const std::function + +#include + #include "yggdrasil_decision_forests/utils/synchronization_primitives.h" #include "yggdrasil_decision_forests/utils/concurrency_default.h" @@ -52,17 +56,38 @@ #include "yggdrasil_decision_forests/utils/concurrency_channel.h" #include "yggdrasil_decision_forests/utils/concurrency_streamprocessor.h" -namespace yggdrasil_decision_forests { -namespace utils { -namespace concurrency { +namespace yggdrasil_decision_forests::utils::concurrency { +// Applies "function" over a range of elements using multi-threading. +// +// Given "num_items" elements divided into "num_blocks" contiguous blocks of +// the same size (except possibly for the last one). This method calls +// "function" on each block in parallel using the thread-pool. +// +// The method is blocking until all the "function" call have returned. +// +// For example, support num_items=10 and num_blocks=3 defines the following +// blocks: [0,4), [4,8), [8,10). Then, "function" will be called in parallel on: +// function(block_idx=0, begin_item_idx=0, end_item_idx=0) +// function(block_idx=1, begin_item_idx=4, end_item_idx=8) +// function(block_idx=2, begin_item_idx=8, end_item_idx=10) +// +// Args: +// num_blocks: Number of range subsets. +// thread_pool: Already started thread pool. +// num_items: Number of items to process. +// function: Function to call. +// +// function's signature: +// block_idx: Index of the block in [0, num_blocks). +// begin_item_idx: First item to process (inclusive). +// end_item_idx: Last item to process (exclusive). +// void ConcurrentForLoop( - const size_t num_blocks, ThreadPool* thread_pool, const size_t num_items, + size_t num_blocks, ThreadPool* thread_pool, size_t num_items, const std::function& function); -} // namespace concurrency -} // namespace utils -} // namespace yggdrasil_decision_forests +} // namespace yggdrasil_decision_forests::utils::concurrency #endif // YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_H_ diff --git a/yggdrasil_decision_forests/utils/concurrency_test.cc b/yggdrasil_decision_forests/utils/concurrency_test.cc index 022462b0..20a0dae7 100644 --- a/yggdrasil_decision_forests/utils/concurrency_test.cc +++ b/yggdrasil_decision_forests/utils/concurrency_test.cc @@ -13,14 +13,19 @@ * limitations under the License. */ +#include + +#include +#include +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/types/optional.h" -#include "yggdrasil_decision_forests/utils/concurrency.h" +#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep -namespace yggdrasil_decision_forests { -namespace utils { -namespace concurrency { +namespace yggdrasil_decision_forests::utils::concurrency { namespace { TEST(ThreadPool, Empty) { @@ -28,8 +33,8 @@ TEST(ThreadPool, Empty) { } TEST(ThreadPool, Simple) { - std::atomic counter = {0}; - int n = 100; + std::atomic counter{0}; + const int n = 100; { ThreadPool pool("MyPool", 1); pool.StartWorkers(); @@ -58,8 +63,9 @@ TEST(StreamProcessor, Simple) { // Continuously consume a result, and restart a new job. for (int i = 0; i < num_jobs; i++) { - const auto result = processor.GetResult().value(); - sum += result; + const absl::optional result_or = processor.GetResult(); + ASSERT_TRUE(result_or.has_value()); + sum += *result_or; if (i < num_jobs - num_initially_planned_jobs) { processor.Submit(i + num_initially_planned_jobs); } @@ -77,9 +83,9 @@ TEST(StreamProcessor, NonCopiableData) { [](Question x) { return x; }); processor.StartWorkers(); - processor.Submit(absl::make_unique(10)); - auto result = processor.GetResult().value(); - CHECK_EQ(*result, 10); + processor.Submit(std::make_unique(10)); + const absl::optional> result_or = processor.GetResult(); + EXPECT_THAT(result_or, testing::Optional(testing::Pointee(10))); } TEST(StreamProcessor, InOrder) { @@ -104,10 +110,8 @@ TEST(StreamProcessor, InOrder) { // Continuously consume a result, and restart a new job. for (int i = 0; i < num_jobs; i++) { - const auto result = processor.GetResult().value(); - const auto expected_result = next_expected_result++; - EXPECT_EQ(result, expected_result); - + const absl::optional result = processor.GetResult(); + EXPECT_THAT(result, testing::Optional(next_expected_result++)); processor.Submit(next_query++); } } @@ -122,10 +126,13 @@ TEST(StreamProcessor, EarlyClose) { processor.Submit(3); processor.CloseSubmits(); - CHECK_EQ(processor.GetResult().value(), 1); - CHECK_EQ(processor.GetResult().value(), 2); - CHECK_EQ(processor.GetResult().value(), 3); - CHECK(!processor.GetResult().has_value()); + absl::optional result = processor.GetResult(); + EXPECT_THAT(result, testing::Optional(1)); + result = processor.GetResult(); + EXPECT_THAT(result, testing::Optional(2)); + result = processor.GetResult(); + EXPECT_THAT(result, testing::Optional(3)); + EXPECT_FALSE(processor.GetResult().has_value()); processor.JoinAllAndStopThreads(); } @@ -150,6 +157,4 @@ TEST(Utils, ConcurrentForLoop) { } } // namespace -} // namespace concurrency -} // namespace utils -} // namespace yggdrasil_decision_forests +} // namespace yggdrasil_decision_forests::utils::concurrency