diff --git a/yggdrasil_decision_forests/metric/report.cc b/yggdrasil_decision_forests/metric/report.cc index 0838c643..60f8842f 100644 --- a/yggdrasil_decision_forests/metric/report.cc +++ b/yggdrasil_decision_forests/metric/report.cc @@ -578,7 +578,9 @@ absl::Status AppendHtmlReport(const proto::EvaluationResults& eval, h::Html html; - html.Append(h::H1("Evaluation report")); + if (options.include_title) { + html.Append(h::H1("Evaluation report")); + } if (options.include_text_report) { ASSIGN_OR_RETURN(const auto text_report, TextReport(eval)); diff --git a/yggdrasil_decision_forests/metric/report.h b/yggdrasil_decision_forests/metric/report.h index 7d1def8c..25f398d7 100644 --- a/yggdrasil_decision_forests/metric/report.h +++ b/yggdrasil_decision_forests/metric/report.h @@ -46,6 +46,7 @@ absl::Status AppendTextReportUplift(const proto::EvaluationResults& eval, // Add the report in a html format. struct HtmlReportOptions { + bool include_title = true; bool include_text_report = true; // Size of the plots. diff --git a/yggdrasil_decision_forests/port/python/ydf/cc/BUILD b/yggdrasil_decision_forests/port/python/ydf/cc/BUILD index aed870e1..c1c44ae0 100644 --- a/yggdrasil_decision_forests/port/python/ydf/cc/BUILD +++ b/yggdrasil_decision_forests/port/python/ydf/cc/BUILD @@ -14,6 +14,7 @@ pybind_extension( deps = [ "//ydf/dataset:dataset_cc", "//ydf/learner:learner_cc", + "//ydf/metric:metric_cc", "//ydf/model:model_cc", "@com_google_pybind11_abseil//pybind11_abseil:import_status_module", "@com_google_pybind11_abseil//pybind11_abseil:status_casters", diff --git a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.cc b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.cc index cbfe7ff0..fa78bf0f 100644 --- a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.cc +++ b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.cc @@ -20,6 +20,7 @@ #include "pybind11_protobuf/native_proto_caster.h" #include "ydf/dataset/dataset.h" #include "ydf/learner/learner.h" +#include "ydf/metric/metric.h" #include "ydf/model/model.h" namespace py = ::pybind11; @@ -35,6 +36,7 @@ PYBIND11_MODULE(ydf, m) { init_dataset(m); init_model(m); init_learner(m); + init_metric(m); } } // namespace yggdrasil_decision_forests::port::python diff --git a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi index 2acf81f3..e179abe4 100644 --- a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi +++ b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi @@ -102,3 +102,16 @@ def GetLearner( deployment_config: abstract_learner_pb2.DeploymentConfig, ) -> GenericCCLearner: ... + +# Metric bindings +# ================ + + +def EvaluationToStr( + evaluation: metric_pb2.EvaluationResults +) -> str: ... + +def EvaluationPlotToHtml( + evaluation: metric_pb2.EvaluationResults +) -> str: ... + diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/BUILD b/yggdrasil_decision_forests/port/python/ydf/metric/BUILD index c1b235f3..cbf0e681 100644 --- a/yggdrasil_decision_forests/port/python/ydf/metric/BUILD +++ b/yggdrasil_decision_forests/port/python/ydf/metric/BUILD @@ -1,4 +1,5 @@ # pytype test and library +load("@pybind11_bazel//:build_defs.bzl", "pybind_library") package( default_visibility = ["//visibility:public"], @@ -15,9 +16,9 @@ py_library( "metric.py", ], deps = [ - # matplotlib dep, # numpy dep, "@ydf_cc//yggdrasil_decision_forests/metric:metric_py_proto", + "//ydf/cc:ydf", "//ydf/dataset:dataspec", "//ydf/utils:documentation", "//ydf/utils:html", @@ -25,6 +26,19 @@ py_library( ], ) +pybind_library( + name = "metric_cc", + srcs = ["metric.cc"], + hdrs = ["metric.h"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@com_google_pybind11_abseil//pybind11_abseil:status_casters", + "@com_google_pybind11_protobuf//pybind11_protobuf:native_proto_caster", + "@ydf_cc//yggdrasil_decision_forests/metric:metric_cc_proto", + "@ydf_cc//yggdrasil_decision_forests/metric:report", + ], +) + # Tests # ===== @@ -39,7 +53,6 @@ py_test( # numpy dep, "@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_py_proto", "@ydf_cc//yggdrasil_decision_forests/metric:metric_py_proto", - "//ydf/utils:test_utils", "@ydf_cc//yggdrasil_decision_forests/utils:distribution_py_proto", ], ) diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py b/yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py index 9db454b8..9e3af62d 100644 --- a/yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py +++ b/yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py @@ -20,9 +20,7 @@ from typing import Any, Optional, Tuple from xml.dom import minidom -# TODO: Add matplotlib as a requirement, or fail. -import matplotlib.pyplot as plt - +from ydf.cc import ydf from ydf.metric import metric from ydf.utils import documentation from ydf.utils import html @@ -243,10 +241,6 @@ def evaluation_to_html_str(e: metric.Evaluation, add_style: bool = True) -> str: documentation_url=documentation.URL_WEIGHTED_NUM_EXAMPLES, ) - # Curves - - # Classification - _object_to_html( doc, html_metric_box, @@ -255,27 +249,9 @@ def evaluation_to_html_str(e: metric.Evaluation, add_style: bool = True) -> str: documentation_url=documentation.URL_CONFUSION_MATRIX, ) - if e.characteristics: - for characteristic in e.characteristics: - # ROC - fig = _plot_roc(characteristic) - image = _fig_to_dom(doc, fig) - _object_to_html( - doc, - html_metric_box, - f"ROC: {characteristic.name} (AUC:{characteristic.roc_auc:g})", - image, - ) - - # PR - fig = _plot_pr(characteristic) - image = _fig_to_dom(doc, fig) - _object_to_html( - doc, - html_metric_box, - f"PR: {characteristic.name} (AUC:{characteristic.roc_auc:g})", - image, - ) + # Curves + plot_html = ydf.EvaluationPlotToHtml(e._evaluation_proto) + _object_to_html(doc, html_metric_box, None, plot_html, raw_html=True) return root.toprettyxml(indent=" ") @@ -437,9 +413,10 @@ def _field_to_html( def _object_to_html( doc: html.Doc, parent, - key: str, + key: Optional[str], value: Any, documentation_url: Optional[str] = None, + raw_html: bool = False, ) -> None: """Friendly html print a "key" and a complex element. @@ -452,6 +429,7 @@ def _object_to_html( key: Name of the field. value: Complex object to display. documentation_url: Url to the documentation of this field. + raw_html: If true, "value" is interpreted as raw html. """ if value is None: @@ -472,7 +450,8 @@ def _object_to_html( html_key.appendChild(link) html_key = link - html_key.appendChild(doc.createTextNode(key)) + if key: + html_key.appendChild(doc.createTextNode(key)) html_value = doc.createElement("div") html_value.setAttribute("class", "value") @@ -481,60 +460,25 @@ def _object_to_html( if isinstance(value, minidom.Element): html_value.appendChild(value) else: - html_pre_value = doc.createElement("pre") - html_value.appendChild(html_pre_value) - html_pre_value.appendChild(doc.createTextNode(str_value)) + if raw_html: + node = _RawXMLNode(value, doc) + html_value.appendChild(node) + else: + html_pre_value = doc.createElement("pre") + html_value.appendChild(html_pre_value) + html_pre_value.appendChild(doc.createTextNode(str_value)) + + +class _RawXMLNode(minidom.Node): + # Required by Minidom + nodeType = 1 + def __init__(self, data, parent): + self.data = data + self.ownerDocument = parent -def _plot_roc(characteristic: metric.Characteristic): - """Plots a ROC curve.""" - - with plt.ioff(): - fig, ax = plt.subplots(1, figsize=(4, 4)) - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.set_box_aspect(1) - ax.plot([0, 1], [0, 1], linestyle="--", color="black", linewidth=0.5) - ax.plot( - characteristic.false_positive_rates, - characteristic.recalls, - color="red", - linewidth=0.5, - ) - ax.set_xlabel("false positive rate") - ax.set_ylabel("true positive rate (recall)") - ax.grid() - fig.tight_layout() - return fig - - -def _plot_pr(characteristic: metric.Characteristic): - """Plots a precision-recall curve.""" - - with plt.ioff(): - fig, ax = plt.subplots(1, figsize=(4, 4)) - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.set_box_aspect(1) - ax.plot( - characteristic.recalls, - characteristic.precisions, - color="red", - linewidth=0.5, - ) - ax.set_xlabel("recall") - ax.set_ylabel("precision") - ax.grid() - fig.tight_layout() - return fig - - -def _fig_to_dom(doc: html.Doc, fig) -> html.Elem: - """Converts a Matplotlib figure into a Dom object.""" - - tmpfile = io.BytesIO() - fig.savefig(tmpfile, format="png") - encoded = base64.b64encode(tmpfile.getvalue()).decode("utf-8") - image = doc.createElement("img") - image.setAttribute("src", "data:image/png;base64," + encoded) - return image + def writexml(self, writer, indent, addindent, newl): + del indent + del addindent + del newl + writer.write(self.data) diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/metric.cc b/yggdrasil_decision_forests/port/python/ydf/metric/metric.cc new file mode 100644 index 00000000..b0a611d8 --- /dev/null +++ b/yggdrasil_decision_forests/port/python/ydf/metric/metric.cc @@ -0,0 +1,56 @@ +/* + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include "absl/status/statusor.h" +#include "pybind11_abseil/status_casters.h" +#include "pybind11_protobuf/native_proto_caster.h" +#include "yggdrasil_decision_forests/metric/metric.pb.h" +#include "yggdrasil_decision_forests/metric/report.h" + +namespace py = ::pybind11; + +namespace yggdrasil_decision_forests::port::python { +namespace { + +absl::StatusOr EvaluationToStr( + const metric::proto::EvaluationResults& evaluation) { + return metric::TextReport(evaluation); +} + +absl::StatusOr EvaluationPlotToHtml( + const metric::proto::EvaluationResults& evaluation) { + std::string html; + metric::HtmlReportOptions options; + options.plot_width = 500; + options.plot_height = 400; + options.include_text_report = false; + options.include_title = false; + options.num_plots_per_columns = 2; + RETURN_IF_ERROR(metric::AppendHtmlReport(evaluation, &html, options)); + return html; +} + +} // namespace + +void init_metric(py::module_& m) { + m.def("EvaluationToStr", EvaluationToStr, py::arg("evaluation")); + m.def("EvaluationPlotToHtml", EvaluationPlotToHtml, py::arg("evaluation")); +} + +} // namespace yggdrasil_decision_forests::port::python diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/metric.h b/yggdrasil_decision_forests/port/python/ydf/metric/metric.h new file mode 100644 index 00000000..d0966fa6 --- /dev/null +++ b/yggdrasil_decision_forests/port/python/ydf/metric/metric.h @@ -0,0 +1,27 @@ +/* + * Copyright 2022 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_ +#define YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_ + +#include + +namespace yggdrasil_decision_forests::port::python { + +void init_metric(pybind11::module_ &m); + +} // namespace yggdrasil_decision_forests::port::python + +#endif // YGGDRASIL_DECISION_FORESTS_PORT_PYTHON_METRIC_METRIC_H_ diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/metric.py b/yggdrasil_decision_forests/port/python/ydf/metric/metric.py index c46e19a9..7c166de4 100644 --- a/yggdrasil_decision_forests/port/python/ydf/metric/metric.py +++ b/yggdrasil_decision_forests/port/python/ydf/metric/metric.py @@ -39,7 +39,6 @@ class ConfusionMatrix: See https://developers.google.com/machine-learning/glossary#confusion-matrix - Attributes: classes: The label classes. The number of elements should match the size of `matrix`. @@ -264,26 +263,10 @@ class Evaluation: Basic usage example: ```python - evaluation = ydf.metric.Evaluation() - evaluation["accuracy"] = 0.6 - ``` - - An evaluation can be constructed with constructor arguments: - - ```python - evaluation = ydf.metric.Evaluation(accuracy=0.6, num_examples=10) - ``` - - An evaluation contains properties to easily access common metrics, as well as - checks to make sure metrics are used correctly. - - ```python - evaluation = ydf.metric.Evaluation() - evaluation.accuracy = 0.6 - - evaluation.accuracy = "hello" - >> Warning: The "accuracy" is generally expected to be a float. Instead got a - str. + evaluation = model.evaluate(test_ds) + print(evaluation) + print(evaluation.accuracy) + evaluation # Html evaluation in notebook ``` Attributes: @@ -296,29 +279,17 @@ class Evaluation: rmse: Root Mean Square Error. Only available for regression task. rmse_ci95_bootstrap: 95% confidence interval of the RMSE computed using bootstrapping. Only available for regression task. - ndcg: - qini: - auuc: + ndcg: Normalized Discounted Cumulative Gain. For Ranking. + qini: For uplifting. + auuc: For uplifting. custom_metrics: User custom metrics dictionary. """ - # Model generic - loss: Optional[float] = None - num_examples: Optional[int] = None - num_examples_weighted: Optional[float] = None - custom_metrics: Dict[str, Any] = dataclasses.field(default_factory=dict) - # Classification - accuracy: Optional[float] = None - confusion_matrix: Optional[ConfusionMatrix] = None - characteristics: Optional[List[Characteristic]] = None - # Regression - rmse: Optional[float] = None - rmse_ci95_bootstrap: Optional[ConfidenceInterval] = None - # Ranking - ndcg: Optional[float] = None - # Uplift - qini: Optional[float] = None - auuc: Optional[float] = None + def __init__( + self, + evaluation_proto: metric_pb2.EvaluationResults, + ): + self._evaluation_proto = evaluation_proto def __str__(self) -> str: """Returns the string representation of an evaluation.""" @@ -334,6 +305,169 @@ def _repr_html_(self) -> str: return display_metric.evaluation_to_html_str(self) + def _get_proto_field_float(self, key: str) -> Optional[float]: + if self._evaluation_proto.HasField(key): + return getattr(self._evaluation_proto, key) + return None + + def _get_proto_field_int(self, key: str) -> Optional[int]: + if self._evaluation_proto.HasField(key): + return getattr(self._evaluation_proto, key) + return None + + @property + def loss(self) -> Optional[float]: + if self._evaluation_proto.HasField("loss_value"): + return self._evaluation_proto.loss_value + + if self._evaluation_proto.HasField("classification"): + clas = self._evaluation_proto.classification + if clas.HasField("sum_log_loss"): + return clas.sum_log_loss / self._evaluation_proto.count_predictions + + return None + + @property + def num_examples(self) -> Optional[float]: + return self._get_proto_field_int("count_predictions_no_weight") + + @property + def num_examples_weighted(self) -> Optional[float]: + return self._get_proto_field_float("count_predictions") + + @property + def custom_metrics(self) -> Dict[str, Any]: + return {k: v for k, v in self._evaluation_proto.user_metrics.items()} + + @property + def accuracy(self) -> Optional[float]: + if self._evaluation_proto.HasField("classification"): + clas = self._evaluation_proto.classification + classes = dataspec.categorical_column_dictionary_to_list( + self._evaluation_proto.label_column + ) + + if clas.HasField("confusion"): + confusion = clas.confusion + assert confusion.nrow == confusion.ncol, "Invalid confusion matrix" + assert confusion.nrow == len(classes), "Invalid confusion matrix" + assert confusion.nrow >= 1, "Invalid confusion matrix" + raw_confusion = np.array(confusion.counts).reshape( + confusion.nrow, confusion.nrow + ) + + return safe_div(np.trace(raw_confusion), np.sum(raw_confusion)) + return None + + @property + def confusion_matrix(self) -> Optional[ConfusionMatrix]: + if self._evaluation_proto.HasField("classification"): + clas = self._evaluation_proto.classification + classes = dataspec.categorical_column_dictionary_to_list( + self._evaluation_proto.label_column + ) + classes_wo_oov = classes[_OUT_OF_DICTIONARY_OFFSET:] + + if clas.HasField("confusion"): + confusion = clas.confusion + assert confusion.nrow == confusion.ncol, "Invalid confusion matrix" + assert confusion.nrow == len(classes), "Invalid confusion matrix" + assert confusion.nrow >= 1, "Invalid confusion matrix" + raw_confusion = np.array(confusion.counts).reshape( + confusion.nrow, confusion.nrow + ) + + return ConfusionMatrix( + classes=tuple(classes_wo_oov), + matrix=raw_confusion[ + _OUT_OF_DICTIONARY_OFFSET:, _OUT_OF_DICTIONARY_OFFSET: + ], + ) + return None + + @property + def characteristics(self) -> Optional[List[Characteristic]]: + if self._evaluation_proto.HasField("classification"): + clas = self._evaluation_proto.classification + classes = dataspec.categorical_column_dictionary_to_list( + self._evaluation_proto.label_column + ) + if clas.rocs: + characteristics = [] + for roc_idx, roc in enumerate(clas.rocs): + if roc_idx == 0: + # Skip the OOV item + continue + if roc_idx == 1 and len(clas.rocs) == 3: + # In case of binary classification, skip the negative class + continue + name = f"'{classes[roc_idx]}' vs others" + characteristics.append( + Characteristic( + name=name, + roc_auc=roc.auc, + pr_auc=roc.pr_auc, + per_threshold=[ + CharacteristicPerThreshold( + true_positive=x.tp, + false_positive=x.fp, + true_negative=x.tn, + false_negative=x.fn, + threshold=x.threshold, + ) + for x in roc.curve + ], + ) + ) + return characteristics + return None + + @property + def rmse(self) -> Optional[float]: + if self._evaluation_proto.HasField("regression"): + reg = self._evaluation_proto.regression + if reg.HasField("sum_square_error"): + return math.sqrt( + safe_div( + reg.sum_square_error, self._evaluation_proto.count_predictions + ) + ) + return None + + @property + def rmse_ci95_bootstrap(self) -> Optional[ConfidenceInterval]: + if self._evaluation_proto.HasField("regression"): + reg = self._evaluation_proto.regression + if reg.HasField("bootstrap_rmse_lower_bounds_95p") and reg.HasField( + "bootstrap_rmse_upper_bounds_95p" + ): + return ( + reg.bootstrap_rmse_lower_bounds_95p, + reg.bootstrap_rmse_upper_bounds_95p, + ) + return None + + @property + def ndcg(self) -> Optional[float]: + if self._evaluation_proto.HasField("ranking"): + rank = self._evaluation_proto.ranking + if rank.HasField("ndcg"): + return rank.ndcg.value + + @property + def qini(self) -> Optional[float]: + if self._evaluation_proto.HasField("uplift"): + uplift = self._evaluation_proto.uplift + if uplift.HasField("qini"): + return uplift.qini + + @property + def auuc(self) -> Optional[float]: + if self._evaluation_proto.HasField("uplift"): + uplift = self._evaluation_proto.uplift + if uplift.HasField("auuc"): + return uplift.auuc + def to_dict(self) -> Dict[str, Any]: """Metrics in a dictionary.""" @@ -364,133 +498,6 @@ def add_item(key, value): return output -def evaluation_proto_to_evaluation( - src: metric_pb2.EvaluationResults, -) -> Evaluation: - """Converts an evaluation from proto to python wrapper format. - - This function does not copy all the fields from the input evaluation proto. - Instead, only metrics targeted as PYDF general users are exported. For - instance, prediction samples are not exported. - - Currently, this function does not export characteristics (e.g. ROC curve) and - confidence bounds. - - Metrics related to the out-of-dictionary (OOD) item in classification label - column are not reported. - - Args: - src: Evaluation in proto format. - - Returns: - Evaluation object. - """ - - evaluation = Evaluation() - - if src.HasField("count_predictions_no_weight"): - evaluation.num_examples = src.count_predictions_no_weight - - if src.HasField("count_predictions"): - evaluation.num_examples_weighted = src.count_predictions - - if src.HasField("loss_value"): - evaluation.loss = src.loss_value - - if src.HasField("classification"): - classes = dataspec.categorical_column_dictionary_to_list(src.label_column) - classes_wo_oov = classes[_OUT_OF_DICTIONARY_OFFSET:] - - if src.classification.HasField("confusion"): - confusion = src.classification.confusion - assert confusion.nrow == confusion.ncol, "Invalid confusion matrix" - assert confusion.nrow == len(classes), "Invalid confusion matrix" - assert confusion.nrow >= 1, "Invalid confusion matrix" - raw_confusion = np.array(confusion.counts).reshape( - confusion.nrow, confusion.nrow - ) - - evaluation.accuracy = safe_div( - np.trace(raw_confusion), np.sum(raw_confusion) - ) - - evaluation.confusion_matrix = ConfusionMatrix( - classes=tuple(classes_wo_oov), - matrix=raw_confusion[ - _OUT_OF_DICTIONARY_OFFSET:, _OUT_OF_DICTIONARY_OFFSET: - ], - ) - - if src.classification.rocs: - characteristics = [] - for roc_idx, roc in enumerate(src.classification.rocs): - if roc_idx == 0: - # Skip the OOV item - continue - if roc_idx == 1 and len(src.classification.rocs) == 3: - # In case of binary classification, skip the negative class - continue - name = f"'{classes[roc_idx]}' vs others" - characteristics.append( - Characteristic( - name=name, - roc_auc=roc.auc, - pr_auc=roc.pr_auc, - per_threshold=[ - CharacteristicPerThreshold( - true_positive=x.tp, - false_positive=x.fp, - true_negative=x.tn, - false_negative=x.fn, - threshold=x.threshold, - ) - for x in roc.curve - ], - ) - ) - evaluation.characteristics = characteristics - - if "loss" not in evaluation.to_dict() and src.classification.HasField( - "sum_log_loss" - ): - evaluation.loss = src.classification.sum_log_loss / src.count_predictions - - if src.HasField("regression"): - reg = src.regression - if reg.HasField("sum_square_error"): - # Note: The RMSE is not the empirical variance of the error i.e., there is - # not corrective term to the denominator. This implementation is similar - # to the ones in sciket-learn, tensorflow and ydf cc. - evaluation.rmse = math.sqrt( - safe_div(reg.sum_square_error, src.count_predictions) - ) - - if reg.HasField("bootstrap_rmse_lower_bounds_95p") and reg.HasField( - "bootstrap_rmse_upper_bounds_95p" - ): - evaluation.rmse_ci95_bootstrap = ( - reg.bootstrap_rmse_lower_bounds_95p, - reg.bootstrap_rmse_upper_bounds_95p, - ) - - if src.HasField("ranking"): - rank = src.ranking - if rank.HasField("ndcg"): - evaluation.ndcg = rank.ndcg.value - - if src.HasField("uplift"): - uplift = src.uplift - if uplift.HasField("qini"): - evaluation.qini = uplift.qini - if uplift.HasField("auuc"): - evaluation.auuc = uplift.auuc - - for k, v in src.user_metrics.items(): - evaluation.custom_metrics[k] = v - - return evaluation - - def safe_div(a: float, b: float) -> float: """Returns a/b. If a==b==0, returns 0. diff --git a/yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py b/yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py index cb0facfe..0bc85b58 100644 --- a/yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py @@ -14,7 +14,6 @@ """Testing Metrics.""" -import os import textwrap from absl.testing import absltest @@ -24,102 +23,9 @@ from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb from yggdrasil_decision_forests.metric import metric_pb2 from ydf.metric import metric -from ydf.utils import test_utils from yggdrasil_decision_forests.utils import distribution_pb2 -class EvaluationTest(absltest.TestCase): - - def test_no_metrics(self): - e = metric.Evaluation() - self.assertEqual(str(e), "No metrics") - self.assertEqual(e.to_dict(), {}) - self.assertIsNone(e.accuracy) - - def test_set_and_get(self): - e = metric.Evaluation() - e.accuracy = 0.6 - self.assertEqual(e.accuracy, 0.6) - self.assertEqual(e.to_dict()["accuracy"], 0.6) - - e.num_examples = 50 - self.assertEqual(e.accuracy, 0.6) - self.assertEqual(e.num_examples, 50) - - def test_str(self): - e = metric.Evaluation(accuracy=0.6) - self.assertEqual(str(e), "accuracy: 0.6\n") - - e.num_examples = 50 - self.assertEqual( - str(e), - """accuracy: 0.6 -num examples: 50 -""", - ) - - e.custom_metrics["my_complex_metric"] = "hello\nworld" - self.assertEqual( - str(e), - """accuracy: 0.6 -my_complex_metric: - hello - world -num examples: 50 -""", - ) - - def test_all_metrics(self): - e = metric.Evaluation() - e.loss = 0.1 - e.num_examples = 10 - e.accuracy = 0.2 - e.confusion_matrix = metric.ConfusionMatrix( - classes=("a",), matrix=np.array([[1]]) - ) - e.rmse = 0.3 - e.rmse_ci95_bootstrap = (0.1, 0.4) - e.ndcg = 0.4 - e.qini = 0.5 - e.auuc = 0.6 - e.num_examples_weighted = 0.7 - - print(str(e)) - - self.assertEqual( - str(e), - textwrap.dedent("""\ - accuracy: 0.2 - confusion matrix: - label (row) \\ prediction (col) - +---+---+ - | | a | - +---+---+ - | a | 1 | - +---+---+ - RMSE: 0.3 - RMSE 95% CI [B]: (0.1, 0.4) - NDCG: 0.4 - QINI: 0.5 - AUUC: 0.6 - loss: 0.1 - num examples: 10 - num examples (weighted): 0.7 - """), - ) - - test_utils.golden_check_string( - self, - e._repr_html_(), - os.path.join( - test_utils.pydf_test_data_path(), - "golden", - "display_metric_to_html.html.expected", - ), - postfix=".html", - ) - - class ConfusionTest(absltest.TestCase): def test_str(self): @@ -198,15 +104,13 @@ def test_base(self): ) -class EvaluationProtoTest(absltest.TestCase): +class EvaluationTest(absltest.TestCase): - def test_convert_empty(self): + def test_empty(self): proto_eval = metric_pb2.EvaluationResults() - self.assertEqual( - metric.evaluation_proto_to_evaluation(proto_eval).to_dict(), {} - ) + self.assertEqual(metric.Evaluation(proto_eval).to_dict(), {}) - def test_convert_classification(self): + def test_classification(self): proto_eval = metric_pb2.EvaluationResults( count_predictions_no_weight=1, count_predictions=1, @@ -224,10 +128,28 @@ def test_convert_classification(self): ncol=3, ), sum_log_loss=2, + rocs=[ + metric_pb2.Roc(), + metric_pb2.Roc(), + metric_pb2.Roc( + count_predictions=10, + auc=0.8, + pr_auc=0.7, + curve=[ + metric_pb2.Roc.Point( + threshold=1, tp=2, fp=3, tn=4, fn=6 + ), + metric_pb2.Roc.Point( + threshold=2, tp=1, fp=2, tn=3, fn=4 + ), + ], + ), + ], ), ) - print(metric.evaluation_proto_to_evaluation(proto_eval)) - dict_eval = metric.evaluation_proto_to_evaluation(proto_eval).to_dict() + evaluation = metric.Evaluation(proto_eval) + print(evaluation) + dict_eval = evaluation.to_dict() self.assertDictContainsSubset( {"accuracy": (1 + 4) / (1 + 2 + 3 + 4), "loss": 2.0, "num_examples": 1}, dict_eval, @@ -238,7 +160,31 @@ def test_convert_classification(self): dict_eval["confusion_matrix"].matrix, [[1, 2], [3, 4]] ) - def test_convert_regression(self): + self.assertEqual( + str(evaluation), + textwrap.dedent("""\ + accuracy: 0.5 + confusion matrix: + label (row) \\ prediction (col) + +---+---+---+ + | | 1 | 2 | + +---+---+---+ + | 1 | 1 | 2 | + +---+---+---+ + | 2 | 3 | 4 | + +---+---+---+ + characteristics: + name: '2' vs others + ROC AUC: 0.8 + PR AUC: 0.7 + Num thresholds: 2 + loss: 2 + num examples: 1 + num examples (weighted): 1 + """), + ) + + def test_regression(self): proto_eval = metric_pb2.EvaluationResults( count_predictions_no_weight=1, loss_value=2, @@ -250,8 +196,10 @@ def test_convert_regression(self): bootstrap_rmse_upper_bounds_95p=10, ), ) + evaluation = metric.Evaluation(proto_eval) + print(evaluation) self.assertDictEqual( - metric.evaluation_proto_to_evaluation(proto_eval).to_dict(), + evaluation.to_dict(), { "loss": 2.0, "num_examples": 1, @@ -261,7 +209,18 @@ def test_convert_regression(self): }, ) - def test_convert_ranking(self): + self.assertEqual( + str(evaluation), + textwrap.dedent("""\ + RMSE: 2 + RMSE 95% CI [B]: (9.0, 10.0) + loss: 2 + num examples: 1 + num examples (weighted): 2 + """), + ) + + def test_ranking(self): proto_eval = metric_pb2.EvaluationResults( count_predictions_no_weight=1, loss_value=2, @@ -271,8 +230,10 @@ def test_convert_ranking(self): ndcg=metric_pb2.MetricEstimate(value=5) ), ) + evaluation = metric.Evaluation(proto_eval) + print(evaluation) self.assertDictEqual( - metric.evaluation_proto_to_evaluation(proto_eval).to_dict(), + evaluation.to_dict(), { "loss": 2.0, "ndcg": 5.0, @@ -281,7 +242,17 @@ def test_convert_ranking(self): }, ) - def test_convert_uplift(self): + self.assertEqual( + str(evaluation), + textwrap.dedent("""\ + NDCG: 5 + loss: 2 + num examples: 1 + num examples (weighted): 3 + """), + ) + + def test_uplift(self): proto_eval = metric_pb2.EvaluationResults( count_predictions_no_weight=1, loss_value=2, @@ -289,8 +260,10 @@ def test_convert_uplift(self): label_column=ds_pb.Column(name="my_label"), uplift=metric_pb2.EvaluationResults.Uplift(qini=6, auuc=7), ) + evaluation = metric.Evaluation(proto_eval) + print(evaluation) self.assertDictEqual( - metric.evaluation_proto_to_evaluation(proto_eval).to_dict(), + evaluation.to_dict(), { "auuc": 7.0, "loss": 2.0, @@ -300,6 +273,17 @@ def test_convert_uplift(self): }, ) + self.assertEqual( + str(evaluation), + textwrap.dedent("""\ + QINI: 6 + AUUC: 7 + loss: 2 + num examples: 1 + num examples (weighted): 3 + """), + ) + if __name__ == "__main__": absltest.main() diff --git a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py index 03e3180f..846a6b1f 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py @@ -176,7 +176,7 @@ def evaluate( task=self._model.task(), ) evaluation_proto = self._model.Evaluate(ds._dataset, options_proto) # pylint: disable=protected-access - return metric.evaluation_proto_to_evaluation(evaluation_proto) + return metric.Evaluation(evaluation_proto) def analyze( self, diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py index 17872160..dd1b75d3 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py @@ -122,6 +122,9 @@ def test_evaluate_adult_gbt(self): """), ) + # with open("/tmp/evaluation.html", "w") as f: + # f.write(evaluation._repr_html_()) + def test_analize_adult_gbt(self): model_path = os.path.join( ydf_test_data_path(), "model", "adult_binary_class_gbdt" diff --git a/yggdrasil_decision_forests/port/python/ydf/test_data/golden/display_metric_to_html.html.expected b/yggdrasil_decision_forests/port/python/ydf/test_data/golden/display_metric_to_html.html.expected deleted file mode 100644 index 1a770756..00000000 --- a/yggdrasil_decision_forests/port/python/ydf/test_data/golden/display_metric_to_html.html.expected +++ /dev/null @@ -1,115 +0,0 @@ -
- -
-
-
accuracy:
-
0.2
-
RMSE:
-
0.3
-
RMSE 95% CI [B]:
-
(0.1, 0.4)
-
NDCG:
-
0.4
-
QINI:
-
0.5
-
AUUC:
-
0.6
-
loss:
-
0.1
- -
10
- -
0.7
-
-
- -
- - - - - - - - - -
Label \ Preda
a1
-
-
-
-