Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to run batch inference with c++ #23450

Closed
davave1693 opened this issue Jan 21, 2025 · 9 comments
Closed

Unable to run batch inference with c++ #23450

davave1693 opened this issue Jan 21, 2025 · 9 comments

Comments

@davave1693
Copy link

Describe the issue

I am trying to run my object detection model in c++.

The model is in onnx format and has an input dynamic batch size. I was able to run preprocessing and inference in python, but I am encountering difficulties in c++.

when running

		mod.session.Run(Ort::RunOptions{ nullptr },
			mod.input_names.data(),
			input_tensor.data(),
			mod.num_input_nodes,
			mod.output_names.data(),
			output_tensors.data(),
			mod.num_output_nodes);

The following error occurs:

Image

To reproduce

model.h

#pragma once
#include <string>
#include "onnxruntime_cxx_api.h"
#include<opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/opencv.hpp>

struct OnnxENV {
	Ort::Env env = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "Default");
};

class model
{
private:
	std::string model_path = "";
	int model_width = 416;
	int model_height = 416;
	bool use_cuda = false;
	bool use_trt = true;
	int batch_size;

	Ort::SessionOptions session_options;
	OrtCUDAProviderOptions cuda_options;
	OrtTensorRTProviderOptions trt_options;
	Ort::MemoryInfo memory_info{ nullptr };



	void transpose_and_convert(cv::Mat& img_in);

public:
	void set_session_options();
	void initialize(OnnxENV* env, std::string model_path);
	void SetUseCuda();
	void SetUseTRT();
	void preprocessing(cv::Mat& image);

	Ort::Session session = Ort::Session(nullptr);
	
	size_t num_input_nodes;
	size_t num_output_nodes;

	std::vector<std::vector<int64_t>> input_dims;
	std::vector<int64_t> input_node_dims;
	std::vector<const char*> input_names;
	std::vector<const char*> output_names;
	size_t input_node_size;
	std::vector<size_t> input_sizes;
};

model.cpp

#include "model.h"



void model::set_session_options()
{
	session_options.SetInterOpNumThreads(1);
	session_options.SetIntraOpNumThreads(1);
	// Optimization will take time and memory during startup
	session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);

	SetUseTRT();
}

void model::initialize(OnnxENV* env, std::string model_path)
{
	batch_size = 1;

	std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
	const wchar_t* widecstr = widestr.c_str();

	session = Ort::Session(env->env, widecstr, session_options);

	num_input_nodes = session.GetInputCount();
	num_output_nodes = session.GetOutputCount();

	input_node_dims =
		session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();

	//if (input_node_dims[0] != -1 || input_node_dims[2] != -1 || input_node_dims[3] != -1)
	//	throw std::invalid_argument("Model was not exported with dynamic shape");
	input_node_dims[0] = batch_size;

	input_dims.push_back(input_node_dims);

	
	Ort::AllocatorWithDefaultOptions allocator;
	input_names.push_back(session.GetInputNameAllocated(0, allocator).get());
	input_names.push_back(nullptr);

	
	for (size_t i = 0; i < num_output_nodes; i++)
		output_names.push_back(session.GetOutputNameAllocated(i, allocator).get());
	output_names.push_back(nullptr);

	input_node_size =
		input_node_dims[0] * input_node_dims[1] * input_node_dims[2] * input_node_dims[3];

	input_sizes.push_back(input_node_size);
}

void model::SetUseCuda() {
	cuda_options.device_id = 0;  //GPU_ID
	cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchExhaustive; // Algo to search for Cudnn
	cuda_options.arena_extend_strategy = 0; //kNextPowerOfTwo


	cuda_options.do_copy_in_default_stream = 1;

	session_options.AppendExecutionProvider_CUDA(cuda_options); // Add CUDA options to session options
}

void model::SetUseTRT() {

	trt_options.device_id = 0;

	trt_options.trt_engine_cache_enable = true;
	trt_options.trt_engine_cache_path = "./trt_cache";

	trt_options.trt_max_partition_iterations = 1000;
	trt_options.trt_min_subgraph_size = 1;

	session_options.AppendExecutionProvider_TensorRT(trt_options);
}

