Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543579662
  • Loading branch information
Marvin182 authored and copybara-github committed Jun 26, 2023
1 parent a766965 commit 2b28d3b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
2 changes: 1 addition & 1 deletion clu/metric_writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.async_writer import ensure_flushes
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.utils import create_default_writer
from clu.metric_writers.utils import write_values

Expand Down
31 changes: 20 additions & 11 deletions clu/metric_writers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
from absl import logging
from clu import values
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
from etils import epath
import jax.numpy as jnp
import numpy as np
Expand All @@ -44,17 +44,22 @@


def _is_scalar(value: Any) -> bool:
if isinstance(value, values.Scalar) or isinstance(value,
(int, float, np.number)):
if isinstance(value, values.Scalar) or isinstance(
value, (int, float, np.number)
):
return True
if isinstance(value, (np.ndarray, jnp.ndarray)):
return value.ndim == 0 or value.size <= 1
return False


def write_values(writer: MetricWriter, step: int,
metrics: Mapping[str, Union[values.Value, values.ArrayType,
values.ScalarType]]):
def write_values(
writer: MetricWriter,
step: int,
metrics: Mapping[
str, Union[values.Value, values.ArrayType, values.ScalarType]
],
):
"""Writes all provided metrics.
Allows providing a mapping of name to Value object, where each Value
Expand All @@ -70,8 +75,9 @@ def write_values(writer: MetricWriter, step: int,
histogram_num_buckets = collections.defaultdict(int)
for k, v in metrics.items():
if isinstance(v, values.Summary):
writes[(writer.write_summaries, frozenset({"metadata": v.metadata
}.items()))][k] = v.value
writes[
(writer.write_summaries, frozenset({"metadata": v.metadata}.items()))
][k] = v.value
elif _is_scalar(v):
if isinstance(v, values.Scalar):
writes[(writer.write_scalars, frozenset())][k] = v.value
Expand All @@ -87,8 +93,10 @@ def write_values(writer: MetricWriter, step: int,
writes[(writer.write_histograms, frozenset())][k] = v.value
histogram_num_buckets[k] = v.num_buckets
elif isinstance(v, values.Audio):
writes[(writer.write_audios,
frozenset({"sample_rate": v.sample_rate}.items()))][k] = v.value
writes[(
writer.write_audios,
frozenset({"sample_rate": v.sample_rate}.items()),
)][k] = v.value
else:
raise ValueError("Metric: ", k, " has unsupported value: ", v)

Expand All @@ -107,7 +115,8 @@ def create_default_writer(
*,
just_logging: bool = False,
asynchronous: bool = True,
collection: Optional[str] = None) -> MultiWriter:
collection: Optional[str] = None,
) -> MultiWriter:
"""Create the default writer for the platform.
On most platforms this will create a MultiWriter that writes to multiple back
Expand Down
56 changes: 33 additions & 23 deletions clu/metric_writers/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from clu.metric_writers import utils
from clu.metric_writers.async_writer import AsyncMultiWriter
from clu.metric_writers.async_writer import AsyncWriter
from clu.metric_writers.summary_writer import SummaryWriter
from clu.metric_writers.interface import MetricWriter
from clu.metric_writers.logging_writer import LoggingWriter
from clu.metric_writers.multi_writer import MultiWriter
from clu.metric_writers.summary_writer import SummaryWriter
import clu.metrics
import flax.struct
import jax.numpy as jnp
Expand Down Expand Up @@ -129,17 +129,20 @@ def test_write(self):
"image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])),
}
histogram_metrics = {
"hist":
HistogramMetric(value=jnp.asarray([7, 8]), num_buckets=num_buckets),
"hist2":
HistogramMetric(
value=jnp.asarray([9, 10]), num_buckets=num_buckets),
"hist": HistogramMetric(
value=jnp.asarray([7, 8]), num_buckets=num_buckets
),
"hist2": HistogramMetric(
value=jnp.asarray([9, 10]), num_buckets=num_buckets
),
}
audio_metrics = {
"audio":
AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate),
"audio2":
AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2),
"audio": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate
),
"audio2": AudioMetric(
value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2
),
}
text_metrics = {
"text": TextMetric(value="hello"),
Expand All @@ -148,10 +151,10 @@ def test_write(self):
"lr": HyperParamMetric(value=0.01),
}
summary_metrics = {
"summary":
SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata="some info"),
"summary2":
SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
"summary": SummaryMetric(
value=jnp.asarray([2, 3, 10]), metadata="some info"
),
"summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5),
}
metrics = {
**scalar_metrics,
Expand All @@ -166,29 +169,36 @@ def test_write(self):
utils.write_values(writer, step, metrics)

writer.write_scalars.assert_called_once_with(
step, {k: m.compute() for k, m in scalar_metrics.items()})
writer.write_images.assert_called_once_with(step,
_to_summary(image_metrics))
step, {k: m.compute() for k, m in scalar_metrics.items()}
)
writer.write_images.assert_called_once_with(
step, _to_summary(image_metrics)
)
writer.write_histograms.assert_called_once_with(
step,
_to_summary(histogram_metrics),
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()})
num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()},
)
writer.write_audios.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))),
sample_rate=ONEOF([sample_rate, sample_rate + 2]))
sample_rate=ONEOF([sample_rate, sample_rate + 2]),
)
writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics))
writer.write_hparams.assert_called_once_with(step,
_to_summary(hparam_metrics))
writer.write_hparams.assert_called_once_with(
step, _to_summary(hparam_metrics)
)
writer.write_summaries.assert_called_with(
step,
ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))),
metadata=ONEOF(["some info", 5]))
metadata=ONEOF(["some info", 5]),
)


def test_create_default_writer_summary_writer_is_added(self):
writer = utils.create_default_writer(
logdir=self.get_temp_dir(), asynchronous=False)
logdir=self.get_temp_dir(), asynchronous=False
)
self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers))
writer = utils.create_default_writer(logdir=None, asynchronous=False)
self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))
Expand Down

0 comments on commit 2b28d3b

Please sign in to comment.