diff --git a/src/operators/nn/functional/deform_conv.cairo b/src/operators/nn/functional/deform_conv.cairo index c8ffc7c3a..bb04b11c8 100644 --- a/src/operators/nn/functional/deform_conv.cairo +++ b/src/operators/nn/functional/deform_conv.cairo @@ -128,12 +128,12 @@ fn deform_conv< 'offset_group inconsistencies' ); - let mut offset_shape = array![n, offset_group]; - offset_shape.append_span(kernel_shape); - offset_shape.append(kernel_shape.len()); - offset_shape.append_span(output_shape); + let mut offset_shape = array![n.into(), offset_group.into()]; + offset_shape.append_span(span_U32_to_span_I32(kernel_shape.clone())); + offset_shape.append(kernel_shape.len().into()); + offset_shape.append_span(span_U32_to_span_I32(output_shape.clone())); - let offset = offset.reshape(offset_shape.span()); + let offset = offset.reshape(offset_shape.span(), false); let mask = match mask { Option::Some(mask) => mask, @@ -151,10 +151,10 @@ fn deform_conv< }, }; - let mut mask_shape = array![n, offset_group]; - mask_shape.append_span(kernel_shape); - mask_shape.append_span(output_shape); - let mask = mask.reshape(mask_shape.span()); + let mut mask_shape = array![n.into(), offset_group.into()]; + mask_shape.append_span(span_U32_to_span_I32(kernel_shape.clone())); + mask_shape.append_span(span_U32_to_span_I32(output_shape.clone())); + let mask = mask.reshape(mask_shape.span(), false); if (*X).shape.len() == 4 { let ih: T = NumberTrait::new_unscaled((*(*X).shape.at(2)).into(), false); @@ -528,3 +528,21 @@ fn sum, +Copy, +NumberTrait, +TensorTrait, +AddEq }; return sum; } + + +fn span_U32_to_span_I32( + mut x: Span +) -> Span { + let mut res = ArrayTrait::new(); + + loop { + match x.pop_front() { + Option::Some(v) => { + res.append((*v).into()); + }, + Option::None => { break; } + }; + }; + + return res.span(); +} \ No newline at end of file diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index a882dedd7..05faad2ba 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -165,7 +165,8 @@ impl FP16x16NN of NNTrait { pads, storage_order, strides, - output_len + output_len) + } fn deform_conv( X: @Tensor, W: @Tensor, diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index f12276a0c..a1ca177dd 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -159,7 +159,8 @@ impl FP32x32NN of NNTrait { pads, storage_order, strides, - output_len + output_len) + } fn deform_conv( X: @Tensor, W: @Tensor, diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index 84fb5d604..6d6770551 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -159,7 +159,8 @@ impl FP64x64NN of NNTrait { pads, storage_order, strides, - output_len + output_len) + } fn deform_conv( X: @Tensor, W: @Tensor, diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index f25bfa86e..924b16d34 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -161,7 +161,8 @@ impl FP8x23NN of NNTrait { pads, storage_order, strides, - output_len + output_len) + } fn deform_conv( X: @Tensor, W: @Tensor, diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index f09e34ef7..973dfb552 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor; use orion::operators::nn::core::NNTrait; use orion::operators::nn::functional; use orion::operators::tensor::implementations::tensor_i32::{I32Tensor, I32TensorAdd}; +use orion::operators::nn::AUTO_PAD; + impl I32NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -130,4 +132,33 @@ impl I32NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn max_pool( + X: @Tensor, + auto_pad: Option, + ceil_mode: Option, + dilations: Option>, + kernel_shape: Span, + pads: Option>, + storage_order: Option, + strides: Option>, + output_len: usize, + ) -> (Tensor, Option>) { + panic(array!['not supported!']) + } + fn deform_conv( + X: @Tensor, + W: @Tensor, + offset: @Tensor, + B: Option>, + mask: Option>, + dilations: Option>, + group: Option, + kernel_shape: Option>, + offset_group: Option, + pads: Option>, + strides: Option>, + ) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index befc8b8b0..d48e398df 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor; use orion::operators::nn::core::NNTrait; use orion::operators::nn::functional; use orion::operators::tensor::implementations::tensor_i8::{I8Tensor, I8TensorAdd}; +use orion::operators::nn::AUTO_PAD; + impl I8NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -130,4 +132,33 @@ impl I8NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn max_pool( + X: @Tensor, + auto_pad: Option, + ceil_mode: Option, + dilations: Option>, + kernel_shape: Span, + pads: Option>, + storage_order: Option, + strides: Option>, + output_len: usize, + ) -> (Tensor, Option>) { + panic(array!['not supported!']) + } + fn deform_conv( + X: @Tensor, + W: @Tensor, + offset: @Tensor, + B: Option>, + mask: Option>, + dilations: Option>, + group: Option, + kernel_shape: Option>, + offset_group: Option, + pads: Option>, + strides: Option>, + ) -> Tensor { + panic(array!['not supported!']) + } } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index 1cd4c926f..504a8199b 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor; use orion::operators::nn::core::NNTrait; use orion::operators::nn::functional; use orion::operators::tensor::implementations::tensor_u32::{U32Tensor, U32TensorAdd}; +use orion::operators::nn::AUTO_PAD; + impl U32NN of NNTrait { fn relu(tensor: @Tensor) -> Tensor { @@ -130,4 +132,33 @@ impl U32NN of NNTrait { ) -> Tensor { functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides) } + + fn max_pool( + X: @Tensor, + auto_pad: Option, + ceil_mode: Option, + dilations: Option>, + kernel_shape: Span, + pads: Option>, + storage_order: Option, + strides: Option>, + output_len: usize, + ) -> (Tensor, Option>) { + panic(array!['not supported!']) + } + fn deform_conv( + X: @Tensor, + W: @Tensor, + offset: @Tensor, + B: Option>, + mask: Option>, + dilations: Option>, + group: Option, + kernel_shape: Option>, + offset_group: Option, + pads: Option>, + strides: Option>, + ) -> Tensor { + panic(array!['not supported!']) + } }