diff --git a/onnxruntime/examples/issue22.rs b/onnxruntime/examples/issue22.rs index b2879b91..9dbd5d5b 100644 --- a/onnxruntime/examples/issue22.rs +++ b/onnxruntime/examples/issue22.rs @@ -34,7 +34,12 @@ fn main() { let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap(); let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap(); - let outputs: Vec<OrtOwnedTensor<f32, _>> = - session.run(vec![input_ids, attention_mask]).unwrap(); + let outputs: Vec<OrtOwnedTensor<f32, _>> = session + .run(vec![input_ids, attention_mask]) + .unwrap() + .into_iter() + .map(|dyn_tensor| dyn_tensor.try_extract()) + .collect::<Result<_, _>>() + .unwrap(); print!("outputs: {:#?}", outputs); } diff --git a/onnxruntime/examples/print_structure.rs b/onnxruntime/examples/print_structure.rs index d86a0e54..12d81e62 100644 --- a/onnxruntime/examples/print_structure.rs +++ b/onnxruntime/examples/print_structure.rs @@ -5,8 +5,7 @@ use std::error::Error; fn main() -> Result<(), Box<dyn Error>> { // provide path to .onnx model on disk let path = std::env::args() - .skip(1) - .next() + .nth(1) .expect("Must provide an .onnx file as the first arg"); let environment = environment::Environment::builder() diff --git a/onnxruntime/examples/sample.rs b/onnxruntime/examples/sample.rs index d16d08da..5563ab4b 100644 --- a/onnxruntime/examples/sample.rs +++ b/onnxruntime/examples/sample.rs @@ -1,8 +1,10 @@ #![forbid(unsafe_code)] use onnxruntime::{ - environment::Environment, ndarray::Array, tensor::OrtOwnedTensor, GraphOptimizationLevel, - LoggingLevel, + environment::Environment, + ndarray::Array, + tensor::{DynOrtTensor, OrtOwnedTensor}, + GraphOptimizationLevel, LoggingLevel, }; use tracing::Level; use tracing_subscriber::FmtSubscriber; @@ -61,11 +63,12 @@ fn run() -> Result<(), Error> { .unwrap(); let input_tensor_values = vec![array]; - let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor_values)?; + let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?; - assert_eq!(outputs[0].shape(), output0_shape.as_slice()); + let output: OrtOwnedTensor<f32, _> = outputs[0].try_extract().unwrap(); + assert_eq!(output.view().shape(), output0_shape.as_slice()); for i in 0..5 { - println!("Score for class [{}] = {}", i, outputs[0][[0, i, 0, 0]]); + println!("Score for class [{}] = {}", i, output.view()[[0, i, 0, 0]]); } Ok(()) diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index f49613fe..a4e2b7a1 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -1,6 +1,6 @@ //! Module containing error definitions. -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, string}; use thiserror::Error; @@ -53,6 +53,12 @@ pub enum OrtError { /// Error occurred when getting ONNX dimensions #[error("Failed to get dimensions: {0}")] GetDimensions(OrtApiError), + /// Error occurred when getting string length + #[error("Failed to get string tensor length: {0}")] + GetStringTensorDataLength(OrtApiError), + /// Error occurred when getting tensor element count + #[error("Failed to get tensor element count: {0}")] + GetTensorShapeElementCount(OrtApiError), /// Error occurred when creating CPU memory information #[error("Failed to get dimensions: {0}")] CreateCpuMemoryInfo(OrtApiError), @@ -77,6 +83,12 @@ pub enum OrtError { /// Error occurred when extracting data from an ONNX tensor into an C array to be used as an `ndarray::ArrayView` #[error("Failed to get tensor data: {0}")] GetTensorMutableData(OrtApiError), + /// Error occurred when extracting string data from an ONNX tensor + #[error("Failed to get tensor string data: {0}")] + GetStringTensorContent(OrtApiError), + /// Error occurred when converting data to a String + #[error("Data was not UTF-8: {0}")] + StringFromUtf8Error(#[from] string::FromUtf8Error), /// Error occurred when downloading a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models) #[error("Failed to download ONNX model: {0}")] @@ -108,16 +120,16 @@ pub enum OrtError { #[derive(Error, Debug)] pub enum NonMatchingDimensionsError { /// Number of inputs from model does not match number of inputs from inference call - #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?} for model (inputs: {inference_input:?}, model: {model_input:?})")] + #[error("Non-matching number of inputs: {inference_input_count:?} for input vs {model_input_count:?}")] InputsCount { /// Number of input dimensions used by inference call inference_input_count: usize, /// Number of input dimensions defined in model model_input_count: usize, - /// Input dimensions used by inference call - inference_input: Vec<Vec<usize>>, - /// Input dimensions defined in model - model_input: Vec<Vec<Option<u32>>>, + // Input dimensions used by inference call + // inference_input: Vec<Vec<usize>>, + // Input dimensions defined in model + // model_input: Vec<Vec<Option<u32>>>, }, } diff --git a/onnxruntime/src/lib.rs b/onnxruntime/src/lib.rs index 0d575b5e..6ae7c333 100644 --- a/onnxruntime/src/lib.rs +++ b/onnxruntime/src/lib.rs @@ -104,7 +104,10 @@ to download. //! let array = ndarray::Array::linspace(0.0_f32, 1.0, 100); //! // Multiple inputs and outputs are possible //! let input_tensor = vec![array]; -//! let outputs: Vec<OrtOwnedTensor<f32,_>> = session.run(input_tensor)?; +//! let outputs: Vec<OrtOwnedTensor<f32, _>> = session.run(input_tensor)? +//! .into_iter() +//! .map(|dyn_tensor| dyn_tensor.try_extract()) +//! .collect::<Result<_, _>>()?; //! # Ok(()) //! # } //! ``` @@ -115,7 +118,10 @@ to download. //! See the [`sample.rs`](https://github.com/nbigaouette/onnxruntime-rs/blob/master/onnxruntime/examples/sample.rs) //! example for more details. -use std::sync::{atomic::AtomicPtr, Arc, Mutex}; +use std::{ + ffi, ptr, + sync::{atomic::AtomicPtr, Arc, Mutex}, +}; use lazy_static::lazy_static; @@ -142,7 +148,7 @@ lazy_static! { // } as *mut sys::OrtApi))); static ref G_ORT_API: Arc<Mutex<AtomicPtr<sys::OrtApi>>> = { let base: *const sys::OrtApiBase = unsafe { sys::OrtGetApiBase() }; - assert_ne!(base, std::ptr::null()); + assert_ne!(base, ptr::null()); let get_api: unsafe extern "C" fn(u32) -> *const onnxruntime_sys::OrtApi = unsafe { (*base).GetApi.unwrap() }; let api: *const sys::OrtApi = unsafe { get_api(sys::ORT_API_VERSION) }; @@ -157,13 +163,13 @@ fn g_ort() -> sys::OrtApi { let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut(); let api_ptr_mut: *mut sys::OrtApi = *api_ref_mut; - assert_ne!(api_ptr_mut, std::ptr::null_mut()); + assert_ne!(api_ptr_mut, ptr::null_mut()); unsafe { *api_ptr_mut } } fn char_p_to_string(raw: *const i8) -> Result<String> { - let c_string = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; + let c_string = unsafe { ffi::CStr::from_ptr(raw as *mut i8).to_owned() }; match c_string.into_string() { Ok(string) => Ok(string), @@ -176,7 +182,7 @@ mod onnxruntime { //! Module containing a custom logger, used to catch the runtime's own logging and send it //! to Rust's tracing logging instead. - use std::ffi::CStr; + use std::{ffi, ffi::CStr, ptr}; use tracing::{debug, error, info, span, trace, warn, Level}; use onnxruntime_sys as sys; @@ -212,7 +218,7 @@ mod onnxruntime { /// Callback from C that will handle the logging, forwarding the runtime's logs to the tracing crate. pub(crate) extern "C" fn custom_logger( - _params: *mut std::ffi::c_void, + _params: *mut ffi::c_void, severity: sys::OrtLoggingLevel, category: *const i8, logid: *const i8, @@ -227,16 +233,16 @@ mod onnxruntime { sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => Level::ERROR, }; - assert_ne!(category, std::ptr::null()); + assert_ne!(category, ptr::null()); let category = unsafe { CStr::from_ptr(category) }; - assert_ne!(code_location, std::ptr::null()); + assert_ne!(code_location, ptr::null()); let code_location = unsafe { CStr::from_ptr(code_location) } .to_str() .unwrap_or("unknown"); - assert_ne!(message, std::ptr::null()); + assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }; - assert_ne!(logid, std::ptr::null()); + assert_ne!(logid, ptr::null()); let logid = unsafe { CStr::from_ptr(logid) }; // Parse the code location @@ -322,154 +328,6 @@ impl Into<sys::GraphOptimizationLevel> for GraphOptimizationLevel { } } -// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum -// FIXME: Add tests to cover the commented out types -/// Enum mapping ONNX Runtime's supported tensor types -#[derive(Debug)] -#[cfg_attr(not(windows), repr(u32))] -#[cfg_attr(windows, repr(i32))] -pub enum TensorElementDataType { - /// 32-bit floating point, equivalent to Rust's `f32` - Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, - /// Unsigned 8-bit int, equivalent to Rust's `u8` - Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, - /// Signed 8-bit int, equivalent to Rust's `i8` - Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, - /// Unsigned 16-bit int, equivalent to Rust's `u16` - Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, - /// Signed 16-bit int, equivalent to Rust's `i16` - Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, - /// Signed 32-bit int, equivalent to Rust's `i32` - Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, - /// Signed 64-bit int, equivalent to Rust's `i64` - Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, - /// String, equivalent to Rust's `String` - String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, - // /// Boolean, equivalent to Rust's `bool` - // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, - // /// 16-bit floating point, equivalent to Rust's `f16` - // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, - /// 64-bit floating point, equivalent to Rust's `f64` - Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, - /// Unsigned 32-bit int, equivalent to Rust's `u32` - Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, - /// Unsigned 64-bit int, equivalent to Rust's `u64` - Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, - // /// Complex 64-bit floating point, equivalent to Rust's `???` - // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, - // /// Complex 128-bit floating point, equivalent to Rust's `???` - // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, - // /// Brain 16-bit floating point - // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, -} - -impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType { - fn into(self) -> sys::ONNXTensorElementDataType { - use TensorElementDataType::*; - match self { - Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, - Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, - Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, - Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, - Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, - Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, - String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, - // Bool => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL - // } - // Float16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 - // } - Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, - Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, - Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, - // Complex64 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 - // } - // Complex128 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 - // } - // Bfloat16 => { - // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 - // } - } - } -} - -/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) -pub trait TypeToTensorElementDataType { - /// Return the ONNX type for a Rust type - fn tensor_element_data_type() -> TensorElementDataType; - - /// If the type is `String`, returns `Some` with utf8 contents, else `None`. - fn try_utf8_bytes(&self) -> Option<&[u8]>; -} - -macro_rules! impl_type_trait { - ($type_:ty, $variant:ident) => { - impl TypeToTensorElementDataType for $type_ { - fn tensor_element_data_type() -> TensorElementDataType { - // unsafe { std::mem::transmute(TensorElementDataType::$variant) } - TensorElementDataType::$variant - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - None - } - } - }; -} - -impl_type_trait!(f32, Float); -impl_type_trait!(u8, Uint8); -impl_type_trait!(i8, Int8); -impl_type_trait!(u16, Uint16); -impl_type_trait!(i16, Int16); -impl_type_trait!(i32, Int32); -impl_type_trait!(i64, Int64); -// impl_type_trait!(bool, Bool); -// impl_type_trait!(f16, Float16); -impl_type_trait!(f64, Double); -impl_type_trait!(u32, Uint32); -impl_type_trait!(u64, Uint64); -// impl_type_trait!(, Complex64); -// impl_type_trait!(, Complex128); -// impl_type_trait!(, Bfloat16); - -/// Adapter for common Rust string types to Onnx strings. -/// -/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but -/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it -/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric -/// types (which might implement `AsRef<str>` at some point in the future). -pub trait Utf8Data { - /// Returns the utf8 contents. - fn utf8_bytes(&self) -> &[u8]; -} - -impl Utf8Data for String { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<'a> Utf8Data for &'a str { - fn utf8_bytes(&self) -> &[u8] { - self.as_bytes() - } -} - -impl<T: Utf8Data> TypeToTensorElementDataType for T { - fn tensor_element_data_type() -> TensorElementDataType { - TensorElementDataType::String - } - - fn try_utf8_bytes(&self) -> Option<&[u8]> { - Some(self.utf8_bytes()) - } -} - /// Allocator type #[derive(Debug, Clone)] #[repr(i32)] @@ -524,7 +382,7 @@ mod test { #[test] fn test_char_p_to_string() { - let s = std::ffi::CString::new("foo").unwrap(); + let s = ffi::CString::new("foo").unwrap(); let ptr = s.as_c_str().as_ptr(); assert_eq!("foo", char_p_to_string(ptr).unwrap()); } diff --git a/onnxruntime/src/session.rs b/onnxruntime/src/session.rs index 04f9cf1c..0f02cb87 100644 --- a/onnxruntime/src/session.rs +++ b/onnxruntime/src/session.rs @@ -1,6 +1,6 @@ //! Module containing session types -use std::{ffi::CString, fmt::Debug, path::Path}; +use std::{convert::TryInto as _, ffi::CString, fmt::Debug, path::Path}; #[cfg(not(target_family = "windows"))] use std::os::unix::ffi::OsStrExt; @@ -18,15 +18,11 @@ use onnxruntime_sys as sys; use crate::{ char_p_to_string, environment::Environment, - error::{status_to_result, NonMatchingDimensionsError, OrtError, Result}, + error::{call_ort, status_to_result, NonMatchingDimensionsError, OrtError, Result}, g_ort, memory::MemoryInfo, - tensor::{ - ort_owned_tensor::{OrtOwnedTensor, OrtOwnedTensorExtractor}, - OrtTensor, - }, - AllocatorType, GraphOptimizationLevel, MemType, TensorElementDataType, - TypeToTensorElementDataType, + tensor::{DynOrtTensor, OrtTensor, TensorElementDataType, TypeToTensorElementDataType}, + AllocatorType, GraphOptimizationLevel, MemType, }; #[cfg(feature = "model-fetching")] @@ -219,12 +215,14 @@ impl<'a> SessionBuilder<'a> { let outputs = (0..num_output_nodes) .map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i)) .collect::<Result<Vec<Output>>>()?; + let input_ort_values = Vec::with_capacity(num_output_nodes as usize); Ok(Session { env: self.env, session_ptr, allocator_ptr, memory_info, + input_ort_values, inputs, outputs, }) @@ -275,12 +273,14 @@ impl<'a> SessionBuilder<'a> { let outputs = (0..num_output_nodes) .map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i)) .collect::<Result<Vec<Output>>>()?; + let input_ort_values = Vec::with_capacity(num_output_nodes as usize); Ok(Session { env: self.env, session_ptr, allocator_ptr, memory_info, + input_ort_values, inputs, outputs, }) @@ -294,6 +294,7 @@ pub struct Session<'a> { session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, memory_info: MemoryInfo, + input_ort_values: Vec<*const sys::OrtValue>, /// Information about the ONNX's inputs as stored in loaded file pub inputs: Vec<Input>, /// Information about the ONNX's outputs as stored in loaded file @@ -361,24 +362,80 @@ impl<'a> Drop for Session<'a> { } impl<'a> Session<'a> { + /// Somedoc + pub fn feed<'s, 't, 'm, TIn, D>(&'s mut self, input_array: Array<TIn, D>) -> Result<()> + where + TIn: TypeToTensorElementDataType + Debug + Clone, + D: ndarray::Dimension, + 'm: 't, // 'm outlives 't (memory info outlives tensor) + 's: 'm, // 's outlives 'm (session outlives memory info) + { + self.validate_input_shapes(&input_array); + // The C API expects pointers for the arrays (pointers to C-arrays) + let input_ort_tensor: OrtTensor<TIn, D> = + OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)?; + + let input_ort_value: *const sys::OrtValue = input_ort_tensor.c_ptr as *const sys::OrtValue; + std::mem::forget(input_ort_tensor); + self.input_ort_values.push(input_ort_value); + + Ok(()) + } + /// Run the input data through the ONNX graph, performing inference. /// /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus /// used for the input data here. - pub fn run<'s, 't, 'm, TIn, TOut, D>( + pub fn run<'s, 't, 'm, TIn, D>( &'s mut self, input_arrays: Vec<Array<TIn, D>>, - ) -> Result<Vec<OrtOwnedTensor<'t, 'm, TOut, ndarray::IxDyn>>> + ) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>> where TIn: TypeToTensorElementDataType + Debug + Clone, - TOut: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, 'm: 't, // 'm outlives 't (memory info outlives tensor) 's: 'm, // 's outlives 'm (session outlives memory info) { - self.validate_input_shapes(&input_arrays)?; - + input_arrays + .into_iter() + .for_each(|input_array| self.feed(input_array).unwrap()); + self.inner_run() + } + /// Run the input data through the ONNX graph, performing inference. + /// + /// Note that ONNX models can have multiple inputs; a `Vec<_>` is thus + /// used for the input data here. + pub fn inner_run<'s, 't, 'm>( + &'s mut self, + // input_arrays: Vec<Array<TIn, D>>, + ) -> Result<Vec<DynOrtTensor<'m, ndarray::IxDyn>>> + where + 'm: 't, // 'm outlives 't (memory info outlives tensor) + 's: 'm, // 's outlives 'm (session outlives memory info) + { // Build arguments to Run() + if self.input_ort_values.len() != self.inputs.len() { + error!( + "Non-matching number of inputs: {} (inference) vs {} (model)", + self.input_ort_values.len(), + self.inputs.len() + ); + return Err(OrtError::NonMatchingDimensions( + NonMatchingDimensionsError::InputsCount { + inference_input_count: 0, + model_input_count: 0, + // inference_input: input_arrays + // .iter() + // .map(|input_array| input_array.shape().to_vec()) + // .collect(), + // model_input: self + // .inputs + // .iter() + // .map(|input| input.dimensions.clone()) + // .collect(), + }, + )); + } let input_names: Vec<String> = self.inputs.iter().map(|input| input.name.clone()).collect(); let input_names_cstring: Vec<CString> = input_names @@ -405,21 +462,9 @@ impl<'a> Session<'a> { .map(|n| n.as_ptr() as *const i8) .collect(); - let mut output_tensor_extractors_ptrs: Vec<*mut sys::OrtValue> = + let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()]; - // The C API expects pointers for the arrays (pointers to C-arrays) - let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays - .into_iter() - .map(|input_array| { - OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array) - }) - .collect::<Result<Vec<OrtTensor<TIn, D>>>>()?; - let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors - .iter() - .map(|input_array_ort| input_array_ort.c_ptr as *const sys::OrtValue) - .collect(); - let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null(); let status = unsafe { @@ -427,34 +472,58 @@ impl<'a> Session<'a> { self.session_ptr, run_options_ptr, input_names_ptr.as_ptr(), - input_ort_values.as_ptr(), - input_ort_values.len() as u64, // C API expects a u64, not isize + self.input_ort_values.as_ptr(), + self.input_ort_values.len() as u64, // C API expects a u64, not isize output_names_ptr.as_ptr(), output_names_ptr.len() as u64, // C API expects a u64, not isize - output_tensor_extractors_ptrs.as_mut_ptr(), + output_tensor_ptrs.as_mut_ptr(), ) }; status_to_result(status).map_err(OrtError::Run)?; + self.input_ort_values.iter().for_each(std::mem::drop); + self.input_ort_values.clear(); let memory_info_ref = &self.memory_info; - let outputs: Result<Vec<OrtOwnedTensor<TOut, ndarray::Dim<ndarray::IxDynImpl>>>> = - output_tensor_extractors_ptrs + let outputs: Result<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = + output_tensor_ptrs .into_iter() - .map(|ptr| { - let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = - std::ptr::null_mut(); - let status = unsafe { - g_ort().GetTensorTypeAndShape.unwrap()(ptr, &mut tensor_info_ptr as _) - }; - status_to_result(status).map_err(OrtError::GetTensorTypeAndShape)?; - let dims = unsafe { get_tensor_dimensions(tensor_info_ptr) }; - unsafe { g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr) }; - let dims: Vec<_> = dims?.iter().map(|&n| n as usize).collect(); - - let mut output_tensor_extractor = - OrtOwnedTensorExtractor::new(memory_info_ref, ndarray::IxDyn(&dims)); - output_tensor_extractor.tensor_ptr = ptr; - output_tensor_extractor.extract::<TOut>() + .map(|tensor_ptr| { + let (dims, data_type, len) = unsafe { + call_with_tensor_info(tensor_ptr, |tensor_info_ptr| { + get_tensor_dimensions(tensor_info_ptr) + .map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>()) + .and_then(|dims| { + extract_data_type(tensor_info_ptr) + .map(|data_type| (dims, data_type)) + }) + .and_then(|(dims, data_type)| { + let mut len = 0_u64; + + call_ort(|ort| { + ort.GetTensorShapeElementCount.unwrap()( + tensor_info_ptr, + &mut len, + ) + }) + .map_err(OrtError::GetTensorShapeElementCount)?; + + Ok(( + dims, + data_type, + len.try_into() + .expect("u64 length could not fit into usize"), + )) + }) + }) + }?; + + Ok(DynOrtTensor::new( + tensor_ptr, + memory_info_ref, + ndarray::IxDyn(&dims), + len, + data_type, + )) }) .collect(); @@ -477,7 +546,7 @@ impl<'a> Session<'a> { // Tensor::from_array(self, array) // } - fn validate_input_shapes<TIn, D>(&mut self, input_arrays: &[Array<TIn, D>]) -> Result<()> + fn validate_input_shapes<TIn, D>(&mut self, input_array: &Array<TIn, D>) where TIn: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, @@ -487,66 +556,43 @@ impl<'a> Session<'a> { // Make sure all dimensions match (except dynamic ones) // Verify length of inputs - if input_arrays.len() != self.inputs.len() { - error!( - "Non-matching number of inputs: {} (inference) vs {} (model)", - input_arrays.len(), - self.inputs.len() - ); - return Err(OrtError::NonMatchingDimensions( - NonMatchingDimensionsError::InputsCount { - inference_input_count: 0, - model_input_count: 0, - inference_input: input_arrays - .iter() - .map(|input_array| input_array.shape().to_vec()) - .collect(), - model_input: self - .inputs - .iter() - .map(|input| input.dimensions.clone()) - .collect(), - }, - )); - } // Verify length of each individual inputs - let inputs_different_length = input_arrays - .iter() - .zip(self.inputs.iter()) - .any(|(l, r)| l.shape().len() != r.dimensions.len()); - if inputs_different_length { + let current_input = self.input_ort_values.len(); + if current_input > self.inputs.len() { error!( - "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + "Attempting to feed too many inputs, expecting {:?} inputs", + self.inputs.len() ); panic!( - "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + "Attempting to feed too many inputs, expecting {:?} inputs", + self.inputs.len() ); } + let input = &self.inputs[current_input]; + if input_array.shape().len() != input.dimensions().count() { + error!("Different input lengths: {:?} vs {:?}", input, input_array); + panic!("Different input lengths: {:?} vs {:?}", input, input_array); + } - // Verify shape of each individual inputs - let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| { - let l_shape = l.shape(); - let r_shape = r.dimensions.as_slice(); - l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { - Some(r3) => *r3 as usize != *l2, - None => false, // None means dynamic size; in that case shape always match - }) + let l = input_array; + let r = input; + let l_shape = l.shape(); + let r_shape = r.dimensions.as_slice(); + let inputs_different_shape = l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 { + Some(r3) => *r3 as usize != *l2, + None => false, // None means dynamic size; in that case shape always match }); if inputs_different_shape { error!( "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + self.inputs, input_array ); panic!( "Different input lengths: {:?} vs {:?}", - self.inputs, input_arrays + self.inputs, input_array ); } - - Ok(()) } } @@ -554,25 +600,60 @@ unsafe fn get_tensor_dimensions( tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, ) -> Result<Vec<i64>> { let mut num_dims = 0; - let status = g_ort().GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims); - status_to_result(status).map_err(OrtError::GetDimensionsCount)?; + call_ort(|ort| ort.GetDimensionsCount.unwrap()(tensor_info_ptr, &mut num_dims)) + .map_err(OrtError::GetDimensionsCount)?; assert_ne!(num_dims, 0); let mut node_dims: Vec<i64> = vec![0; num_dims as usize]; - let status = g_ort().GetDimensions.unwrap()( - tensor_info_ptr, - node_dims.as_mut_ptr(), // FIXME: UB? - num_dims, - ); - status_to_result(status).map_err(OrtError::GetDimensions)?; + call_ort(|ort| { + ort.GetDimensions.unwrap()( + tensor_info_ptr, + node_dims.as_mut_ptr(), // FIXME: UB? + num_dims, + ) + }) + .map_err(OrtError::GetDimensions)?; Ok(node_dims) } +unsafe fn extract_data_type( + tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo, +) -> Result<TensorElementDataType> { + let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + call_ort(|ort| ort.GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys)) + .map_err(OrtError::TensorElementType)?; + assert_ne!( + type_sys, + sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED + ); + // This transmute should be safe since its value is read from GetTensorElementType which we must trust. + Ok(std::mem::transmute(type_sys)) +} + +/// Calls the provided closure with the result of `GetTensorTypeAndShape`, deallocating the +/// resulting `*OrtTensorTypeAndShapeInfo` before returning. +unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> Result<T> +where + F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> Result<T>, +{ + let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut(); + call_ort(|ort| ort.GetTensorTypeAndShape.unwrap()(tensor_ptr, &mut tensor_info_ptr as _)) + .map_err(OrtError::GetTensorTypeAndShape)?; + + let res = f(tensor_info_ptr); + + // no return code, so no errors to check for + g_ort().ReleaseTensorTypeAndShapeInfo.unwrap()(tensor_info_ptr); + + res +} + /// This module contains dangerous functions working on raw pointers. /// Those functions are only to be used from inside the /// `SessionBuilder::with_model_from_file()` method. mod dangerous { use super::*; + use crate::tensor::TensorElementDataType; pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> Result<u64> { let f = g_ort().SessionGetInputCount.unwrap(); @@ -689,16 +770,7 @@ mod dangerous { status_to_result(status).map_err(OrtError::CastTypeInfoToTensorInfo)?; assert_ne!(tensor_info_ptr, std::ptr::null_mut()); - let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - let status = - unsafe { g_ort().GetTensorElementType.unwrap()(tensor_info_ptr, &mut type_sys) }; - status_to_result(status).map_err(OrtError::TensorElementType)?; - assert_ne!( - type_sys, - sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED - ); - // This transmute should be safe since its value is read from GetTensorElementType which we must trust. - let io_type: TensorElementDataType = unsafe { std::mem::transmute(type_sys) }; + let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? }; // info!("{} : type={}", i, type_); diff --git a/onnxruntime/src/tensor.rs b/onnxruntime/src/tensor.rs index 92404842..74e8329c 100644 --- a/onnxruntime/src/tensor.rs +++ b/onnxruntime/src/tensor.rs @@ -27,5 +27,332 @@ pub mod ndarray_tensor; pub mod ort_owned_tensor; pub mod ort_tensor; -pub use ort_owned_tensor::OrtOwnedTensor; +pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor}; pub use ort_tensor::OrtTensor; + +use crate::tensor::ort_owned_tensor::TensorPointerHolder; +use crate::{error::call_ort, OrtError, Result}; +use onnxruntime_sys::{self as sys, OnnxEnumInt}; +use std::{convert::TryInto as _, ffi, fmt, ptr, rc, result, string}; + +// FIXME: Use https://docs.rs/bindgen/0.54.1/bindgen/struct.Builder.html#method.rustified_enum +// FIXME: Add tests to cover the commented out types +/// Enum mapping ONNX Runtime's supported tensor types +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(not(windows), repr(u32))] +#[cfg_attr(windows, repr(i32))] +pub enum TensorElementDataType { + /// 32-bit floating point, equivalent to Rust's `f32` + Float = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt, + /// Unsigned 8-bit int, equivalent to Rust's `u8` + Uint8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt, + /// Signed 8-bit int, equivalent to Rust's `i8` + Int8 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt, + /// Unsigned 16-bit int, equivalent to Rust's `u16` + Uint16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt, + /// Signed 16-bit int, equivalent to Rust's `i16` + Int16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt, + /// Signed 32-bit int, equivalent to Rust's `i32` + Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt, + /// Signed 64-bit int, equivalent to Rust's `i64` + Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt, + /// String, equivalent to Rust's `String` + String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt, + // /// Boolean, equivalent to Rust's `bool` + // Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt, + // /// 16-bit floating point, equivalent to Rust's `f16` + // Float16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt, + /// 64-bit floating point, equivalent to Rust's `f64` + Double = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt, + /// Unsigned 32-bit int, equivalent to Rust's `u32` + Uint32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt, + /// Unsigned 64-bit int, equivalent to Rust's `u64` + Uint64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt, + // /// Complex 64-bit floating point, equivalent to Rust's `???` + // Complex64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 as OnnxEnumInt, + // /// Complex 128-bit floating point, equivalent to Rust's `???` + // Complex128 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 as OnnxEnumInt, + // /// Brain 16-bit floating point + // Bfloat16 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt, +} + +impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType { + fn into(self) -> sys::ONNXTensorElementDataType { + use TensorElementDataType::*; + match self { + Float => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + Uint8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, + Int8 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, + Uint16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, + Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, + Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, + Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, + // Bool => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL + // } + // Float16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 + // } + Double => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, + Uint32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, + Uint64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, + // Complex64 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 + // } + // Complex128 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 + // } + // Bfloat16 => { + // sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 + // } + } + } +} + +/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`) +pub trait TypeToTensorElementDataType { + /// Return the ONNX type for a Rust type + fn tensor_element_data_type() -> TensorElementDataType; + + /// If the type is `String`, returns `Some` with utf8 contents, else `None`. + fn try_utf8_bytes(&self) -> Option<&[u8]>; +} + +macro_rules! impl_prim_type_to_ort_trait { + ($type_:ty, $variant:ident) => { + impl TypeToTensorElementDataType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + // unsafe { std::mem::transmute(TensorElementDataType::$variant) } + TensorElementDataType::$variant + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + None + } + } + }; +} + +impl_prim_type_to_ort_trait!(f32, Float); +impl_prim_type_to_ort_trait!(u8, Uint8); +impl_prim_type_to_ort_trait!(i8, Int8); +impl_prim_type_to_ort_trait!(u16, Uint16); +impl_prim_type_to_ort_trait!(i16, Int16); +impl_prim_type_to_ort_trait!(i32, Int32); +impl_prim_type_to_ort_trait!(i64, Int64); +// impl_type_trait!(bool, Bool); +// impl_type_trait!(f16, Float16); +impl_prim_type_to_ort_trait!(f64, Double); +impl_prim_type_to_ort_trait!(u32, Uint32); +impl_prim_type_to_ort_trait!(u64, Uint64); +// impl_type_trait!(, Complex64); +// impl_type_trait!(, Complex128); +// impl_type_trait!(, Bfloat16); + +/// Adapter for common Rust string types to Onnx strings. +/// +/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but +/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it +/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric +/// types (which might implement `AsRef<str>` at some point in the future). +pub trait Utf8Data { + /// Returns the utf8 contents. + fn utf8_bytes(&self) -> &[u8]; +} + +impl Utf8Data for String { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<'a> Utf8Data for &'a str { + fn utf8_bytes(&self) -> &[u8] { + self.as_bytes() + } +} + +impl<T: Utf8Data> TypeToTensorElementDataType for T { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn try_utf8_bytes(&self) -> Option<&[u8]> { + Some(self.utf8_bytes()) + } +} + +/// Trait used to map onnxruntime types to Rust types +pub trait TensorDataToType: Sized + fmt::Debug { + /// The tensor element type that this type can extract from + fn tensor_element_data_type() -> TensorElementDataType; + + /// Extract an `ArrayView` from the ort-owned tensor. + fn extract_data<'t, D>( + shape: D, + tensor_element_len: usize, + tensor_ptr: rc::Rc<TensorPointerHolder>, + ) -> Result<TensorData<'t, Self, D>> + where + D: ndarray::Dimension; +} + +/// Represents the possible ways tensor data can be accessed. +/// +/// This should only be used internally. +#[derive(Debug)] +pub enum TensorData<'t, T, D> +where + D: ndarray::Dimension, +{ + /// Data resides in ort's tensor, in which case the 't lifetime is what makes this valid. + /// This is used for data types whose in-memory form from ort is compatible with Rust's, like + /// primitive numeric types. + TensorPtr { + /// The pointer ort produced. Kept alive so that `array_view` is valid. + ptr: rc::Rc<TensorPointerHolder>, + /// A view into `ptr` + array_view: ndarray::ArrayView<'t, T, D>, + }, + /// String data is output differently by ort, and of course is also variable size, so it cannot + /// use the same simple pointer representation. + // Since 't outlives this struct, the 't lifetime is more than we need, but no harm done. + Strings { + /// Owned Strings copied out of ort's output + strings: ndarray::Array<T, D>, + }, +} + +/// Implements `OwnedTensorDataToType` for primitives, which can use `GetTensorMutableData` +macro_rules! impl_prim_type_from_ort_trait { + ($type_:ty, $variant:ident) => { + impl TensorDataToType for $type_ { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::$variant + } + + fn extract_data<'t, D>( + shape: D, + _tensor_element_len: usize, + tensor_ptr: rc::Rc<TensorPointerHolder>, + ) -> Result<TensorData<'t, Self, D>> + where + D: ndarray::Dimension, + { + extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| { + TensorData::TensorPtr { + ptr: tensor_ptr, + array_view: v, + } + }) + } + } + }; +} + +/// Construct an [ndarray::ArrayView] over an Ort tensor. +/// +/// Only to be used on types whose Rust in-memory representation matches Ort's (e.g. primitive +/// numeric types like u32). +fn extract_primitive_array<'t, D, T: TensorDataToType>( + shape: D, + tensor: *mut sys::OrtValue, +) -> Result<ndarray::ArrayView<'t, T, D>> +where + D: ndarray::Dimension, +{ + // Get pointer to output tensor float values + let mut output_array_ptr: *mut T = ptr::null_mut(); + let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; + let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = + output_array_ptr_ptr as *mut *mut std::ffi::c_void; + unsafe { + crate::error::call_ort(|ort| { + ort.GetTensorMutableData.unwrap()(tensor, output_array_ptr_ptr_void) + }) + } + .map_err(OrtError::GetTensorMutableData)?; + assert_ne!(output_array_ptr, ptr::null_mut()); + + let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) }; + Ok(array_view) +} + +impl_prim_type_from_ort_trait!(f32, Float); +impl_prim_type_from_ort_trait!(u8, Uint8); +impl_prim_type_from_ort_trait!(i8, Int8); +impl_prim_type_from_ort_trait!(u16, Uint16); +impl_prim_type_from_ort_trait!(i16, Int16); +impl_prim_type_from_ort_trait!(i32, Int32); +impl_prim_type_from_ort_trait!(i64, Int64); +impl_prim_type_from_ort_trait!(f64, Double); +impl_prim_type_from_ort_trait!(u32, Uint32); +impl_prim_type_from_ort_trait!(u64, Uint64); + +impl TensorDataToType for String { + fn tensor_element_data_type() -> TensorElementDataType { + TensorElementDataType::String + } + + fn extract_data<'t, D: ndarray::Dimension>( + shape: D, + tensor_element_len: usize, + tensor_ptr: rc::Rc<TensorPointerHolder>, + ) -> Result<TensorData<'t, Self, D>> { + // Total length of string data, not including \0 suffix + let mut total_length = 0_u64; + unsafe { + call_ort(|ort| { + ort.GetStringTensorDataLength.unwrap()(tensor_ptr.tensor_ptr, &mut total_length) + }) + .map_err(OrtError::GetStringTensorDataLength)? + } + + // In the JNI impl of this, tensor_element_len was included in addition to total_length, + // but that seems contrary to the docs of GetStringTensorDataLength, and those extra bytes + // don't seem to be written to in practice either. + // If the string data actually did go farther, it would panic below when using the offset + // data to get slices for each string. + let mut string_contents = vec![0_u8; total_length as usize]; + // one extra slot so that the total length can go in the last one, making all per-string + // length calculations easy + let mut offsets = vec![0_u64; tensor_element_len as usize + 1]; + + unsafe { + call_ort(|ort| { + ort.GetStringTensorContent.unwrap()( + tensor_ptr.tensor_ptr, + string_contents.as_mut_ptr() as *mut ffi::c_void, + total_length, + offsets.as_mut_ptr(), + tensor_element_len as u64, + ) + }) + .map_err(OrtError::GetStringTensorContent)? + } + + // final offset = overall length so that per-string length calculations work for the last + // string + debug_assert_eq!(0, offsets[tensor_element_len]); + offsets[tensor_element_len] = total_length; + + let strings = offsets + // offsets has 1 extra offset past the end so that all windows work + .windows(2) + .map(|w| { + let start: usize = w[0].try_into().expect("Offset didn't fit into usize"); + let next_start: usize = w[1].try_into().expect("Offset didn't fit into usize"); + + let slice = &string_contents[start..next_start]; + String::from_utf8(slice.into()) + }) + .collect::<result::Result<Vec<String>, string::FromUtf8Error>>() + .map_err(OrtError::StringFromUtf8Error)?; + + let array = ndarray::Array::from_shape_vec(shape, strings) + .expect("Shape extracted from tensor didn't match tensor contents"); + + Ok(TensorData::Strings { strings: array }) + } +} diff --git a/onnxruntime/src/tensor/ort_owned_tensor.rs b/onnxruntime/src/tensor/ort_owned_tensor.rs index 161fe105..f782df1b 100644 --- a/onnxruntime/src/tensor/ort_owned_tensor.rs +++ b/onnxruntime/src/tensor/ort_owned_tensor.rs @@ -1,134 +1,229 @@ //! Module containing tensor with memory owned by the ONNX Runtime -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, ops::Deref, ptr, rc, result}; -use ndarray::{Array, ArrayView}; +use ndarray::ArrayView; +use thiserror::Error; use tracing::debug; use onnxruntime_sys as sys; use crate::{ - error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor, - OrtError, Result, TypeToTensorElementDataType, + error::call_ort, + g_ort, + memory::MemoryInfo, + tensor::{TensorData, TensorDataToType, TensorElementDataType}, + OrtError, }; -/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. -/// -/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. -/// It is not meant to be created directly. -/// -/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html) -/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data. +/// Errors that can occur while extracting a tensor from ort output. +#[derive(Error, Debug)] +pub enum TensorExtractError { + /// The user tried to extract the wrong type of tensor from the underlying data + #[error( + "Data type mismatch: was {:?}, tried to convert to {:?}", + actual, + requested + )] + DataTypeMismatch { + /// The actual type of the ort output + actual: TensorElementDataType, + /// The type corresponding to the attempted conversion into a Rust type, not equal to `actual` + requested: TensorElementDataType, + }, + /// An onnxruntime error occurred + #[error("Onnxruntime error: {:?}", 0)] + OrtError(#[from] OrtError), +} + +/// A wrapper around a tensor produced by onnxruntime inference. /// -/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to -/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). +/// Since different outputs for the same model can have different types, this type is used to allow +/// the user to dynamically query each output's type and extract the appropriate tensor type with +/// [try_extract]. #[derive(Debug)] -pub struct OrtOwnedTensor<'t, 'm, T, D> +pub struct DynOrtTensor<'m, D> where - T: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, - 'm: 't, // 'm outlives 't { - pub(crate) tensor_ptr: *mut sys::OrtValue, - array_view: ArrayView<'t, T, D>, + // TODO could this also hold a Vec<u8> for strings so that the extracted tensor could then + // hold a Vec<&str>? + tensor_ptr_holder: rc::Rc<TensorPointerHolder>, memory_info: &'m MemoryInfo, + shape: D, + tensor_element_len: usize, + data_type: TensorElementDataType, } -impl<'t, 'm, T, D> Deref for OrtOwnedTensor<'t, 'm, T, D> +impl<'m, D> DynOrtTensor<'m, D> where - T: TypeToTensorElementDataType + Debug + Clone, D: ndarray::Dimension, { - type Target = ArrayView<'t, T, D>; + pub(crate) fn new( + tensor_ptr: *mut sys::OrtValue, + memory_info: &'m MemoryInfo, + shape: D, + tensor_element_len: usize, + data_type: TensorElementDataType, + ) -> DynOrtTensor<'m, D> { + DynOrtTensor { + tensor_ptr_holder: rc::Rc::from(TensorPointerHolder { tensor_ptr }), + memory_info, + shape, + tensor_element_len, + data_type, + } + } - fn deref(&self) -> &Self::Target { - &self.array_view + /// The ONNX data type this tensor contains. + pub fn data_type(&self) -> TensorElementDataType { + self.data_type + } + + /// Extract a tensor containing `T`. + /// + /// Where the type permits it, the tensor will be a view into existing memory. + /// + /// # Errors + /// + /// An error will be returned if `T`'s ONNX type doesn't match this tensor's type, or if an + /// onnxruntime error occurs. + pub fn try_extract<'t, T>(&self) -> result::Result<OrtOwnedTensor<'t, T, D>, TensorExtractError> + where + T: TensorDataToType + Clone + Debug, + 'm: 't, // mem info outlives tensor + D: 't, // not clear why this is needed since we clone the shape, but it doesn't make + // a difference in practice since the shape is extracted from the tensor + { + if self.data_type != T::tensor_element_data_type() { + Err(TensorExtractError::DataTypeMismatch { + actual: self.data_type, + requested: T::tensor_element_data_type(), + }) + } else { + // Note: Both tensor and array will point to the same data, nothing is copied. + // As such, there is no need to free the pointer used to create the ArrayView. + assert_ne!(self.tensor_ptr_holder.tensor_ptr, ptr::null_mut()); + + let mut is_tensor = 0; + unsafe { + call_ort(|ort| { + ort.IsTensor.unwrap()(self.tensor_ptr_holder.tensor_ptr, &mut is_tensor) + }) + } + .map_err(OrtError::IsTensor)?; + assert_eq!(is_tensor, 1); + + let data = T::extract_data( + self.shape.clone(), + self.tensor_element_len, + rc::Rc::clone(&self.tensor_ptr_holder), + )?; + + Ok(OrtOwnedTensor { data }) + } } } -impl<'t, 'm, T, D> OrtOwnedTensor<'t, 'm, T, D> +/// Tensor containing data owned by the ONNX Runtime C library, used to return values from inference. +/// +/// This tensor type is returned by the [`Session::run()`](../session/struct.Session.html#method.run) method. +/// It is not meant to be created directly. +/// +/// The tensor hosts an [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html) +/// of the data on the C side. This allows manipulation on the Rust side using `ndarray` without copying the data. +/// +/// `OrtOwnedTensor` implements the [`std::deref::Deref`](#impl-Deref) trait for ergonomic access to +/// the underlying [`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html). +#[derive(Debug)] +pub struct OrtOwnedTensor<'t, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, { - /// Apply a softmax on the specified axis - pub fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> + data: TensorData<'t, T, D>, +} + +impl<'t, T, D> OrtOwnedTensor<'t, T, D> +where + T: TensorDataToType, + D: ndarray::Dimension + 't, +{ + /// Produce a [ViewHolder] for the underlying data, which + pub fn view<'s>(&'s self) -> ViewHolder<'s, T, D> where - D: ndarray::RemoveAxis, - T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign, + 't: 's, // tensor ptr can outlive the TensorData { - self.array_view.softmax(axis) + ViewHolder::new(&self.data) } } -#[derive(Debug)] -pub(crate) struct OrtOwnedTensorExtractor<'m, D> +/// An intermediate step on the way to an ArrayView. +// Since Deref has to produce a reference, and the referent can't be a local in deref(), it must +// be a field in a struct. This struct exists only to hold that field. +// Its lifetime 's is bound to the TensorData its view was created around, not the underlying tensor +// pointer, since in the case of strings the data is the Array in the TensorData, not the pointer. +pub struct ViewHolder<'s, T, D> where + T: TensorDataToType, D: ndarray::Dimension, { - pub(crate) tensor_ptr: *mut sys::OrtValue, - memory_info: &'m MemoryInfo, - shape: D, + array_view: ndarray::ArrayView<'s, T, D>, } -impl<'m, D> OrtOwnedTensorExtractor<'m, D> +impl<'s, T, D> ViewHolder<'s, T, D> where + T: TensorDataToType, D: ndarray::Dimension, { - pub(crate) fn new(memory_info: &'m MemoryInfo, shape: D) -> OrtOwnedTensorExtractor<'m, D> { - OrtOwnedTensorExtractor { - tensor_ptr: std::ptr::null_mut(), - memory_info, - shape, - } - } - - pub(crate) fn extract<'t, T>(self) -> Result<OrtOwnedTensor<'t, 'm, T, D>> + fn new<'t>(data: &'s TensorData<'t, T, D>) -> ViewHolder<'s, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + 't: 's, // underlying tensor ptr lives at least as long as TensorData { - // Note: Both tensor and array will point to the same data, nothing is copied. - // As such, there is no need too free the pointer used to create the ArrayView. - - assert_ne!(self.tensor_ptr, std::ptr::null_mut()); - - let mut is_tensor = 0; - let status = unsafe { g_ort().IsTensor.unwrap()(self.tensor_ptr, &mut is_tensor) }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_eq!(is_tensor, 1); - - // Get pointer to output tensor float values - let mut output_array_ptr: *mut T = std::ptr::null_mut(); - let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr; - let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = - output_array_ptr_ptr as *mut *mut std::ffi::c_void; - let status = unsafe { - g_ort().GetTensorMutableData.unwrap()(self.tensor_ptr, output_array_ptr_ptr_void) - }; - status_to_result(status).map_err(OrtError::IsTensor)?; - assert_ne!(output_array_ptr, std::ptr::null_mut()); - - let array_view = unsafe { ArrayView::from_shape_ptr(self.shape, output_array_ptr) }; - - Ok(OrtOwnedTensor { - tensor_ptr: self.tensor_ptr, - array_view, - memory_info: self.memory_info, - }) + match data { + TensorData::TensorPtr { array_view, .. } => ViewHolder { + // we already have a view, but creating a view from a view is cheap + array_view: array_view.view(), + }, + TensorData::Strings { strings } => ViewHolder { + // This view creation has to happen here, not at new()'s callsite, because + // a field can't be a reference to another field in the same struct. Thus, we have + // this separate struct to hold the view that refers to the `Array`. + array_view: strings.view(), + }, + } } } -impl<'t, 'm, T, D> Drop for OrtOwnedTensor<'t, 'm, T, D> +impl<'t, T, D> Deref for ViewHolder<'t, T, D> where - T: TypeToTensorElementDataType + Debug + Clone, + T: TensorDataToType, D: ndarray::Dimension, - 'm: 't, // 'm outlives 't { + type Target = ArrayView<'t, T, D>; + + fn deref(&self) -> &Self::Target { + &self.array_view + } +} + +/// Holds on to a tensor pointer until dropped. +/// +/// This allows creating an [OrtOwnedTensor] from a [DynOrtTensor] without consuming `self`, which +/// would prevent retrying extraction and also make interacting with outputs `Vec` awkward. +/// It also avoids needing `OrtOwnedTensor` to keep a reference to `DynOrtTensor`, which would be +/// inconvenient. +#[derive(Debug)] +pub struct TensorPointerHolder { + pub(crate) tensor_ptr: *mut sys::OrtValue, +} + +impl Drop for TensorPointerHolder { #[tracing::instrument] fn drop(&mut self) { debug!("Dropping OrtOwnedTensor."); unsafe { g_ort().ReleaseValue.unwrap()(self.tensor_ptr) } - self.tensor_ptr = std::ptr::null_mut(); + self.tensor_ptr = ptr::null_mut(); } } diff --git a/onnxruntime/src/tensor/ort_tensor.rs b/onnxruntime/src/tensor/ort_tensor.rs index 437e2e86..0937afe1 100644 --- a/onnxruntime/src/tensor/ort_tensor.rs +++ b/onnxruntime/src/tensor/ort_tensor.rs @@ -8,9 +8,11 @@ use tracing::{debug, error}; use onnxruntime_sys as sys; use crate::{ - error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo, - tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType, - TypeToTensorElementDataType, + error::{call_ort, status_to_result}, + g_ort, + memory::MemoryInfo, + tensor::{ndarray_tensor::NdArrayTensor, TensorElementDataType, TypeToTensorElementDataType}, + OrtError, Result, }; /// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html) diff --git a/onnxruntime/tests/integration_tests.rs b/onnxruntime/tests/integration_tests.rs index ee531feb..c332e7ce 100644 --- a/onnxruntime/tests/integration_tests.rs +++ b/onnxruntime/tests/integration_tests.rs @@ -12,9 +12,11 @@ mod download { use ndarray::s; use test_env_log::test; + use onnxruntime::tensor::ndarray_tensor::NdArrayTensor; use onnxruntime::{ download::vision::{DomainBasedImageClassification, ImageClassification}, environment::Environment, + tensor::{DynOrtTensor, OrtOwnedTensor}, GraphOptimizationLevel, LoggingLevel, }; @@ -62,7 +64,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_rgb(); + .to_rgb8(); // Python: // # image[y, x, RGB] @@ -93,13 +95,14 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); // Downloaded model does not have a softmax as final layer; call softmax on second axis // and iterate on resulting probabilities, creating an index to later access labels. - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -170,7 +173,7 @@ mod download { input0_shape[3] as u32, FilterType::Nearest, ) - .to_luma(); + .to_luma8(); let array = ndarray::Array::from_shape_fn((1, 1, 28, 28), |(_, c, j, i)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -184,11 +187,12 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); - let mut probabilities: Vec<(usize, f32)> = outputs[0] + let output: OrtOwnedTensor<_, _> = outputs[0].try_extract().unwrap(); + let mut probabilities: Vec<(usize, f32)> = output + .view() .softmax(ndarray::Axis(1)) .into_iter() .copied() @@ -268,7 +272,7 @@ mod download { .join(IMAGE_TO_LOAD), ) .unwrap() - .to_rgb(); + .to_rgb8(); let array = ndarray::Array::from_shape_fn((1, 224, 224, 3), |(_, j, i, c)| { let pixel = image_buffer.get_pixel(i as u32, j as u32); @@ -282,15 +286,15 @@ mod download { let input_tensor_values = vec![array]; // Perform the inference - let outputs: Vec< - onnxruntime::tensor::OrtOwnedTensor<f32, ndarray::Dim<ndarray::IxDynImpl>>, - > = session.run(input_tensor_values).unwrap(); + let outputs: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = + session.run(input_tensor_values).unwrap(); assert_eq!(outputs.len(), 1); - let output = &outputs[0]; + let output: OrtOwnedTensor<'_, f32, ndarray::Dim<ndarray::IxDynImpl>> = + outputs[0].try_extract().unwrap(); // The image should have doubled in size - assert_eq!(output.shape(), [1, 448, 448, 3]); + assert_eq!(output.view().shape(), [1, 448, 448, 3]); } } diff --git a/onnxruntime/tests/string_type.rs b/onnxruntime/tests/string_type.rs new file mode 100644 index 00000000..628db05e --- /dev/null +++ b/onnxruntime/tests/string_type.rs @@ -0,0 +1,56 @@ +use std::error::Error; + +use onnxruntime::tensor::{OrtOwnedTensor, TensorElementDataType}; +use onnxruntime::{environment::Environment, tensor::DynOrtTensor, LoggingLevel}; + +#[test] +fn run_model_with_string_1d_input_output() -> Result<(), Box<dyn Error>> { + let environment = Environment::builder() + .with_name("test") + .with_log_level(LoggingLevel::Verbose) + .build()?; + + let mut session = environment + .new_session_builder()? + .with_model_from_file("../test-models/tensorflow/unique_model.onnx")?; + + // Inputs: + // 0: + // name = input_1:0 + // type = String + // dimensions = [None] + // Outputs: + // 0: + // name = Identity:0 + // type = Int32 + // dimensions = [None] + // 1: + // name = Identity_1:0 + // type = String + // dimensions = [None] + + let array = ndarray::Array::from(vec!["foo", "bar", "foo", "foo", "baz"]); + let input_tensor_values = vec![array]; + + let outputs: Vec<DynOrtTensor<_>> = session.run(input_tensor_values)?; + + assert_eq!(TensorElementDataType::Int32, outputs[0].data_type()); + assert_eq!(TensorElementDataType::String, outputs[1].data_type()); + + let int_output: OrtOwnedTensor<i32, _> = outputs[0].try_extract()?; + let string_output: OrtOwnedTensor<String, _> = outputs[1].try_extract()?; + + assert_eq!(&[5], int_output.view().shape()); + assert_eq!(&[3], string_output.view().shape()); + + assert_eq!(&[0, 1, 0, 0, 2], int_output.view().as_slice().unwrap()); + assert_eq!( + vec!["foo", "bar", "baz"] + .into_iter() + .map(|s| s.to_owned()) + .collect::<Vec<_>>(), + string_output.view().as_slice().unwrap() + ); + + Ok(()) +} diff --git a/test-models/tensorflow/.gitignore b/test-models/tensorflow/.gitignore new file mode 100644 index 00000000..aea6a084 --- /dev/null +++ b/test-models/tensorflow/.gitignore @@ -0,0 +1,2 @@ +/Pipfile.lock +/models diff --git a/test-models/tensorflow/Pipfile b/test-models/tensorflow/Pipfile new file mode 100644 index 00000000..a7b370ab --- /dev/null +++ b/test-models/tensorflow/Pipfile @@ -0,0 +1,13 @@ +[[source]] +url = "https://pypi.org/simple" +verify_ssl = true +name = "pypi" + +[packages] +tensorflow = "==2.4.1" +tf2onnx = "==1.8.3" + +[dev-packages] + +[requires] +python_version = "3.8" diff --git a/test-models/tensorflow/README.md b/test-models/tensorflow/README.md new file mode 100644 index 00000000..4f2e68f2 --- /dev/null +++ b/test-models/tensorflow/README.md @@ -0,0 +1,18 @@ +# Setup + +Have Pipenv make the virtualenv for you: + +``` +pipenv install +``` + +# Model: Unique + +A TensorFlow model that removes duplicate tensor elements. + +This supports strings, and doesn't require custom operators. + +``` +pipenv run python src/unique_model.py +pipenv run python -m tf2onnx.convert --saved-model models/unique_model --output unique_model.onnx --opset 11 +``` diff --git a/test-models/tensorflow/src/unique_model.py b/test-models/tensorflow/src/unique_model.py new file mode 100644 index 00000000..fb79dc8b --- /dev/null +++ b/test-models/tensorflow/src/unique_model.py @@ -0,0 +1,19 @@ +import tensorflow as tf +import numpy as np +import tf2onnx + + +class UniqueModel(tf.keras.Model): + + def __init__(self, name='model1', **kwargs): + super(UniqueModel, self).__init__(name=name, **kwargs) + + def call(self, inputs): + return tf.unique(inputs) + + +model1 = UniqueModel() + +print(model1(tf.constant(["foo", "bar", "foo", "baz"]))) + +model1.save("models/unique_model") diff --git a/test-models/tensorflow/unique_model.onnx b/test-models/tensorflow/unique_model.onnx new file mode 100644 index 00000000..320f6200 Binary files /dev/null and b/test-models/tensorflow/unique_model.onnx differ