Skip to content

Commit

Permalink
[CI] Support bf16 output in end-to-end numerical testing (nod-ai#829)
Browse files Browse the repository at this point in the history
See comments in nod-ai#822
  • Loading branch information
newling authored Oct 7, 2024
1 parent 0757023 commit 82841cd
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 32 deletions.
91 changes: 86 additions & 5 deletions build_tools/ci/cpu_comparison/input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from numpy.random import Generator, MT19937, SeedSequence


def convert_f32_to_bf16(float32_value):
def convert_f32_to_bf16(float32_array):
"""
IEEE float32 to bfloat16
Expand All @@ -59,17 +59,25 @@ def convert_f32_to_bf16(float32_value):
to: [SEEEEEEEEMMMMMMM]
================= remove 16 bits of mantissa
"""
int32_repr = float32_value.view(np.int32)
bf16_int_repr = int32_repr >> 16
return np.uint16(bf16_int_repr)
v0 = float32_array.view(np.uint32) >> 16
return v0.astype(np.uint16)


def convert_bf16_to_f32(bfloat16_array):
"""
IEEE bfloat16 to float32. See docstring of convert_f32_to_bf16 for a
bit of info on the mantissa/exponent manipulation.
"""
v0 = bfloat16_array.astype(np.uint32) << 16
return np.frombuffer(v0.tobytes(), dtype=np.float32)


def generate_bfloat16_data(num_values, lower_bound, upper_bound, rng):

float_data = rng.integers(lower_bound, upper_bound, num_values).astype(np.float32)

# Convert float32 data to bfloat16
bf16_data = [convert_f32_to_bf16(f) for f in float_data]
bf16_data = convert_f32_to_bf16(float_data)

# Pack bfloat16 data into binary format
binary_data = struct.pack(f"{len(bf16_data)}H", *bf16_data)
Expand Down Expand Up @@ -165,6 +173,79 @@ def write_input(bin_filename, num_elements, element_type, input_number, input_se
file.write(data)


def get_output_type(filename):
"""
Reads the contents of 'filename' which must contain an MLIR function with
a single returned value, a tensor.
If there's a line of the form '// output 4xf32' then
just return the string '4xf32'.
Otherwise find the return op at the end of the function, and get the
type from the tensor type. i.e. get '3xf32' from 'tensor<3xf32>'
"""

with open(filename, "r") as file:
# First attempt: find line of the form '// output 4xf32'
# This is fail safe for developers: Just add this line to IR being
# tested.
for line in file:
line = line.strip()
tokens = line.split()
if len(tokens) > 2 and tokens[0] == "//":
if tokens[1] == "output":
return tokens[2].strip()

# Second attempt (for legacy test files)
# Find a line of the form
# 'return %foo : tensor<1x2x3x4xsi32>'
with open(filename, "r") as file:
for line in file:
if "return " in line:
line = line.strip()
lines = line.split("tensor<")
assert len(lines) == 2
line = lines[-1]
line = line[0:-1]
return line

raise ValueError(
"Could not find output from the MLIR file. Consider adding a line of the form // output to the file."
)


def np_from_binfile(bin_file, type_str):
"""
Load a numpy array from the binary file bin_file.
Not much interesting here, but the case where element_type_str is 'bf16' is
possibly not obvious: there is no native numpy element type for brainfloat,
so we load it as uint16 and then convert it to float32 (by just packing
extra mantissa 0 bits).
"""

element_type_str = type_str.strip().split("x")[-1]

# Get a numpy type from the string.
np_type = None
if element_type_str == "bf16":
np_type = np.uint16
else:
np_type = get_numpy_type(element_type_str)

shape = [int(x) for x in type_str.strip().split("x")[0:-1]]

# Load data with the numpy type specified.
array = np.fromfile(bin_file, dtype=np_type)
array = array.reshape(shape)

# If the numpy type was just a proxy, do some extra processing.
if element_type_str == "bf16":
array = convert_bf16_to_f32(array)

return array


def generate_inputs(filename, write_dir, seed):
"""
Parse the input file 'filename' and generate binary files for the inputs of
Expand Down
42 changes: 30 additions & 12 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

import numpy as np

from input_generator import generate_inputs, verify_determinism, load_input
from input_generator import (
generate_inputs,
verify_determinism,
load_input,
get_output_type,
np_from_binfile,
)
from matmul_template.matmul_generator import generate_matmul_test
from convolution_template.convolution_generator import ConvolutionMlirGenerator
from output_comparer import compare
Expand Down Expand Up @@ -103,12 +109,11 @@ def shell_out(cmd: list, workdir=None, verbose: int = 0, raise_on_error=True, en
if not raise_on_error and handle.returncode != 0:
print(
f"Error executing script, error code was {handle.returncode}. Not raising an error.",
file=sys.stderr
file=sys.stderr,
)
if raise_on_error and handle.returncode != 0:
raise RuntimeError(
f"Error executing script, error code was {handle.returncode}",
file=sys.stderr
f"Error executing script, error code was {handle.returncode}"
)
return stdout_decode, stderr_decode

Expand Down Expand Up @@ -176,18 +181,18 @@ def generate_aie_vmfb(
return aie_vmfb


def generate_aie_output(config, aie_vmfb, input_args, function_name, name):
def generate_aie_output(config, aie_vmfb, input_args, function_name, name, output_type):
"""
Run a compiled AIE module (aie_vmfb), returning a numpy array of the output.
"""

aie_npy = config.output_dir / f"{name}_aie.npy"
aie_bin = config.output_dir / f"{name}_aie.bin"
run_args = [
config.iree_run_exe,
f"--module={aie_vmfb}",
*input_args,
"--device=xrt",
f"--output=@{aie_npy}",
f"--output=@{aie_bin}",
]
if function_name:
run_args += [f"--function={function_name}"]
Expand All @@ -201,7 +206,7 @@ def generate_aie_output(config, aie_vmfb, input_args, function_name, name):
if config.verbose:
print(f"Time spent in running the model: {run_time // 1e6} [ms]")

return np.load(aie_npy)
return np_from_binfile(aie_bin, output_type)


def generate_llvm_cpu_output(
Expand All @@ -210,6 +215,7 @@ def generate_llvm_cpu_output(
test_file,
input_args,
function_name,
output_type,
):
"""
Compile and run a test file for IREE's CPU backend, returning a numpy array
Expand All @@ -227,17 +233,17 @@ def generate_llvm_cpu_output(
]
shell_out(compilation_flags, workdir=config.output_dir, verbose=config.verbose)

cpu_npy = config.output_dir / f"{name}_cpu.npy"
cpu_bin = config.output_dir / f"{name}_cpu.bin"
run_args = [
config.iree_run_exe,
f"--module={cpu_vmfb}",
*input_args,
f"--output=@{cpu_npy}",
f"--output=@{cpu_bin}",
]
if function_name:
run_args += [f"--function={function_name}"]
shell_out(run_args, workdir=config.output_dir, verbose=config.verbose)
return np.load(cpu_npy)
return np_from_binfile(cpu_bin, output_type)


class TestConfig:
Expand Down Expand Up @@ -447,6 +453,7 @@ def aie_vs_baseline(
rtol,
atol,
n_repeats,
output_type,
):
"""
If the outputs differ, add the test file to a list of failures.
Expand Down Expand Up @@ -503,6 +510,7 @@ def aie_vs_baseline(
input_args,
function_name,
name,
output_type,
)

summary_string = compare(baseline_value, aie_output, rtol, atol)
Expand Down Expand Up @@ -538,9 +546,15 @@ def aie_vs_llvm_cpu(
print(f"Running {name} test")

input_args = generate_inputs(test_file, config.output_dir, seed)
output_type = get_output_type(test_file)

cpu_output = generate_llvm_cpu_output(
config, name, test_file, input_args, function_name
config,
name,
test_file,
input_args,
function_name,
output_type,
)

aie_vs_baseline(
Expand All @@ -556,6 +570,7 @@ def aie_vs_llvm_cpu(
rtol,
atol,
n_repeats,
output_type,
)


Expand All @@ -578,6 +593,8 @@ def aie_vs_np_matmul(

name = name_from_mlir_filename(test_file)
input_args = generate_inputs(test_file, config.output_dir, seed)
output_type = get_output_type(test_file)

numpy_output = matmul_from_input_strings(input_args)
aie_vs_baseline(
config,
Expand All @@ -592,6 +609,7 @@ def aie_vs_np_matmul(
rtol,
atol,
n_repeats,
output_type,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// These 2 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
// input 2x14x14x32xi8
// input 3x3x32x64xi8

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
// and we successively call the same matmul.
//
// This test is a part of three test files that developers can look at together.
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 4 times each (8 calls total). The graph ((8,4,8) -> (8,8,4)) x 4
// 2. matmul_f32_8_4_8.mlir => calls the (8,4,8) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,4,8) x 4
// 3. matmul_f32_8_8_4.mlir => calls the (8,8,4) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,8,4) x 4

// These 2 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
//
// input 8x8xf32
// input 8x4xf32
// output 8x4xf32

!A_TYPE = tensor<8x8xf32>
!B_TYPE = tensor<8x4xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// This test is useful to compare against the `two_matmul_switching` when no switching happens
// and we successively call the same matmul
// This test is a part of three test files that developers can look at together.
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 4 times each (8 calls total). The graph ((8,4,8) -> (8,8,4)) x 4
// 2. matmul_f32_8_4_8.mlir => calls the (8,4,8) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,4,8) x 4
// 3. matmul_f32_8_8_4.mlir => calls the (8,8,4) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,8,4) x 4

// These 2 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
//
// input 8x4xf32
// input 4x8xf32
// output 8x8xf32

!A_TYPE = tensor<8x4xf32>
!B_TYPE = tensor<4x8xf32>
Expand Down
3 changes: 2 additions & 1 deletion build_tools/ci/cpu_comparison/test_files/matmul_int32.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// These 2 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
// input 128x128xi32
// input 128x128xi32
// output 128x128xi32

!lhs = tensor<128x128xi32>
!rhs = tensor<128x128xi32>
Expand Down
3 changes: 2 additions & 1 deletion build_tools/ci/cpu_comparison/test_files/three_matmuls.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// This test shows arbitrary matmuls that would have producer consumer relationships
// across different dispatches running on CI.

// These 4 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
//
// input 32x32xf32
// input 32x32xf32
// input 32x4xf32
// input 4x32xf32
// output 4x4xf32

!A_TYPE = tensor<32x32xf32>
!B_TYPE = tensor<32x4xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// This test shows switching between two matmuls and is useful to model the switching cost.
// This test is a part of three test files that developers can look at together.
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 1. two_matmul_switching.mlir => switches calls between two matmuls M,N,K size (8,4,8) and (8,8,4)
// 4 times each (8 calls total). The graph ((8,4,8) -> (8,8,4)) x 4
// 2. matmul_f32_8_4_8.mlir => calls the (8,4,8) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,4,8) x 4
// 3. matmul_f32_8_8_4.mlir => calls the (8,8,4) matmul 4 times hence doesnt have a switching cost so we
// have a baseline for it. The graph is (8,8,4) x 4

// These 2 lines are required by the script which generates input data:
// These lines are required for e2e numerical testing:
//
// input 8x4xf32
// input 4x8xf32
// output 8x4xf32

!A_TYPE = tensor<8x4xf32>
!B_TYPE = tensor<4x8xf32>
Expand Down
13 changes: 13 additions & 0 deletions build_tools/ci/cpu_comparison/test_input_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .input_generator import *
import numpy as np


def test_conversion():
"""
Check that float(bfloat(a)) is (almost) a.
"""
expected = np.array([1.5, 3.125, -1.5, -32.0, 0.0, -3.125], dtype=np.float32)
a = np.array([1.5, 3.14, -1.5, -32, 0, -3.14], np.float32)
b = [convert_f32_to_bf16(x) for x in a]
c = convert_bf16_to_f32(np.array(b))
assert np.allclose(c, expected, 0, 0)
Loading

0 comments on commit 82841cd

Please sign in to comment.