Skip to content

Commit

Permalink
feat: Add Mixed Array Support
Browse files Browse the repository at this point in the history
  • Loading branch information
28Smiles committed Apr 4, 2022
1 parent cd5a6eb commit e1be572
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 41 deletions.
30 changes: 30 additions & 0 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ use sys::OnnxEnumInt;

// Re-export ndarray as it's part of the public API anyway
pub use ndarray;
use ndarray::Array;
use crate::tensor::OrtTensor;

lazy_static! {
// static ref G_ORT: Arc<Mutex<AtomicPtr<sys::OrtApi>>> =
Expand Down Expand Up @@ -459,6 +461,34 @@ impl_type_trait!(u64, Uint64);
// impl_type_trait!(, Complex128);
// impl_type_trait!(, Bfloat16);

#[derive(Debug)]
pub enum TypedArray<D: ndarray::Dimension> {
F32(Array<f32, D>),
U8(Array<u8, D>),
I8(Array<i8, D>),
U16(Array<u16, D>),
I16(Array<i16, D>),
I32(Array<i32, D>),
I64(Array<i64, D>),
F64(Array<f64, D>),
U32(Array<u32, D>),
U64(Array<u64, D>),
}

#[derive(Debug)]
pub enum TypedOrtTensor<'t, D: ndarray::Dimension> {
F32(OrtTensor<'t, f32, D>),
U8(OrtTensor<'t, u8, D>),
I8(OrtTensor<'t, i8, D>),
U16(OrtTensor<'t, u16, D>),
I16(OrtTensor<'t, i16, D>),
I32(OrtTensor<'t, i32, D>),
I64(OrtTensor<'t, i64, D>),
F64(OrtTensor<'t, f64, D>),
U32(OrtTensor<'t, u32, D>),
U64(OrtTensor<'t, u64, D>),
}

/// Adapter for common Rust string types to Onnx strings.
///
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
Expand Down
Loading

0 comments on commit e1be572

Please sign in to comment.