Skip to content

Commit

Permalink
Add support for mean-average-error (MAE) metric.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571267619
  • Loading branch information
achoum authored and copybara-github committed Oct 6, 2023
1 parent 2f1d1c4 commit 9dda105
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 10 deletions.
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/metric/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ cc_library_ydf(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@org_boost//:boost",
] + select({
"//conditions:default": [
Expand Down Expand Up @@ -153,6 +154,8 @@ cc_test(
"//yggdrasil_decision_forests/utils:distribution",
"//yggdrasil_decision_forests/utils:random",
"//yggdrasil_decision_forests/utils:test",
"//yggdrasil_decision_forests/utils:testing_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
Expand Down
112 changes: 110 additions & 2 deletions yggdrasil_decision_forests/metric/metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

#include "yggdrasil_decision_forests/metric/metric.h"

#include <cmath>
#include <cstddef>
#include <functional>
#include <limits>
#include <random>
#include <vector>

#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/metric/metric.pb.h"
#include "yggdrasil_decision_forests/model/abstract_model.pb.h"
Expand Down Expand Up @@ -375,6 +377,7 @@ void MergeEvaluationClassification(
void MergeEvaluationRegression(const proto::EvaluationResults::Regression& src,
proto::EvaluationResults::Regression* dst) {
dst->set_sum_square_error(dst->sum_square_error() + src.sum_square_error());
dst->set_sum_abs_error(dst->sum_abs_error() + src.sum_abs_error());
dst->set_sum_label(dst->sum_label() + src.sum_label());
dst->set_sum_square_label(dst->sum_square_label() + src.sum_square_label());
}
Expand Down Expand Up @@ -884,6 +887,8 @@ absl::Status AddPrediction(const proto::EvaluationOptions& option,
const float error = pred_reg.value() - pred_reg.ground_truth();
eval_reg->set_sum_square_error(eval_reg->sum_square_error() +
error * error * pred.weight());
eval_reg->set_sum_abs_error(eval_reg->sum_abs_error() +
std::abs(error) * pred.weight());
eval_reg->set_sum_label(eval_reg->sum_label() +
pred_reg.ground_truth() * pred.weight());
eval_reg->set_sum_square_label(
Expand Down Expand Up @@ -1012,6 +1017,13 @@ float RMSE(const proto::EvaluationResults& eval) {
return sqrt(eval.regression().sum_square_error() / eval.count_predictions());
}

float MAE(const proto::EvaluationResults& eval) {
if (eval.count_predictions() == 0) {
return std::numeric_limits<float>::quiet_NaN();
}
return eval.regression().sum_abs_error() / eval.count_predictions();
}

float NDCG(const proto::EvaluationResults& eval) {
return eval.ranking().ndcg().value();
}
Expand Down Expand Up @@ -1605,6 +1617,8 @@ absl::StatusOr<double> GetMetricRegression(
switch (metric.Type_case()) {
case proto::MetricAccessor::Regression::kRmse:
return RMSE(evaluation);
case proto::MetricAccessor::Regression::kMae:
return MAE(evaluation);
default:
return absl::InvalidArgumentError("Not implemented");
}
Expand Down Expand Up @@ -1639,8 +1653,13 @@ absl::StatusOr<double> GetMetricUplift(
absl::StatusOr<double> GetUserCustomizedMetrics(
const proto::EvaluationResults& evaluation,
const proto::MetricAccessor::UserMetric& metric) {
// user_metrics is a mapping from metrics name to metrics value.
return evaluation.user_metrics().find(metric.metrics_name())->second;
// user_metrics is a mapping from metrics name to metrics value.
const auto it = evaluation.user_metrics().find(metric.metrics_name());
if (it == evaluation.user_metrics().end()) {
return absl::InvalidArgumentError(
absl::StrCat("Cannot find metric: ", metric.metrics_name()));
}
return it->second;
}

// Returns an absl invalid status with an error message about the
Expand Down Expand Up @@ -1816,6 +1835,95 @@ proto::EvaluationResults BinaryClassificationEvaluationHelper(
return eval;
}

// Value returned by "MAEImp".
struct MAEImpResult {
// Sum of errors
double sum_err = 0;
// Sum of weights
double sum_weights = 0;
};

// Implementation of the mean-average-error metric. Returns the sum of weights
// and weighted sum of absolute error. If "use_weights" is false, "weights" is
// ignored.
//
// The use of weights is templated to improve the execution speed.
template <bool use_weights>
MAEImpResult MAEImp(const absl::Span<const float> labels,
const absl::Span<const float> predictions,
const absl::Span<const float> weights) {
DCHECK_EQ(labels.size(), predictions.size());

MAEImpResult accumulator;

// Note: "example_idx" is not the global example index.
for (size_t example_idx = 0; example_idx < labels.size(); ++example_idx) {
const float label = labels[example_idx];
const float prediction = predictions[example_idx];
if constexpr (use_weights) {
const float weight = weights[example_idx];
accumulator.sum_weights += weight;
accumulator.sum_err += weight * std::abs(label - prediction);
} else {
accumulator.sum_err += std::abs(label - prediction);
}
}
if constexpr (!use_weights) {
accumulator.sum_weights = labels.size();
}
return accumulator;
}

absl::StatusOr<double> MAE(const std::vector<float>& labels,
const std::vector<float>& predictions,
const std::vector<float>& weights,
utils::concurrency::ThreadPool* thread_pool) {
MAEImpResult global;
if (thread_pool == nullptr) {
if (weights.empty()) {
global = MAEImp<false>(labels, predictions, weights);
} else {
global = MAEImp<true>(labels, predictions, weights);
}
} else {
const auto num_threads = thread_pool->num_threads();

std::vector<MAEImpResult> per_threads(num_threads);

utils::concurrency::ConcurrentForLoop(
num_threads, thread_pool, labels.size(),
[&labels, &predictions, &per_threads, &weights](
size_t block_idx, size_t begin_idx, size_t end_idx) -> void {
auto& block = per_threads[block_idx];
if (weights.empty()) {
block = MAEImp<false>(absl::Span<const float>(labels).subspan(
begin_idx, end_idx - begin_idx),
absl::Span<const float>(predictions)
.subspan(begin_idx, end_idx - begin_idx),
weights);
} else {
block = MAEImp<true>(absl::Span<const float>(labels).subspan(
begin_idx, end_idx - begin_idx),
absl::Span<const float>(predictions)
.subspan(begin_idx, end_idx - begin_idx),
absl::Span<const float>(weights).subspan(
begin_idx, end_idx - begin_idx));
}
});

for (const auto& block : per_threads) {
global.sum_err += block.sum_err;
global.sum_weights += block.sum_weights;
}
}

if (global.sum_weights > 0) {
return global.sum_err / global.sum_weights;
} else {
return std::numeric_limits<double>::quiet_NaN();
}
}

template <bool use_weights>
void RMSEImp(const std::vector<float>& labels,
const std::vector<float>& predictions,
Expand Down
12 changes: 12 additions & 0 deletions yggdrasil_decision_forests/metric/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void ComputeXAtYMetrics(
float Accuracy(const proto::EvaluationResults& eval);
float LogLoss(const proto::EvaluationResults& eval);
float RMSE(const proto::EvaluationResults& eval);
float MAE(const proto::EvaluationResults& eval);
float ErrorRate(const proto::EvaluationResults& eval);

// Loss of the model. Can have different semantic for different models.
Expand Down Expand Up @@ -292,6 +293,17 @@ absl::StatusOr<double> RMSE(
const std::vector<float>& weights,
utils::concurrency::ThreadPool* thread_pool = nullptr);

// Computes the mean average error (MAE) of a set of predictions.
//
// The size of "labels" and "predictions" should be equal. If "weights" is not
// empty, the size of all "labels", "predictions", and "weights" should be
// equal. Returns NaN if the weighted sum of examples is zero (including if
// "labels" is empty).
absl::StatusOr<double> MAE(
const std::vector<float>& labels, const std::vector<float>& predictions,
const std::vector<float>& weights,
utils::concurrency::ThreadPool* thread_pool = nullptr);

// Gets the threshold on a binary classifier output that maximize accuracy.
float ComputeThresholdForMaxAccuracy(
const google::protobuf::RepeatedPtrField<proto::Roc::Point>& curve);
Expand Down
6 changes: 5 additions & 1 deletion yggdrasil_decision_forests/metric/metric.proto
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ message EvaluationResults {
// bootstrapping.
optional double bootstrap_rmse_lower_bounds_95p = 4;
optional double bootstrap_rmse_upper_bounds_95p = 5;
// Next ID: 6
// Sum of absolute value of the error.
optional double sum_abs_error = 6 [default = 0];
// Next ID: 7
}

message Ranking {
Expand Down Expand Up @@ -306,8 +308,10 @@ message MetricAccessor {
message Regression {
oneof Type {
Rmse rmse = 1;
Mae mae = 2;
}
message Rmse {}
message Mae {}
}

message Loss {}
Expand Down
106 changes: 99 additions & 7 deletions yggdrasil_decision_forests/metric/metric_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

#include "yggdrasil_decision_forests/metric/metric.h"

#include <cmath>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/substitute.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
Expand All @@ -28,13 +31,15 @@
#include "yggdrasil_decision_forests/utils/distribution.h"
#include "yggdrasil_decision_forests/utils/random.h"
#include "yggdrasil_decision_forests/utils/test.h"
#include "yggdrasil_decision_forests/utils/testing_macros.h"

namespace yggdrasil_decision_forests {
namespace metric {
namespace {

using test::EqualsProto;
using testing::ElementsAre;
using ::testing::Bool;
using ::testing::IsNan;

TEST(Metric, EvaluationOfClassification) {
// Create a fake column specification.
Expand Down Expand Up @@ -243,6 +248,7 @@ TEST(Metric, EvaluationOfRegression) {
EXPECT_EQ(eval.count_sampled_predictions(), 4);
EXPECT_EQ(eval.task(), model::proto::Task::REGRESSION);
EXPECT_NEAR(RMSE(eval), sqrt(0.5), 0.0001);
EXPECT_NEAR(MAE(eval), (0. + 1. + 0. + 1.) / 4, 0.0001);
EXPECT_NEAR(DefaultRMSE(eval), 0.5, 0.0001);

// Create reports.
Expand Down Expand Up @@ -762,18 +768,28 @@ TEST(Metric, RMSEConfidenceIntervals) {
0.03);
}

TEST(Metric, GetMetric) {
TEST(Metric, GetMetricRegression) {
const proto::EvaluationResults results_regression = PARSE_TEST_PROTO(R"pb(
task: REGRESSION
label_column { type: NUMERICAL }
regression { sum_square_error: 10 }
regression { sum_square_error: 10 sum_abs_error: 5 }
count_predictions: 10
)pb");
EXPECT_NEAR(
GetMetric(results_regression, PARSE_TEST_PROTO("regression { rmse {}}"))
.value(),
RMSE(results_regression), 0.0001);

ASSERT_OK_AND_ASSIGN(
const auto rmse,
GetMetric(results_regression,
PARSE_TEST_PROTO(R"pb(regression { rmse {} })pb")));
EXPECT_NEAR(rmse, RMSE(results_regression), 0.0001);

ASSERT_OK_AND_ASSIGN(
const auto mae,
GetMetric(results_regression,
PARSE_TEST_PROTO(R"pb(regression { mae {} })pb")));
EXPECT_NEAR(mae, MAE(results_regression), 0.0001);
}

TEST(Metric, GetMetricClassification) {
const proto::EvaluationResults results_classification = PARSE_TEST_PROTO(R"pb(
task: CLASSIFICATION
label_column {
Expand Down Expand Up @@ -932,6 +948,46 @@ TEST(Metric, GetMetric) {
0.0001);
}

TEST(Metric, GetMetricCustomMetric) {
const proto::EvaluationResults evaluation = PARSE_TEST_PROTO(R"pb(
user_metrics { key: "my_custom_metric" value: 2 }
)pb");

ASSERT_OK_AND_ASSIGN(
const auto custom,
GetMetric(evaluation, PARSE_TEST_PROTO(R"pb(
user_metric { metrics_name: "my_custom_metric" })pb")));
EXPECT_NEAR(custom, 2, 0.0001);

EXPECT_THAT(
GetMetric(evaluation, PARSE_TEST_PROTO(R"pb(
user_metric { metrics_name: "non_existing_metric" })pb"))
.status(),
test::StatusIs(absl::StatusCode::kInvalidArgument,
"Cannot find metric: non_existing_metric"));
}

TEST(Metric, GetMetricEmpty) {
const proto::EvaluationResults results_regression = PARSE_TEST_PROTO(R"pb(
task: REGRESSION
label_column { type: NUMERICAL }
regression { sum_square_error: 0 sum_abs_error: 0 }
count_predictions: 0
)pb");

EXPECT_THAT(GetMetric(results_regression,
PARSE_TEST_PROTO(R"pb(regression { rmse {} })pb"))
.value(),
IsNan());
EXPECT_THAT(RMSE(results_regression), IsNan());

EXPECT_THAT(GetMetric(results_regression,
PARSE_TEST_PROTO(R"pb(regression { mae {} })pb"))
.value(),
IsNan());
EXPECT_THAT(MAE(results_regression), IsNan());
}

TEST(Metric, MinMaxStream) {
MinMaxStream<int> bounds;
EXPECT_TRUE(bounds.empty());
Expand Down Expand Up @@ -1426,6 +1482,42 @@ TEST(Metric, RMSEThreaded) {
0.8164966, 0.0001);
}

class MAETest : public testing::TestWithParam<bool> {};

TEST_P(MAETest, MAE) {
const bool threaded = GetParam();
const double test_precision = 0.0001;

utils::concurrency::ThreadPool thread_pool("", 4);
if (threaded) {
thread_pool.StartWorkers();
}
auto* effective_threadpool = threaded ? &thread_pool : nullptr;

{
ASSERT_OK_AND_ASSIGN(const auto mae,
MAE(/*labels=*/{1, 2, 3}, /*predictions=*/{1, 3, 4},
/*weights=*/{1, 1, 1}, effective_threadpool));
EXPECT_NEAR(mae, (0. + 1. + 1.) / 3, test_precision);
}

{
ASSERT_OK_AND_ASSIGN(const auto mae,
MAE(/*labels=*/{1, 2, 3}, /*predictions=*/{1, 3, 4},
/*weights=*/{1, 2, 3}, effective_threadpool));
EXPECT_NEAR(mae, (0. * 1 + 1. * 2 + 1. * 3) / (1 + 2 + 3), test_precision);
}

{
ASSERT_OK_AND_ASSIGN(const auto mae,
MAE(/*labels=*/{1, 2, 3}, /*predictions=*/{1, 3, 4},
/*weights=*/{}, effective_threadpool));
EXPECT_NEAR(mae, (0. + 1. + 1.) / 3, test_precision);
}
}

INSTANTIATE_TEST_SUITE_P(WithAndWithoutThreading, MAETest, Bool());

TEST(DefaultMetrics, Classification) {
const dataset::proto::Column label = PARSE_TEST_PROTO(
R"pb(
Expand Down
Loading

0 comments on commit 9dda105

Please sign in to comment.