Skip to content

Commit

Permalink
[YDF] Rename LAMBDA_MART_NDCG5 to LAMBDA_MART_NDCG
Browse files Browse the repository at this point in the history
We want to make the truncation parameter configurable. Renaming the loss is the first step.

PiperOrigin-RevId: 675162089
  • Loading branch information
rstz authored and copybara-github committed Sep 16, 2024
1 parent b192787 commit d5630e6
Show file tree
Hide file tree
Showing 15 changed files with 60 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2106,11 +2106,12 @@ GradientBoostedTreesLearner::GetGenericHyperParameterSpecification() const {
- `SQUARED_ERROR`: Least square loss. Only valid for regression.
- `POISSON`: Poisson log likelihood loss. Mainly used for counting problems. Only valid for regression.
- `MULTINOMIAL_LOG_LIKELIHOOD`: Multinomial log likelihood i.e. cross-entropy. Only valid for binary or multi-class classification.
- `LAMBDA_MART_NDCG5`: LambdaMART with NDCG5.
- `LAMBDA_MART_NDCG`: LambdaMART with NDCG@5.
- `XE_NDCG_MART`: Cross Entropy Loss NDCG. See arxiv.org/abs/1911.09798.
- `BINARY_FOCAL_LOSS`: Focal loss. Only valid for binary classification. See https://arxiv.org/pdf/1708.02002.pdf.
- `POISSON`: Poisson log likelihood. Only valid for regression.
- `MEAN_AVERAGE_ERROR`: Mean average error a.k.a. MAE.
- `LAMBDA_MART_NDCG5`: DEPRECATED, use LAMBDA_MART_NDCG. LambdaMART with NDCG@5.
)");
}

Expand Down Expand Up @@ -2469,7 +2470,7 @@ absl::StatusOr<proto::Loss> DefaultLoss(

if (task == model::proto::Task::RANKING &&
label_spec.type() == dataset::proto::ColumnType::NUMERICAL) {
return proto::Loss::LAMBDA_MART_NDCG5;
return proto::Loss::LAMBDA_MART_NDCG;
}

return absl::InvalidArgumentError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ message GradientBoostedTreesTrainingConfig {
// of models trained with and without l2 regularization can be compared.
//
// Used for the following losses: BINOMIAL_LOG_LIKELIHOOD, SQUARED_ERROR,
// MULTINOMIAL_LOG_LIKELIHOOD, LAMBDA_MART_NDCG5, or if use_hessian_gain is
// MULTINOMIAL_LOG_LIKELIHOOD, LAMBDA_MART_NDCG, or if use_hessian_gain is
// true.
//
// Note: In the case of RMSE loss for regression, the L2 regularization play
Expand All @@ -237,7 +237,7 @@ message GradientBoostedTreesTrainingConfig {

// L1 regularization on the tree predictions i.e. on the value of the leaf.
//
// Used for the following losses: LAMBDA_MART_NDCG5, or if use_hessian_gain is
// Used for the following losses: LAMBDA_MART_NDCG, or if use_hessian_gain is
// true.
optional float l1_regularization = 19 [default = 0.0];

Expand All @@ -256,7 +256,7 @@ message GradientBoostedTreesTrainingConfig {
// beneficial.
//
// Currently only used for the losses:
// - LAMBDA_MART_NDCG5
// - LAMBDA_MART_NDCG
optional float lambda_loss = 14 [default = 1.0];

// How is the forest of tree built. Defaults to "mart".
Expand Down Expand Up @@ -298,7 +298,7 @@ message GradientBoostedTreesTrainingConfig {
// the splits to minimize the variance of "gradient / hessian".
//
// Hessian gain is available for the losses: BINOMIAL_LOG_LIKELIHOOD,
// SQUARED_ERROR, MULTINOMIAL_LOG_LIKELIHOOD, LAMBDA_MART_NDCG5.
// SQUARED_ERROR, MULTINOMIAL_LOG_LIKELIHOOD, LAMBDA_MART_NDCG.
optional bool use_hessian_gain = 20 [default = false];

// Minimum value of the sum of the hessians in the leafs. Splits that would
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2035,6 +2035,16 @@ TEST(GradientBoostedTrees, PredefinedHyperParametersRanking) {
absl::nullopt);
}

TEST(GradientBoostedTrees, RankingDeprecatedLoss) {
model::proto::TrainingConfig train_config;
train_config.set_learner(GradientBoostedTreesLearner::kRegisteredName);
auto* gbt_config =
train_config.MutableExtension(proto::gradient_boosted_trees_config);
gbt_config->set_loss(proto::LAMBDA_MART_NDCG5);
utils::TestPredefinedHyperParametersRankingDataset(train_config, 2,
absl::nullopt);
}

TEST_F(GradientBoostedTreesOnAdult, InterruptAndResumeTraining) {
// Train a model for a few seconds, interrupt its training, and resume it.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ cc_library_ydf(
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ class NDCGLoss : public AbstractLoss {
utils::concurrency::ThreadPool* thread_pool) const override;
};

REGISTER_AbstractGradientBoostedTreeLoss(NDCGLoss, "LAMBDA_MART_NDCG5");

// LAMBDA_MART_NDCG5 also creates this loss.
REGISTER_AbstractGradientBoostedTreeLoss(NDCGLoss, "LAMBDA_MART_NDCG");

} // namespace gradient_boosted_trees
} // namespace model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ absl::StatusOr<std::unique_ptr<AbstractLoss>> CreateLoss(
}

auto loss_key = proto::Loss_Name(loss);
if (loss == proto::LAMBDA_MART_NDCG5) {
loss_key = "LAMBDA_MART_NDCG";
}
ASSIGN_OR_RETURN(auto loss_imp, AbstractLossRegisterer::Create(
loss_key, config, task, label_column));
RETURN_IF_ERROR(loss_imp->Status());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ TEST(LossLibrary, CanonicalLosses) {
model::proto::Task::RANKING, numerical_label_column,
config));

