Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654922896
  • Loading branch information
sourabh2k15 authored and copybara-github committed Jul 22, 2024
1 parent afbe0dc commit c2c1b69
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion init2winit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def append_scalar_metrics(self, metrics):
# TODO(gdahl,gilmer): Should this be an atomic file?
with gfile.GFile(self._csv_path, 'w') as csv_file:
measurements.to_csv(csv_file, index=False)

if 'global_step' in metrics:
if isinstance(metrics['global_step'], np.ndarray):
if metrics['global_step'].shape == (1,):
metrics['global_step'] = int(metrics['global_step'][0])
else:
metrics['global_step'] = int(metrics['global_step'])

if self._xm_work_unit:
for name, value in metrics.items():
if name not in self._measurements:
Expand All @@ -239,7 +247,7 @@ def append_scalar_metrics(self, metrics):

if self._tb_metric_writer:
self._tb_metric_writer.write_scalars(
step=int(metrics['global_step']), scalars=metrics)
step=metrics['global_step'], scalars=metrics)
# This gives a 1-2% slowdown in steps_per_sec on cifar-10 with batch
# size 512. We could only flush at the end of training to optimize this.
self._tb_metric_writer.flush()
Expand Down

0 comments on commit c2c1b69

Please sign in to comment.