From 8b04210357a8089f769922c6bb67f45b0544251f Mon Sep 17 00:00:00 2001 From: Richard Stotz Date: Mon, 16 Sep 2024 09:12:08 -0700 Subject: [PATCH] [YDF] Allow configuring the truncation of NDCG losses PiperOrigin-RevId: 675173198 --- CHANGELOG.md | 3 + documentation/public/docs/hyperparameters.md | 195 ++---------------- .../gradient_boosted_trees.cc | 40 ++++ .../gradient_boosted_trees.h | 2 + .../gradient_boosted_trees.proto | 27 ++- .../learner/gradient_boosted_trees/loss/BUILD | 4 + .../loss/loss_imp_cross_entropy_ndcg.cc | 9 +- .../loss/loss_imp_cross_entropy_ndcg.h | 6 +- .../loss/loss_imp_mean_square_error.cc | 1 + .../loss/loss_imp_ndcg.cc | 19 +- .../loss/loss_imp_ndcg.h | 6 +- .../loss/loss_imp_ndcg_test.cc | 63 ++++++ .../gradient_boosted_trees/loss/loss_utils.h | 2 - .../gradient_boosted_trees.cc | 2 +- .../gradient_boosted_trees.proto | 2 +- .../port/python/CHANGELOG.md | 3 + .../specialized_learners_pre_generated.py | 14 +- .../port/python/ydf/model/generic_model.py | 5 +- 18 files changed, 206 insertions(+), 197 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2888b422..a192ddd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Changelog under `yggdrasil_decision_forests/port/python/CHANGELOG.md`. ### Features - Speed-up training of GBT models by ~10% +- Rename LAMBDA_MART_NDCG5 to LAMBDA_MART_NDCG. The old name is deprecated + but can still be used. +- Allow configuring the truncation of NDCG losses. ## 1.10.0 - 2024-08-21 diff --git a/documentation/public/docs/hyperparameters.md b/documentation/public/docs/hyperparameters.md index 8c0075ed..368529f5 100644 --- a/documentation/public/docs/hyperparameters.md +++ b/documentation/public/docs/hyperparameters.md @@ -133,6 +133,13 @@ reasonable time. of the training using the validation dataset. Enabling this feature can increase the training time significantly. +#### [cross_entropy_ndcg_truncation](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto) + +- **Type:** Integer **Default:** 5 **Possible values:** min:1 + +- Truncation of the cross-entropy NDCG loss. Only used with cross-entropy NDCG + loss i.e. `loss="XE_NDCG_MART"` + #### [dart_dropout](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto) - **Type:** Real **Default:** 0.01 **Possible values:** min:0 max:1 @@ -285,9 +292,9 @@ reasonable time. - **Type:** Categorical **Default:** DEFAULT **Possible values:** DEFAULT, BINOMIAL_LOG_LIKELIHOOD, SQUARED_ERROR, MULTINOMIAL_LOG_LIKELIHOOD, LAMBDA_MART_NDCG5, XE_NDCG_MART, BINARY_FOCAL_LOSS, POISSON, - MEAN_AVERAGE_ERROR + MEAN_AVERAGE_ERROR, LAMBDA_MART_NDCG -- The loss optimized by the model. If not specified (DEFAULT) the loss is selected automatically according to the \"task\" and label statistics. For example, if task=CLASSIFICATION and the label has two possible values, the loss will be set to BINOMIAL_LOG_LIKELIHOOD. Possible values are:
- `DEFAULT`: Select the loss automatically according to the task and label statistics.
- `BINOMIAL_LOG_LIKELIHOOD`: Binomial log likelihood. Only valid for binary classification.
- `SQUARED_ERROR`: Least square loss. Only valid for regression.
- `POISSON`: Poisson log likelihood loss. Mainly used for counting problems. Only valid for regression.
- `MULTINOMIAL_LOG_LIKELIHOOD`: Multinomial log likelihood i.e. cross-entropy. Only valid for binary or multi-class classification.
- `LAMBDA_MART_NDCG5`: LambdaMART with NDCG5.
- `XE_NDCG_MART`: Cross Entropy Loss NDCG. See arxiv.org/abs/1911.09798.
- `BINARY_FOCAL_LOSS`: Focal loss. Only valid for binary classification. See https://arxiv.org/pdf/1708.02002.pdf.
- `POISSON`: Poisson log likelihood. Only valid for regression.
- `MEAN_AVERAGE_ERROR`: Mean average error a.k.a. MAE.
+- The loss optimized by the model. If not specified (DEFAULT) the loss is selected automatically according to the \"task\" and label statistics. For example, if task=CLASSIFICATION and the label has two possible values, the loss will be set to BINOMIAL_LOG_LIKELIHOOD. Possible values are:
- `DEFAULT`: Select the loss automatically according to the task and label statistics.
- `BINOMIAL_LOG_LIKELIHOOD`: Binomial log likelihood. Only valid for binary classification.
- `SQUARED_ERROR`: Least square loss. Only valid for regression.
- `POISSON`: Poisson log likelihood loss. Mainly used for counting problems. Only valid for regression.
- `MULTINOMIAL_LOG_LIKELIHOOD`: Multinomial log likelihood i.e. cross-entropy. Only valid for binary or multi-class classification.
- `LAMBDA_MART_NDCG`: LambdaMART with NDCG@5.
- `XE_NDCG_MART`: Cross Entropy Loss NDCG. See arxiv.org/abs/1911.09798.
- `BINARY_FOCAL_LOSS`: Focal loss. Only valid for binary classification. See https://arxiv.org/pdf/1708.02002.pdf.
- `POISSON`: Poisson log likelihood. Only valid for regression.
- `MEAN_AVERAGE_ERROR`: Mean average error a.k.a. MAE.
- `LAMBDA_MART_NDCG5`: DEPRECATED, use LAMBDA_MART_NDCG. LambdaMART with NDCG@5.
#### [max_depth](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) @@ -354,6 +361,13 @@ reasonable time. - Method used to handle missing attribute values.
- `GLOBAL_IMPUTATION`: Missing attribute values are imputed, with the mean (in case of numerical attribute) or the most-frequent-item (in case of categorical attribute) computed on the entire dataset (i.e. the information contained in the data spec).
- `LOCAL_IMPUTATION`: Missing attribute values are imputed with the mean (numerical attribute) or most-frequent-item (in the case of categorical attribute) evaluated on the training examples in the current node.
- `RANDOM_LOCAL_IMPUTATION`: Missing attribute values are imputed from randomly sampled values from the training examples in the current node. This method was proposed by Clinic et al. in "Random Survival Forests" (https://projecteuclid.org/download/pdfview_1/euclid.aoas/1223908043). +#### [ndcg_truncation](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto) + +- **Type:** Integer **Default:** 5 **Possible values:** min:1 + +- Truncation of the NDCG loss. Only used with NDCG loss i.e. + `loss="LAMBDA_MART_NDCG"` + #### [num_candidate_attributes](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - **Type:** Integer **Default:** -1 **Possible values:** min:-1 @@ -1357,99 +1371,6 @@ The hyper-parameter protobuffers are used with the C++ and CLI APIs. ### Hyper-parameters -#### [allow_na_conditions](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** false **Possible values:** true, false - -- If true, the tree training evaluates conditions of the type `X is NA` i.e. - `X is missing`. - -#### [categorical_algorithm](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** CART **Possible values:** CART, ONE_HOT, - RANDOM - -- How to learn splits on categorical attributes.
- `CART`: CART algorithm. Find categorical splits of the form "value \in mask". The solution is exact for binary classification, regression and ranking. It is approximated for multi-class classification. This is a good first algorithm to use. In case of overfitting (very small dataset, large dictionary), the "random" algorithm is a good alternative.
- `ONE_HOT`: One-hot encoding. Find the optimal categorical split of the form "attribute == param". This method is similar (but more efficient) than converting converting each possible categorical value into a boolean feature. This method is available for comparison purpose and generally performs worse than other alternatives.
- `RANDOM`: Best splits among a set of random candidate. Find the a categorical split of the form "value \in mask" using a random search. This solution can be seen as an approximation of the CART algorithm. This method is a strong alternative to CART. This algorithm is inspired from section "5.1 Categorical Variables" of "Random Forest", 2001. - -#### [categorical_set_split_greedy_sampling](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Real **Default:** 0.1 **Possible values:** min:0 max:1 - -- For categorical set splits e.g. texts. Probability for a categorical value - to be a candidate for the positive set. The sampling is applied once per - node (i.e. not at every step of the greedy optimization). - -#### [categorical_set_split_max_num_items](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** -1 **Possible values:** min:-1 - -- For categorical set splits e.g. texts. Maximum number of items (prior to the - sampling). If more items are available, the least frequent items are - ignored. Changing this value is similar to change the "max_vocab_count" - before loading the dataset, with the following exception: With - `max_vocab_count`, all the remaining items are grouped in a special - Out-of-vocabulary item. With `max_num_items`, this is not the case. - -#### [categorical_set_split_min_item_frequency](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** 1 **Possible values:** min:1 - -- For categorical set splits e.g. texts. Minimum number of occurrences of an - item to be considered. - -#### [growing_strategy](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** LOCAL **Possible values:** LOCAL, - BEST_FIRST_GLOBAL - -- How to grow the tree.
- `LOCAL`: Each node is split independently of the other nodes. In other words, as long as a node satisfy the splits "constraints (e.g. maximum depth, minimum number of observations), the node will be split. This is the "classical" way to grow decision trees.
- `BEST_FIRST_GLOBAL`: The node with the best loss reduction among all the nodes of the tree is selected for splitting. This method is also called "best first" or "leaf-wise growth". See "Best-first decision tree learning", Shi and "Additive logistic regression : A statistical view of boosting", Friedman for more details. - -#### [honest](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** false **Possible values:** true, false - -- In honest trees, different training examples are used to infer the structure - and the leaf values. This regularization technique trades examples for bias - estimates. It might increase or reduce the quality of the model. See - "Generalized Random Forests", Athey et al. In this paper, Honest trees are - trained with the Random Forest algorithm with a sampling without - replacement. - -#### [honest_fixed_separation](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** false **Possible values:** true, false - -- For honest trees only i.e. honest=true. If true, a new random separation is - generated for each tree. If false, the same separation is used for all the - trees (e.g., in Gradient Boosted Trees containing multiple trees). - -#### [honest_ratio_leaf_examples](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Real **Default:** 0.5 **Possible values:** min:0 max:1 - -- For honest trees only i.e. honest=true. Ratio of examples used to set the - leaf values. - -#### [in_split_min_examples_check](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** true **Possible values:** true, false - -- Whether to check the `min_examples` constraint in the split search (i.e. - splits leading to one child having less than `min_examples` examples are - considered invalid) or before the split search (i.e. a node can be derived - only if it contains more than `min_examples` examples). If false, there can - be nodes with less than `min_examples` training examples. - -#### [keep_non_leaf_label_distribution](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** true **Possible values:** true, false - -- Whether to keep the node value (i.e. the distribution of the labels of the - training examples) of non-leaf nodes. This information is not used during - serving, however it can be used for model interpretation as well as hyper - parameter tuning. This can take lots of space, sometimes accounting for half - of the model size. - #### [max_depth](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - **Type:** Integer **Default:** -2 **Possible values:** min:-2 @@ -1459,75 +1380,12 @@ The hyper-parameter protobuffers are used with the C++ and CLI APIs. `max_depth=-2` means that the maximum depth is log2(number of sampled examples per tree) (default). -#### [max_num_nodes](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** 31 **Possible values:** min:-1 - -- Maximum number of nodes in the tree. Set to -1 to disable this limit. Only - available for `growing_strategy=BEST_FIRST_GLOBAL`. - -#### [maximum_model_size_in_memory_in_bytes](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/abstract_learner.proto) - -- **Type:** Real **Default:** -1 - -- Limit the size of the model when stored in ram. Different algorithms can - enforce this limit differently. Note that when models are compiled into an - inference, the size of the inference engine is generally much smaller than - the original model. - -#### [maximum_training_duration_seconds](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/abstract_learner.proto) - -- **Type:** Real **Default:** -1 - -- Maximum training duration of the model expressed in seconds. Each learning - algorithm is free to use this parameter at it sees fit. Enabling maximum - training duration makes the model training non-deterministic. - -#### [mhld_oblique_max_num_attributes](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** 4 **Possible values:** min:1 - -- For MHLD oblique splits i.e. `split_axis=MHLD_OBLIQUE`. Maximum number of - attributes in the projection. Increasing this value increases the training - time. Decreasing this value acts as a regularization. The value should be in - [2, num_numerical_features]. If the value is above the total number of - numerical features, the value is capped automatically. The value 1 is - allowed but results in ordinary (non-oblique) splits. - #### [min_examples](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - **Type:** Integer **Default:** 5 **Possible values:** min:1 - Minimum number of examples in a node. -#### [missing_value_policy](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** GLOBAL_IMPUTATION **Possible values:** - GLOBAL_IMPUTATION, LOCAL_IMPUTATION, RANDOM_LOCAL_IMPUTATION - -- Method used to handle missing attribute values.
- `GLOBAL_IMPUTATION`: Missing attribute values are imputed, with the mean (in case of numerical attribute) or the most-frequent-item (in case of categorical attribute) computed on the entire dataset (i.e. the information contained in the data spec).
- `LOCAL_IMPUTATION`: Missing attribute values are imputed with the mean (numerical attribute) or most-frequent-item (in the case of categorical attribute) evaluated on the training examples in the current node.
- `RANDOM_LOCAL_IMPUTATION`: Missing attribute values are imputed from randomly sampled values from the training examples in the current node. This method was proposed by Clinic et al. in "Random Survival Forests" (https://projecteuclid.org/download/pdfview_1/euclid.aoas/1223908043). - -#### [num_candidate_attributes](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** 0 **Possible values:** min:-1 - -- Number of unique valid attributes tested for each node. An attribute is - valid if it has at least a valid split. If `num_candidate_attributes=0`, the - value is set to the classical default value for Random Forest: `sqrt(number - of input attributes)` in case of classification and - `number_of_input_attributes / 3` in case of regression. If - `num_candidate_attributes=-1`, all the attributes are tested. - -#### [num_candidate_attributes_ratio](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Real **Default:** -1 **Possible values:** min:-1 max:1 - -- Ratio of attributes tested at each node. If set, it is equivalent to - `num_candidate_attributes = number_of_input_features x - num_candidate_attributes_ratio`. The possible values are between ]0, and 1] - as well as -1. If not set or equal to -1, the `num_candidate_attributes` is - used. - #### [num_trees](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.proto) - **Type:** Integer **Default:** 300 **Possible values:** min:0 @@ -1553,13 +1411,6 @@ The hyper-parameter protobuffers are used with the C++ and CLI APIs. - Random seed for the training of the model. Learners are expected to be deterministic by the random seed. -#### [sorting_strategy](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** AUTO **Possible values:** IN_NODE, - PRESORT, FORCE_PRESORT, AUTO - -- How are sorted the numerical features in order to find the splits
- AUTO: Selects the most efficient method among IN_NODE, FORCE_PRESORT, and LAYER.
- IN_NODE: The features are sorted just before being used in the node. This solution is slow but consumes little amount of memory.
- FORCE_PRESORT: The features are pre-sorted at the start of the training. This solution is faster but consumes much more memory than IN_NODE.
- PRESORT: Automatically choose between FORCE_PRESORT and IN_NODE.
. - #### [sparse_oblique_normalization](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - **Type:** Categorical **Default:** NONE **Possible values:** NONE, @@ -1606,20 +1457,6 @@ The hyper-parameter protobuffers are used with the C++ and CLI APIs. maximum depth to log2(examples used per tree) unless max_depth is set explicitly. -#### [uplift_min_examples_in_treatment](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Integer **Default:** 5 **Possible values:** min:0 - -- For uplift models only. Minimum number of examples per treatment in a node. - -#### [uplift_split_score](https://github.com/google/yggdrasil-decision-forests/blob/main/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto) - -- **Type:** Categorical **Default:** KULLBACK_LEIBLER **Possible values:** - KULLBACK_LEIBLER, KL, EUCLIDEAN_DISTANCE, ED, CHI_SQUARED, CS, - CONSERVATIVE_EUCLIDEAN_DISTANCE, CED - -- For uplift models only. Splitter score i.e. score optimized by the splitters. The scores are introduced in "Decision trees for uplift modeling with single and multiple treatments", Rzepakowski et al. Notation: `p` probability / average value of the positive outcome, `q` probability / average value in the control group.
- `KULLBACK_LEIBLER` or `KL`: - p log (p/q)
- `EUCLIDEAN_DISTANCE` or `ED`: (p-q)^2
- `CHI_SQUARED` or `CS`: (p-q)^2/q
- ## HYPERPARAMETER_OPTIMIZER ### Protobuffer training configuration diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc index 5e9ee602..bf8563c2 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.cc @@ -133,6 +133,8 @@ constexpr char GradientBoostedTreesLearner::kHParamValidationIntervalInTrees[]; constexpr char GradientBoostedTreesLearner::kHParamLoss[]; constexpr char GradientBoostedTreesLearner::kHParamFocalLossGamma[]; constexpr char GradientBoostedTreesLearner::kHParamFocalLossAlpha[]; +constexpr char GradientBoostedTreesLearner::kHParamNDCGTruncation[]; +constexpr char GradientBoostedTreesLearner::kHParamXENDCGTruncation[]; using dataset::VerticalDataset; using CategoricalColumn = VerticalDataset::CategoricalColumn; @@ -1931,6 +1933,22 @@ absl::Status GradientBoostedTreesLearner::SetHyperParametersImpl( } } + { + const auto hparam = generic_hyper_params->Get(kHParamNDCGTruncation); + if (hparam.has_value()) { + gbt_config->mutable_ndcg_loss_options()->set_ndcg_truncation( + hparam.value().value().real()); + } + } + + { + const auto hparam = generic_hyper_params->Get(kHParamXENDCGTruncation); + if (hparam.has_value()) { + gbt_config->mutable_xe_ndcg()->set_ndcg_truncation( + hparam.value().value().real()); + } + } + return absl::OkStatus(); } @@ -2401,6 +2419,28 @@ For example, in the case of binary classification, the pre-link function output R"(EXPERIMENTAL. Weighting parameter for focal loss, positive samples weighted by alpha, negative samples by (1-alpha). The default 0.5 value means no active class-level weighting. Only used with focal loss i.e. `loss="BINARY_FOCAL_LOSS"`)"); } + { + auto& param = + hparam_def.mutable_fields()->operator[](kHParamNDCGTruncation); + param.mutable_integer()->set_minimum(1.f); + param.mutable_integer()->set_default_value( + gbt_config.ndcg_loss_options().ndcg_truncation()); + param.mutable_documentation()->set_proto_path(proto_path); + param.mutable_documentation()->set_description( + R"(Truncation of the NDCG loss. Only used with NDCG loss i.e. `loss="LAMBDA_MART_NDCG"`)"); + } + + { + auto& param = + hparam_def.mutable_fields()->operator[](kHParamXENDCGTruncation); + param.mutable_integer()->set_minimum(1.f); + param.mutable_integer()->set_default_value( + gbt_config.xe_ndcg().ndcg_truncation()); + param.mutable_documentation()->set_proto_path(proto_path); + param.mutable_documentation()->set_description( + R"(Truncation of the cross-entropy NDCG loss. Only used with cross-entropy NDCG loss i.e. `loss="XE_NDCG_MART"`)"); + } + RETURN_IF_ERROR(decision_tree::GetGenericHyperParameterSpecification( gbt_config.decision_tree(), &hparam_def)); return hparam_def; diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h index de568de6..ea05994d 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h @@ -132,6 +132,8 @@ class GradientBoostedTreesLearner : public AbstractLearner { static constexpr char kHParamLoss[] = "loss"; static constexpr char kHParamFocalLossGamma[] = "focal_loss_gamma"; static constexpr char kHParamFocalLossAlpha[] = "focal_loss_alpha"; + static constexpr char kHParamNDCGTruncation[] = "ndcg_truncation"; + static constexpr char kHParamXENDCGTruncation[] = "cross_entropy_ndcg_truncation"; absl::StatusOr> TrainWithStatusImpl( const dataset::VerticalDataset& train_dataset, diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto index 1bcb28ee..94f413b8 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.proto @@ -25,7 +25,7 @@ option java_outer_classname = "GradientBoostedTreesLearner"; // Training configuration for the Gradient Boosted Trees algorithm. message GradientBoostedTreesTrainingConfig { - // Next ID: 38 + // Next ID: 39 // Basic parameters. @@ -180,6 +180,7 @@ message GradientBoostedTreesTrainingConfig { LambdaMartNdcg lambda_mart_ndcg = 12; XeNdcg xe_ndcg = 26; BinaryFocalLossOptions binary_focal_loss_options = 36; + NDCGLossOptions ndcg_loss_options = 38; } message LambdaMartNdcg { @@ -199,6 +200,17 @@ message GradientBoostedTreesTrainingConfig { ONE = 2; } optional Gamma gamma = 1 [default = UNIFORM]; + + // Number of candidates considered when computing the NDCG loss. + // + // NDCG losses are usually truncated at a particular rank level (generally + // between 4 and 10), i.e. only the highly ranked documents are considered + // when computing the rank. A smaller values results in a model with + // increased emphasis on the first results of the ranking. + // + // Note that the NDCG truncation of the classic NDCG loss must be configured + // separately. + optional int32 ndcg_truncation = 2 [default = 5]; } message BinaryFocalLossOptions { @@ -217,6 +229,19 @@ message GradientBoostedTreesTrainingConfig { optional float positive_sample_coefficient = 2 [default = 0.5]; } + message NDCGLossOptions { + // Number of candidates considered when computing the NDCG loss. + // + // NDCG losses are usually truncated at a particular rank level (generally + // between 4 and 10), i.e. only the highly ranked documents are considered + // when computing the rank. A smaller values results in a model with + // increased emphasis on the first results of the ranking. + // + // Note that the NDCG truncation of the cross-entropy NDCG loss must be + // configured separately. + optional int32 ndcg_truncation = 1 [default = 5]; + } + // L2 regularization on the tree predictions i.e. on the value of the leaf. // See the equation 2 of the XGBoost paper for the definition // (https://arxiv.org/pdf/1603.02754.pdf). diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD index 92f6534c..79df2a3c 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/BUILD @@ -158,6 +158,7 @@ cc_library_ydf( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], alwayslink = 1, @@ -481,9 +482,12 @@ cc_test( ":loss_interface", "//yggdrasil_decision_forests/dataset:vertical_dataset", "//yggdrasil_decision_forests/learner/gradient_boosted_trees", + "//yggdrasil_decision_forests/learner/gradient_boosted_trees:gradient_boosted_trees_cc_proto", "//yggdrasil_decision_forests/model:abstract_model_cc_proto", + "//yggdrasil_decision_forests/utils:status_macros", "//yggdrasil_decision_forests/utils:test", "//yggdrasil_decision_forests/utils:testing_macros", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", ], diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.cc index 2f97e1e6..3d0ff12d 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.cc @@ -25,6 +25,7 @@ #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/learner/abstract_learner.pb.h" @@ -46,6 +47,12 @@ absl::Status CrossEntropyNDCGLoss::Status() const { return absl::InvalidArgumentError( "Cross Entropy NDCG loss is only compatible with a ranking task."); } + if (ndcg_truncation_ < 1) { + return absl::InvalidArgumentError( + absl::StrCat("The NDCG truncation must be set to a positive integer, " + "currently found: ", + ndcg_truncation_)); + } return absl::OkStatus(); } @@ -183,7 +190,7 @@ absl::StatusOr CrossEntropyNDCGLoss::Loss( return absl::InternalError("Missing ranking index"); } float loss_value = - -ranking_index->NDCG(predictions, weights, kNDCG5Truncation); + -ranking_index->NDCG(predictions, weights, ndcg_truncation_); return LossResults{/*.loss =*/loss_value, /*.secondary_metrics =*/{}}; } diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h index 14c6aea5..f85b88ad 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h @@ -50,7 +50,8 @@ class CrossEntropyNDCGLoss : public AbstractLoss { CrossEntropyNDCGLoss( const proto::GradientBoostedTreesTrainingConfig& gbt_config, model::proto::Task task, const dataset::proto::Column& label_column) - : AbstractLoss(gbt_config, task, label_column) {} + : AbstractLoss(gbt_config, task, label_column), + ndcg_truncation_(gbt_config.xe_ndcg().ndcg_truncation()) {} absl::Status Status() const override; @@ -83,6 +84,9 @@ class CrossEntropyNDCGLoss : public AbstractLoss { const absl::Span weights, const RankingGroupsIndices* ranking_index, utils::concurrency::ThreadPool* thread_pool) const override; + + private: + const int ndcg_truncation_; }; REGISTER_AbstractGradientBoostedTreeLoss(CrossEntropyNDCGLoss, "XE_NDCG_MART"); diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_square_error.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_square_error.cc index ccec39e5..bf7fc56c 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_square_error.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_mean_square_error.cc @@ -134,6 +134,7 @@ absl::StatusOr MeanSquaredErrorLoss::Loss( const absl::Span weights, const RankingGroupsIndices* ranking_index, utils::concurrency::ThreadPool* thread_pool) const { + constexpr int kNDCG5Truncation = 5; float loss_value; // The RMSE is also the loss. ASSIGN_OR_RETURN(loss_value, metric::RMSE(labels, predictions, weights)); diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.cc index b3d16191..5562398d 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.cc @@ -24,6 +24,7 @@ #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/learner/abstract_learner.pb.h" @@ -45,6 +46,12 @@ absl::Status NDCGLoss::Status() const { return absl::InvalidArgumentError( "NDCG loss is only compatible with a ranking task."); } + if (ndcg_truncation_ < 1) { + return absl::InvalidArgumentError( + absl::StrCat("The NDCG truncation must be set to a positive integer, " + "currently found: ", + ndcg_truncation_)); + } return absl::OkStatus(); } @@ -71,7 +78,7 @@ absl::Status NDCGLoss::UpdateGradients( std::vector& hessian_data = *(*gradients)[0].hessian; DCHECK_EQ(gradient_data.size(), hessian_data.size()); - metric::NDCGCalculator ndcg_calculator(kNDCG5Truncation); + metric::NDCGCalculator ndcg_calculator(ndcg_truncation_); const float lambda_loss = gbt_config_.lambda_loss(); const float lambda_loss_squared = lambda_loss * lambda_loss; @@ -97,7 +104,7 @@ absl::Status NDCGLoss::UpdateGradients( // i.e. ground truth. float utility_norm_factor = 1.; if (!gbt_config_.lambda_mart_ndcg().gradient_use_non_normalized_dcg()) { - const int max_rank = std::min(kNDCG5Truncation, group_size); + const int max_rank = std::min(ndcg_truncation_, group_size); float max_ndcg = 0; for (int rank = 0; rank < max_rank; rank++) { max_ndcg += ndcg_calculator.Term(group.items[rank].relevance, rank); @@ -144,11 +151,11 @@ absl::Status NDCGLoss::UpdateGradients( // "delta_utility" corresponds to "Z_{i,j}" in the paper. float delta_utility = 0; - if (item_1_idx < kNDCG5Truncation) { + if (item_1_idx < ndcg_truncation_) { delta_utility += ndcg_calculator.Term(relevance_2, item_1_idx) - ndcg_calculator.Term(relevance_1, item_1_idx); } - if (item_2_idx < kNDCG5Truncation) { + if (item_2_idx < ndcg_truncation_) { delta_utility += ndcg_calculator.Term(relevance_1, item_2_idx) - ndcg_calculator.Term(relevance_2, item_2_idx); } @@ -193,7 +200,7 @@ absl::Status NDCGLoss::UpdateGradients( } std::vector NDCGLoss::SecondaryMetricNames() const { - return {"NDCG@5"}; + return {absl::StrCat("NDCG@", ndcg_truncation_)}; } absl::StatusOr NDCGLoss::Loss( @@ -207,7 +214,7 @@ absl::StatusOr NDCGLoss::Loss( } const float ndcg = - ranking_index->NDCG(predictions, weights, kNDCG5Truncation); + ranking_index->NDCG(predictions, weights, ndcg_truncation_); return LossResults{/*.loss =*/-ndcg, /*.secondary_metrics =*/{ndcg}}; } diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h index 837e1811..7767454a 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h @@ -49,7 +49,8 @@ class NDCGLoss : public AbstractLoss { NDCGLoss(const proto::GradientBoostedTreesTrainingConfig& gbt_config, model::proto::Task task, const dataset::proto::Column& label_column) - : AbstractLoss(gbt_config, task, label_column) {} + : AbstractLoss(gbt_config, task, label_column), + ndcg_truncation_(gbt_config.ndcg_loss_options().ndcg_truncation()) {} absl::Status Status() const override; @@ -82,6 +83,9 @@ class NDCGLoss : public AbstractLoss { const absl::Span weights, const RankingGroupsIndices* ranking_index, utils::concurrency::ThreadPool* thread_pool) const override; + + private: + const int ndcg_truncation_; }; // LAMBDA_MART_NDCG5 also creates this loss. diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg_test.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg_test.cc index a39dca03..60cd7273 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg_test.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg_test.cc @@ -15,14 +15,19 @@ #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_ndcg.h" +#include + #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "yggdrasil_decision_forests/dataset/vertical_dataset.h" #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h" +#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.pb.h" #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_imp_cross_entropy_ndcg.h" #include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_interface.h" #include "yggdrasil_decision_forests/model/abstract_model.pb.h" +#include "yggdrasil_decision_forests/utils/status_macros.h" #include "yggdrasil_decision_forests/utils/test.h" #include "yggdrasil_decision_forests/utils/testing_macros.h" @@ -216,6 +221,46 @@ TEST_P(NDCGLossTest, ComputeRankingLossPerfectlyWrongPredictions) { } } +TEST_P(NDCGLossTest, ComputeRankingLossPerfectlyWrongPredictionsTruncation1) { + ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset, + CreateToyDataset()); + const bool weighted = GetParam(); + std::vector weights; + if (weighted) { + weights = {1.f, 3.f, 1.f, 3.f}; + } + + // Perfectly wrong predictions. + std::vector predictions = {2., 2., 1., 1.}; + proto::GradientBoostedTreesTrainingConfig gbt_config; + gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(1); + const NDCGLoss loss_imp(gbt_config, model::proto::Task::RANKING, + dataset.data_spec().columns(0)); + RankingGroupsIndices index; + EXPECT_OK(index.Initialize(dataset, 0, 1)); + ASSERT_OK_AND_ASSIGN( + LossResults loss_results, + loss_imp.Loss(dataset, + /* label_col_idx= */ 0, predictions, weights, &index)); + if (weighted) { + // R> 0.7238181 = (sum((2^c(1,3)-1)/log2(seq(2)+1)) / + // sum((2^c(3,1)-1)/log2(seq(2)+1)) + 3(sum((2^c(2,4)-1)/log2(seq(2)+1)) / + // sum((2^c(4,2)-1)/log2(seq(2)+1))) )/4 + EXPECT_NEAR(loss_results.loss, -(1. / 7. + 9. / 15.) / 4., kTestPrecision); + EXPECT_THAT( + loss_results.secondary_metrics, + ElementsAre(FloatNear((1. / 7. + 9. / 15.) / 4., kTestPrecision))); + } else { + // R> 0.7238181 = (sum((2^c(1,3)-1)/log2(seq(2)+1)) / + // sum((2^c(3,1)-1)/log2(seq(2)+1)) + sum((2^c(2,4)-1)/log2(seq(2)+1)) / + // sum((2^c(4,2)-1)/log2(seq(2)+1)) )/2 + EXPECT_NEAR(loss_results.loss, -(1. / 7. + 3. / 15.) / 2., kTestPrecision); + EXPECT_THAT( + loss_results.secondary_metrics, + ElementsAre(FloatNear((1. / 7. + 3. / 15.) / 2., kTestPrecision))); + } +} + TEST_P(NDCGLossTest, ComputeRankingLossPerfectPredictions) { ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset, CreateToyDataset()); @@ -271,6 +316,24 @@ TEST(NDCGLossTest, SecondaryMetricName) { EXPECT_THAT(loss_imp.SecondaryMetricNames(), ElementsAre("NDCG@5")); } +TEST(NDCGLossTest, SecondaryMetricNameTrucation10) { + ASSERT_OK_AND_ASSIGN(const auto dataset, CreateToyDataset()); + proto::GradientBoostedTreesTrainingConfig gbt_config; + gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(10); + const auto loss_imp = NDCGLoss(gbt_config, model::proto::Task::RANKING, + dataset.data_spec().columns(0)); + EXPECT_THAT(loss_imp.SecondaryMetricNames(), ElementsAre("NDCG@10")); +} + +TEST(NDCGLossTest, InvalidTruncation) { + ASSERT_OK_AND_ASSIGN(const auto dataset, CreateToyDataset()); + proto::GradientBoostedTreesTrainingConfig gbt_config; + gbt_config.mutable_ndcg_loss_options()->set_ndcg_truncation(0); + const auto loss_imp = NDCGLoss(gbt_config, model::proto::Task::RANKING, + dataset.data_spec().columns(0)); + EXPECT_EQ(loss_imp.Status().code(), absl::StatusCode::kInvalidArgument); +} + INSTANTIATE_TEST_SUITE_P(NDCGLossTestSuite, NDCGLossTest, testing::Bool()); } // namespace diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h index bea5100e..3785a743 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h @@ -47,8 +47,6 @@ namespace gradient_boosted_trees { // this maximum only triggers a stern warning. constexpr int64_t kMaximumItemsInRankingGroup = 2048; -constexpr int kNDCG5Truncation = 5; - // Index of the secondary metrics according to the type of loss. constexpr int kBinomialLossSecondaryMetricClassificationIdx = 0; diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc index de6eb896..fb46cf9d 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc @@ -586,7 +586,7 @@ GradientBoostedTreesModel::ValidationEvaluation() const { validation_evaluation.mutable_regression()->set_sum_square_error( metric_value); validation_evaluation.set_count_predictions(1.f); - } else if (metric_name == "NDCG@5") { + } else if (absl::StartsWith(metric_name, "NDCG@")) { validation_evaluation.mutable_ranking()->mutable_ndcg()->set_value( metric_value); validation_evaluation.mutable_ranking()->set_ndcg_truncation(5); diff --git a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.proto b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.proto index 85e61ce3..c26f34f7 100644 --- a/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.proto +++ b/yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.proto @@ -70,7 +70,7 @@ enum Loss { POISSON = 7; // Mean average error (MAE). MEAN_AVERAGE_ERROR = 8; - // LambdaMART with NDCG@5. + // LambdaMART with NDCG loss. Truncation defaults to 5, configurable. LAMBDA_MART_NDCG = 9; } diff --git a/yggdrasil_decision_forests/port/python/CHANGELOG.md b/yggdrasil_decision_forests/port/python/CHANGELOG.md index 5f48f9c1..a70b7b2e 100644 --- a/yggdrasil_decision_forests/port/python/CHANGELOG.md +++ b/yggdrasil_decision_forests/port/python/CHANGELOG.md @@ -21,6 +21,9 @@ - Speed-up training of GBT models by ~10% - Add `ydf.util.read_tf_record` and `ydf.util.write_tf_record` to facilitate TF Record datasets usage. +- Rename LAMBDA_MART_NDCG5 to LAMBDA_MART_NDCG. The old name is deprecated + but can still be used. +- Allow configuring the truncation of NDCG losses. - Enable multi-threading when using `model.predict` and `model.evaluate`. - Default number of threads of `model.analyze` is equal to the number of cores. diff --git a/yggdrasil_decision_forests/port/python/ydf/learner/specialized_learners_pre_generated.py b/yggdrasil_decision_forests/port/python/ydf/learner/specialized_learners_pre_generated.py index 45f6cee8..223d7ea2 100644 --- a/yggdrasil_decision_forests/port/python/ydf/learner/specialized_learners_pre_generated.py +++ b/yggdrasil_decision_forests/port/python/ydf/learner/specialized_learners_pre_generated.py @@ -1119,6 +1119,9 @@ class GradientBoostedTreesLearner(generic_learner.GenericLearner): variable importance of the model at the end of the training using the validation dataset. Enabling this feature can increase the training time significantly. Default: False. + cross_entropy_ndcg_truncation: Truncation of the cross-entropy NDCG loss. + Only used with cross-entropy NDCG loss i.e. `loss="XE_NDCG_MART"` + Default: 5. dart_dropout: Dropout rate applied when using the DART i.e. when forest_extraction=DART. Default: None. early_stopping: Early stopping detects the overfitting of the model and @@ -1211,14 +1214,15 @@ class GradientBoostedTreesLearner(generic_learner.GenericLearner): likelihood loss. Mainly used for counting problems. Only valid for regression. - `MULTINOMIAL_LOG_LIKELIHOOD`: Multinomial log likelihood i.e. cross-entropy. Only valid for binary or multi-class classification. - - `LAMBDA_MART_NDCG5`: LambdaMART with NDCG5. - `XE_NDCG_MART`: Cross + `LAMBDA_MART_NDCG`: LambdaMART with NDCG@5. - `XE_NDCG_MART`: Cross Entropy Loss NDCG. See arxiv.org/abs/1911.09798. - `BINARY_FOCAL_LOSS`: Focal loss. Only valid for binary classification. See https://arxiv.org/pdf/1708.02002.pdf. - `POISSON`: Poisson log likelihood. Only valid for regression. - `MEAN_AVERAGE_ERROR`: Mean average error a.k.a. MAE. For custom losses, pass the loss object here. Note that when using custom losses, the link function is deactivated (aka - apply_link_function is always False). + apply_link_function is always False). - `LAMBDA_MART_NDCG5`: DEPRECATED, + use LAMBDA_MART_NDCG. LambdaMART with NDCG@5. Default: "DEFAULT". max_depth: Maximum depth of the tree. `max_depth=1` means that all trees will be roots. `max_depth=-1` means that tree depth is not restricted by @@ -1261,6 +1265,8 @@ class GradientBoostedTreesLearner(generic_learner.GenericLearner): et al. in "Random Survival Forests" (https://projecteuclid.org/download/pdfview_1/euclid.aoas/1223908043). Default: "GLOBAL_IMPUTATION". + ndcg_truncation: Truncation of the NDCG loss. Only used with NDCG loss i.e. + `loss="LAMBDA_MART_NDCG"` Default: 5. num_candidate_attributes: Number of unique valid attributes tested for each node. An attribute is valid if it has at least a valid split. If `num_candidate_attributes=0`, the value is set to the classical default @@ -1452,6 +1458,7 @@ def __init__( categorical_set_split_max_num_items: int = -1, categorical_set_split_min_item_frequency: int = 1, compute_permutation_variable_importance: bool = False, + cross_entropy_ndcg_truncation: int = 5, dart_dropout: Optional[float] = None, early_stopping: str = "LOSS_INCREASE", early_stopping_initial_iteration: int = 10, @@ -1480,6 +1487,7 @@ def __init__( mhld_oblique_sample_attributes: Optional[bool] = None, min_examples: int = 5, missing_value_policy: str = "GLOBAL_IMPUTATION", + ndcg_truncation: int = 5, num_candidate_attributes: Optional[int] = -1, num_candidate_attributes_ratio: Optional[float] = None, num_trees: int = 300, @@ -1529,6 +1537,7 @@ def __init__( "compute_permutation_variable_importance": ( compute_permutation_variable_importance ), + "cross_entropy_ndcg_truncation": cross_entropy_ndcg_truncation, "dart_dropout": dart_dropout, "early_stopping": early_stopping, "early_stopping_initial_iteration": early_stopping_initial_iteration, @@ -1561,6 +1570,7 @@ def __init__( "mhld_oblique_sample_attributes": mhld_oblique_sample_attributes, "min_examples": min_examples, "missing_value_policy": missing_value_policy, + "ndcg_truncation": ndcg_truncation, "num_candidate_attributes": num_candidate_attributes, "num_candidate_attributes_ratio": num_candidate_attributes_ratio, "num_trees": num_trees, diff --git a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py index badf2e8c..d4c4671f 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py @@ -60,8 +60,9 @@ class Task(enum.Enum): Attributes: CLASSIFICATION: Predict a categorical label i.e., an item of an enumeration. REGRESSION: Predict a numerical label i.e., a quantity. - RANKING: Rank items by label values. The label is expected to be between 0 - and 4 with NDCG semantic (0: completely unrelated, 4: perfect match). + RANKING: Rank items by label values. When using default NDCG settings, the + label is expected to be between 0 and 4 with NDCG semantic (0: completely + unrelated, 4: perfect match). CATEGORICAL_UPLIFT: Predicts the incremental impact of a treatment on a categorical outcome. NUMERICAL_UPLIFT: Predicts the incremental impact of a treatment on a