Skip to content

Commit

Permalink
Fix incorrect model prediction for binary TF and CART models without …
Browse files Browse the repository at this point in the history
…features.

The error was spotted in: tensorflow/decision-forests#188
Fixed #188

PiperOrigin-RevId: 556206372
  • Loading branch information
achoum authored and copybara-github committed Aug 12, 2023
1 parent 39f12e9 commit 36b7de4
Showing 1 changed file with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,20 @@ inline void PredictHelper(
utils::usage::OnInference(num_examples, model.metadata);
const int num_features = model.features().fixed_length_features().size();
predictions->resize(num_examples);
const typename Model::ValueType* sample = examples.data();
for (int example_idx = 0; example_idx < num_examples; ++example_idx) {
float output = 0.f;
if (num_features > 0) {
const auto* sample = &examples[example_idx * num_features];
for (const auto root_node_idx : model.root_offsets) {
const auto* node = &model.nodes[root_node_idx];
while (node->right_idx) {
node += EvalCondition(node, sample) ? node->right_idx : 1;
}
output += node->label;

for (const auto root_node_idx : model.root_offsets) {
const auto* node = &model.nodes[root_node_idx];
while (node->right_idx) {
node += EvalCondition(node, sample) ? node->right_idx : 1;
}
output += node->label;
}

(*predictions)[example_idx] = FinalTransform(model, output);
sample += num_features;
}
}

Expand Down Expand Up @@ -301,7 +302,7 @@ inline void PredictHelperOptimizedV1(

// Select the first example.
// Note: The examples are stored example-major/feature-minor.
const typename Model::ValueType* sample = &examples[0];
const typename Model::ValueType* sample = examples.data();
for (size_t example_idx = 0; example_idx < num_examples; ++example_idx) {
// Accumulator of the predictions for the current example.
float output = 0.f;
Expand Down

0 comments on commit 36b7de4

Please sign in to comment.