void model::preprocessing(cv::Mat& image)
{
	cv::resize(image, image, cv::Size(416, 416));
	cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
	transpose_and_convert((image));

	// Add a new dimension at the beginning
	cv::dnn::blobFromImage(image, image, 1.0 / 255.0, cv::Size(416, 416));

}

void model::transpose_and_convert(cv::Mat& img_in) {
	cv::Mat img_out;
	img_in.convertTo(img_in, CV_32FC3); // Converti in float32
	cv::transpose(img_in, img_in); // Trasponi
}

main.cpp

#include <opencv2/videoio.hpp>
#include "model.h"
#include "opencv2/core.hpp"

int main()
{
	OnnxENV Env;
	model mod;

	std::string model_path = "model.onnx";
	mod.initialize(&Env, model_path);

	std::string video_path = "video.mp4";
	cv::VideoCapture cap(video_path);
	while (true)
	{
		cv::Mat bgr_image;
		cap.read(bgr_image);

		//preprocess
		std::vector<cv::Mat> preprocessed;

		mod.preprocessing(bgr_image);
		preprocessed.push_back(bgr_image);

		std::vector<float> input_tensor_values(mod.input_node_size);
		input_tensor_values.assign(bgr_image.begin<float>(), bgr_image.end<float>());

		std::vector<Ort::Value> input_tensor;
		Ort::MemoryInfo memory_info =
			Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator,
				OrtMemType::OrtMemTypeDefault);
		input_tensor.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
			input_tensor_values.data(),
			input_tensor_values.size(),
			mod.input_node_dims.data(),
			mod.input_node_dims.size()));

		std::vector<Ort::Value> output_tensors;

		for (size_t i = 0; i < mod.num_output_nodes; i++)
			output_tensors.emplace_back(nullptr);

		mod.session.Run(Ort::RunOptions{ nullptr },
			mod.input_names.data(),
			input_tensor.data(),
			mod.num_input_nodes,
			mod.output_names.data(),
			output_tensors.data(),
			mod.num_output_nodes);
		
	}

	return 0;
}

Below the python version

import onnxruntime as ort
import numpy as np
import cv2

def PrepareImageForDetection(frame, im_w, im_h):

    # Input
    resized = cv2.resize(frame, (im_w, im_h), interpolation=cv2.INTER_LINEAR)
    img_in = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
    img_in = np.transpose(img_in, (2, 0, 1)).astype(np.float32)
    img_in = np.expand_dims(img_in, axis=0)
    img_in /= 255.0
    return img_in

tensorrt_enabled = True
if torch.cuda.is_available():
    ort_provider = []

    if tensorrt_enabled:
        ort_provider += [('TensorrtExecutionProvider',
         {
            'trt_engine_cache_enable' : True,
            'trt_engine_cache_path' : './trt_cache/'
         })]

    ort_provider += [
        ('CUDAExecutionProvider', {
            'device_id': 0,
            'arena_extend_strategy': 'kNextPowerOfTwo',
            # 'gpu_mem_limit': 7 * 1024 * 1024 * 1024,
            'cudnn_conv_algo_search': 'EXHAUSTIVE',
            'do_copy_in_default_stream': True,
        })
    ]

ort_provider +=  ['CPUExecutionProvider']

batch_size = 1
preprocessed_batch_imgs = np.empty((batch_size, 3, 416, 416), dtype=np.float32)

session_options = ort.SessionOptions()
session = ort.InferenceSession("model.onnx",sess_options=session_options ,providers=ort_provider)

input_name = session.get_inputs()[0].name  
input_type = session.get_inputs()[0].type
input_shape = session.get_inputs()[0].shape

image = cv2.imread("image.png")

raw_batch_imgs = [image]  #batch_size = 1
preprocessed_batch_imgs = np.empty((batch_size, 3, 416, 416), dtype=np.float32)
preprocessed_batch_imgs[0] = preprocessing(raw_batch_imgs[0], 416, 416)[0]

inputs = {session.input_name:preprocessed_batch_imgs}
result = session.run(inputs)

Urgency

Blocking problem for project

Platform

Windows

OS Version

11 Pro

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.20.1

ONNX Runtime API

C++

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

10.5.0.18

@eKevinHoang
Copy link

It seems that the input data has not been handled correctly.
The input for batch inference must be a tensor in the format (batch_size, channels, height, width).

You can refer to the Python code here for guidance:
https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/predictor.py

