Skip to content

Commit

Permalink
Multithreading implementation of gradient computation for MAE loss.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574773796
  • Loading branch information
achoum authored and copybara-github committed Oct 19, 2023
1 parent 46c961c commit 9149829
Show file tree
Hide file tree
Showing 9 changed files with 154 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.h"

#include <stddef.h>

#include <algorithm>
#include <string>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/types.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/learner/abstract_learner.pb.h"
Expand All @@ -29,14 +32,38 @@
#include "yggdrasil_decision_forests/metric/metric.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/logging.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/math.h"
#include "yggdrasil_decision_forests/utils/random.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"

namespace yggdrasil_decision_forests {
namespace model {
namespace gradient_boosted_trees {
namespace {
void UpdateGradientsSingleThread(const absl::Span<const float> labels,
const absl::Span<const float> predictions,
absl::Span<float> gradient_data,
absl::Span<float> hessian_data) {
DCHECK_EQ(labels.size(), predictions.size());
DCHECK_EQ(labels.size(), gradient_data.size());
DCHECK_EQ(labels.size(), hessian_data.size());

// We use "table" to avoid a branch in the following loop.
// This optimization was found to improve the code speed. This should be
// revisited as new compilers are likely to do this optimization
// automatically one day.
static float table[] = {-1.f, 1.f};

for (size_t example_idx = 0; example_idx < labels.size(); ++example_idx) {
const float label = labels[example_idx];
const float prediction = predictions[example_idx];
gradient_data[example_idx] = table[label >= prediction];
hessian_data[example_idx] = 1.f;
}
}

} // namespace

absl::Status MeanAverageErrorLoss::Status() const {
if (task_ != model::proto::Task::REGRESSION) {
Expand Down Expand Up @@ -106,26 +133,33 @@ absl::Status MeanAverageErrorLoss::UpdateGradients(
const RankingGroupsIndices* ranking_index, GradientDataRef* gradients,
utils::RandomEngine* random,
utils::concurrency::ThreadPool* thread_pool) const {
// TODO: b/303811729 - Use "thread_pool" if set.

STATUS_CHECK_EQ(gradients->size(), 1);
const UnsignedExampleIdx num_examples = labels.size();
std::vector<float>& gradient_data = *(*gradients)[0].gradient;
std::vector<float>& hessian_data = *(*gradients)[0].hessian;
STATUS_CHECK_EQ(gradient_data.size(), hessian_data.size());
// We use "table" to avoid a branch in the following loop.
// This optimization was found to improve the code speed. This should be
// revisited as new compilers are likely to do this optimization
// automatically one day.
static float table[] = {-1.f, 1.f};

for (UnsignedExampleIdx example_idx = 0; example_idx < num_examples;
example_idx++) {
const float label = labels[example_idx];
const float prediction = predictions[example_idx];
gradient_data[example_idx] = table[label >= prediction];
hessian_data[example_idx] = 1.f;
if (thread_pool == nullptr) {
UpdateGradientsSingleThread(labels, predictions,
absl::Span<float>(gradient_data),
absl::Span<float>(hessian_data));
} else {
utils::concurrency::ConcurrentForLoop(
thread_pool->num_threads(), thread_pool, labels.size(),
[&labels, &predictions, &gradient_data, &hessian_data](
const size_t block_idx, const size_t begin_idx,
const size_t end_idx) -> void {
UpdateGradientsSingleThread(
absl::Span<const float>(labels).subspan(begin_idx,
end_idx - begin_idx),
absl::Span<const float>(predictions)
.subspan(begin_idx, end_idx - begin_idx),
absl::Span<float>(gradient_data)
.subspan(begin_idx, end_idx - begin_idx),
absl::Span<float>(hessian_data)
.subspan(begin_idx, end_idx - begin_idx));
});
}

return absl::OkStatus();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_average_error.h"

#include <cmath>
#include <memory>
#include <tuple>
#include <vector>

Expand All @@ -26,17 +27,16 @@
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/utils/concurrency.h"
#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/random.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/status_macros.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/test.h"
#include "yggdrasil_decision_forests/utils/testing_macros.h"

namespace yggdrasil_decision_forests::model::gradient_boosted_trees {

namespace {

using ::testing::Bool;
using ::testing::Combine;
using ::testing::ElementsAre;
using ::testing::FloatNear;
Expand Down Expand Up @@ -89,6 +89,19 @@ enum class UseMultithreading : bool {

class MeanAverageErrorLossWeightAndThreadingTest
: public testing::TestWithParam<std::tuple<UseWeights, UseMultithreading>> {
protected:
void SetUp() override {
const bool threaded = std::get<1>(GetParam()) == UseMultithreading::kYes;
if (threaded) {
thread_pool_ = std::make_unique<utils::concurrency::ThreadPool>("", 4);
thread_pool_->StartWorkers();
}
}

void TearDown() override { thread_pool_.reset(); }

// The thread pool is only set if "UseMultithreading=kYes".
std::unique_ptr<utils::concurrency::ThreadPool> thread_pool_;
};

class MeanAverageErrorLossWeightTest
Expand Down Expand Up @@ -125,10 +138,11 @@ TEST(MeanAverageErrorLossTestNonWeighted, InitialPredictionsOdd) {
ElementsAre(2.f));
}

TEST_P(MeanAverageErrorLossWeightTest, UpdateGradients) {
TEST_P(MeanAverageErrorLossWeightAndThreadingTest, UpdateGradients) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
const bool weighted = GetParam() == UseWeights::kYes;
const bool weighted = std::get<0>(GetParam()) == UseWeights::kYes;

const std::vector<float> weights = CreateToyWeights(weighted);

dataset::VerticalDataset gradient_dataset;
Expand All @@ -153,7 +167,7 @@ TEST_P(MeanAverageErrorLossWeightTest, UpdateGradients) {
ASSERT_OK(loss_imp.UpdateGradients(gradient_dataset,
/* label_col_idx= */ 0, predictions,
/*ranking_index=*/nullptr, &gradients,
&random));
&random, thread_pool_.get()));

ASSERT_THAT(gradients, Not(IsEmpty()));
if (weighted) {
Expand All @@ -167,27 +181,18 @@ TEST_P(MeanAverageErrorLossWeightAndThreadingTest, ComputeLoss) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
const bool weighted = std::get<0>(GetParam()) == UseWeights::kYes;
const bool threaded = std::get<1>(GetParam()) == UseMultithreading::kYes;
const std::vector<float> weights = CreateToyWeights(weighted);

const std::vector<float> predictions(4, 0.f);
const MeanAverageErrorLoss loss_imp(/*gbt_config=*/{},
model::proto::Task::REGRESSION,
dataset.data_spec().columns(0));
LossResults loss_results;
if (threaded) {
utils::concurrency::ThreadPool thread_pool("", 4);
thread_pool.StartWorkers();
ASSERT_OK_AND_ASSIGN(loss_results,
loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions,
weights, nullptr, &thread_pool));
} else {
ASSERT_OK_AND_ASSIGN(
loss_results,
loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions, weights, nullptr));
}
ASSERT_OK_AND_ASSIGN(
loss_results, loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions, weights,
nullptr, thread_pool_.get()));

if (weighted) {
// MAE = \sum (abs(prediction_i - label_i) * weight_i) / \sum weight_i
const float expected_mae =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ absl::Status AbstractLoss::UpdateGradients(
const dataset::VerticalDataset& dataset, int label_col_idx,
const std::vector<float>& predictions,
const RankingGroupsIndices* ranking_index,
std::vector<GradientData>* gradients, utils::RandomEngine* random) const {
std::vector<GradientData>* gradients, utils::RandomEngine* random,
utils::concurrency::ThreadPool* thread_pool) const {
GradientDataRef compact_gradient(gradients->size());
for (int i = 0; i < gradients->size(); i++) {
compact_gradient[i] = {&(*gradients)[i].gradient, &(*gradients)[i].hessian};
Expand All @@ -53,15 +54,17 @@ absl::Status AbstractLoss::UpdateGradients(
label_col_idx);
if (categorical_labels) {
return UpdateGradients(categorical_labels->values(), predictions,
ranking_index, &compact_gradient, random, nullptr);
ranking_index, &compact_gradient, random,
thread_pool);
}

const auto* numerical_labels =
dataset.ColumnWithCastOrNull<dataset::VerticalDataset::NumericalColumn>(
label_col_idx);
if (numerical_labels) {
return UpdateGradients(numerical_labels->values(), predictions,
ranking_index, &compact_gradient, random, nullptr);
ranking_index, &compact_gradient, random,
thread_pool);
}

return absl::InternalError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ class AbstractLoss {
// This method calls the UpdateGradients defined above depending on the type
// of the label column in the VerticalDataset (currently, only support float
// (Numerical) and int32 (Categorical)).
absl::Status UpdateGradients(const dataset::VerticalDataset& dataset,
int label_col_idx,
const std::vector<float>& predictions,
const RankingGroupsIndices* ranking_index,
std::vector<GradientData>* gradients,
utils::RandomEngine* random) const;
absl::Status UpdateGradients(
const dataset::VerticalDataset& dataset, int label_col_idx,
const std::vector<float>& predictions,
const RankingGroupsIndices* ranking_index,
std::vector<GradientData>* gradients, utils::RandomEngine* random,
utils::concurrency::ThreadPool* thread_pool = nullptr) const;

// Gets the name of the metrics returned in "secondary_metric" of the "Loss"
// method.
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ cc_test(
":concurrency",
# "@com_google_googletest//:gtest_main", # When fixed
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/types:optional",
],
)

Expand Down
32 changes: 7 additions & 25 deletions yggdrasil_decision_forests/utils/concurrency.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,17 @@
* limitations under the License.
*/

#ifndef YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_UTILS_H_
#define YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_UTILS_H_
#include "yggdrasil_decision_forests/utils/concurrency.h"

#include <stddef.h>

#include <algorithm>
#include <functional>

#include "yggdrasil_decision_forests/utils/concurrency.h"
#include "yggdrasil_decision_forests/utils/logging.h" // IWYU pragma: keep

namespace yggdrasil_decision_forests {
namespace utils {
namespace concurrency {
namespace yggdrasil_decision_forests::utils::concurrency {

// Utility to apply a function over a range of elements using multi-threading.
//
// Given "num_items" elements divided into "num_blocks" contiguous blocks of
// the same size (except possibly for the last one). This method calls
// "function" on each block in parallel using the thread-pool.
//
// The method is blocking until all the "function" call have returned.
//
// For example, support num_items=10 and num_blocks=3 defines the following
// blocks: [0,4), [4,8), [8,10). Then, "function" will be called in parallel on:
// function(block_idx=0, begin_item_idx=0, end_item_idx=0)
// function(block_idx=1, begin_item_idx=4, end_item_idx=8)
// function(block_idx=2, begin_item_idx=8, end_item_idx=10)
//
void ConcurrentForLoop(
const size_t num_blocks, ThreadPool* thread_pool, const size_t num_items,
const std::function<void(size_t block_idx, size_t begin_item_idx,
Expand All @@ -62,8 +48,4 @@ void ConcurrentForLoop(
blocker.Wait();
}

} // namespace concurrency
} // namespace utils
} // namespace yggdrasil_decision_forests

#endif // YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_UTILS_H_
} // namespace yggdrasil_decision_forests::utils::concurrency
39 changes: 32 additions & 7 deletions yggdrasil_decision_forests/utils/concurrency.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,49 @@
#ifndef YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_H_
#define YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_H_

#include <stddef.h>

#include <functional>

#include "yggdrasil_decision_forests/utils/synchronization_primitives.h"

#include "yggdrasil_decision_forests/utils/concurrency_default.h"

#include "yggdrasil_decision_forests/utils/concurrency_channel.h"
#include "yggdrasil_decision_forests/utils/concurrency_streamprocessor.h"

namespace yggdrasil_decision_forests {
namespace utils {
namespace concurrency {
namespace yggdrasil_decision_forests::utils::concurrency {

// Applies "function" over a range of elements using multi-threading.
//
// Given "num_items" elements divided into "num_blocks" contiguous blocks of
// the same size (except possibly for the last one). This method calls
// "function" on each block in parallel using the thread-pool.
//
// The method is blocking until all the "function" call have returned.
//
// For example, support num_items=10 and num_blocks=3 defines the following
// blocks: [0,4), [4,8), [8,10). Then, "function" will be called in parallel on:
// function(block_idx=0, begin_item_idx=0, end_item_idx=0)
// function(block_idx=1, begin_item_idx=4, end_item_idx=8)
// function(block_idx=2, begin_item_idx=8, end_item_idx=10)
//
// Args:
// num_blocks: Number of range subsets.
// thread_pool: Already started thread pool.
// num_items: Number of items to process.
// function: Function to call.
//
// function's signature:
// block_idx: Index of the block in [0, num_blocks).
// begin_item_idx: First item to process (inclusive).
// end_item_idx: Last item to process (exclusive).
//
void ConcurrentForLoop(
const size_t num_blocks, ThreadPool* thread_pool, const size_t num_items,
size_t num_blocks, ThreadPool* thread_pool, size_t num_items,
const std::function<void(size_t block_idx, size_t begin_item_idx,
size_t end_item_idx)>& function);

} // namespace concurrency
} // namespace utils
} // namespace yggdrasil_decision_forests
} // namespace yggdrasil_decision_forests::utils::concurrency

#endif // YGGDRASIL_DECISION_FORESTS_UTILS_CONCURRENCY_H_
Loading

0 comments on commit 9149829

Please sign in to comment.