Skip to content

Commit

Permalink
[YDF] [TF-DF] Minor fixes
Browse files Browse the repository at this point in the history
- Pass release_cpu_linux flag to test_bazel to on linux for manylinux compatibility
- Remove the MacOS Tensorflow warning
- Fix an incorrect warning on max_depth =-1 or -2, which are special cases (and always set)

PiperOrigin-RevId: 665904417
  • Loading branch information
rstz authored and copybara-github committed Aug 21, 2024
1 parent 8bf9db0 commit 0d4e307
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ absl::Status SetHyperParameters(
const auto hparam = generic_hyper_params->Get(kHParamMaxDepth);
if (hparam.has_value()) {
dt_config->set_max_depth(hparam.value().value().integer());
if (dt_config->max_depth() < 2) {
if (dt_config->max_depth() < 2 && dt_config->max_depth() != -1 &&
dt_config->max_depth() != -2) {
LOG(WARNING)
<< "Setting max_depth=" << dt_config->max_depth()
<< " for a training model will not result in any learning (i.e. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -984,18 +984,11 @@ def to_tensorflow_function( # pytype: disable=name-error
classification, outputs a tensorflow of shape [num examples, 2]
containing the probability of both the negative and positive classes.
Has no effect on non-binary classification models.
force: Try to export even in currently unsupported environments. WARNING:
Setting this to true may crash the Python runtime.
force: Try to export even in currently unsupported environments.
Returns:
A TensorFlow @tf.function.
"""
if platform.system() == "Darwin" and not force:
raise ValueError(
"Exporting to TensorFlow is currently broken on MacOS and may crash"
" the current Python process. To proceed anyway, add parameter"
" `force=True`."
)

return _get_export_tf().ydf_model_to_tf_function(
ydf_model=self,
Expand Down

0 comments on commit 0d4e307

Please sign in to comment.