Skip to content

Commit

Permalink
[Testing] Fix bf16->f32 utility to preserve shape of the input (nod-a…
Browse files Browse the repository at this point in the history
…i#835)

-- This commit adds a quick fix for bf16->f32 utility to preserve the
shape of the input.
-- Without this the utility is linearizing the input which causes issues
during comparison.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Oct 8, 2024
1 parent ebc7c4e commit 941596b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion build_tools/ci/cpu_comparison/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def convert_bf16_to_f32(bfloat16_array):
bit of info on the mantissa/exponent manipulation.
"""
v0 = bfloat16_array.astype(np.uint32) << 16
return np.frombuffer(v0.tobytes(), dtype=np.float32)
return np.frombuffer(v0.tobytes(), dtype=np.float32).reshape(bfloat16_array.shape)


def generate_bfloat16_data(num_values, lower_bound, upper_bound, rng):
Expand Down
5 changes: 3 additions & 2 deletions build_tools/ci/cpu_comparison/test_input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ def test_conversion():
"""
Check that float(bfloat(a)) is (almost) a.
"""
expected = np.array([1.5, 3.125, -1.5, -32.0, 0.0, -3.125], dtype=np.float32)
a = np.array([1.5, 3.14, -1.5, -32, 0, -3.14], np.float32)
expected = np.array([1.5, 3.125, -1.5, -32.0, 0.0, -3.125], dtype=np.float32).reshape([2,3])
a = np.array([1.5, 3.14, -1.5, -32, 0, -3.14], np.float32).reshape([2,3])
b = [convert_f32_to_bf16(x) for x in a]
c = convert_bf16_to_f32(np.array(b))
assert np.allclose(c, expected, 0, 0)

0 comments on commit 941596b

Please sign in to comment.