Skip to content

Commit

Permalink
Speed-up and fix bug in condition evaluation during training.
Browse files Browse the repository at this point in the history
The training condition evaluation is composed of a loop over the examples and a switch over the condition type (and a few other things). Prior to this change, the example loop was outside the condition switch loop, forcing the algorithm to re-check the condition type (and other things) for each examples. After this change, the condition type is outside of the example loop.

Example of speed-ups:
1. Average speed-up of 3.7% on all benchmark.
2. Speed-up of 10-15% on Adult dataset with GBT.
3. Speed-up of 8% on Adult dataset with RF.
4. Speed-up of 9% on 4M dataset with 200 features with discretized GBT.
5. No speed difference (<1% gain) on 4M dataset with 200 features non-discretized GBT. Note: The absolute gain is the same as 4., but since 5.'s training is longer, the relative gain is insignificant.

PiperOrigin-RevId: 670888905
  • Loading branch information
achoum authored and copybara-github committed Sep 4, 2024
1 parent abe57f6 commit c407154
Show file tree
Hide file tree
Showing 8 changed files with 411 additions and 24 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Note: This is the changelog of the C++ library. The Python port has a separate
Changelog under `yggdrasil_decision_forests/port/python/CHANGELOG.md`.

## HEAD

### Features

- Speed-up training of GBT models by ~10%

## 1.10.0 - 2024-08-21

### Features
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def main(argv) -> None:
):
print(data)

print("Mean delta% (smaller is better):", data["delta%"].mean())
print("Median delta% (smaller is better):", data["delta%"].median())
print("Min delta% (smaller is better):", data["delta%"].min())
print("Max delta% (smaller is better):", data["delta%"].max())
print(
"Rate of negative delta% (larger is better):",
(data["delta%"] <= 0).mean(),
)


if __name__ == "__main__":
app.run(main)
27 changes: 4 additions & 23 deletions yggdrasil_decision_forests/learner/decision_tree/training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3819,8 +3819,7 @@ void SetInternalDefaultHyperParameters(
const model::proto::TrainingConfig& config,
const model::proto::TrainingConfigLinking& link_config,
const dataset::proto::DataSpecification& data_spec,
proto::DecisionTreeTrainingConfig* dt_config) {
}
proto::DecisionTreeTrainingConfig* dt_config) {}

void SetDefaultHyperParameters(proto::DecisionTreeTrainingConfig* config) {
// Emulation of histogram splits.
Expand Down Expand Up @@ -4524,27 +4523,9 @@ absl::Status SplitExamples(const dataset::VerticalDataset& dataset,
positive_examples->clear();
negative_examples->clear();

std::vector<UnsignedExampleIdx>* example_sets[] = {negative_examples,
positive_examples};

// Index of the example selected for this node.
const auto column_data = dataset.column(condition.attribute());

if (!dataset_is_dense) {
for (const UnsignedExampleIdx example_idx : examples) {
const auto dst = example_sets[EvalConditionFromColumn(
condition, column_data, dataset, example_idx)];
dst->push_back(example_idx);
}
} else {
UnsignedExampleIdx dense_example_idx = 0;
for (const UnsignedExampleIdx example_idx : examples) {
const auto dst = example_sets[EvalConditionFromColumn(
condition, column_data, dataset, dense_example_idx)];
dense_example_idx++;
dst->push_back(example_idx);
}
}
RETURN_IF_ERROR(EvalConditionOnDataset(dataset, examples, condition,
dataset_is_dense, positive_examples,
negative_examples));

// The following test ensure that the effective number of positive examples is
// equal to the expected number of positive examples. A miss alignment
Expand Down
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/model/decision_tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/dataset:data_spec",
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/dataset:example_cc_proto",
"//yggdrasil_decision_forests/dataset:types",
"//yggdrasil_decision_forests/dataset:vertical_dataset",
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/utils:bitmap",
Expand Down Expand Up @@ -152,6 +153,7 @@ cc_test(
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/dataset:data_spec_inference",
"//yggdrasil_decision_forests/dataset:example_cc_proto",
"//yggdrasil_decision_forests/dataset:types",
"//yggdrasil_decision_forests/dataset:vertical_dataset",
"//yggdrasil_decision_forests/dataset:vertical_dataset_io",
"//yggdrasil_decision_forests/utils:filesystem",
Expand Down
Loading

0 comments on commit c407154

Please sign in to comment.