Skip to content

Commit

Permalink
Add parameter to control the maximum duration of the model analysis. …
Browse files Browse the repository at this point in the history
…Default to 10 seconds.

PiperOrigin-RevId: 675552053
  • Loading branch information
achoum authored and copybara-github committed Sep 17, 2024
1 parent fa0bd2f commit ec48a7b
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 22 deletions.
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/port/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
learner.
- Remove parameters for distributed training and resuming training from
learners that do not support these capabilities.
- By default, `model.analyze` for a maximum of 20 seconds (i.e.
`maximum_duration=20` by default).

### Feature

Expand All @@ -28,6 +30,7 @@
- Default number of threads of `model.analyze` is equal to the number of
cores.
- Add multi-threaded results in `model.benchmark`.
- Add argument to control the maximum duration of `model.analyze`.

### Fix

Expand Down
15 changes: 10 additions & 5 deletions yggdrasil_decision_forests/port/python/ydf/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,19 @@ def _add_column(
):
column_data = column_data.astype(np.bytes_)
elif np.issubdtype(column_data.dtype, np.floating):
raise ValueError(
message = (
f"Cannot import column {column.name!r} with"
f" semantic={column.semantic} as it contains floating point values."
f" Got {original_column_data!r}.\nNote: If the column is a label,"
" make sure the correct task is selected. For example, you cannot"
" train a classification model (task=ydf.Task.CLASSIFICATION) with"
" floating point labels."
)
if is_label:
message += (
"\nNote: This is a label column. Try one of the following"
" solutions: (1) To train a classification model, cast the label"
" values as integers. (2) To train a regression or a ranking"
" model, configure the learner with `task=ydf.Task.REGRESSION`)."
)
message += f"\nGot {original_column_data!r}."
raise ValueError(message)

if column_data.dtype.type == np.bytes_:
if inference_args is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import os
import signal
from typing import Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from absl import logging
from absl.testing import absltest
Expand Down Expand Up @@ -867,6 +867,19 @@ def test_wrong_shape_singledim_model(self):
"feature": np.array([[0], [1]]),
})

def test_analyze_ensure_maximum_duration(self):
# Create an analysis that would take a lot of time if not limited in time.
def create_dataset(n: int) -> Dict[str, np.ndarray]:
return {
"feature": np.random.uniform(size=(n, 100)),
"label": np.random.uniform(size=(n,)),
}

model = specialized_learners.RandomForestLearner(
label="label", task=generic_learner.Task.REGRESSION
).train(create_dataset(1_000))
_ = model.analyze(create_dataset(100_000))


