Skip to content

Commit

Permalink
Make metric unit tests both more powerful and less prone to false err…
Browse files Browse the repository at this point in the history
…ors.

- When learners are expected to be deterministic, check metric against golden value (instead of valid range). This is the default behavior for the internal build.
- Make it possible to run test with random seed-values. This way, tests measure the learning variance from changing the random seed, or equivalently, use a different random number generator (e.g., same are going in the external build). In this case, metrics are tested again metric range.
- All metric range have been re-computed by running all tests 1000 times + adding 50% margin. In many cases, the new range is tighter than it was before.
- Remove non-deterministic in tests outside of the random seed (if the seed is not fixed). This significantly reduce the variance of the test results.

PiperOrigin-RevId: 568194822
  • Loading branch information
achoum authored and copybara-github committed Sep 25, 2023
1 parent 826e348 commit 5455343
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 177 deletions.
2 changes: 1 addition & 1 deletion yggdrasil_decision_forests/learner/cart/cart_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ TEST_F(CartOnAdult, Base) {
TrainAndEvaluateModel();
// Random Forest has an accuracy of ~0.860.
EXPECT_NEAR(metric::Accuracy(evaluation_), 0.8560, 0.01);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.4373, 0.04);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.4373, 0.05);

// Show the tree structure.
std::string description;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ void BinomialLogLikelihoodLoss::TemplatedLossImp(
const std::vector<float>& weights, size_t begin_example_idx,
size_t end_example_idx, double* __restrict sum_loss,
utils::IntegersConfusionMatrixDouble* confusion_matrix) {
double local_sum_loss = 0;
for (size_t example_idx = begin_example_idx; example_idx < end_example_idx;
example_idx++) {
// The loss function expects a 0/1 label.
Expand All @@ -250,19 +251,19 @@ void BinomialLogLikelihoodLoss::TemplatedLossImp(
if constexpr (use_weights) {
const float weight = weights[example_idx];
confusion_matrix->Add(labels[example_idx], predicted_label, weight);
*sum_loss -=
local_sum_loss -=
2 * weight *
(label_for_loss * prediction - std::log(1.f + std::exp(prediction)));
} else {
confusion_matrix->Add(labels[example_idx], predicted_label, 1.f);
// Loss:
// -2 * ( label * prediction - log(1+exp(prediction)))
*sum_loss -= 2 * (label_for_loss * prediction -
std::log(1.f + std::exp(prediction)));
DCheckIsFinite(*sum_loss);
local_sum_loss -= 2 * (label_for_loss * prediction -
std::log(1.f + std::exp(prediction)));
}
DCheckIsFinite(*sum_loss);
DCheckIsFinite(local_sum_loss);
}
*sum_loss += local_sum_loss;
}

template <typename T>
Expand Down
17 changes: 12 additions & 5 deletions yggdrasil_decision_forests/learner/multitasker/multitasker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "yggdrasil_decision_forests/learner/multitasker/multitasker.h"

#include <limits>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/flags/flag.h"
Expand Down Expand Up @@ -79,7 +81,8 @@ TEST_F(MultitaskerOnAdult, Base) {
t3->set_task(model::proto::Task::CLASSIFICATION);

TrainAndEvaluateModel();
YDF_EXPECT_METRIC_NEAR(metric::Accuracy(evaluation_), 0.860, 0.01);
YDF_TEST_METRIC(metric::Accuracy(evaluation_), 0.860, 0.01,
std::numeric_limits<double>::quiet_NaN());

utils::RandomEngine rnd(1234);

Expand All @@ -98,7 +101,8 @@ TEST_F(MultitaskerOnAdult, Base) {
metric::proto::EvaluationOptions eval_options;
eval_options.set_task(model::proto::Task::CLASSIFICATION);
auto eval = submodel->Evaluate(test_dataset_, eval_options, &rnd);
YDF_EXPECT_METRIC_NEAR(metric::Accuracy(eval), 0.860, 0.01);
YDF_TEST_METRIC(metric::Accuracy(eval), 0.860, 0.01,
std::numeric_limits<double>::quiet_NaN());
}

{
Expand All @@ -113,7 +117,8 @@ TEST_F(MultitaskerOnAdult, Base) {
metric::proto::EvaluationOptions eval_options;
eval_options.set_task(model::proto::Task::REGRESSION);
auto eval = submodel->Evaluate(test_dataset_, eval_options, &rnd);
YDF_EXPECT_METRIC_NEAR(metric::RMSE(eval), 10.2048, 0.05);
YDF_TEST_METRIC(metric::RMSE(eval), 10.2048, 0.05,
std::numeric_limits<double>::quiet_NaN());
}

{
Expand All @@ -128,7 +133,8 @@ TEST_F(MultitaskerOnAdult, Base) {
metric::proto::EvaluationOptions eval_options;
eval_options.set_task(model::proto::Task::CLASSIFICATION);
auto eval = submodel->Evaluate(test_dataset_, eval_options, &rnd);
YDF_EXPECT_METRIC_NEAR(metric::Accuracy(eval), 0.76474, 0.01);
YDF_TEST_METRIC(metric::Accuracy(eval), 0.76474, 0.01,
std::numeric_limits<double>::quiet_NaN());
}

{
Expand Down Expand Up @@ -185,7 +191,8 @@ TEST_F(MultitaskerOnAdult, Stacked) {
t3->set_task(model::proto::Task::CLASSIFICATION);

TrainAndEvaluateModel();
YDF_EXPECT_METRIC_NEAR(metric::Accuracy(evaluation_), 0.860, 0.01);
YDF_TEST_METRIC(metric::Accuracy(evaluation_), 0.860, 0.01,
std::numeric_limits<double>::quiet_NaN());

utils::RandomEngine rnd(1234);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ TEST_F(RandomForestOnAdult, Base) {

EXPECT_LE(rank_capital_gain, 5);
EXPECT_LE(rank_relationship, 5);
EXPECT_LE(rank_occupation, 5);
EXPECT_LE(rank_occupation, 6);

// Worst 2 variables.
const int rank_fnlwgt = utils::GetVariableImportanceRank(
Expand Down Expand Up @@ -397,7 +397,7 @@ TEST_F(RandomForestOnAdult, NoWinnerTakeAllRandomCategorical) {
rf_config->mutable_decision_tree()->mutable_categorical()->mutable_random();
TrainAndEvaluateModel();
EXPECT_NEAR(metric::Accuracy(evaluation_), 0.82618, 0.005);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.40623, 0.02);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.3817, 0.02);
}

TEST_F(RandomForestOnAdult, NoWinnerTakeAllExampleSampling) {
Expand Down Expand Up @@ -570,7 +570,7 @@ TEST_F(RandomForestOnAdult, MaxNumNodes) {

EXPECT_NEAR(metric::Accuracy(evaluation_), 0.862, 0.015);
// Disabling winner take all reduce the logloss (as expected).
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.368, 0.045);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.368, 0.06);
}

TEST_F(RandomForestOnAdult, SparseOblique) {
Expand Down Expand Up @@ -634,7 +634,7 @@ TEST_F(RandomForestOnAbalone, Base) {
absl::StrCat("csv:", oob_prediction_path));

TrainAndEvaluateModel();
EXPECT_NEAR(metric::RMSE(evaluation_), 2.0825, 0.01);
EXPECT_NEAR(metric::RMSE(evaluation_), 2.0926, 0.01);

// Check the oob predictions.
const auto oob_predictions = file::GetContent(oob_prediction_path).value();
Expand Down
146 changes: 100 additions & 46 deletions yggdrasil_decision_forests/utils/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cxxabi.h>

#include <algorithm>
#include <cstring>
#include <memory>
#include <random>
#include <string>
Expand Down Expand Up @@ -69,6 +70,7 @@ namespace utils {

namespace {

// Shuffles a dataset randomly. Does not rely on a static seed.
void ShuffleDataset(dataset::VerticalDataset* dataset) {
absl::BitGen bitgen;
std::vector<dataset::VerticalDataset::row_t> example_idxs(dataset->nrow());
Expand All @@ -77,6 +79,45 @@ void ShuffleDataset(dataset::VerticalDataset* dataset) {
*dataset = dataset->Extract(example_idxs).value();
}

// Generates a random seed. Does not rely on a static seed.
int64_t RandomSeed() {
absl::BitGen bitgen;
return bitgen();
}

// Generates a deterministic sequence of boolean value approximating poorly a
// binomial distribution sampling.
class DeterministicBinomial {
public:
DeterministicBinomial(const float rate) : rate_(rate) {}

bool Sample() {
if (num_total_ == 0) {
// Always return false first, unless the rate is 1.
num_total_++;
if (rate_ == 1) {
num_pos_++;
return true;
}
return false;
}

if (num_pos_ > rate_ * num_total_) {
num_total_++;
return false;
} else {
num_pos_++;
num_total_++;
return true;
}
}

private:
float rate_;
int num_pos_ = 0;
int num_total_ = 0;
};

} // namespace

void TrainAndTestTester::ConfigureForSyntheticDataset() {
Expand Down Expand Up @@ -136,9 +177,8 @@ void TrainAndTestTester::TrainAndEvaluateModel(
// Configure the learner.
CHECK_OK(model::GetLearner(train_config_, &learner_, deployment_config_));

if (inject_random_noise_ && !learner_->training_config().has_random_seed()) {
absl::BitGen bitgen;
learner_->mutable_training_config()->set_random_seed(bitgen());
if (change_random_seed_ && !learner_->training_config().has_random_seed()) {
learner_->mutable_training_config()->set_random_seed(RandomSeed());
}

if (generic_parameters_.has_value()) {
Expand Down Expand Up @@ -204,7 +244,7 @@ void TrainAndTestTester::TrainAndEvaluateModel(
YDF_LOG(INFO) << "Training duration: " << training_duration_;

// Evaluate the model.
utils::RandomEngine rnd(1234);
utils::RandomEngine rnd(1234); // Not used
evaluation_ = model_->Evaluate(test_dataset_, eval_options_, &rnd);

// Print the model evaluation.
Expand Down Expand Up @@ -406,21 +446,22 @@ void TrainAndTestTester::BuildTrainValidTestDatasets(
CHECK_OK(LoadVerticalDataset(train_path, data_spec, &dataset));

// Split the dataset in two folds: training and testing.
std::vector<dataset::VerticalDataset::row_t> train_example_idxs,
test_example_idxs, valid_example_idxs;
std::vector<dataset::VerticalDataset::row_t> train_example_idxs;
std::vector<dataset::VerticalDataset::row_t> test_example_idxs;
std::vector<dataset::VerticalDataset::row_t> valid_example_idxs;

DeterministicBinomial sampling(dataset_sampling_);
DeterministicBinomial train_test_split(split_train_ratio_);
DeterministicBinomial test_valid_split(0.5f);

// TODO: Make deterministic.
utils::RandomEngine rnd(1234);
std::uniform_real_distribution<double> dist_01;
// If a validation example should be generated (i.e.
// pass_validation_dataset_=true), next_example_is_valid indicates if the next
// example will be used for validation or testing.
bool next_example_is_valid = true;

for (dataset::VerticalDataset::row_t example_idx = 0;
example_idx < dataset.nrow(); example_idx++) {
// Down-sampling of examples.
// TODO: Make the split deterministic.
if (dataset_sampling_ < dist_01(rnd)) {
if (!sampling.Sample()) {
continue;
}

Expand All @@ -438,21 +479,13 @@ void TrainAndTestTester::BuildTrainValidTestDatasets(
}
}

bool is_training_example;
if (split_train_ratio_ == 0.5f) {
// Deterministic split.
is_training_example = (example_idx % 2) == 0;
} else {
is_training_example = dist_01(rnd) < split_train_ratio_;
}
const bool is_training_example = train_test_split.Sample();

if (is_training_example) {
train_example_idxs.push_back(example_idx);
} else {
if (pass_validation_dataset_) {
(next_example_is_valid ? valid_example_idxs : test_example_idxs)
.push_back(example_idx);
next_example_is_valid ^= true;
if (pass_validation_dataset_ && test_valid_split.Sample()) {
valid_example_idxs.push_back(example_idx);
} else {
test_example_idxs.push_back(example_idx);
}
Expand Down Expand Up @@ -720,7 +753,7 @@ void TestPredefinedHyperParameters(

// Evaluate the model.
if (min_accuracy.has_value()) {
utils::RandomEngine rnd(1234);
utils::RandomEngine rnd(1234); // Not used.
const auto evaluation = model->Evaluate(test_ds, {}, &rnd);
EXPECT_GE(metric::Accuracy(evaluation), min_accuracy.value());
}
Expand Down Expand Up @@ -847,32 +880,53 @@ absl::Status ExportUpliftPredictionsToTFUpliftCsvFormat(

void InternalExportMetricCondition(const absl::string_view test,
const double value, const double center,
const double margin,
const double margin, const double golden,
const absl::string_view metric,
const int line,
const absl::string_view file) {
// Margin of error when comparing golden metric values.
constexpr double kGoldenMargin = 0.0001;

const auto filename = file::GetBasename(file);
const auto abs_diff = std::abs(value - center);
const auto success = abs_diff < margin;
#ifdef EXPORT_METRIC_CONDITION
const auto uid = GenUniqueId();
const auto path =
file::JoinPath(EXPORT_METRIC_CONDITION, absl::StrCat(uid, ".csv"));
std::string content =
absl::StrCat("test,value,center,margin,metric,line,filename,success\n",
test, ",", value, ",", center, ",", margin, ",", metric, ",",
line, ",", filename, ",", success);
CHECK_OK(file::SetContent(path, content));
#endif
if (!success) {
EXPECT_TRUE(false) << "Non satified range condition for " << metric
<< " in " << test << "\ndefined at\n"
<< file << ":" << line << "\nThe metric value " << value
<< " is not in " << center << " +- " << margin
<< ".\ni.e. not in [" << (center - margin) << " , "
<< (center + margin)
<< "].\nThe absolute value of the difference is "
<< abs_diff << ".";
const bool golden_test = kYdfTestMetricCheckGold && !std::isnan(golden);

double abs_diff_margin = std::abs(value - center);
double abs_diff_golden = std::abs(value - golden);
bool success_margin = abs_diff_margin < margin;
bool success_golden = abs_diff_golden < kGoldenMargin;

if (strlen(kYdfTestMetricDumpDir) > 0) {
// Export metric to csv file.
const auto uid = GenUniqueId();
const auto path =
file::JoinPath(kYdfTestMetricDumpDir, absl::StrCat(uid, ".csv"));
std::string content = absl::StrCat(
"test,value,center,margin,metric,line,filename,success_margin,success_"
"golden,golden\n",
test, ",", value, ",", center, ",", margin, ",", metric, ",", line, ",",
filename, ",", success_margin, ",", success_golden, ",", golden);
CHECK_OK(file::SetContent(path, content));
} else {
if (!success_margin) {
EXPECT_TRUE(false) << "Non satified range condition for " << metric
<< " in " << test << "\ndefined at\n"
<< file << ":" << line << "\nThe metric value "
<< value << " is not in " << center << " +- " << margin
<< ".\ni.e. not in [" << (center - margin) << " , "
<< (center + margin)
<< "].\nThe absolute value of the difference is "
<< abs_diff_margin << ".";
}

if (golden_test && !success_golden) {
EXPECT_TRUE(false) << "Non satified golden value condition for " << metric
<< " in " << test << "\ndefined at\n"
<< file << ":" << line << "\nThe metric value "
<< value << " is different from " << golden
<< " (margin:" << kGoldenMargin
<< ").\nThe absolute value of the difference is "
<< abs_diff_golden << ".";
}
}
}

Expand Down
Loading

0 comments on commit 5455343

Please sign in to comment.