From aea2c8e8f16496417cbdea5dac64ce569d0060d1 Mon Sep 17 00:00:00 2001 From: Leon Camus Date: Mon, 4 Apr 2022 17:52:32 +0200 Subject: [PATCH] feat: Add Mixed Array Support --- onnxruntime/src/lib.rs | 30 ++++ onnxruntime/src/session.rs | 308 +++++++++++++++++++++++++++++++++++-- 2 files changed, 325 insertions(+), 13 deletions(-) diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 7ad71db4..cdb53dc0 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -154,6 +154,8 @@ use sys::OnnxEnumInt; // Re-export ndarray as it's part of the public API anyway pub use ndarray; +use ndarray::Array; +use crate::tensor::OrtTensor; lazy_static! { // static ref G_ORT: Arc>> = @@ -459,6 +461,34 @@ impl_type_trait!(u64, Uint64); // impl_type_trait!(, Complex128); // impl_type_trait!(, Bfloat16); +#[derive(Debug)] +pub enum TypedArray { + F32(Array), + U8(Array), + I8(Array), + U16(Array), + I16(Array), + I32(Array), + I64(Array), + F64(Array), + U32(Array), + U64(Array), +} + +#[derive(Debug)] +pub enum TypedOrtTensor<'t, D: ndarray::Dimension> { + F32(OrtTensor<'t, f32, D>), + U8(OrtTensor<'t, u8, D>), + I8(OrtTensor<'t, i8, D>), + U16(OrtTensor<'t, u16, D>), + I16(OrtTensor<'t, i16, D>), + I32(OrtTensor<'t, i32, D>), + I64(OrtTensor<'t, i64, D>), + F64(OrtTensor<'t, f64, D>), + U32(OrtTensor<'t, u32, D>), + U64(OrtTensor<'t, u64, D>), +} + /// Adapter for common Rust string types to Onnx strings. /// /// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 9a846c01..82172855 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -30,6 +30,7 @@ use crate::{ }, AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, TypeToTensorElementDataType, + TypedArray, TypedOrtTensor }; #[cfg(feature = "model-fetching")] @@ -475,6 +476,131 @@ impl<'a> Session<'a> { outputs } + /// Run the input data through the ONNX graph, performing inference. + /// + /// Note that ONNX models can have multiple inputs; a `Vec` is thus + /// used for the input data here. + pub fn run_mixed<'s, 't, 'm, TOut, D>( + &'s mut self, + input_arrays: Vec>, + ) -> Result>> + where + TOut: TypeToTensorElementDataType + Debug + Clone, + D: ndarray::Dimension, + 'm: 't, // 'm outlives 't (memory info outlives tensor) + 's: 'm, // 's outlives 'm (session outlives memory info) + { + self.validate_untyped_input_shapes(&input_arrays)?; + + // Build arguments to Run() + + let input_names_ptr: Vec<*const i8> = self + .inputs + .iter() + .map(|input| input.name.clone()) + .map(|n| CString::new(n).unwrap()) + .map(|n| n.into_raw() as *const i8) + .collect(); + + let output_names_cstring: Vec = self + .outputs + .iter() + .map(|output| output.name.clone()) + .map(|n| CString::new(n).unwrap()) + .collect(); + let output_names_ptr: Vec<*const i8> = output_names_cstring + .iter() + .map(|n| n.as_ptr() as *const i8) + .collect(); + + let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + vec![std::ptr::null_mut(); self.outputs.len()]; + + // The C API expects pointers for the arrays (pointers to C-arrays) + let input_ort_tensors: Vec> = input_arrays + .into_iter() + .map(|input_array| { + match input_array { + TypedArray::F32(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::F32(t)), + TypedArray::U8(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::U8(t)), + TypedArray::I8(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::I8(t)), + TypedArray::U16(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::U16(t)), + TypedArray::I16(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::I16(t)), + TypedArray::I32(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::I32(t)), + TypedArray::I64(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::I64(t)), + TypedArray::F64(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::F64(t)), + TypedArray::U32(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::U32(t)), + TypedArray::U64(input_array) => OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array).map(|t| TypedOrtTensor::U64(t)), + } + }) + .collect::>>>()?; + let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors + .iter() + .map(|input_array_ort| match input_array_ort { + TypedOrtTensor::F32(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::U8(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::I8(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::U16(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::I16(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::I32(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::I64(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::F64(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::U32(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + TypedOrtTensor::U64(input_array_ort) => input_array_ort.c_ptr as *const sys::OrtValue, + }) + .collect(); + + let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); + + let status = unsafe { + g_ort().Run.unwrap()( + self.session_ptr, + run_options_ptr, + input_names_ptr.as_ptr(), + input_ort_values.as_ptr(), + input_ort_values.len(), + output_names_ptr.as_ptr(), + output_names_ptr.len(), + output_tensor_extractors_ptrs.as_mut_ptr(), + ) + }; + status_to_result(status).map_err(OrtError::Run)?; + + let memory_info_ref = &self.memory_info; + let outputs: Result>>> = + output_tensor_extractors_ptrs + .into_iter() + .map(|ptr| { + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = + std::ptr::null_mut(); + let status = unsafe { + g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) + }; + status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; + let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; + unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; + let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); + + let mut output_tensor_extractor = + OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); + output_tensor_extractor.tensor_ptr = ptr; + output_tensor_extractor.extract::() + }) + .collect(); + + // Reconvert to CString so drop impl is called and memory is freed + let cstrings: Result> = input_names_ptr + .into_iter() + .map(|p| { + assert_not_null_pointer(p, "i8 for CString")?; + unsafe { Ok(CString::from_raw(p as *mut i8)) } + }) + .collect(); + cstrings?; + + outputs + } + // pub fn tensor_from_array<'a, 'b, T, D>(&'a self, array: Array) -> Tensor<'b, T, D> // where // 'a: 'b, // 'a outlives 'b @@ -482,11 +608,62 @@ impl<'a> Session<'a> { // Tensor::from_array(self, array) // } - fn validate_input_shapes(&mut self, input_arrays: &[Array]) -> Result<()> + fn validate_input_shape( + &self, + input_array: &Array, + input: &Input, + input_arrays_dimensions: &Vec>, + input_arrays: &[DEB], + ) -> Result<()> where TIn: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, + DEB: Debug, { + // Verify length + if input_array.shape().len() != input.dimensions.len() { + error!( + "Different input lengths: {:?} vs {:?}", + self.inputs, input_arrays + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsLength { + inference_input: input_arrays_dimensions.clone(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + // Verify shape + let inputs_different_shape = input_array.shape().iter().zip(input.dimensions.iter()).any(|(l2, r2)| match r2 { + Some(r3) => *r3 as usize != *l2, + None => false, // None means dynamic size; in that case shape always match + }); + if inputs_different_shape { + error!( + "Different input lengths: {:?} vs {:?}", + self.inputs, input_arrays + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsLength { + inference_input: input_arrays_dimensions.clone(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + Ok(()) + } + + fn validate_untyped_input_shapes(&mut self, input_arrays: &[TypedArray]) -> Result<()> { // ****************************************************************** // FIXME: Properly handle errors here // Make sure all dimensions match (except dynamic ones) @@ -504,7 +681,18 @@ impl<'a> Session<'a> { model_input_count: 0, inference_input: input_arrays .iter() - .map(|input_array| input_array.shape().to_vec()) + .map(|input_array| match input_array { + TypedArray::F32(array) => array.shape().to_vec(), + TypedArray::U8(array) => array.shape().to_vec(), + TypedArray::I8(array) => array.shape().to_vec(), + TypedArray::U16(array) => array.shape().to_vec(), + TypedArray::I16(array) => array.shape().to_vec(), + TypedArray::I32(array) => array.shape().to_vec(), + TypedArray::I64(array) => array.shape().to_vec(), + TypedArray::F64(array) => array.shape().to_vec(), + TypedArray::U32(array) => array.shape().to_vec(), + TypedArray::U64(array) => array.shape().to_vec(), + }) .collect(), model_input: self .inputs @@ -519,7 +707,18 @@ impl<'a> Session<'a> { let inputs_different_length = input_arrays .iter() .zip(self.inputs.iter()) - .any(|(l, r)| l.shape().len() != r.dimensions.len()); + .any(|(l, r)| match l { + TypedArray::F32(array) => array.shape().len(), + TypedArray::U8(array) => array.shape().len(), + TypedArray::I8(array) => array.shape().len(), + TypedArray::U16(array) => array.shape().len(), + TypedArray::I16(array) => array.shape().len(), + TypedArray::I32(array) => array.shape().len(), + TypedArray::I64(array) => array.shape().len(), + TypedArray::F64(array) => array.shape().len(), + TypedArray::U32(array) => array.shape().len(), + TypedArray::U64(array) => array.shape().len(), + } != r.dimensions.len()); if inputs_different_length { error!( "Different input lengths: {:?} vs {:?}", @@ -529,7 +728,18 @@ impl<'a> Session<'a> { NonMatchingDimensionsError::InputsLength { inference_input: input_arrays .iter() - .map(|input_array| input_array.shape().to_vec()) + .map(|input_array| match input_array { + TypedArray::F32(array) => array.shape().to_vec(), + TypedArray::U8(array) => array.shape().to_vec(), + TypedArray::I8(array) => array.shape().to_vec(), + TypedArray::U16(array) => array.shape().to_vec(), + TypedArray::I16(array) => array.shape().to_vec(), + TypedArray::I32(array) => array.shape().to_vec(), + TypedArray::I64(array) => array.shape().to_vec(), + TypedArray::F64(array) => array.shape().to_vec(), + TypedArray::U32(array) => array.shape().to_vec(), + TypedArray::U64(array) => array.shape().to_vec(), + }) .collect(), model_input: self .inputs @@ -540,16 +750,79 @@ impl<'a> Session<'a> { )); } - // Verify shape of each individual inputs - let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| { - let l_shape = l.shape(); - let r_shape = r.dimensions.as_slice(); - l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { - Some(r3) => *r3 as usize != *l2, - None => false, // None means dynamic size; in that case shape always match + let input_arrays_dimensions = input_arrays + .iter() + .map(|input_array| match input_array { + TypedArray::F32(input_array) => input_array.shape().to_vec(), + TypedArray::U8(input_array) => input_array.shape().to_vec(), + TypedArray::I8(input_array) => input_array.shape().to_vec(), + TypedArray::U16(input_array) => input_array.shape().to_vec(), + TypedArray::I16(input_array) => input_array.shape().to_vec(), + TypedArray::I32(input_array) => input_array.shape().to_vec(), + TypedArray::I64(input_array) => input_array.shape().to_vec(), + TypedArray::F64(input_array) => input_array.shape().to_vec(), + TypedArray::U32(input_array) => input_array.shape().to_vec(), + TypedArray::U64(input_array) => input_array.shape().to_vec(), }) - }); - if inputs_different_shape { + .collect(); + + for (input_array, input) in input_arrays.iter().zip(self.inputs.iter()) { + match input_array { + TypedArray::F32(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::U8(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::I8(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::U16(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::I16(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::I32(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::I64(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::F64(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::U32(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + TypedArray::U64(input_array) => self.validate_input_shape(input_array, input, &input_arrays_dimensions, &input_arrays), + }?; + } + + Ok(()) + } + + fn validate_input_shapes(&mut self, input_arrays: &[Array]) -> Result<()> + where + TIn: TypeToTensorElementDataType + Debug + Clone, + D: ndarray::Dimension, + { + // ****************************************************************** + // FIXME: Properly handle errors here + // Make sure all dimensions match (except dynamic ones) + + // Verify length of inputs + if input_arrays.len() != self.inputs.len() { + error!( + "Non-matching number of inputs: {} (inference) vs {} (model)", + input_arrays.len(), + self.inputs.len() + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsCount { + inference_input_count: 0, + model_input_count: 0, + inference_input: input_arrays + .iter() + .map(|input_array| input_array.shape().to_vec()) + .collect(), + model_input: self + .inputs + .iter() + .map(|input| input.dimensions.clone()) + .collect(), + }, + )); + } + + // Verify length of each individual inputs + let inputs_different_length = input_arrays + .iter() + .zip(self.inputs.iter()) + .any(|(l, r)| l.shape().len() != r.dimensions.len()); + if inputs_different_length { error!( "Different input lengths: {:?} vs {:?}", self.inputs, input_arrays @@ -569,6 +842,15 @@ impl<'a> Session<'a> { )); } + let input_arrays_dimensions = input_arrays + .iter() + .map(|input_array| input_array.shape().to_vec()) + .collect(); + + for (input_array, input) in input_arrays.iter().zip(self.inputs.iter()) { + self.validate_input_shape(input_array, input, &input_arrays_dimensions, input_arrays)?; + } + Ok(()) } }