class CARTLearnerTest(LearnerTest):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def analyze(
conditional_expectation_plot: bool = True,
permutation_variable_importance_rounds: int = 1,
num_threads: Optional[int] = None,
maximum_duration: Optional[float] = 20,
) -> analysis.Analysis:
"""Analyzes a model on a test dataset.
Expand Down Expand Up @@ -724,6 +725,8 @@ def analyze(
If permutation_variable_importance_rounds=0, disables the computation of
permutation variable importances.
num_threads: Number of threads to use to compute the analysis.
maximum_duration: Maximum duration of the analysis in seconds. Note that
the analysis can last a little longer than this value.
Returns:
Model analysis.
Expand All @@ -739,6 +742,7 @@ def analyze(

options_proto = model_analysis_pb2.Options(
num_threads=num_threads,
maximum_duration_seconds=maximum_duration,
pdp=model_analysis_pb2.Options.PlotConfig(
enabled=partial_depepence_plot,
example_sampling=sampling,
Expand Down
4 changes: 4 additions & 0 deletions yggdrasil_decision_forests/tools/local_copybara_export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ run_test() {
set -e

if [[ "$INTERACTIVE" = 1 ]]; then
echo "With TensorFlow build:"
echo "INSTALL_DEPENDENCIES=1 TF_SUPPORT="ON" COMPILERS="clang-12" CPP_VERSIONS="17" RUN_TESTS=1 ./tools/test_bazel.sh"
echo "Without Tensorflow build:"
echo "INSTALL_DEPENDENCIES=1 TF_SUPPORT="OFF" COMPILERS="clang-12" GO_PORT="0" PY_PORT="0" CPP_VERSIONS="14" ./tools/test_bazel.sh"
CMD='$SHELL'
else
CMD='INSTALL_DEPENDENCIES=1 TF_SUPPORT="ON" COMPILERS="clang-12" CPP_VERSIONS="17" RUN_TESTS=1 ./tools/test_bazel.sh;$SHELL'
Expand Down
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ cc_library_ydf(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
],
)

Expand Down
27 changes: 17 additions & 10 deletions yggdrasil_decision_forests/utils/model_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,11 @@ absl::StatusOr<proto::AnalysisResult> Analyse(
return absl::InvalidArgumentError("The dataset is empty.");
}

const absl::optional<float> maximum_duration_seconds =
options.has_maximum_duration_seconds()
? options.maximum_duration_seconds()
: absl::optional<float>{};

// Try to create a fast engine.
const model::AbstractModel* effective_model = &model;
auto engine_or = model.BuildFastEngine();
Expand All @@ -737,11 +742,12 @@ absl::StatusOr<proto::AnalysisResult> Analyse(
/*flag_2d=*/false,
/*flag_2d_categorical_numerical=*/false));

ASSIGN_OR_RETURN(*analysis.mutable_pdp_set(),
utils::ComputePartialDependencePlotSet(
dataset, *effective_model, attribute_idxs,
options.pdp().num_numerical_bins(),
options.pdp().example_sampling()));
ASSIGN_OR_RETURN(
*analysis.mutable_pdp_set(),
utils::ComputePartialDependencePlotSet(
dataset, *effective_model, attribute_idxs,
options.pdp().num_numerical_bins(),
options.pdp().example_sampling(), maximum_duration_seconds));
}

// Conditional Expectation Plot
Expand All @@ -752,11 +758,12 @@ absl::StatusOr<proto::AnalysisResult> Analyse(
/*flag_2d=*/false,
/*flag_2d_categorical_numerical=*/false));

ASSIGN_OR_RETURN(*analysis.mutable_cep_set(),
utils::ComputeConditionalExpectationPlotSet(
dataset, *effective_model, attribute_idxs,
options.cep().num_numerical_bins(),
options.cep().example_sampling()));
ASSIGN_OR_RETURN(
*analysis.mutable_cep_set(),
utils::ComputeConditionalExpectationPlotSet(
dataset, *effective_model, attribute_idxs,
options.cep().num_numerical_bins(),
options.cep().example_sampling(), maximum_duration_seconds));
}

// TODO: Implement permuted variable importances for anomaly detection.
Expand Down
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/utils/model_analysis.proto
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ message Options {
// Prefix used to generate unique html element ids. If not set, use a random
// prefix.
optional string html_id_prefix = 18;

// Maximum duration of the analysis in seconds.
optional float maximum_duration_seconds = 19;
}

// Results of a model analysis.
Expand Down
28 changes: 24 additions & 4 deletions yggdrasil_decision_forests/utils/partial_dependence_plot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "yggdrasil_decision_forests/dataset/data_spec.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/example.pb.h"
Expand Down Expand Up @@ -495,7 +495,8 @@ absl::Status AppendAttributesCombinations2D(
absl::StatusOr<proto::PartialDependencePlotSet> ComputePartialDependencePlotSet(
const dataset::VerticalDataset& dataset, const model::AbstractModel& model,
const std::vector<std::vector<int>>& attribute_idxs,
const int num_numerical_bins, const float example_sampling) {
const int num_numerical_bins, const float example_sampling,
const absl::optional<float> maximum_duration_seconds) {
LOG(INFO) << "Initiate PDP accumulator";
ASSIGN_OR_RETURN(auto pdp_set,
InitializePartialDependencePlotSet(
Expand All @@ -508,6 +509,11 @@ absl::StatusOr<proto::PartialDependencePlotSet> ComputePartialDependencePlotSet(
std::default_random_engine random;
std::uniform_real_distribution<float> dist_unif_unit;

absl::optional<absl::Time> cutoff_time;
if (maximum_duration_seconds.has_value()) {
cutoff_time = absl::Now() + absl::Seconds(maximum_duration_seconds.value());
}

// TODO: Multi-thread.
dataset::proto::Example example;
for (size_t example_idx = 0; example_idx < dataset.nrow(); example_idx++) {
Expand All @@ -516,6 +522,10 @@ absl::StatusOr<proto::PartialDependencePlotSet> ComputePartialDependencePlotSet(
}
if ((example_idx % 100) == 0) {
LOG_EVERY_N_SEC(INFO, 30) << example_idx + 1 << " examples scanned.";
if (cutoff_time.has_value() && absl::Now() > cutoff_time) {
LOG(INFO) << "Maximum duration reached. Interrupting analysis early.";
break;
}
}
dataset.ExtractExample(example_idx, &example);

Expand All @@ -528,8 +538,9 @@ absl::StatusOr<proto::PartialDependencePlotSet> ComputePartialDependencePlotSet(
absl::StatusOr<ConditionalExpectationPlotSet>
ComputeConditionalExpectationPlotSet(
const dataset::VerticalDataset& dataset, const model::AbstractModel& model,
const std::vector<std::vector<int>>& attribute_idxs, int num_numerical_bins,
float example_sampling) {
const std::vector<std::vector<int>>& attribute_idxs,
const int num_numerical_bins, const float example_sampling,
const absl::optional<float> maximum_duration_seconds) {
LOG(INFO) << "Initiate CEP accumulator";
ASSIGN_OR_RETURN(auto pdp_set,
InitializeConditionalExpectationPlotSet(
Expand All @@ -542,6 +553,11 @@ ComputeConditionalExpectationPlotSet(
std::default_random_engine random;
std::uniform_real_distribution<float> dist_unif_01;

absl::optional<absl::Time> cutoff_time;
if (maximum_duration_seconds.has_value()) {
cutoff_time = absl::Now() + absl::Seconds(maximum_duration_seconds.value());
}

// TODO: Multi-thread.
dataset::proto::Example example;
for (size_t example_idx = 0; example_idx < dataset.nrow(); example_idx++) {
Expand All @@ -550,6 +566,10 @@ ComputeConditionalExpectationPlotSet(
}
if ((example_idx % 100) == 0) {
LOG_EVERY_N_SEC(INFO, 30) << example_idx + 1 << " examples scanned.";
if (cutoff_time.has_value() && absl::Now() > cutoff_time) {
LOG(INFO) << "Maximum duration reached. Interrupting analysis early.";
break;
}
}
dataset.ExtractExample(example_idx, &example);
RETURN_IF_ERROR(
Expand Down
7 changes: 5 additions & 2 deletions yggdrasil_decision_forests/utils/partial_dependence_plot.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/example.pb.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
Expand Down Expand Up @@ -78,13 +79,15 @@ absl::Status UpdateConditionalExpectationPlotSet(
absl::StatusOr<PartialDependencePlotSet> ComputePartialDependencePlotSet(
const dataset::VerticalDataset& dataset, const model::AbstractModel& model,
const std::vector<std::vector<int>>& attribute_idxs, int num_numerical_bins,
float example_sampling);
float example_sampling,
absl::optional<float> maximum_duration_seconds = {});

absl::StatusOr<ConditionalExpectationPlotSet>
ComputeConditionalExpectationPlotSet(
const dataset::VerticalDataset& dataset, const model::AbstractModel& model,
const std::vector<std::vector<int>>& attribute_idxs, int num_numerical_bins,
float example_sampling);
float example_sampling,
absl::optional<float> maximum_duration_seconds = {});

// Appends all the "num_dims"-dimensional combinations of input features.
absl::Status AppendAttributesCombinations(
Expand Down

0 comments on commit ec48a7b

Please sign in to comment.