Skip to content

Commit

Permalink
Fix to convert measurement arrays to scalars.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662993694
  • Loading branch information
priyakasimbeg authored and copybara-github committed Aug 14, 2024
1 parent c2c1b69 commit 689a815
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions init2winit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def array_append(full_array, to_append):
return jnp.concatenate((full_array, to_append))


def reduce_to_scalar(value):
"""Helper function to reduce an numpy array to a scalar by extracting the first element."""
if isinstance(value, np.ndarray):
if value.shape == (1,):
value = value[0]
return value


def dtype_from_str(dtype_string):
# We use strings to avoid having to import jnp into the config files.
if dtype_string == 'float32':
Expand Down Expand Up @@ -231,19 +239,16 @@ def append_scalar_metrics(self, metrics):
measurements.to_csv(csv_file, index=False)

if 'global_step' in metrics:
if isinstance(metrics['global_step'], np.ndarray):
if metrics['global_step'].shape == (1,):
metrics['global_step'] = int(metrics['global_step'][0])
else:
metrics['global_step'] = int(metrics['global_step'])
metrics['global_step'] = int(reduce_to_scalar(metrics['global_step']))

if self._xm_work_unit:
for name, value in metrics.items():
if name not in self._measurements:
self._measurements[name] = self._xm_work_unit.get_measurement_series(
label=name)
self._measurements[name].create_measurement(
objective_value=value, step=metrics['global_step'])
objective_value=reduce_to_scalar(value), step=metrics['global_step']
)

if self._tb_metric_writer:
self._tb_metric_writer.write_scalars(
Expand Down

0 comments on commit 689a815

Please sign in to comment.