Skip to content

Commit

Permalink
metrics: Use Array and ArrayLike types thoughout
Browse files Browse the repository at this point in the history
Currently the inputs to `from_model_output` are not typed. However,
these functions cannot accept arbitrary inputs, they need to be a value
convertable to a `jax.Array`. This change fixes this so that:
  - `from_model_output` takes in types of `Array` or `ArrayLike`
  - Removes use of `jnp.array` as a type as it's equivalent to `Any`
  - Makes members of Metric classes have type `Array`
  - Moves mask checking code into its own function

While we could make everything use `Array` (instead of `ArrayLike`),
this would break code like:
```
@flax.struct.dataclass
class Collection(metrics.Collection):
  train_accuracy: metrics.Accuracy
  learning_rate: metrics.LastValue.from_output("learning_rate")

Collection.gather_from_model_output(learning_rate=0.02, ...)
```
which seems undesirable.

Note that `count` and `value` for `LastValue` have type `ArrayLike`,
as this code needs to support passing a plain number for `value` or
`count`. Also, the base `Metric.compute()` method has type `Any`,
because some metrics return `Array` while others use `dict[str, Array]`.

PiperOrigin-RevId: 529227218
  • Loading branch information
josephlr authored and copybara-github committed May 4, 2023
1 parent f8eec70 commit 77471a2
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 87 deletions.
Loading

0 comments on commit 77471a2

Please sign in to comment.