From 0d4e3073b264738634c27ff474d2ceb6a9572712 Mon Sep 17 00:00:00 2001 From: Richard Stotz Date: Wed, 21 Aug 2024 09:07:42 -0700 Subject: [PATCH] [YDF] [TF-DF] Minor fixes - Pass release_cpu_linux flag to test_bazel to on linux for manylinux compatibility - Remove the MacOS Tensorflow warning - Fix an incorrect warning on max_depth =-1 or -2, which are special cases (and always set) PiperOrigin-RevId: 665904417 --- .../learner/decision_tree/generic_parameters.cc | 3 ++- .../port/python/ydf/model/generic_model.py | 9 +-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc index 4a8d9b5d..d339b6eb 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc @@ -460,7 +460,8 @@ absl::Status SetHyperParameters( const auto hparam = generic_hyper_params->Get(kHParamMaxDepth); if (hparam.has_value()) { dt_config->set_max_depth(hparam.value().value().integer()); - if (dt_config->max_depth() < 2) { + if (dt_config->max_depth() < 2 && dt_config->max_depth() != -1 && + dt_config->max_depth() != -2) { LOG(WARNING) << "Setting max_depth=" << dt_config->max_depth() << " for a training model will not result in any learning (i.e. " diff --git a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py index 558b6e2b..702c8c47 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py @@ -984,18 +984,11 @@ def to_tensorflow_function( # pytype: disable=name-error classification, outputs a tensorflow of shape [num examples, 2] containing the probability of both the negative and positive classes. Has no effect on non-binary classification models. - force: Try to export even in currently unsupported environments. WARNING: - Setting this to true may crash the Python runtime. + force: Try to export even in currently unsupported environments. Returns: A TensorFlow @tf.function. """ - if platform.system() == "Darwin" and not force: - raise ValueError( - "Exporting to TensorFlow is currently broken on MacOS and may crash" - " the current Python process. To proceed anyway, add parameter" - " `force=True`." - ) return _get_export_tf().ydf_model_to_tf_function( ydf_model=self,