EXPECT_OK(CreateLoss(proto::Loss::LAMBDA_MART_NDCG,
model::proto::Task::RANKING, numerical_label_column,
config));

EXPECT_OK(CreateLoss(proto::Loss::XE_NDCG_MART, model::proto::Task::RANKING,
numerical_label_column, config));

Expand Down Expand Up @@ -119,6 +123,11 @@ TEST(LossLibrary, CustomLosses) {
CustomRegressionLossFunctions{})
.ok());

EXPECT_FALSE(CreateLoss(proto::Loss::LAMBDA_MART_NDCG,
model::proto::Task::RANKING, numerical_label_column,
config, CustomBinaryClassificationLossFunctions{})
.ok());

EXPECT_FALSE(CreateLoss(proto::Loss::LAMBDA_MART_NDCG5,
model::proto::Task::RANKING, numerical_label_column,
config, CustomBinaryClassificationLossFunctions{})
Expand All @@ -129,6 +138,11 @@ TEST(LossLibrary, CustomLosses) {
config, CustomBinaryClassificationLossFunctions{})
.ok());

EXPECT_FALSE(CreateLoss(proto::Loss::LAMBDA_MART_NDCG,
model::proto::Task::RANKING, numerical_label_column,
config, CustomRegressionLossFunctions{})
.ok());

