diff --git a/build_tools/ci/cpu_comparison/input_generator.py b/build_tools/ci/cpu_comparison/input_generator.py index 8064aed89..73c377484 100644 --- a/build_tools/ci/cpu_comparison/input_generator.py +++ b/build_tools/ci/cpu_comparison/input_generator.py @@ -69,7 +69,7 @@ def convert_bf16_to_f32(bfloat16_array): bit of info on the mantissa/exponent manipulation. """ v0 = bfloat16_array.astype(np.uint32) << 16 - return np.frombuffer(v0.tobytes(), dtype=np.float32) + return np.frombuffer(v0.tobytes(), dtype=np.float32).reshape(bfloat16_array.shape) def generate_bfloat16_data(num_values, lower_bound, upper_bound, rng): diff --git a/build_tools/ci/cpu_comparison/test_input_generator.py b/build_tools/ci/cpu_comparison/test_input_generator.py index ee605e177..50d23bbb1 100644 --- a/build_tools/ci/cpu_comparison/test_input_generator.py +++ b/build_tools/ci/cpu_comparison/test_input_generator.py @@ -6,8 +6,9 @@ def test_conversion(): """ Check that float(bfloat(a)) is (almost) a. """ - expected = np.array([1.5, 3.125, -1.5, -32.0, 0.0, -3.125], dtype=np.float32) - a = np.array([1.5, 3.14, -1.5, -32, 0, -3.14], np.float32) + expected = np.array([1.5, 3.125, -1.5, -32.0, 0.0, -3.125], dtype=np.float32).reshape([2,3]) + a = np.array([1.5, 3.14, -1.5, -32, 0, -3.14], np.float32).reshape([2,3]) b = [convert_f32_to_bf16(x) for x in a] c = convert_bf16_to_f32(np.array(b)) assert np.allclose(c, expected, 0, 0) +