Skip to content

Commit

Permalink
Use fast median computation in MAE loss
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574181519
  • Loading branch information
achoum authored and copybara-github committed Oct 17, 2023
1 parent be81665 commit ebf6709
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ TEST_F(GradientBoostedTreesOnAdult, ValidPathDataset) {
inject_random_noise_ = true;
TrainAndEvaluateModel();
YDF_TEST_METRIC(metric::Accuracy(evaluation_), 0.8732, 0.0023, 0.8747);
YDF_TEST_METRIC(metric::LogLoss(evaluation_), 0.2794, 0.0027, 0.2776);
YDF_TEST_METRIC(metric::LogLoss(evaluation_), 0.2794, 0.0057, 0.2776);
}

TEST_F(GradientBoostedTreesOnAdult, DISABLED_VariableImportance) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/utils:concurrency",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:math",
"//yggdrasil_decision_forests/utils:random",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/status",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/math.h"
#include "yggdrasil_decision_forests/utils/random.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"

Expand Down Expand Up @@ -60,15 +61,7 @@ absl::StatusOr<std::vector<float>> MeanAverageErrorLoss::InitialPredictions(

float initial_prediction;
if (weights.empty()) {
auto sorted_labels = labels;
std::sort(sorted_labels.begin(), sorted_labels.end());
if ((labels.size() % 2) == 0) {
initial_prediction = (sorted_labels[sorted_labels.size() / 2] +
sorted_labels[(sorted_labels.size() / 2) - 1]) /
2;
} else {
initial_prediction = sorted_labels[sorted_labels.size() / 2];
}
initial_prediction = utils::Median(labels);
} else {
struct Item {
float label;
Expand Down
5 changes: 4 additions & 1 deletion yggdrasil_decision_forests/utils/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ namespace {
// The content of "values" is reordered during the computation.
// "values" cannot be empty.
float QuickSelect(std::vector<float>& values, size_t target_idx) {
DCHECK_GT(values.size(), 0);
// Boundaries of the search window.
size_t left = 0;
// Using a "right" instead of an "end" simplifies the code.
size_t right = values.size() - 1;

while (true) {
DCHECK_LE(right, values.size());
DCHECK_LE(left, values.size());
DCHECK_LE(left, right) << "The left index cannot move past the right index";
DCHECK_GE(target_idx, left) << "target_idx should be in [left, right]";
DCHECK_LE(target_idx, right) << "target_idx should be in [left, right]";
Expand Down Expand Up @@ -68,7 +71,7 @@ float QuickSelect(std::vector<float>& values, size_t target_idx) {
return values[pivot_idx];
} else if (target_idx < pivot_idx) {
right = pivot_idx - 1;
} else if (pivot_idx == 0) {
} else {
left = pivot_idx + 1;
}
}
Expand Down

0 comments on commit ebf6709

Please sign in to comment.