From 7412704bebb128e4b86e19af9c709cdc5c6407a5 Mon Sep 17 00:00:00 2001 From: Vincent Dumoulin Date: Tue, 28 Mar 2023 11:58:26 -0700 Subject: [PATCH] Refactor CMAP into RankedBasedMetrics and add generalized mean rank metric PiperOrigin-RevId: 520093902 --- chirp/models/metrics.py | 45 ++++++++++++++++ .../models/{cmap.py => rank_based_metrics.py} | 36 +++++++++---- chirp/projects/multicluster/classify.py | 6 +-- chirp/tests/metrics_test.py | 52 ++++++++++++++----- chirp/train/classifier.py | 10 ++-- chirp/train/hubert.py | 6 ++- chirp/train/separator.py | 16 ++---- 7 files changed, 130 insertions(+), 41 deletions(-) rename chirp/models/{cmap.py => rank_based_metrics.py} (62%) diff --git a/chirp/models/metrics.py b/chirp/models/metrics.py index 3fba55af..37ac2f0a 100644 --- a/chirp/models/metrics.py +++ b/chirp/models/metrics.py @@ -185,6 +185,51 @@ def average_precision( return mask * raw_av_prec +def generalized_mean_rank( + scores: jnp.ndarray, + labels: jnp.ndarray, + label_mask: jnp.ndarray | None = None, + sort_descending: bool = True, +) -> tuple[jnp.ndarray, jnp.ndarray]: + """Computes the generalized mean rank and its variance over the last axis. + + The generalized mean rank can be expressed as + + (sum_i #FP ranked above TP_i) / (#FP * #TP). + + We treat all labels as either true positives (if the label is 1) or false + positives (if the label is zero). + + Args: + scores: A score for each label which can be ranked. + labels: A multi-hot encoding of the ground truth positives. Must match the + shape of scores. + label_mask: A mask indicating which labels to involve in the calculation. + sort_descending: An indicator if the search result ordering is in descending + order (e.g. for evaluating over similarity metrics where higher scores are + preferred). If false, computes the generalize mean rank on descendingly + sorted inputs. + + Returns: + The generalized mean rank and its variance. + """ + # TODO(vdumoulin): add support for `label_mask`. + if label_mask is not None: + raise NotImplementedError + + idx = jnp.argsort(scores, axis=-1) + if sort_descending: + idx = jnp.flip(idx, axis=-1) + labels = jnp.take_along_axis(labels, idx, axis=-1) + + num_fp = (labels == 0).sum(axis=-1) + num_fp_above = jnp.cumsum(labels == 0, axis=-1) + + gmr = num_fp_above.mean(axis=-1, where=(labels > 0)) / num_fp + gmr_var = num_fp_above.var(axis=-1, where=(labels > 0)) / num_fp + return gmr, gmr_var + + def least_squares_solve_mix(matrix, rhs, diag_loading=1e-3): # Assumes a real-valued matrix, with zero mean. adj_matrix = jnp.conjugate(jnp.swapaxes(matrix, -1, -2)) diff --git a/chirp/models/cmap.py b/chirp/models/rank_based_metrics.py similarity index 62% rename from chirp/models/cmap.py rename to chirp/models/rank_based_metrics.py index 15964eaf..09f2c8c8 100644 --- a/chirp/models/cmap.py +++ b/chirp/models/rank_based_metrics.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Metric for Class Mean Average Precision (CMAP).""" +"""Rank-based metrics, including cmAP and generalized mean rank.""" from chirp.models import metrics from clu import metrics as clu_metrics import flax.struct @@ -22,16 +22,19 @@ @flax.struct.dataclass -class CMAP( +class RankBasedMetrics( # TODO(bartvm): Create class factory for calculating over different outputs clu_metrics.CollectingMetric.from_outputs(("label", "label_logits")) ): - """(Class-wise) mean average precision. + """(Class-wise) rank-based metrics. This metric calculates the average precision score of each class, and also returns the average of those values. This is sometimes referred to as the macro-averaged average precision, or the class-wise mean average precision (CmAP). + + It also calculates the generalized mean rank for each class and returns the + geometric average of those values. """ def compute(self, sample_threshold: int = 0): @@ -40,28 +43,43 @@ def compute(self, sample_threshold: int = 0): with jax.default_device(jax.devices("cpu")[0]): mask = jnp.sum(values["label"], axis=0) > sample_threshold if jnp.sum(mask) == 0: - return {"macro": 0.0} + return {"macro_cmap": 0.0, "macro_gmr": 0.0} # Same as sklearn's average_precision_score(label, logits, average=None) # but that implementation doesn't scale to 10k+ classes class_aps = metrics.average_precision( values["label_logits"].T, values["label"].T ) class_aps = jnp.where(mask, class_aps, jnp.nan) + + class_gmr, class_gmr_var = metrics.generalized_mean_rank( + values["label_logits"].T, values["label"].T + ) + class_gmr = jnp.where(mask, class_gmr, jnp.nan) + class_gmr_var = jnp.where(mask, class_gmr_var, jnp.nan) + return { - "macro": jnp.mean(class_aps, where=mask), - "individual": class_aps, + "macro_cmap": jnp.mean(class_aps, where=mask), + "individual_cmap": class_aps, + # If the GMR is 0.0 for at least one class, then the geometric average + # goes to zero. Instead, we take the geometric average of 1 - GMR and + # then take 1 - geometric_average. + "macro_gmr": 1.0 - jnp.exp( + jnp.mean(jnp.log(1.0 - class_gmr), where=mask) + ), + "individual_gmr": class_gmr, + "individual_gmr_var": class_gmr_var, } -def add_cmap_to_metrics_collection(name: str, metrics_collection): - """Adds a CMAP instance to an existing CLU metrics collection.""" +def add_rank_based_metrics_to_metrics_collection(name: str, metrics_collection): + """Adds a RankBasedMetric instance to an existing CLU metrics collection.""" new_collection = flax.struct.dataclass( type( "_ValidCollection", (metrics_collection,), { "__annotations__": { - f"{name}_cmap": CMAP, + f"{name}_rank_based": RankBasedMetrics, **metrics_collection.__annotations__, } }, diff --git a/chirp/projects/multicluster/classify.py b/chirp/projects/multicluster/classify.py index 9f7e9081..f676820a 100644 --- a/chirp/projects/multicluster/classify.py +++ b/chirp/projects/multicluster/classify.py @@ -17,7 +17,7 @@ import dataclasses -from chirp.models import cmap +from chirp.models import rank_based_metrics from chirp.projects.multicluster import data_lib import tensorflow as tf @@ -83,8 +83,8 @@ def train_embedding_model( # Manually compute per-class mAP and CmAP scores. test_logits = model.predict(test_ds, verbose=0) test_labels = merged.data['label_hot'][test_locs] - maps = cmap.CMAP.from_model_output( + maps = rank_based_metrics.RankBasedMetrics.from_model_output( label_logits=test_logits, label=test_labels ).compute() - cmap_value = maps.pop('macro') + cmap_value = maps.pop('macro_cmap') return ClassifierMetrics(acc, auc_roc, recall, cmap_value, maps) diff --git a/chirp/tests/metrics_test.py b/chirp/tests/metrics_test.py index 5527abb4..faffef7c 100644 --- a/chirp/tests/metrics_test.py +++ b/chirp/tests/metrics_test.py @@ -17,9 +17,9 @@ import functools import os -from chirp.models import cmap from chirp.models import cwt from chirp.models import metrics +from chirp.models import rank_based_metrics from clu import metrics as clu_metrics import flax import jax @@ -150,33 +150,61 @@ def test_cmap(self): [0, 0, 0], [1, 0, 0], ]) - full_cmap_value = cmap.CMAP.from_model_output( + full_cmap_value = rank_based_metrics.RankBasedMetrics.from_model_output( label=labels, label_logits=scores - ).compute()["macro"] + ).compute()["macro_cmap"] # Check against the manually verified outcome. self.assertAlmostEqual(full_cmap_value, 0.49687502) - batched_cmap_metric = cmap.CMAP.empty() + batched_cmap_metric = rank_based_metrics.RankBasedMetrics.empty() batched_cmap_metric = batched_cmap_metric.merge( - cmap.CMAP.from_model_output(label_logits=scores[:5], label=labels[:5]) + rank_based_metrics.RankBasedMetrics.from_model_output( + label_logits=scores[:5], label=labels[:5] + ) ) batched_cmap_metric = batched_cmap_metric.merge( - cmap.CMAP.from_model_output(label_logits=scores[5:], label=labels[5:]) + rank_based_metrics.RankBasedMetrics.from_model_output( + label_logits=scores[5:], label=labels[5:] + ) ) - batched_cmap_value = batched_cmap_metric.compute()["macro"] + batched_cmap_value = batched_cmap_metric.compute()["macro_cmap"] self.assertEqual(batched_cmap_value, full_cmap_value) # Check that when setting a threshold to 3, the cmap is only computed # taking into account column 1 (the only one with >3 samples). self.assertEqual( - cmap.CMAP.from_model_output(label_logits=scores, label=labels).compute( - sample_threshold=3 - )["macro"], - cmap.CMAP.from_model_output( + rank_based_metrics.RankBasedMetrics.from_model_output( + label_logits=scores, label=labels + ).compute(sample_threshold=3)["macro_cmap"], + rank_based_metrics.RankBasedMetrics.from_model_output( label_logits=scores[:, 1:2], label=labels[:, 1:2] - ).compute()["macro"], + ).compute()["macro_cmap"], ) + def test_gmr(self): + # The following example was worked out manually and verified. + scores = jnp.array([[0.9, 0.2, 0.3, 0.6, 0.5, 0.7, 0.1, 0.4, 0.8]]).T + labels = jnp.array([[0, 0, 0, 0, 0, 1, 0, 1, 1]]).T + full_gmr_value = rank_based_metrics.RankBasedMetrics.from_model_output( + label=labels, label_logits=scores + ).compute()["macro_gmr"] + # Check against the manually verified outcome. + self.assertAlmostEqual(full_gmr_value, 5.0 / 18.0) + + batched_gmr_metric = rank_based_metrics.RankBasedMetrics.empty() + batched_gmr_metric = batched_gmr_metric.merge( + rank_based_metrics.RankBasedMetrics.from_model_output( + label_logits=scores[:5], label=labels[:5] + ) + ) + batched_gmr_metric = batched_gmr_metric.merge( + rank_based_metrics.RankBasedMetrics.from_model_output( + label_logits=scores[5:], label=labels[5:] + ) + ) + batched_gmr_value = batched_gmr_metric.compute()["macro_gmr"] + self.assertEqual(batched_gmr_value, full_gmr_value) + if __name__ == "__main__": absltest.main() diff --git a/chirp/train/classifier.py b/chirp/train/classifier.py index c5202f8e..c9f10978 100644 --- a/chirp/train/classifier.py +++ b/chirp/train/classifier.py @@ -21,9 +21,9 @@ from absl import logging from chirp import export_utils from chirp.data import pipeline -from chirp.models import cmap from chirp.models import metrics from chirp.models import output +from chirp.models import rank_based_metrics from chirp.models import taxonomy_model from chirp.taxonomy import class_utils from chirp.train import utils @@ -277,12 +277,14 @@ def evaluate( if taxonomy_loss_weight != 0.0: taxonomy_keys += utils.TAXONOMY_KEYS - # The metrics are the same as for training, but with CmAP added + # The metrics are the same as for training, but with rank-based metrics added. base_metrics_collection = make_metrics_collection( name, taxonomy_keys, model_bundle.model.num_classes ) - valid_metrics_collection = cmap.add_cmap_to_metrics_collection( - name, base_metrics_collection + valid_metrics_collection = ( + rank_based_metrics.add_rank_based_metrics_to_metrics_collection( + name, base_metrics_collection + ) ) @functools.partial(jax.pmap, axis_name="batch") diff --git a/chirp/train/hubert.py b/chirp/train/hubert.py index a5d7c121..a8842896 100644 --- a/chirp/train/hubert.py +++ b/chirp/train/hubert.py @@ -23,7 +23,7 @@ from absl import logging from chirp.data import pipeline -from chirp.models import cmap +from chirp.models import rank_based_metrics from chirp.models import frontend as frontend_models from chirp.models import hubert from chirp.models import layers @@ -912,7 +912,9 @@ def evaluate( (base_metrics_collection,), { "__annotations__": { - f"{name}_cmap": cmap.CMAP, + f"{name}_rank_based_metrics": ( + rank_based_metrics.RankBasedMetrics + ), **base_metrics_collection.__annotations__, } }, diff --git a/chirp/train/separator.py b/chirp/train/separator.py index 776dd79f..2b9ac212 100644 --- a/chirp/train/separator.py +++ b/chirp/train/separator.py @@ -22,9 +22,9 @@ from absl import logging from chirp import export_utils from chirp.data import pipeline -from chirp.models import cmap from chirp.models import metrics from chirp.models import output +from chirp.models import rank_based_metrics from chirp.models import separation_model from chirp.taxonomy import class_utils from chirp.train import utils @@ -82,14 +82,6 @@ def p_log_sisnr_loss( ) -@flax.struct.dataclass -class ValidationMetrics(clu_metrics.Collection): - valid_loss: clu_metrics.Average.from_fun(p_log_snr_loss) - valid_mixit_log_mse: clu_metrics.Average.from_fun(p_log_mse_loss) - valid_mixit_neg_snr: clu_metrics.Average.from_fun(p_log_snr_loss) - valid_cmap: cmap.CMAP - - def keyed_cross_entropy( key: str, outputs: separation_model.SeparatorOutput, @@ -288,8 +280,10 @@ def evaluate( ): """Run evaluation.""" base_metrics_collection = make_metrics_collection('valid__') - valid_metrics_collection = cmap.add_cmap_to_metrics_collection( - 'valid', base_metrics_collection + valid_metrics_collection = ( + rank_based_metrics.add_rank_based_metrics_to_metrics_collection( + 'valid', base_metrics_collection + ) ) @functools.partial(jax.pmap, axis_name='batch')