From b515d753c49756c96a0f8b26a7eb3309ae32d70f Mon Sep 17 00:00:00 2001 From: Joe Richey Date: Wed, 3 May 2023 16:36:06 -0700 Subject: [PATCH] metrics: Use Array and ArrayLike types thoughout Currently the inputs to `from_model_output` are not typed. However, these functions cannot accept arbitrary inputs, they need to be a value convertable to a `jax.Array`. This change fixes this so that: - `from_model_output` takes in types of `Array` or `ArrayLike` - Removes use of `jnp.array` as a type as it's equivalent to `Any` - Makes members of Metric classes have type `Array` - Moves mask checking code into its own function While we could make everything use `Array` (instead of `ArrayLike`), this would break code like: ``` @flax.struct.dataclass class Collection(metrics.Collection): train_accuracy: metrics.Accuracy learning_rate: metrics.LastValue.from_output("learning_rate") Collection.gather_from_model_output(learning_rate=0.02, ...) ``` which seems undesirable. Note that `count` and `value` for `LastValue` have type `ArrayLike`, as this code needs to support passing a plain number for `value` or `count`. Also, the base `Metric.compute()` method has type `Any`, because some metrics return `Array` while others use `dict[str, Array]`. PiperOrigin-RevId: 529227218 --- clu/metrics.py | 206 ++++++++++++++++++++++++++------------------ clu/metrics_test.py | 26 ++++++ clu/values.py | 8 +- 3 files changed, 150 insertions(+), 90 deletions(-) diff --git a/clu/metrics.py b/clu/metrics.py index e8d41f7..e2bf9ec 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -27,7 +27,7 @@ The "model output" is a dictionary of values with unique keys that all have a specific meaning (such as `loss`, `logits`, and `labels`) and every metric depends on at least one such model output by name. These outputs are usually -expected to be instances of `jnp.array`. +expected to be instances of `jax.Array`. Synopsis: @@ -55,8 +55,9 @@ def evaluate(model, p_variables, test_ds): ms = p_eval_step(ms, model, p_variables, inputs, labels) return ms.unreplicate().compute() """ -from collections.abc import Callable, Mapping, Sequence -from typing import Any, Optional, TypeVar +from __future__ import annotations +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Protocol, TypeVar from absl import logging @@ -67,17 +68,51 @@ def evaluate(model, p_variables, test_ds): import jax.numpy as jnp import numpy as np +Array = jax.Array +# For backwards compatibility, allow explicit None in model_output. +ArrayLike = Optional[jax.typing.ArrayLike] + + +class FromFunCallable(Protocol): + """The type of functions that can be passed to `Metrics.from_fun()`.""" + + def __call__(self, **kwargs: ArrayLike) -> Array | Mapping[str, Array]: + ... + + # TODO(b/200953513): Migrate away from logging imports (on module level) # to logging the actual usage. See b/200953513. -def _assert_same_shape(a: jnp.array, b: jnp.array): +def _get_mask(output: Array, mask: ArrayLike | None, msg: str) -> Array | None: + """For a given mask and output to be masked, return the mask we should use.""" + if mask is None: + return None + output_shape = jnp.shape(output) + leading_output = output_shape[0] if output_shape else 0 + mask_shape = jnp.shape(mask) + if mask_shape and mask_shape[0] == leading_output: + return jnp.array(mask) + + logging.warning( + "Ignoring mismatched mask in %s: output.shape=%s vs. mask.shape=%s", + msg, + output_shape, + mask_shape, + ) + return None + + +def _assert_same_shape(a: Array, b: Array): """Raises a `ValueError` if shapes of `a` and `b` don't match.""" if a.shape != b.shape: raise ValueError(f"Expected same shape: {a.shape} != {b.shape}") +M = TypeVar("M", bound="Metric") + + class Metric: """Interface for computing metrics from intermediate values. @@ -90,11 +125,11 @@ class Metric: @flax.struct.dataclass class Average(Metric): - total: jnp.array - count: jnp.array + total: jax.Array + count: jax.Array @classmethod - def from_model_output(cls, value: jnp.array, **_) -> Metric: + def from_model_output(cls, value: jax.Array, **_) -> Metric: return cls(total=value.sum(), count=np.prod(value.shape)) def merge(self, other: Metric) -> Metric: @@ -114,11 +149,13 @@ def compute(self): """ @classmethod - def from_model_output(cls, *args, **kwargs) -> "Metric": + def from_model_output( + cls: type[M], *args: ArrayLike, **kwargs: ArrayLike + ) -> M: """Creates a `Metric` from model outputs.""" raise NotImplementedError("Must override from_model_output()") - def merge(self, other: "Metric") -> "Metric": + def merge(self: M, other: M) -> M: """Returns `Metric` that is the accumulation of `self` and `other`. Args: @@ -141,15 +178,15 @@ def merge(self, other: "Metric") -> "Metric": # `_reduce_merge()` must be associative[1], otherwise we would get # different results when using different devices. # [1] https://en.wikipedia.org/wiki/Associative_property - def _reduce_merge(self, other: "Metric") -> "Metric": + def _reduce_merge(self: M, other: M) -> M: return self.merge(other) - def compute(self) -> jnp.array: + def compute(self) -> Any: """Computes final metrics from intermediate values.""" raise NotImplementedError("Must override compute()") @classmethod - def empty(cls) -> "Metric": + def empty(cls: type[M]) -> M: """Returns an empty instance (i.e. `.merge(Metric.empty())` is a no-op).""" raise NotImplementedError("Must override empty()") @@ -157,7 +194,7 @@ def compute_value(self) -> clu.values.Value: """Wraps compute() and returns a values.Value.""" return clu.values.Scalar(self.compute()) - def reduce(self) -> "Metric": + def reduce(self: M) -> M: """Reduces the metric along it first axis by calling `_reduce_merge()`. This function primary use case is to aggregate metrics collected across @@ -173,7 +210,7 @@ def reduce(self) -> "Metric": reduced metric. """ - def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]: + def reduce_step(reduced: M, metric: M) -> tuple[M, None]: # pylint: disable-next=protected-access return reduced._reduce_merge(metric), None @@ -183,7 +220,7 @@ def reduce_step(reduced: Metric, metric: Metric) -> tuple[Metric, None]: return jax.lax.scan(reduce_step, first, remainder)[0] @classmethod - def from_fun(cls, fun: Callable): # pylint: disable=g-bare-generic + def from_fun(cls, fun: FromFunCallable): # No way to annotate return type """Calls `cls.from_model_output` with the return value from `fun`. Returns a `Metric` derived from `cls` whose `.from_model_output` (1) calls @@ -233,7 +270,7 @@ class FromFun(cls): """Wrapper Metric class that collects output after applying `fun`.""" @classmethod - def from_model_output(cls, **model_output) -> Metric: + def from_model_output(cls: type[M], **model_output: ArrayLike) -> M: mask = model_output.get("mask") output = fun(**model_output) if isinstance(output, Mapping) and "mask" in output: @@ -252,12 +289,7 @@ def from_model_output(cls, **model_output) -> Metric: first_output = next(iter(output.values())) else: first_output = output - if (first_output.shape or [0])[0] != mask.shape[0]: - logging.warning( - "Ignoring mask for fun(**model output) because of shape " - "mismatch: output.shape=%s vs. mask.shape=%s", - first_output.shape, mask.shape) - mask = None + mask = _get_mask(first_output, mask, "fun(**model output)") if isinstance(output, Mapping): return super().from_model_output(**output, mask=mask) else: @@ -266,7 +298,7 @@ def from_model_output(cls, **model_output) -> Metric: return FromFun @classmethod - def from_output(cls, name: str): # pylint: disable=g-bare-generic + def from_output(cls, name: str): # No way to annotate return type """Calls `cls.from_model_output` with model output named `name`. Synopsis: @@ -295,15 +327,13 @@ class FromOutput(cls): """Wrapper Metric class that collects output named `name`.""" @classmethod - def from_model_output(cls, **model_output) -> Metric: - output = jnp.array(model_output[name]) + def from_model_output(cls: type[M], **model_output: ArrayLike) -> M: + output = model_output.get(name) + if output is None: + raise KeyError(f"'{name}' is not present in the model output") + output = jnp.array(output) mask = model_output.get("mask") - if mask is not None and (output.shape or [0])[0] != mask.shape[0]: - logging.warning( - "Ignoring mask for model output '%s' because of shape mismatch: " - "output.shape=%s vs. mask.shape=%s", name, output.shape, - mask.shape) - mask = None + mask = _get_mask(output, mask, f"model output {name}") return super().from_model_output(output, mask=mask) return FromOutput @@ -366,10 +396,10 @@ def merge(update): values: dict[str, tuple[np.ndarray, ...]] @classmethod - def empty(cls) -> "CollectingMetric": + def empty(cls) -> CollectingMetric: return cls(values={}) - def merge(self, other: "CollectingMetric") -> "CollectingMetric": + def merge(self, other: CollectingMetric) -> CollectingMetric: values = { name: (*value, *other.values[name]) for name, value in self.values.items() @@ -384,31 +414,33 @@ def merge(self, other: "CollectingMetric") -> "CollectingMetric": return self return type(self)(jax.tree_map(np.asarray, values)) - def reduce(self) -> "CollectingMetric": + def reduce(self) -> CollectingMetric: # Note that this is usually called from inside a `pmap()` via # `Collection.gather_from_model_output()` so we concatenate using jnp. return type(self)( {name: jnp.concatenate(values) for name, values in self.values.items()}) - def compute(self) -> dict[str, np.ndarray]: + def compute(self): # No return type annotation, so subclasses can override return {k: np.concatenate(v) for k, v in self.values.items()} @classmethod - def from_outputs(cls, names: Sequence[str]): + def from_outputs(cls, names: Sequence[str]) -> type[CollectingMetric]: """Returns a metric class that collects all model outputs named `names`.""" @flax.struct.dataclass class FromOutputs(cls): # pylint:disable=missing-class-docstring @classmethod - def from_model_output(cls, **model_output) -> Metric: - - def make_array(value): + def from_model_output(cls: type[M], **model_output: ArrayLike) -> M: + def get_value(name: str) -> Array: + value = model_output.get(name) + if value is None: + raise KeyError(f"'{name}' is not present in the model output") value = jnp.array(value) # Can't jnp.concatenate() scalars, promote to shape=(1,) in that case. return value[None] if value.ndim == 0 else value - return cls({name: (make_array(model_output[name]),) for name in names}) + return cls({name: (get_value(name),) for name in names}) return FromOutputs @@ -417,13 +449,13 @@ def make_array(value): class _ReductionCounter(Metric): """Pseudo metric that keeps track of the total number of `.merge()`.""" - value: jnp.array + value: Array @classmethod - def empty(cls): + def empty(cls) -> _ReductionCounter: return cls(value=jnp.array(1, jnp.int32)) - def merge(self, other: "_ReductionCounter") -> "_ReductionCounter": + def merge(self, other: _ReductionCounter) -> _ReductionCounter: return _ReductionCounter(self.value + other.value) @@ -461,7 +493,7 @@ class Metrics(Collection): _reduction_counter: _ReductionCounter @classmethod - def create(cls, **metrics: type[Metric]) -> type["Collection"]: + def create(cls, **metrics: type[Metric]) -> type[Collection]: """Handy short-cut to define a `Collection` inline. Instead declaring a `Collection` dataclass: @@ -487,7 +519,7 @@ class MyMetrics(metrics.Collection): type("_InlineCollection", (Collection,), {"__annotations__": metrics})) @classmethod - def create_collection(cls, **metrics: Metric) -> "Collection": + def create_collection(cls, **metrics: Metric) -> Collection: """Creates a custom collection object with fields metrics. This object will be an instance of custom subclass of `Collection` with @@ -524,7 +556,7 @@ def empty(cls: type[C]) -> C: }) @classmethod - def _from_model_output(cls: type[C], **kwargs) -> C: + def _from_model_output(cls: type[C], **kwargs: ArrayLike) -> C: """Creates a `Collection` from model outputs.""" return cls( _reduction_counter=_ReductionCounter(jnp.array(1, dtype=jnp.int32)), @@ -534,7 +566,7 @@ def _from_model_output(cls: type[C], **kwargs) -> C: }) @classmethod - def single_from_model_output(cls: type[C], **kwargs) -> C: + def single_from_model_output(cls: type[C], **kwargs: ArrayLike) -> C: """Creates a `Collection` from model outputs. Note: This function should only be called when metrics are collected in a @@ -549,7 +581,9 @@ def single_from_model_output(cls: type[C], **kwargs) -> C: return cls._from_model_output(**kwargs) @classmethod - def gather_from_model_output(cls: type[C], axis_name="batch", **kwargs) -> C: + def gather_from_model_output( + cls: type[C], axis_name="batch", **kwargs: ArrayLike + ) -> C: """Creates a `Collection` from model outputs in a distributed setting. Args: @@ -603,7 +637,7 @@ def reduce(self: C) -> C: for metric_name, metric in vars(self).items() }) - def compute(self) -> dict[str, jnp.array]: + def compute(self) -> dict[str, Array]: """Returns a dictionary mapping metric field name to `Metric.compute()`.""" _check_reduction_counter_ndim(self._reduction_counter) return { @@ -647,13 +681,15 @@ class LastValue(Metric): check. For backward compatibility this class can be initialized using the keyword `LastValue(value=10)` or `total` and `count`. """ - total: jnp.array - count: jnp.array - - def __init__(self, total: Optional[jnp.array] = None, - count: Optional[jnp.array] = None, - value: Optional[jnp.array] = None, - ): + total: Array + count: ArrayLike + + def __init__( + self, + total: ArrayLike | None = None, + count: ArrayLike | None = None, + value: ArrayLike | None = None, + ): """Constructor which supports keyword argument value as initializer. If "value" is provided, then "total" should *not* be provided. @@ -663,24 +699,23 @@ def __init__(self, total: Optional[jnp.array] = None, count: Count of examples, 1 if not provided value: Value, if provided, will be assumed to be "count" of values. """ - count = count if count is not None else jnp.array(1, dtype=jnp.int32) + count = count if count is not None else 1 if value is not None: if total is not None: raise ValueError("Only one of 'total' and 'value' should be None. " f'Got {total}, {value}') total = value * count - object.__setattr__(self, "total", total) + object.__setattr__(self, "total", jnp.array(total)) object.__setattr__(self, "count", count) @classmethod - def empty(cls): + def empty(cls) -> LastValue: return cls(jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) @classmethod - def from_model_output(cls, - value: jnp.array, - mask: Optional[jnp.array] = None, - **_) -> Metric: + def from_model_output( + cls, value: Array, mask: Array | None = None, **_: ArrayLike + ) -> LastValue: if mask is None: mask = jnp.ones((value.shape or [()])[0]) return cls( @@ -688,11 +723,11 @@ def from_model_output(cls, count=mask.sum().astype(jnp.int32), ) - def merge(self, other: "LastValue") -> "LastValue": + def merge(self, other: LastValue) -> LastValue: _assert_same_shape(self.value, other.value) return other - def _reduce_merge(self, other: "LastValue") -> "LastValue": + def _reduce_merge(self, other: LastValue) -> LastValue: # We need to average during reduction. _assert_same_shape(self.total, other.total) return type(self)( @@ -701,12 +736,12 @@ def _reduce_merge(self, other: "LastValue") -> "LastValue": ) @property - def value(self) -> jnp.array: + def value(self) -> Array: # Explicitly allow NaN division as it is part of normal computation here. with jax.debug_nans(False): return self.total / self.count - def compute(self) -> Any: + def compute(self) -> Array: return self.value @@ -726,18 +761,17 @@ class Average(Metric): See also documentation of `Metric`. """ - total: jnp.array - count: jnp.array + total: Array + count: ArrayLike @classmethod - def empty(cls) -> Metric: + def empty(cls) -> Average: return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) @classmethod - def from_model_output(cls, - values: jnp.array, - mask: Optional[jnp.array] = None, - **_) -> Metric: + def from_model_output( + cls, values: Array, mask: Array | None = None, **_: ArrayLike + ) -> Average: if values.ndim == 0: values = values[None] if mask is None: @@ -760,7 +794,7 @@ def from_model_output(cls, jnp.zeros_like(values, dtype=jnp.int32)).sum(), ) - def merge(self, other: "Average") -> "Average": + def merge(self, other: Average) -> Average: _assert_same_shape(self.total, other.total) return type(self)( total=self.total + other.total, @@ -778,22 +812,21 @@ class Std(Metric): See also documentation of `Metric`. """ - total: jnp.array - sum_of_squares: jnp.array - count: jnp.array + total: Array + sum_of_squares: Array + count: ArrayLike @classmethod - def empty(cls): + def empty(cls) -> Std: return cls( total=jnp.array(0, jnp.float32), sum_of_squares=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32)) @classmethod - def from_model_output(cls, - values: jnp.array, - mask: Optional[jnp.array] = None, - **_) -> Metric: + def from_model_output( + cls, values: Array, mask: Array | None = None, **_: ArrayLike + ) -> Std: if values.ndim == 0: values = values[None] utils.check_param(values, ndim=1) @@ -805,7 +838,7 @@ def from_model_output(cls, count=mask.sum(), ) - def merge(self, other: "Std") -> "Std": + def merge(self, other: Std) -> Std: _assert_same_shape(self.total, other.total) return type(self)( total=self.total + other.total, @@ -839,8 +872,9 @@ class Accuracy(Average): """ @classmethod - def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, - **kwargs) -> Metric: + def from_model_output( + cls, *, logits: Array, labels: Array, **kwargs: ArrayLike + ) -> Metric: if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" diff --git a/clu/metrics_test.py b/clu/metrics_test.py index c0464c1..1278a0b 100644 --- a/clu/metrics_test.py +++ b/clu/metrics_test.py @@ -270,12 +270,38 @@ def test_accuracy(self, reduce): self.make_compute_metric(metrics.Accuracy, reduce)(self.model_outputs), self.results["train_accuracy"]) + # Make sure mask properly forwards to Average (including explicit None) + def test_accuracy_masked(self): + logits = jnp.array([[0, 0.1], [0, -1.0]]) + labels = jnp.array([1, 1]) # 1st correct, 2nd incorrect + + accuracy_no_mask = metrics.Accuracy.from_model_output( + logits=logits, labels=labels, mask=None + ).compute() + self.assertEqual(accuracy_no_mask, 0.5) + + accuracy_with_mask = metrics.Accuracy.from_model_output( + logits=logits, labels=labels, mask=jnp.array([True, False]) + ).compute() + self.assertEqual(accuracy_with_mask, 1.0) + def test_last_value_asserts_shape(self): metric1 = metrics.LastValue.from_model_output(jnp.arange(3.)) metric2 = jax.tree_map(lambda *args: jnp.stack(args), metric1, metric1) with self.assertRaisesRegex(ValueError, r"^Expected same shape"): metric1.merge(metric2) + def test_last_value_from_output(self): + metric_class = metrics.LastValue.from_output("good") + metric = metric_class.from_model_output(good=jnp.arange(5)) + self.assertEqual(metric.compute(), 2) # Average of 0,1,2,3,4 + self.assertEqual(metric.value, 2) + + def test_missing_value_raises_error(self): + metric = metrics.LastValue.from_output("bad") + with self.assertRaisesRegex(KeyError, "bad"): + metric.from_model_output(good=jnp.arange(5)).compute() + @parameterized.named_parameters( ("", False), ("_reduce", True), diff --git a/clu/values.py b/clu/values.py index ca9f4a1..a5e5fb4 100644 --- a/clu/values.py +++ b/clu/values.py @@ -17,9 +17,8 @@ A Metric should return one of the following types when compute() is called. """ -import abc import dataclasses -from typing import Any, Union +from typing import Any, Union, Protocol, runtime_checkable import jax.numpy as jnp import numpy as np @@ -28,13 +27,14 @@ ScalarType = Union[int, float, np.number, np.ndarray, jnp.ndarray] -class Value(abc.ABC): +@runtime_checkable +class Value(Protocol): """Class defining available metric computation return values. Types mirror those available in MetricWriter. See clu/metric_writers/interface.py """ - pass + value: Any @dataclasses.dataclass