Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Metrics and Collection class inherit from struct.PyTreeNode, so that we don't have to decorate them with struct.dataclass anymore. #254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/asynclib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/asynclib_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/dataset_iterator_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/deterministic_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/deterministic_data_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/async_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/async_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/logging_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/logging_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/multi_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/multi_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/summary_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/summary_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/torch_tensorboard_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/torch_tensorboard_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
41 changes: 13 additions & 28 deletions clu/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The CLU Authors.
# Copyright 2023 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,6 @@
import flax
import jax

@flax.struct.dataclass # required for jax.tree_*
class Metrics(metrics.Collection):
accuracy: metrics.Accuracy
loss: metrics.Average.from_output("loss")
Expand Down Expand Up @@ -63,6 +62,7 @@ def evaluate(model, p_variables, test_ds):
from clu.internal import utils
import clu.values
import flax
from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -78,7 +78,7 @@ def _assert_same_shape(a: jnp.array, b: jnp.array):
raise ValueError(f"Expected same shape: {a.shape} != {b.shape}")


class Metric:
class Metric(struct.PyTreeNode):
"""Interface for computing metrics from intermediate values.

Refer to `Collection` for computing multipel metrics at the same time.
Expand All @@ -88,7 +88,6 @@ class Metric:
import jax.numpy as jnp
import flax

@flax.struct.dataclass
class Average(Metric):
total: jnp.array
count: jnp.array
Expand Down Expand Up @@ -200,7 +199,6 @@ def from_fun(cls, fun: Callable): # pylint: disable=g-bare-generic
def get_head1(head1_loss, head1_mask, **_):
return dict(loss=head1_loss, mask=head1_mask)

@flax.struct.dataclass
class MultiHeadMetrics(metrics.Collection):
head1_loss: metrics.Average.from_output("loss").from_fun(get_head1)
...
Expand All @@ -218,7 +216,6 @@ class MultiHeadMetrics(metrics.Collection):
`model_output`.
"""

@flax.struct.dataclass
class FromFun(cls):
"""Wrapper Metric class that collects output after applying `fun`."""

Expand Down Expand Up @@ -261,7 +258,6 @@ def from_output(cls, name: str): # pylint: disable=g-bare-generic

Synopsis:

@flax.struct.dataclass
class Metrics(Collection):
loss: Average.from_output('loss')

Expand All @@ -280,7 +276,6 @@ class Metrics(Collection):
a first argument the model output specified by `name`.
"""

@flax.struct.dataclass
class FromOutput(cls):
"""Wrapper Metric class that collects output named `name`."""

Expand All @@ -299,7 +294,6 @@ def from_model_output(cls, **model_output) -> Metric:
return FromOutput


@flax.struct.dataclass
class CollectingMetric(Metric):
"""A special metric that collects model outputs.

Expand All @@ -323,7 +317,6 @@ class CollectingMetric(Metric):

Example to use compute average precision using `sklearn`:

@flax.struct.dataclass
class AveragePrecision(
metrics.CollectingMetric.from_outputs(("labels", "logits"))):

Expand Down Expand Up @@ -387,7 +380,6 @@ def compute(self) -> Dict[str, np.ndarray]:
def from_outputs(cls, names: Sequence[str]):
"""Returns a metric class that collects all model outputs named `names`."""

@flax.struct.dataclass
class FromOutputs(cls): # pylint:disable=missing-class-docstring

@classmethod
Expand All @@ -403,7 +395,6 @@ def make_array(value):
return FromOutputs


@flax.struct.dataclass
class _ReductionCounter(Metric):
"""Pseudo metric that keeps track of the total number of `.merge()`."""

Expand All @@ -425,15 +416,13 @@ def _check_reduction_counter_ndim(reduction_counter: _ReductionCounter):
f"call a flax.jax_utils.unreplicate() or a Collections.reduce()?")


@flax.struct.dataclass
class Collection:
class Collection(struct.PyTreeNode):
"""Updates a collection of `Metric` from model outputs.

Refer to the module documentation for a complete example.

Synopsis:

@flax.struct.dataclass
class Metrics(Collection):
accuracy: Accuracy

Expand All @@ -453,7 +442,6 @@ def create(cls, **metrics: Type[Metric]) -> Type["Collection"]:

Instead declaring a `Collection` dataclass:

@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy

Expand All @@ -470,7 +458,7 @@ class MyMetrics(metrics.Collection):
Returns:
A subclass of Collection with fields defined by provided `metrics`.
"""
return flax.struct.dataclass(
return struct.dataclass(
type("_InlineCollection", (Collection,), {"__annotations__": metrics}))

@classmethod
Expand All @@ -485,7 +473,6 @@ def create_collection(cls, **metrics: Metric) -> "Collection":

is equivalent to:

@flax.struct.dataclass
class MyMetrics(metrics.Collection):
accuracy: metrics.Accuracy
my_metrics = MyMetrics(_ReductionCounter(jnp.array(1)),
Expand Down Expand Up @@ -623,7 +610,6 @@ def unreplicate(self) -> "Collection":
return flax.jax_utils.unreplicate(self)


@flax.struct.dataclass
class LastValue(Metric):
"""Keeps the last average global batch value.

Expand Down Expand Up @@ -655,21 +641,23 @@ def __init__(self, total: Optional[jnp.array] = None,
count = count if count is not None else jnp.array(1, dtype=jnp.int32)
if value is not None:
if total is not None:
raise ValueError("Only one of 'total' and 'value' should be None. "
f'Got {total}, {value}')
raise ValueError(
"Only one of 'total' and 'value' should be None. "
f"Got {total}, {value}"
)
total = value * count
object.__setattr__(self, "total", total)
object.__setattr__(self, "count", count)
super().__init__()

@classmethod
def empty(cls):
return cls(jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))

@classmethod
def from_model_output(cls,
value: jnp.array,
mask: Optional[jnp.array] = None,
**_) -> Metric:
def from_model_output(
cls, value: jnp.array, mask: Optional[jnp.array] = None, **_
) -> Metric:
if mask is None:
mask = jnp.ones((value.shape or [()])[0])
return cls(
Expand Down Expand Up @@ -699,7 +687,6 @@ def compute(self) -> Any:
return self.value


@flax.struct.dataclass
class Average(Metric):
"""Computes the average of a scalar or a batch of tensors.

Expand Down Expand Up @@ -760,7 +747,6 @@ def compute(self) -> Any:
return self.total / self.count


@flax.struct.dataclass
class Std(Metric):
"""Computes the standard deviation of a scalar or a batch of scalars.

Expand Down Expand Up @@ -817,7 +803,6 @@ def compute(self) -> Any:
return variance**.5


@flax.struct.dataclass
class Accuracy(Average):
"""Computes the accuracy from model outputs `logits` and `labels`.

Expand Down
Loading