diff --git a/init2winit/utils.py b/init2winit/utils.py index 911718b2..66f5181e 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -229,6 +229,14 @@ def append_scalar_metrics(self, metrics): # TODO(gdahl,gilmer): Should this be an atomic file? with gfile.GFile(self._csv_path, 'w') as csv_file: measurements.to_csv(csv_file, index=False) + + if 'global_step' in metrics: + if isinstance(metrics['global_step'], np.ndarray): + if metrics['global_step'].shape == (1,): + metrics['global_step'] = int(metrics['global_step'][0]) + else: + metrics['global_step'] = int(metrics['global_step']) + if self._xm_work_unit: for name, value in metrics.items(): if name not in self._measurements: @@ -239,7 +247,7 @@ def append_scalar_metrics(self, metrics): if self._tb_metric_writer: self._tb_metric_writer.write_scalars( - step=int(metrics['global_step']), scalars=metrics) + step=metrics['global_step'], scalars=metrics) # This gives a 1-2% slowdown in steps_per_sec on cifar-10 with batch # size 512. We could only flush at the end of training to optimize this. self._tb_metric_writer.flush()