diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 15dcb93a8f679..8fee94ebe76a3 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -788,7 +788,7 @@ def main(): logger.error("fp16 is for GPU only") return - if args.precision == Precision.INT8 and args.use_gpu and args.provider != "migraphx": + if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx", "rocm"]: logger.error("int8 is for CPU only") return