Skip to content

Commit

Permalink
updated test config
Browse files Browse the repository at this point in the history
  • Loading branch information
mmrahorovic committed Dec 12, 2023
1 parent 45074d9 commit 0ed3681
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions tests/fpgadataflow/test_fpgadataflow_mvau_rtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ def make_single_matmul_modelwrapper(W, ofm_shape, mh, ifm, weights, idt, wdt):
def prepare_inputs(input_tensor):
return {"ifm": input_tensor}

@pytest.mark.parametrize("mh", [31])
@pytest.mark.parametrize("mw", [279])
#@pytest.mark.parametrize("pe", [1,2,4,8])
@pytest.mark.parametrize("pe", [31])
#@pytest.mark.parametrize("simd", [1,3,6,9,18,36])
@pytest.mark.parametrize("simd", [9])
@pytest.mark.parametrize("mh", [4])
# @pytest.mark.parametrize("mw", [36])
@pytest.mark.parametrize("mw", [18])
# @pytest.mark.parametrize("pe", [1,2,4,8])
@pytest.mark.parametrize("pe", [2])
# @pytest.mark.parametrize("simd", [1,3,6,9,18,36])
@pytest.mark.parametrize("simd", [6])
#@pytest.mark.parametrize("idt", [DataType["UINT4"], DataType["UINT8"]])
@pytest.mark.parametrize("idt", [DataType["UINT8"]])
#@pytest.mark.parametrize("wdt", [DataType["INT4"], DataType["INT6"]])
Expand Down Expand Up @@ -121,13 +122,19 @@ def test_fpgadataflow_mvau_rtl(mh, mw, pe, simd, idt, wdt, part, segmentlen):
[mw, mh]
)
W = gen_finn_dt_tensor(wdt, (mw, mh))
# np.save("weights.npy", W)
##
W = np.load("weights.npy")
model = make_single_matmul_modelwrapper(W, ofm_shape, mh, ifm, weights, idt, wdt)
model = model.transform(GiveUniqueNodeNames())

model.save(build_dir+"/matmul.onnx")

# Create MatMul & obtain golden reference output
A = gen_finn_dt_tensor(model.get_tensor_datatype("ifm"), model.get_tensor_shape("ifm"))
# np.save("activations.npy", A)
##
# A = np.load("activations.npy")
input_dict = prepare_inputs(A)

## Execute ONNX model
Expand Down Expand Up @@ -198,5 +205,6 @@ def test_fpgadataflow_mvau_rtl(mh, mw, pe, simd, idt, wdt, part, segmentlen):
# model = model.transform(CreateStitchedIP(fpgapart=part, clk_ns=clk_ns, vitis=True))
# model.save(build_dir+"/stitched_ip.onnx")

assert (output_mvau_hls == output_mvau_rtl).all()
#assert (output_mvau_hls == output_mvau_rtl).all()
assert (output_matmul['ofm'] == output_mvau_rtl).all()
# assert (output_mvau_hls.size > 0)

0 comments on commit 0ed3681

Please sign in to comment.