Skip to content

Commit

Permalink
Use macros to simplify invoking operators with different input types
Browse files Browse the repository at this point in the history
Add macros to simplify instantiating and invoking a generic operator function
according to the type of an input. This initial version only supports the
case where an operator supports all available tensor types.
  • Loading branch information
robertknight committed Feb 5, 2025
1 parent 1dcc924 commit 6ffb5a5
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 252 deletions.
93 changes: 22 additions & 71 deletions src/ops/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use smallvec::SmallVec;
use crate::number::IsNaN;
use crate::ops::reduce::{cmp_nan_greater, cmp_nan_less};
use crate::ops::{
resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator, OutputList,
map_input, resolve_axis, resolve_index, Input, InputList, IntoOpResult, OpError, Operator,
OutputList,
};
use crate::tensor_pool::{AutoReturn, TensorPool};

Expand Down Expand Up @@ -136,12 +137,10 @@ impl Operator for Gather {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let indices = inputs.require_as::<i32>(1)?;
match input {
Input::Int32Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::FloatTensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::UInt8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
Input::Int8Tensor(input) => gather(pool, input, self.axis, indices).into_op_result(),
}

map_input!(input, x, {
gather(pool, x, self.axis, indices).into_op_result()
})
}
}

Expand Down Expand Up @@ -281,20 +280,10 @@ impl Operator for GatherElements {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let indices = inputs.require_as::<i32>(1)?;
match input {
Input::Int32Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
Input::FloatTensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
Input::Int8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_elements(pool, input, indices, self.axis).into_op_result()
}
}

map_input!(input, x, {
gather_elements(pool, x, indices, self.axis).into_op_result()
})
}
}

Expand Down Expand Up @@ -408,20 +397,10 @@ impl Operator for GatherND {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let indices = inputs.require_as::<i32>(1)?;
match input {
Input::Int32Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
Input::FloatTensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
Input::Int8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
Input::UInt8Tensor(input) => {
gather_nd(pool, input, indices, self.batch_dims).into_op_result()
}
}

map_input!(input, x, {
gather_nd(pool, x, indices, self.batch_dims).into_op_result()
})
}
}

Expand Down Expand Up @@ -526,27 +505,11 @@ impl Operator for ScatterElements {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let data = inputs.require(0)?;
let indices = inputs.require_as::<i32>(1)?;
let updates = inputs.require(2)?;

match (data, updates) {
(Input::Int32Tensor(data), Input::Int32Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::FloatTensor(data), Input::FloatTensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_elements(pool, data, indices, updates, self.axis, self.reduction)
.into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
map_input!(data, x, {
let updates = inputs.require_as(2)?;
scatter_elements(pool, x, indices, updates, self.axis, self.reduction).into_op_result()
})
}
}

Expand Down Expand Up @@ -632,23 +595,11 @@ impl Operator for ScatterND {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let data = inputs.require(0)?;
let indices = inputs.require_as::<i32>(1)?;
let updates = inputs.require(2)?;

match (data, updates) {
(Input::Int32Tensor(data), Input::Int32Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::FloatTensor(data), Input::FloatTensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::Int8Tensor(data), Input::Int8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
(Input::UInt8Tensor(data), Input::UInt8Tensor(updates)) => {
scatter_nd(pool, data, indices, updates, self.reduction).into_op_result()
}
_ => Err(OpError::UnsupportedType),
}
map_input!(data, x, {
let updates = inputs.require_as(2)?;
scatter_nd(pool, x, indices, updates, self.reduction).into_op_result()
})
}
}

Expand Down
11 changes: 4 additions & 7 deletions src/ops/identity.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView};

use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
use crate::ops::{
map_input, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList,
};
use crate::tensor_pool::TensorPool;

fn identity<T: Copy>(pool: &TensorPool, src: TensorView<T>) -> Tensor<T> {
Expand All @@ -18,12 +20,7 @@ impl Operator for Identity {

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let result: Output = match input {
Input::Int32Tensor(t) => identity(pool, t).into(),
Input::FloatTensor(t) => identity(pool, t).into(),
_ => return Err(OpError::UnsupportedType),
};
result.into_op_result()
map_input!(input, x, { identity(pool, x).into_op_result() })
}

fn can_run_in_place(&self) -> bool {
Expand Down
140 changes: 33 additions & 107 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use smallvec::SmallVec;

use crate::ops::binary_elementwise::{broadcast_shapes, fast_broadcast_cycles_repeats};
use crate::ops::{
resolve_axes, resolve_axis, static_dims, Input, InputList, IntoOpResult, OpError, Operator,
Output, OutputList,
map_input, map_output, resolve_axes, resolve_axis, static_dims, Input, InputList, IntoOpResult,
OpError, Operator, Output, OutputList,
};
use crate::tensor_pool::TensorPool;

Expand Down Expand Up @@ -155,12 +155,7 @@ impl Operator for Expand {
let shape = inputs.require_as(1)?;
let shape = static_dims!(shape, 1)?;

match input {
Input::FloatTensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int32Tensor(input) => expand(pool, input, &shape).into_op_result(),
Input::UInt8Tensor(input) => expand(pool, input, &shape).into_op_result(),
Input::Int8Tensor(input) => expand(pool, input, &shape).into_op_result(),
}
map_input!(input, x, { expand(pool, x, &shape).into_op_result() })
}

fn can_run_in_place(&self) -> bool {
Expand All @@ -183,13 +178,10 @@ impl Operator for Expand {
return Ok(input);
}

let output: Output = match input {
Output::FloatTensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::Int32Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::Int8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
Output::UInt8Tensor(input) => expand_to(pool, input.view(), &out_shape).into(),
};
Ok(output)
map_output!(input, x, {
let output = expand_to(pool, x.view(), &out_shape);
Ok(output.into())
})
}
}

Expand Down Expand Up @@ -237,13 +229,7 @@ impl Operator for Flatten {

fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;

match input {
Input::FloatTensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::Int32Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::Int8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
Input::UInt8Tensor(input) => flatten(pool, input, self.axis).into_op_result(),
}
map_input!(input, x, { flatten(pool, x, self.axis).into_op_result() })
}

fn can_run_in_place(&self) -> bool {
Expand All @@ -256,24 +242,10 @@ impl Operator for Flatten {
input: Output,
_: InputList,
) -> Result<Output, OpError> {
match input {
Output::Int32Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
Output::FloatTensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
Output::Int8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
flatten_in_place(pool, &mut output, self.axis)?;
Ok(output.into())
}
}
map_output!(input, x, {
flatten_in_place(pool, &mut x, self.axis)?;
Ok(x.into())
})
}
}

