Skip to content

Commit

Permalink
Add option to hide the title when plotting evaluation reports.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571061008
  • Loading branch information
achoum authored and copybara-github committed Oct 5, 2023
1 parent b4f039f commit 0cd75ba
Show file tree
Hide file tree
Showing 14 changed files with 418 additions and 480 deletions.
4 changes: 3 additions & 1 deletion yggdrasil_decision_forests/metric/report.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,9 @@ absl::Status AppendHtmlReport(const proto::EvaluationResults& eval,

h::Html html;

html.Append(h::H1("Evaluation report"));
if (options.include_title) {
html.Append(h::H1("Evaluation report"));
}

if (options.include_text_report) {
ASSIGN_OR_RETURN(const auto text_report, TextReport(eval));
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/metric/report.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ absl::Status AppendTextReportUplift(const proto::EvaluationResults& eval,

// Add the report in a html format.
struct HtmlReportOptions {
bool include_title = true;
bool include_text_report = true;

// Size of the plots.
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pybind_extension(
deps = [
"//ydf/dataset:dataset_cc",
"//ydf/learner:learner_cc",
"//ydf/metric:metric_cc",
"//ydf/model:model_cc",
"@com_google_pybind11_abseil//pybind11_abseil:import_status_module",
"@com_google_pybind11_abseil//pybind11_abseil:status_casters",
Expand Down
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "pybind11_protobuf/native_proto_caster.h"
#include "ydf/dataset/dataset.h"
#include "ydf/learner/learner.h"
#include "ydf/metric/metric.h"
#include "ydf/model/model.h"

namespace py = ::pybind11;
Expand All @@ -35,6 +36,7 @@ PYBIND11_MODULE(ydf, m) {
init_dataset(m);
init_model(m);
init_learner(m);
init_metric(m);
}

} // namespace yggdrasil_decision_forests::port::python
13 changes: 13 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,16 @@ def GetLearner(
deployment_config: abstract_learner_pb2.DeploymentConfig,
) -> GenericCCLearner: ...


# Metric bindings
# ================


def EvaluationToStr(
evaluation: metric_pb2.EvaluationResults
) -> str: ...

def EvaluationPlotToHtml(
evaluation: metric_pb2.EvaluationResults
) -> str: ...

17 changes: 15 additions & 2 deletions yggdrasil_decision_forests/port/python/ydf/metric/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pytype test and library
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")

package(
default_visibility = ["//visibility:public"],
Expand All @@ -15,16 +16,29 @@ py_library(
"metric.py",
],
deps = [
# matplotlib dep,
# numpy dep,
"@ydf_cc//yggdrasil_decision_forests/metric:metric_py_proto",
"//ydf/cc:ydf",
"//ydf/dataset:dataspec",
"//ydf/utils:documentation",
"//ydf/utils:html",
"//ydf/utils:string_lib",
],
)

pybind_library(
name = "metric_cc",
srcs = ["metric.cc"],
hdrs = ["metric.h"],
deps = [
"@com_google_absl//absl/status:statusor",
"@com_google_pybind11_abseil//pybind11_abseil:status_casters",
"@com_google_pybind11_protobuf//pybind11_protobuf:native_proto_caster",
"@ydf_cc//yggdrasil_decision_forests/metric:metric_cc_proto",
"@ydf_cc//yggdrasil_decision_forests/metric:report",
],
)

# Tests
# =====

Expand All @@ -39,7 +53,6 @@ py_test(
# numpy dep,
"@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_py_proto",
"@ydf_cc//yggdrasil_decision_forests/metric:metric_py_proto",
"//ydf/utils:test_utils",
"@ydf_cc//yggdrasil_decision_forests/utils:distribution_py_proto",
],
)
114 changes: 29 additions & 85 deletions yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from typing import Any, Optional, Tuple
from xml.dom import minidom

# TODO: Add matplotlib as a requirement, or fail.
import matplotlib.pyplot as plt

from ydf.cc import ydf
from ydf.metric import metric
from ydf.utils import documentation
from ydf.utils import html
Expand Down Expand Up @@ -243,10 +241,6 @@ def evaluation_to_html_str(e: metric.Evaluation, add_style: bool = True) -> str:
documentation_url=documentation.URL_WEIGHTED_NUM_EXAMPLES,
)

# Curves

# Classification

_object_to_html(
doc,
html_metric_box,
Expand All @@ -255,27 +249,9 @@ def evaluation_to_html_str(e: metric.Evaluation, add_style: bool = True) -> str:
documentation_url=documentation.URL_CONFUSION_MATRIX,
)

if e.characteristics:
for characteristic in e.characteristics:
# ROC
fig = _plot_roc(characteristic)
image = _fig_to_dom(doc, fig)
_object_to_html(
doc,
html_metric_box,
f"ROC: {characteristic.name} (AUC:{characteristic.roc_auc:g})",
image,
)

# PR
fig = _plot_pr(characteristic)
image = _fig_to_dom(doc, fig)
_object_to_html(
doc,
html_metric_box,
f"PR: {characteristic.name} (AUC:{characteristic.roc_auc:g})",
image,
)
# Curves
plot_html = ydf.EvaluationPlotToHtml(e._evaluation_proto)
_object_to_html(doc, html_metric_box, None, plot_html, raw_html=True)

return root.toprettyxml(indent=" ")

Expand Down Expand Up @@ -437,9 +413,10 @@ def _field_to_html(
def _object_to_html(
doc: html.Doc,
parent,
key: str,
key: Optional[str],
value: Any,
documentation_url: Optional[str] = None,
raw_html: bool = False,
) -> None:
"""Friendly html print a "key" and a complex element.
Expand All @@ -452,6 +429,7 @@ def _object_to_html(
key: Name of the field.
value: Complex object to display.
documentation_url: Url to the documentation of this field.
raw_html: If true, "value" is interpreted as raw html.
"""

if value is None:
Expand All @@ -472,7 +450,8 @@ def _object_to_html(
html_key.appendChild(link)
html_key = link

html_key.appendChild(doc.createTextNode(key))
if key:
html_key.appendChild(doc.createTextNode(key))

html_value = doc.createElement("div")
html_value.setAttribute("class", "value")
Expand All @@ -481,60 +460,25 @@ def _object_to_html(
if isinstance(value, minidom.Element):
html_value.appendChild(value)
else:
html_pre_value = doc.createElement("pre")
html_value.appendChild(html_pre_value)
html_pre_value.appendChild(doc.createTextNode(str_value))
if raw_html:
node = _RawXMLNode(value, doc)
html_value.appendChild(node)
else:
html_pre_value = doc.createElement("pre")
html_value.appendChild(html_pre_value)
html_pre_value.appendChild(doc.createTextNode(str_value))


class _RawXMLNode(minidom.Node):
# Required by Minidom
nodeType = 1

def __init__(self, data, parent):
self.data = data
self.ownerDocument = parent

def _plot_roc(characteristic: metric.Characteristic):
"""Plots a ROC curve."""

with plt.ioff():
fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot([0, 1], [0, 1], linestyle="--", color="black", linewidth=0.5)
ax.plot(
characteristic.false_positive_rates,
characteristic.recalls,
color="red",
linewidth=0.5,
)
ax.set_xlabel("false positive rate")
ax.set_ylabel("true positive rate (recall)")
ax.grid()
fig.tight_layout()
return fig


def _plot_pr(characteristic: metric.Characteristic):
"""Plots a precision-recall curve."""

with plt.ioff():
fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot(
characteristic.recalls,
characteristic.precisions,
color="red",
linewidth=0.5,
)
ax.set_xlabel("recall")
ax.set_ylabel("precision")
ax.grid()
fig.tight_layout()
return fig


def _fig_to_dom(doc: html.Doc, fig) -> html.Elem:
"""Converts a Matplotlib figure into a Dom object."""

tmpfile = io.BytesIO()
fig.savefig(tmpfile, format="png")
encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8")
image = doc.createElement("img")
image.setAttribute("src", "data:image/png;base64," + encoded)
return image
def writexml(self, writer, indent, addindent, newl):
del indent
del addindent
del newl
writer.write(self.data)
56 changes: 56 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/metric/metric.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <pybind11/pybind11.h>

#include <string>

#include "absl/status/statusor.h"
#include "pybind11_abseil/status_casters.h"
#include "pybind11_protobuf/native_proto_caster.h"
#include "yggdrasil_decision_forests/metric/metric.pb.h"
#include "yggdrasil_decision_forests/metric/report.h"

namespace py = ::pybind11;

namespace yggdrasil_decision_forests::port::python {
namespace {

absl::StatusOr<std::string> EvaluationToStr(
const metric::proto::EvaluationResults& evaluation) {
return metric::TextReport(evaluation);
}

absl::StatusOr<std::string> EvaluationPlotToHtml(
const metric::proto::EvaluationResults& evaluation) {
std::string html;
metric::HtmlReportOptions options;
options.plot_width = 500;
options.plot_height = 400;
options.include_text_report = false;
options.include_title = false;
options.num_plots_per_columns = 2;
RETURN_IF_ERROR(metric::AppendHtmlReport(evaluation, &html, options));
return html;
}

} // namespace

void init_metric(py::module_& m) {
m.def("EvaluationToStr", EvaluationToStr, py::arg("evaluation"));
m.def("EvaluationPlotToHtml", EvaluationPlotToHtml, py::arg("evaluation"));
}

} // namespace yggdrasil_decision_forests::port::python
27 changes: 27 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/metric/metric.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_
#define YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_

#include <pybind11/pybind11.h>

namespace yggdrasil_decision_forests::port::python {

void init_metric(pybind11::module_ &m);

} // namespace yggdrasil_decision_forests::port::python

#endif // YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_
Loading

0 comments on commit 0cd75ba

Please sign in to comment.