Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
metrics: Use Array and ArrayLike types thoughout
Currently the inputs to `from_model_output` are not typed. However, these functions cannot accept arbitrary inputs, they need to be a value convertable to a `jax.Array`. This change fixes this so that: - `from_model_output` takes in types of `Array` or `ArrayLike` - Removes use of `jnp.array` as a type as it's equivalent to `Any` - Makes members of Metric classes have type `Array` - Moves mask checking code into its own function While we could make everything use `Array` (instead of `ArrayLike`) this would break code like: ``` @flax.struct.dataclass class Collection(metrics.Collection): train_accuracy: metrics.Accuracy learning_rate: metrics.LastValue.from_output("learning_rate") Collection.gather_from_model_output(learning_rate=0.02, ...) ``` which seems undesirable. Note that `count` and `value` for `LastValue` have type `ArrayLike`, as this code needs to support passing a plain number for `value` or `count`. Also, the base `Metric.compute()` method has type `Any`, because some metrics return `Array` while others use `dict[str, Array]`. PiperOrigin-RevId: 529227218
- Loading branch information