Expand Down Expand Up @@ -388,12 +360,9 @@ impl Operator for Reshape {
let shape = inputs.require_as(1)?;
let shape = static_dims!(shape, 1)?;

match input {
Input::Int32Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::FloatTensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::Int8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
Input::UInt8Tensor(t) => reshape(pool, t, &shape, self.allow_zero).into_op_result(),
}
map_input!(input, x, {
reshape(pool, x, &shape, self.allow_zero).into_op_result()
})
}

fn can_run_in_place(&self) -> bool {
Expand All @@ -409,24 +378,10 @@ impl Operator for Reshape {
let shape = other.require_as(0)?;
let shape = static_dims!(shape, 1)?;

match input {
Output::Int32Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
Output::FloatTensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
Output::Int8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
Output::UInt8Tensor(mut output) => {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
}
}
map_output!(input, output, {
reshape_in_place(pool, &mut output, &shape, self.allow_zero)?;
Ok(output.into())
})
}
}

Expand Down Expand Up @@ -531,12 +486,7 @@ impl Operator for Squeeze {
let axes = inputs.get_as(1)?;
let axes = axes.map(|axes| static_dims!(axes, 1)).transpose()?;

match input {
Input::FloatTensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::Int32Tensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::Int8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
Input::UInt8Tensor(t) => squeeze(pool, t, axes).into_op_result(),
}
map_input!(input, x, { squeeze(pool, x, axes).into_op_result() })
}

fn can_run_in_place(&self) -> bool {
Expand All @@ -552,24 +502,10 @@ impl Operator for Squeeze {
let axes = other.get_as(0)?;
let axes = axes.map(|axes| static_dims!(axes, 1)).transpose()?;

match input {
Output::FloatTensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
Output::Int32Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
Output::UInt8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
Output::Int8Tensor(mut t) => {
squeeze_in_place(&mut t, axes)?;
Ok(t.into())
}
}
map_output!(input, output, {
squeeze_in_place(&mut output, axes)?;
Ok(output.into())
})
}
}

Expand Down Expand Up @@ -609,12 +545,10 @@ impl Operator for Transpose {
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let input = inputs.require(0)?;
let perm_slice = self.perm.as_deref();
match input {
Input::FloatTensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::Int32Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::Int8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
Input::UInt8Tensor(input) => transpose(pool, input, perm_slice).into_op_result(),
}

map_input!(input, x, {
transpose(pool, x, perm_slice).into_op_result()
})
}
}

Expand Down Expand Up @@ -668,12 +602,7 @@ impl Operator for Unsqueeze {
let axes = inputs.require_as(1)?;
let axes = static_dims!(axes, 1)?;

match input {
Input::FloatTensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::Int32Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::Int8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
Input::UInt8Tensor(input) => unsqueeze(pool, input, &axes).into_op_result(),
}
map_input!(input, x, { unsqueeze(pool, x, &axes).into_op_result() })
}

fn can_run_in_place(&self) -> bool {
Expand All @@ -683,18 +612,15 @@ impl Operator for Unsqueeze {
fn run_in_place(
&self,
_pool: &TensorPool,
output: Output,
input: Output,
inputs: InputList,
) -> Result<Output, OpError> {
let axes = inputs.require_as(0)?;
let axes = static_dims!(axes, 1)?;

match output {
Output::FloatTensor(t) => unsqueeze_in_place(t, &axes).map(Output::FloatTensor),
Output::Int32Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int32Tensor),
Output::Int8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::Int8Tensor),
Output::UInt8Tensor(t) => unsqueeze_in_place(t, &axes).map(Output::UInt8Tensor),
}
map_output!(input, output, {
Ok(unsqueeze_in_place(output, &axes)?.into())
})
}
}

Expand Down
Loading

0 comments on commit 6ffb5a5

Please sign in to comment.