EXPECT_FALSE(CreateLoss(proto::Loss::LAMBDA_MART_NDCG5,
model::proto::Task::RANKING, numerical_label_column,
config, CustomRegressionLossFunctions{})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -400,6 +401,7 @@ void GradientBoostedTreesModel::Predict(
LOG(FATAL) << "Non supported task";
}
} break;
case proto::Loss::LAMBDA_MART_NDCG:
case proto::Loss::LAMBDA_MART_NDCG5:
case proto::Loss::XE_NDCG_MART: {
double accumulator = initial_predictions_[0];
Expand Down Expand Up @@ -493,6 +495,7 @@ void GradientBoostedTreesModel::Predict(
});
prediction->mutable_regression()->set_value(accumulator);
} break;
case proto::Loss::LAMBDA_MART_NDCG:
case proto::Loss::LAMBDA_MART_NDCG5:
case proto::Loss::XE_NDCG_MART: {
double accumulator = initial_predictions_[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ enum Loss {
SQUARED_ERROR = 2;
// Multinomial log likelihood i.e. cross-entropy.
MULTINOMIAL_LOG_LIKELIHOOD = 3;
// DEPRECATED: Use LAMBDA_MART_NDCG.
// LambdaMART with NDCG5
LAMBDA_MART_NDCG5 = 4;
LAMBDA_MART_NDCG5 = 4 [deprecated = true];
// XE_NDCG_MART [arxiv.org/abs/1911.09798]
XE_NDCG_MART = 5;
// EXPERIMENTAl. Focal loss. Only valid for binary classification.
Expand All @@ -69,6 +70,8 @@ enum Loss {
POISSON = 7;
// Mean average error (MAE).
MEAN_AVERAGE_ERROR = 8;
// LambdaMART with NDCG@5.
LAMBDA_MART_NDCG = 9;
}

// Log of the training. This proto is generated during the training of the
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,9 @@ func NewRankingGBDTGenericEngine(model *gbt.Model, compatibility example.Compati
return nil, fmt.Errorf("Invalid initial predictions")
}
if model.GbtHeader.GetLoss() != gbt_pb.Loss_LAMBDA_MART_NDCG5 &&
model.GbtHeader.GetLoss() != gbt_pb.Loss_LAMBDA_MART_NDCG &&
model.GbtHeader.GetLoss() != gbt_pb.Loss_XE_NDCG_MART {
return nil, fmt.Errorf("Incompatible loss. Expecting squared error")
return nil, fmt.Errorf("Incompatible loss. Expecting ranking loss")
}

engine, err := newOneDimensionEngine(activationIdentity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def activation(self) -> custom_loss.Activation:
return custom_loss.Activation.SOFTMAX
elif loss in [
gradient_boosted_trees_pb2.Loss.SQUARED_ERROR,
gradient_boosted_trees_pb2.Loss.LAMBDA_MART_NDCG,
gradient_boosted_trees_pb2.Loss.LAMBDA_MART_NDCG5,
gradient_boosted_trees_pb2.Loss.XE_NDCG_MART,
gradient_boosted_trees_pb2.Loss.POISSON,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ absl::Status GenericToSpecializedModel(
GradientBoostedTreesBinaryRegressiveModel* dst) {
if (src.loss() != Loss::BINOMIAL_LOG_LIKELIHOOD &&
src.loss() != Loss::SQUARED_ERROR &&
src.loss() != Loss::LAMBDA_MART_NDCG &&
src.loss() != Loss::LAMBDA_MART_NDCG5 &&
src.loss() != Loss::XE_NDCG_MART) {
return absl::InvalidArgumentError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,8 @@ absl::Status GenericToSpecializedModel(
absl::Status GenericToSpecializedModel(
const GradientBoostedTreesModel& src,
GradientBoostedTreesRankingNumericalOnly* dst) {
if (src.loss() != Loss::LAMBDA_MART_NDCG5 ||
if ((src.loss() != Loss::LAMBDA_MART_NDCG5 &&
src.loss() != Loss::LAMBDA_MART_NDCG) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError("The GBT is not trained for ranking.");
}
Expand All @@ -964,7 +965,8 @@ absl::Status GenericToSpecializedModel(
absl::Status GenericToSpecializedModel(
const GradientBoostedTreesModel& src,
GradientBoostedTreesRankingNumericalAndCategorical* dst) {
if (src.loss() != Loss::LAMBDA_MART_NDCG5 ||
if ((src.loss() != Loss::LAMBDA_MART_NDCG5 &&
src.loss() != Loss::LAMBDA_MART_NDCG) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError("The GBT is not trained for ranking.");
}
Expand Down Expand Up @@ -1168,7 +1170,8 @@ absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src,
template <>
absl::Status GenericToSpecializedModel(const GradientBoostedTreesModel& src,
GradientBoostedTreesRanking* dst) {
if (src.loss() != Loss::LAMBDA_MART_NDCG5 ||
if ((src.loss() != Loss::LAMBDA_MART_NDCG5 &&
src.loss() != Loss::LAMBDA_MART_NDCG) ||
src.initial_predictions().size() != 1) {
return absl::InvalidArgumentError(
"The Gradient Boosted Tree is not trained for ranking.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ absl::Status GenericToSpecializedModel(
const model::gradient_boosted_trees::GradientBoostedTreesModel& src,
GradientBoostedTreesRankingQuickScorerExtended* dst) {
if (src.loss() != Loss::LAMBDA_MART_NDCG5 &&
src.loss() != Loss::LAMBDA_MART_NDCG &&
src.loss() != Loss::XE_NDCG_MART) {
return absl::InvalidArgumentError(
"The GBDT is not trained for ranking with ranking loss.");
Expand Down

0 comments on commit d5630e6

Please sign in to comment.