From 36b7de4f6ef8a1fdc5e1d5595ffa63f37cc4765e Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Fri, 11 Aug 2023 21:34:16 -0700 Subject: [PATCH] Fix incorrect model prediction for binary TF and CART models without features. The error was spotted in: https://github.com/tensorflow/decision-forests/issues/188 Fixed #188 PiperOrigin-RevId: 556206372 --- .../decision_forest_serving.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc index d91a9edd..4c953ca3 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_serving.cc @@ -179,19 +179,20 @@ inline void PredictHelper( utils::usage::OnInference(num_examples, model.metadata); const int num_features = model.features().fixed_length_features().size(); predictions->resize(num_examples); + const typename Model::ValueType* sample = examples.data(); for (int example_idx = 0; example_idx < num_examples; ++example_idx) { float output = 0.f; - if (num_features > 0) { - const auto* sample = &examples[example_idx * num_features]; - for (const auto root_node_idx : model.root_offsets) { - const auto* node = &model.nodes[root_node_idx]; - while (node->right_idx) { - node += EvalCondition(node, sample) ? node->right_idx : 1; - } - output += node->label; + + for (const auto root_node_idx : model.root_offsets) { + const auto* node = &model.nodes[root_node_idx]; + while (node->right_idx) { + node += EvalCondition(node, sample) ? node->right_idx : 1; } + output += node->label; } + (*predictions)[example_idx] = FinalTransform(model, output); + sample += num_features; } } @@ -301,7 +302,7 @@ inline void PredictHelperOptimizedV1( // Select the first example. // Note: The examples are stored example-major/feature-minor. - const typename Model::ValueType* sample = &examples[0]; + const typename Model::ValueType* sample = examples.data(); for (size_t example_idx = 0; example_idx < num_examples; ++example_idx) { // Accumulator of the predictions for the current example. float output = 0.f;