From 2b28d3b3002ff86951af496278f18b1b6b1d0dcd Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Mon, 26 Jun 2023 16:53:39 -0700 Subject: [PATCH] Internal change. PiperOrigin-RevId: 543579662 --- clu/metric_writers/__init__.py | 2 +- clu/metric_writers/utils.py | 31 +++++++++++------- clu/metric_writers/utils_test.py | 56 +++++++++++++++++++------------- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/clu/metric_writers/__init__.py b/clu/metric_writers/__init__.py index 7232bff..a5d2739 100644 --- a/clu/metric_writers/__init__.py +++ b/clu/metric_writers/__init__.py @@ -47,10 +47,10 @@ from clu.metric_writers.async_writer import AsyncMultiWriter from clu.metric_writers.async_writer import AsyncWriter from clu.metric_writers.async_writer import ensure_flushes -from clu.metric_writers.summary_writer import SummaryWriter from clu.metric_writers.interface import MetricWriter from clu.metric_writers.logging_writer import LoggingWriter from clu.metric_writers.multi_writer import MultiWriter +from clu.metric_writers.summary_writer import SummaryWriter from clu.metric_writers.utils import create_default_writer from clu.metric_writers.utils import write_values diff --git a/clu/metric_writers/utils.py b/clu/metric_writers/utils.py index c4fe38e..d4769ae 100644 --- a/clu/metric_writers/utils.py +++ b/clu/metric_writers/utils.py @@ -31,10 +31,10 @@ from absl import logging from clu import values from clu.metric_writers.async_writer import AsyncMultiWriter -from clu.metric_writers.summary_writer import SummaryWriter from clu.metric_writers.interface import MetricWriter from clu.metric_writers.logging_writer import LoggingWriter from clu.metric_writers.multi_writer import MultiWriter +from clu.metric_writers.summary_writer import SummaryWriter from etils import epath import jax.numpy as jnp import numpy as np @@ -44,17 +44,22 @@ def _is_scalar(value: Any) -> bool: - if isinstance(value, values.Scalar) or isinstance(value, - (int, float, np.number)): + if isinstance(value, values.Scalar) or isinstance( + value, (int, float, np.number) + ): return True if isinstance(value, (np.ndarray, jnp.ndarray)): return value.ndim == 0 or value.size <= 1 return False -def write_values(writer: MetricWriter, step: int, - metrics: Mapping[str, Union[values.Value, values.ArrayType, - values.ScalarType]]): +def write_values( + writer: MetricWriter, + step: int, + metrics: Mapping[ + str, Union[values.Value, values.ArrayType, values.ScalarType] + ], +): """Writes all provided metrics. Allows providing a mapping of name to Value object, where each Value @@ -70,8 +75,9 @@ def write_values(writer: MetricWriter, step: int, histogram_num_buckets = collections.defaultdict(int) for k, v in metrics.items(): if isinstance(v, values.Summary): - writes[(writer.write_summaries, frozenset({"metadata": v.metadata - }.items()))][k] = v.value + writes[ + (writer.write_summaries, frozenset({"metadata": v.metadata}.items())) + ][k] = v.value elif _is_scalar(v): if isinstance(v, values.Scalar): writes[(writer.write_scalars, frozenset())][k] = v.value @@ -87,8 +93,10 @@ def write_values(writer: MetricWriter, step: int, writes[(writer.write_histograms, frozenset())][k] = v.value histogram_num_buckets[k] = v.num_buckets elif isinstance(v, values.Audio): - writes[(writer.write_audios, - frozenset({"sample_rate": v.sample_rate}.items()))][k] = v.value + writes[( + writer.write_audios, + frozenset({"sample_rate": v.sample_rate}.items()), + )][k] = v.value else: raise ValueError("Metric: ", k, " has unsupported value: ", v) @@ -107,7 +115,8 @@ def create_default_writer( *, just_logging: bool = False, asynchronous: bool = True, - collection: Optional[str] = None) -> MultiWriter: + collection: Optional[str] = None, +) -> MultiWriter: """Create the default writer for the platform. On most platforms this will create a MultiWriter that writes to multiple back diff --git a/clu/metric_writers/utils_test.py b/clu/metric_writers/utils_test.py index eda6cef..fe7a396 100644 --- a/clu/metric_writers/utils_test.py +++ b/clu/metric_writers/utils_test.py @@ -25,10 +25,10 @@ from clu.metric_writers import utils from clu.metric_writers.async_writer import AsyncMultiWriter from clu.metric_writers.async_writer import AsyncWriter -from clu.metric_writers.summary_writer import SummaryWriter from clu.metric_writers.interface import MetricWriter from clu.metric_writers.logging_writer import LoggingWriter from clu.metric_writers.multi_writer import MultiWriter +from clu.metric_writers.summary_writer import SummaryWriter import clu.metrics import flax.struct import jax.numpy as jnp @@ -129,17 +129,20 @@ def test_write(self): "image": ImageMetric(jnp.asarray([[4, 5], [1, 2]])), } histogram_metrics = { - "hist": - HistogramMetric(value=jnp.asarray([7, 8]), num_buckets=num_buckets), - "hist2": - HistogramMetric( - value=jnp.asarray([9, 10]), num_buckets=num_buckets), + "hist": HistogramMetric( + value=jnp.asarray([7, 8]), num_buckets=num_buckets + ), + "hist2": HistogramMetric( + value=jnp.asarray([9, 10]), num_buckets=num_buckets + ), } audio_metrics = { - "audio": - AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate), - "audio2": - AudioMetric(value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2), + "audio": AudioMetric( + value=jnp.asarray([1, 5]), sample_rate=sample_rate + ), + "audio2": AudioMetric( + value=jnp.asarray([1, 5]), sample_rate=sample_rate + 2 + ), } text_metrics = { "text": TextMetric(value="hello"), @@ -148,10 +151,10 @@ def test_write(self): "lr": HyperParamMetric(value=0.01), } summary_metrics = { - "summary": - SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata="some info"), - "summary2": - SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5), + "summary": SummaryMetric( + value=jnp.asarray([2, 3, 10]), metadata="some info" + ), + "summary2": SummaryMetric(value=jnp.asarray([2, 3, 10]), metadata=5), } metrics = { **scalar_metrics, @@ -166,29 +169,36 @@ def test_write(self): utils.write_values(writer, step, metrics) writer.write_scalars.assert_called_once_with( - step, {k: m.compute() for k, m in scalar_metrics.items()}) - writer.write_images.assert_called_once_with(step, - _to_summary(image_metrics)) + step, {k: m.compute() for k, m in scalar_metrics.items()} + ) + writer.write_images.assert_called_once_with( + step, _to_summary(image_metrics) + ) writer.write_histograms.assert_called_once_with( step, _to_summary(histogram_metrics), - num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()}) + num_buckets={k: v.num_buckets for k, v in histogram_metrics.items()}, + ) writer.write_audios.assert_called_with( step, ONEOF(_to_list_of_dicts(_to_summary(audio_metrics))), - sample_rate=ONEOF([sample_rate, sample_rate + 2])) + sample_rate=ONEOF([sample_rate, sample_rate + 2]), + ) writer.write_texts.assert_called_once_with(step, _to_summary(text_metrics)) - writer.write_hparams.assert_called_once_with(step, - _to_summary(hparam_metrics)) + writer.write_hparams.assert_called_once_with( + step, _to_summary(hparam_metrics) + ) writer.write_summaries.assert_called_with( step, ONEOF(_to_list_of_dicts(_to_summary(summary_metrics))), - metadata=ONEOF(["some info", 5])) + metadata=ONEOF(["some info", 5]), + ) def test_create_default_writer_summary_writer_is_added(self): writer = utils.create_default_writer( - logdir=self.get_temp_dir(), asynchronous=False) + logdir=self.get_temp_dir(), asynchronous=False + ) self.assertTrue(any(isinstance(w, SummaryWriter) for w in writer._writers)) writer = utils.create_default_writer(logdir=None, asynchronous=False) self.assertFalse(any(isinstance(w, SummaryWriter) for w in writer._writers))