Skip to content

Commit

Permalink
Simplify reduce_to_scalar function in utils for measurements.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672737311
  • Loading branch information
priyakasimbeg authored and copybara-github committed Sep 10, 2024
1 parent 7d078d1 commit 7fba9e2
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions init2winit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,9 @@ def array_append(full_array, to_append):


def reduce_to_scalar(value):
"""Helper function to reduce an numpy array to a scalar by extracting the first element."""
while isinstance(value, np.ndarray) and value.ndim > 0:
value = value[0]
if isinstance(value, np.ndarray) and value.ndim == 0:
value = value.item()
"""Reduce an numpy array to a scalar by extracting the first element."""
if isinstance(value, np.ndarray) or isinstance(value, jnp.ndarray):
value = value.item(0)
return value


Expand Down

0 comments on commit 7fba9e2

Please sign in to comment.