@davave1693
Copy link
Author

Thank, you @eKevinHoang.

The fact is that I am already able to run batch inference in python, but I can not understand what's wrong in c++.

I am honestly struggling to find both documentation and a working example of batch inference for networks that accept images as input.
The only example I found was in onnxruntime-inference_examples repo, but in this example experimental_onnxruntime_cxx_api.h is used and:

  • that file is not included in the released package
  • I read somewhere it is not suggested to include that header instead of onnxruntime_cxx_api.h

How can I inspect the shape of bgr_image and input_tensor in my example?

Thank you

@rodrigovimieiro
Copy link

I am honestly struggling to find both documentation and a working example of batch inference for networks that accept images as input.

Check this one: ONNX-Runtime-GPU-image-classifciation-example.

@eKevinHoang
Copy link

@davave1693 You can refer to the following example code. Based on your Python code, you also need to perform cv2::resize and cv2::cvtColor before calling this function.

Ort::Value prepareInputTensor(const std::vector<cv::Mat>& images, int batchSize, int channels, int height, int width) {
    // Initializes a vector to hold the input data.
    std::vector<float> inputData(batchSize * channels * height * width);

    for (int i = 0; i < batchSize; ++i) {
        const cv::Mat& img = images[i];
        for (int c = 0; c < channels; ++c) {
            for (int h = 0; h < height; ++h) {
                for (int w = 0; w < width; ++w) {
                    // normalize the pixel values to the range [0, 1]
                    inputData[i * channels * height * width + c * height * width + h * width + w] =
                        img.at<cv::Vec3b>(h, w)[c] / 255.0f;
                }
            }
        }
    }

    Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::vector<int64_t> inputShape = {batchSize, channels, height, width};
    return Ort::Value::CreateTensor<float>(memoryInfo, inputData.data(), inputData.size(), inputShape.data(), inputShape.size());
}

@davave1693
Copy link
Author

Thanks @rodrigovimieiro and @eKevinHoang.

@rodrigovimieiro I will try to compile and run the example at your link and then let you know if I succed.

@eKevinHoang, I updated my preprocessing method as you suggested:

Ort::Value model::preprocessing(const cv::Mat& image)
{
	cv::Mat new_img;
	cv::resize(image, new_img, cv::Size(416, 416));
	cv::cvtColor(new_img, new_img, cv::COLOR_BGR2RGB);

	Ort::Value ort_value = prepareInputTensor(image, 1, 3, 416, 416);

	return ort_value;
}

The line Ort::Value ort_value = prepareInputTensor(image, 1, 3, 416, 416); raise an error when trying to copy the method output. Do you know why i cannot return an Ort::Value?

Same error occurs if I return std::vector and call (in main)

  auto inputData = mod.preprocessing(bgr_image);
  Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
  std::vector<int64_t> inputShape = { 1, 3, 416, 416 };
  input_tensor.emplace_back( Ort::Value::CreateTensor<float>(memoryInfo, inputData.data(), inputData.size(), inputShape.data(), inputShape.size()));

@davave1693
Copy link
Author

Hello @rodrigovimieiro, I was able to build the example but I have got the same exact error

Image

when I call

 mSession->Run(Ort::RunOptions{nullptr}, inputNames.data(),
               inputTensors.data(), 1, outputNames.data(),
               outputTensors.data(), 1);

in image_classifier.cpp

@eKevinHoang
Copy link

eKevinHoang commented Jan 22, 2025

@davave1693 I think you are using a C++ version earlier than C++17. I am currently using C++17. 😄

To simplify testing, you can use the following code:
cv::Mat blob = cv::dnn::blobFromImages(images, 1.0 / 255.0, target_size, cv::Scalar(), true);

Then create the tensor as follows:
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, blob.ptr<float>(), blob.total(), input_shape.data(), input_shape.size());

This way, you don’t need to perform cv2.resize or cv2.cvtColor.

I hope this helps you test the results more effectively.

@yuslepukhin
Copy link
Member

Catch the exception and print its message. C++ terminates the program on unhandled exception.

@davave1693
Copy link
Author

@eKevinHoang yes, I am currently using c++14.

Anyway, the preprocessing method you provided made my inference work!

I just had to modify it in such a way that a vector is rerturned from preprocessing instead of a Ort::Value

Thank you again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants