Skip to content

Commit

Permalink
[YDF] Allow configuring the truncation of NDCG losses
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675173198
  • Loading branch information
rstz authored and copybara-github committed Sep 16, 2024
1 parent d5630e6 commit 8b04210
Show file tree
Hide file tree
Showing 18 changed files with 206 additions and 197 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Changelog under `yggdrasil_decision_forests/port/python/CHANGELOG.md`.
### Features

- Speed-up training of GBT models by ~10%
- Rename LAMBDA_MART_NDCG5 to LAMBDA_MART_NDCG. The old name is deprecated
but can still be used.
- Allow configuring the truncation of NDCG losses.

## 1.10.0 - 2024-08-21

Expand Down
195 changes: 16 additions & 179 deletions documentation/public/docs/hyperparameters.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ constexpr char GradientBoostedTreesLearner::kHParamValidationIntervalInTrees[];
constexpr char GradientBoostedTreesLearner::kHParamLoss[];
constexpr char GradientBoostedTreesLearner::kHParamFocalLossGamma[];
constexpr char GradientBoostedTreesLearner::kHParamFocalLossAlpha[];
constexpr char GradientBoostedTreesLearner::kHParamNDCGTruncation[];
constexpr char GradientBoostedTreesLearner::kHParamXENDCGTruncation[];

using dataset::VerticalDataset;
using CategoricalColumn = VerticalDataset::CategoricalColumn;
Expand Down Expand Up @@ -1931,6 +1933,22 @@ absl::Status GradientBoostedTreesLearner::SetHyperParametersImpl(
}
}

{
const auto hparam = generic_hyper_params->Get(kHParamNDCGTruncation);
if (hparam.has_value()) {
gbt_config->mutable_ndcg_loss_options()->set_ndcg_truncation(
hparam.value().value().real());
}
}

{
const auto hparam = generic_hyper_params->Get(kHParamXENDCGTruncation);
if (hparam.has_value()) {
gbt_config->mutable_xe_ndcg()->set_ndcg_truncation(
hparam.value().value().real());
}
}

return absl::OkStatus();
}

Expand Down Expand Up @@ -2401,6 +2419,28 @@ For example, in the case of binary classification, the pre-link function output
R"(EXPERIMENTAL. Weighting parameter for focal loss, positive samples weighted by alpha, negative samples by (1-alpha). The default 0.5 value means no active class-level weighting. Only used with focal loss i.e. `loss="BINARY_FOCAL_LOSS"`)");
}

{
auto& param =
hparam_def.mutable_fields()->operator[](kHParamNDCGTruncation);
param.mutable_integer()->set_minimum(1.f);
param.mutable_integer()->set_default_value(
gbt_config.ndcg_loss_options().ndcg_truncation());
param.mutable_documentation()->set_proto_path(proto_path);
param.mutable_documentation()->set_description(
R"(Truncation of the NDCG loss. Only used with NDCG loss i.e. `loss="LAMBDA_MART_NDCG"`)");
}

{
auto& param =
hparam_def.mutable_fields()->operator[](kHParamXENDCGTruncation);
param.mutable_integer()->set_minimum(1.f);
param.mutable_integer()->set_default_value(
gbt_config.xe_ndcg().ndcg_truncation());
param.mutable_documentation()->set_proto_path(proto_path);
param.mutable_documentation()->set_description(
R"(Truncation of the cross-entropy NDCG loss. Only used with cross-entropy NDCG loss i.e. `loss="XE_NDCG_MART"`)");
}

RETURN_IF_ERROR(decision_tree::GetGenericHyperParameterSpecification(
gbt_config.decision_tree(), &hparam_def));
return hparam_def;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class GradientBoostedTreesLearner : public AbstractLearner {
static constexpr char kHParamLoss[] = "loss";
static constexpr char kHParamFocalLossGamma[] = "focal_loss_gamma";
static constexpr char kHParamFocalLossAlpha[] = "focal_loss_alpha";
static constexpr char kHParamNDCGTruncation[] = "ndcg_truncation";
static constexpr char kHParamXENDCGTruncation[] = "cross_entropy_ndcg_truncation";

absl::StatusOr<std::unique_ptr<AbstractModel>> TrainWithStatusImpl(
const dataset::VerticalDataset& train_dataset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ option java_outer_classname = "GradientBoostedTreesLearner";

// Training configuration for the Gradient Boosted Trees algorithm.
message GradientBoostedTreesTrainingConfig {
// Next ID: 38
// Next ID: 39

// Basic parameters.

Expand Down Expand Up @@ -180,6 +180,7 @@ message GradientBoostedTreesTrainingConfig {
LambdaMartNdcg lambda_mart_ndcg = 12;
XeNdcg xe_ndcg = 26;
BinaryFocalLossOptions binary_focal_loss_options = 36;
NDCGLossOptions ndcg_loss_options = 38;
}

message LambdaMartNdcg {
Expand All @@ -199,6 +200,17 @@ message GradientBoostedTreesTrainingConfig {
ONE = 2;
}
optional Gamma gamma = 1 [default = UNIFORM];

// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the classic NDCG loss must be configured
// separately.
optional int32 ndcg_truncation = 2 [default = 5];
}

message BinaryFocalLossOptions {
Expand All @@ -217,6 +229,19 @@ message GradientBoostedTreesTrainingConfig {
optional float positive_sample_coefficient = 2 [default = 0.5];
}

message NDCGLossOptions {
// Number of candidates considered when computing the NDCG loss.
//
// NDCG losses are usually truncated at a particular rank level (generally
// between 4 and 10), i.e. only the highly ranked documents are considered
// when computing the rank. A smaller values results in a model with
// increased emphasis on the first results of the ranking.
//
// Note that the NDCG truncation of the cross-entropy NDCG loss must be
// configured separately.
optional int32 ndcg_truncation = 1 [default = 5];
}

// L2 regularization on the tree predictions i.e. on the value of the leaf.
// See the equation 2 of the XGBoost paper for the definition
// (https://arxiv.org/pdf/1603.02754.pdf).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ cc_library_ydf(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
Expand Down Expand Up @@ -481,9 +482,12 @@ cc_test(
":loss_interface",
"//yggdrasil_decision_forests/dataset:vertical_dataset",
"//yggdrasil_decision_forests/learner/gradient_boosted_trees",
"//yggdrasil_decision_forests/learner/gradient_boosted_trees:gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:test",
"//yggdrasil_decision_forests/utils:testing_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/learner/abstract_learner.pb.h"
Expand All @@ -46,6 +47,12 @@ absl::Status CrossEntropyNDCGLoss::Status() const {
return absl::InvalidArgumentError(
"Cross Entropy NDCG loss is only compatible with a ranking task.");
}
if (ndcg_truncation_ < 1) {
return absl::InvalidArgumentError(
absl::StrCat("The NDCG truncation must be set to a positive integer, "
"currently found: ",
ndcg_truncation_));
}
return absl::OkStatus();
}

Expand Down Expand Up @@ -183,7 +190,7 @@ absl::StatusOr<LossResults> CrossEntropyNDCGLoss::Loss(
return absl::InternalError("Missing ranking index");
}
float loss_value =
-ranking_index->NDCG(predictions, weights, kNDCG5Truncation);
-ranking_index->NDCG(predictions, weights, ndcg_truncation_);
return LossResults{/*.loss =*/loss_value, /*.secondary_metrics =*/{}};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class CrossEntropyNDCGLoss : public AbstractLoss {
CrossEntropyNDCGLoss(
const proto::GradientBoostedTreesTrainingConfig& gbt_config,
model::proto::Task task, const dataset::proto::Column& label_column)
: AbstractLoss(gbt_config, task, label_column) {}
: AbstractLoss(gbt_config, task, label_column),
ndcg_truncation_(gbt_config.xe_ndcg().ndcg_truncation()) {}

absl::Status Status() const override;

Expand Down Expand Up @@ -83,6 +84,9 @@ class CrossEntropyNDCGLoss : public AbstractLoss {
const absl::Span<const float> weights,
const RankingGroupsIndices* ranking_index,
utils::concurrency::ThreadPool* thread_pool) const override;

private:
const int ndcg_truncation_;
};

REGISTER_AbstractGradientBoostedTreeLoss(CrossEntropyNDCGLoss, "XE_NDCG_MART");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ absl::StatusOr<LossResults> MeanSquaredErrorLoss::Loss(
const absl::Span<const float> weights,
const RankingGroupsIndices* ranking_index,
utils::concurrency::ThreadPool* thread_pool) const {
constexpr int kNDCG5Truncation = 5;
float loss_value;
// The RMSE is also the loss.
ASSIGN_OR_RETURN(loss_value, metric::RMSE(labels, predictions, weights));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/learner/abstract_learner.pb.h"
Expand All @@ -45,6 +46,12 @@ absl::Status NDCGLoss::Status() const {
return absl::InvalidArgumentError(
"NDCG loss is only compatible with a ranking task.");
}
if (ndcg_truncation_ < 1) {
return absl::InvalidArgumentError(
absl::StrCat("The NDCG truncation must be set to a positive integer, "
"currently found: ",
ndcg_truncation_));
}
return absl::OkStatus();
}

Expand All @@ -71,7 +78,7 @@ absl::Status NDCGLoss::UpdateGradients(
std::vector<float>& hessian_data = *(*gradients)[0].hessian;
DCHECK_EQ(gradient_data.size(), hessian_data.size());

metric::NDCGCalculator ndcg_calculator(kNDCG5Truncation);
metric::NDCGCalculator ndcg_calculator(ndcg_truncation_);

const float lambda_loss = gbt_config_.lambda_loss();
const float lambda_loss_squared = lambda_loss * lambda_loss;
Expand All @@ -97,7 +104,7 @@ absl::Status NDCGLoss::UpdateGradients(
// i.e. ground truth.
float utility_norm_factor = 1.;
if (!gbt_config_.lambda_mart_ndcg().gradient_use_non_normalized_dcg()) {
const int max_rank = std::min(kNDCG5Truncation, group_size);
const int max_rank = std::min(ndcg_truncation_, group_size);
float max_ndcg = 0;
for (int rank = 0; rank < max_rank; rank++) {
max_ndcg += ndcg_calculator.Term(group.items[rank].relevance, rank);
Expand Down Expand Up @@ -144,11 +151,11 @@ absl::Status NDCGLoss::UpdateGradients(

// "delta_utility" corresponds to "Z_{i,j}" in the paper.
float delta_utility = 0;
if (item_1_idx < kNDCG5Truncation) {
if (item_1_idx < ndcg_truncation_) {
delta_utility += ndcg_calculator.Term(relevance_2, item_1_idx) -
ndcg_calculator.Term(relevance_1, item_1_idx);
}
if (item_2_idx < kNDCG5Truncation) {
if (item_2_idx < ndcg_truncation_) {
delta_utility += ndcg_calculator.Term(relevance_1, item_2_idx) -
ndcg_calculator.Term(relevance_2, item_2_idx);
}
Expand Down Expand Up @@ -193,7 +200,7 @@ absl::Status NDCGLoss::UpdateGradients(
}

std::vector<std::string> NDCGLoss::SecondaryMetricNames() const {
return {"NDCG@5"};
return {absl::StrCat("NDCG@", ndcg_truncation_)};
}

absl::StatusOr<LossResults> NDCGLoss::Loss(
Expand All @@ -207,7 +214,7 @@ absl::StatusOr<LossResults> NDCGLoss::Loss(
}

const float ndcg =
ranking_index->NDCG(predictions, weights, kNDCG5Truncation);
ranking_index->NDCG(predictions, weights, ndcg_truncation_);
return LossResults{/*.loss =*/-ndcg, /*.secondary_metrics =*/{ndcg}};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class NDCGLoss : public AbstractLoss {

NDCGLoss(const proto::GradientBoostedTreesTrainingConfig& gbt_config,
model::proto::Task task, const dataset::proto::Column& label_column)
: AbstractLoss(gbt_config, task, label_column) {}
: AbstractLoss(gbt_config, task, label_column),
ndcg_truncation_(gbt_config.ndcg_loss_options().ndcg_truncation()) {}

absl::Status Status() const override;

Expand Down Expand Up @@ -82,6 +83,9 @@ class NDCGLoss : public AbstractLoss {
const absl::Span<const float> weights,
const RankingGroupsIndices* ranking_index,
utils::concurrency::ThreadPool* thread_pool) const override;

private:
const int ndcg_truncation_;
};

// LAMBDA_MART_NDCG5 also creates this loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@

#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h"

#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/test.h"
#include "yggdrasil_decision_forests/utils/testing_macros.h"

Expand Down Expand Up @@ -216,6 +221,46 @@ TEST_P(NDCGLossTest, ComputeRankingLossPerfectlyWrongPredictions) {
}
}

TEST_P(NDCGLossTest, ComputeRankingLossPerfectlyWrongPredictionsTruncation1) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {1.f, 3.f, 1.f, 3.f};
}

// Perfectly wrong predictions.
std::vector<float> predictions = {2., 2., 1., 1.};
proto::GradientBoostedTreesTrainingConfig gbt_config;
gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(1);
const NDCGLoss loss_imp(gbt_config, model::proto::Task::RANKING,
dataset.data_spec().columns(0));
RankingGroupsIndices index;
EXPECT_OK(index.Initialize(dataset, 0, 1));
ASSERT_OK_AND_ASSIGN(
LossResults loss_results,
loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions, weights, &index));
if (weighted) {
// R> 0.7238181 = (sum((2^c(1,3)-1)/log2(seq(2)+1)) /
// sum((2^c(3,1)-1)/log2(seq(2)+1)) + 3(sum((2^c(2,4)-1)/log2(seq(2)+1)) /
// sum((2^c(4,2)-1)/log2(seq(2)+1))) )/4
EXPECT_NEAR(loss_results.loss, -(1. / 7. + 9. / 15.) / 4., kTestPrecision);
EXPECT_THAT(
loss_results.secondary_metrics,
ElementsAre(FloatNear((1. / 7. + 9. / 15.) / 4., kTestPrecision)));
} else {
// R> 0.7238181 = (sum((2^c(1,3)-1)/log2(seq(2)+1)) /
// sum((2^c(3,1)-1)/log2(seq(2)+1)) + sum((2^c(2,4)-1)/log2(seq(2)+1)) /
// sum((2^c(4,2)-1)/log2(seq(2)+1)) )/2
EXPECT_NEAR(loss_results.loss, -(1. / 7. + 3. / 15.) / 2., kTestPrecision);
EXPECT_THAT(
loss_results.secondary_metrics,
ElementsAre(FloatNear((1. / 7. + 3. / 15.) / 2., kTestPrecision)));
}
}

TEST_P(NDCGLossTest, ComputeRankingLossPerfectPredictions) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
Expand Down Expand Up @@ -271,6 +316,24 @@ TEST(NDCGLossTest, SecondaryMetricName) {
EXPECT_THAT(loss_imp.SecondaryMetricNames(), ElementsAre("NDCG@5"));
}

TEST(NDCGLossTest, SecondaryMetricNameTrucation10) {
ASSERT_OK_AND_ASSIGN(const auto dataset, CreateToyDataset());
proto::GradientBoostedTreesTrainingConfig gbt_config;
gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(10);
const auto loss_imp = NDCGLoss(gbt_config, model::proto::Task::RANKING,
dataset.data_spec().columns(0));
EXPECT_THAT(loss_imp.SecondaryMetricNames(), ElementsAre("NDCG@10"));
}

TEST(NDCGLossTest, InvalidTruncation) {
ASSERT_OK_AND_ASSIGN(const auto dataset, CreateToyDataset());
proto::GradientBoostedTreesTrainingConfig gbt_config;
gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(0);
const auto loss_imp = NDCGLoss(gbt_config, model::proto::Task::RANKING,
dataset.data_spec().columns(0));
EXPECT_EQ(loss_imp.Status().code(), absl::StatusCode::kInvalidArgument);
}

INSTANTIATE_TEST_SUITE_P(NDCGLossTestSuite, NDCGLossTest, testing::Bool());

} // namespace
Expand Down
Loading

0 comments on commit 8b04210

Please sign in to comment.