Skip to content

Commit

Permalink
fix: to match reshape new signature
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Apr 22, 2024
1 parent 5d3802a commit a0ce46e
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 13 deletions.
36 changes: 27 additions & 9 deletions src/operators/nn/functional/deform_conv.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ fn deform_conv<
'offset_group inconsistencies'
);

let mut offset_shape = array![n, offset_group];
offset_shape.append_span(kernel_shape);
offset_shape.append(kernel_shape.len());
offset_shape.append_span(output_shape);
let mut offset_shape = array![n.into(), offset_group.into()];
offset_shape.append_span(span_U32_to_span_I32(kernel_shape.clone()));
offset_shape.append(kernel_shape.len().into());
offset_shape.append_span(span_U32_to_span_I32(output_shape.clone()));

let offset = offset.reshape(offset_shape.span());
let offset = offset.reshape(offset_shape.span(), false);

let mask = match mask {
Option::Some(mask) => mask,
Expand All @@ -151,10 +151,10 @@ fn deform_conv<
},
};

let mut mask_shape = array![n, offset_group];
mask_shape.append_span(kernel_shape);
mask_shape.append_span(output_shape);
let mask = mask.reshape(mask_shape.span());
let mut mask_shape = array![n.into(), offset_group.into()];
mask_shape.append_span(span_U32_to_span_I32(kernel_shape.clone()));
mask_shape.append_span(span_U32_to_span_I32(output_shape.clone()));
let mask = mask.reshape(mask_shape.span(), false);

if (*X).shape.len() == 4 {
let ih: T = NumberTrait::new_unscaled((*(*X).shape.at(2)).into(), false);
Expand Down Expand Up @@ -528,3 +528,21 @@ fn sum<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +AddEq
};
return sum;
}


fn span_U32_to_span_I32(
mut x: Span<usize>
) -> Span<i32> {
let mut res = ArrayTrait::new();

loop {
match x.pop_front() {
Option::Some(v) => {
res.append((*v).into());
},
Option::None => { break; }
};
};

return res.span();
}
3 changes: 2 additions & 1 deletion src/operators/nn/implementations/nn_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ impl FP16x16NN of NNTrait<FP16x16> {
pads,
storage_order,
strides,
output_len
output_len)
}
fn deform_conv(
X: @Tensor<FP16x16>,
W: @Tensor<FP16x16>,
Expand Down
3 changes: 2 additions & 1 deletion src/operators/nn/implementations/nn_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ impl FP32x32NN of NNTrait<FP32x32> {
pads,
storage_order,
strides,
output_len
output_len)
}
fn deform_conv(
X: @Tensor<FP32x32>,
W: @Tensor<FP32x32>,
Expand Down
3 changes: 2 additions & 1 deletion src/operators/nn/implementations/nn_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ impl FP64x64NN of NNTrait<FP64x64> {
pads,
storage_order,
strides,
output_len
output_len)
}
fn deform_conv(
X: @Tensor<FP64x64>,
W: @Tensor<FP64x64>,
Expand Down
3 changes: 2 additions & 1 deletion src/operators/nn/implementations/nn_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ impl FP8x23NN of NNTrait<FP8x23> {
pads,
storage_order,
strides,
output_len
output_len)
}
fn deform_conv(
X: @Tensor<FP8x23>,
W: @Tensor<FP8x23>,
Expand Down
31 changes: 31 additions & 0 deletions src/operators/nn/implementations/nn_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor;
use orion::operators::nn::core::NNTrait;
use orion::operators::nn::functional;
use orion::operators::tensor::implementations::tensor_i32::{I32Tensor, I32TensorAdd};
use orion::operators::nn::AUTO_PAD;


impl I32NN of NNTrait<i32> {
fn relu(tensor: @Tensor<i32>) -> Tensor<i32> {
Expand Down Expand Up @@ -130,4 +132,33 @@ impl I32NN of NNTrait<i32> {
) -> Tensor<i32> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn max_pool(
X: @Tensor<i32>,
auto_pad: Option<AUTO_PAD>,
ceil_mode: Option<usize>,
dilations: Option<Span<usize>>,
kernel_shape: Span<usize>,
pads: Option<Span<usize>>,
storage_order: Option<usize>,
strides: Option<Span<usize>>,
output_len: usize,
) -> (Tensor<i32>, Option<Tensor<usize>>) {
panic(array!['not supported!'])
}
fn deform_conv(
X: @Tensor<i32>,
W: @Tensor<i32>,
offset: @Tensor<i32>,
B: Option<Span<i32>>,
mask: Option<Tensor<i32>>,
dilations: Option<Span<usize>>,
group: Option<usize>,
kernel_shape: Option<Span<usize>>,
offset_group: Option<usize>,
pads: Option<Span<usize>>,
strides: Option<Span<usize>>,
) -> Tensor<i32> {
panic(array!['not supported!'])
}
}
31 changes: 31 additions & 0 deletions src/operators/nn/implementations/nn_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor;
use orion::operators::nn::core::NNTrait;
use orion::operators::nn::functional;
use orion::operators::tensor::implementations::tensor_i8::{I8Tensor, I8TensorAdd};
use orion::operators::nn::AUTO_PAD;


impl I8NN of NNTrait<i8> {
fn relu(tensor: @Tensor<i8>) -> Tensor<i8> {
Expand Down Expand Up @@ -130,4 +132,33 @@ impl I8NN of NNTrait<i8> {
) -> Tensor<i8> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn max_pool(
X: @Tensor<i8>,
auto_pad: Option<AUTO_PAD>,
ceil_mode: Option<usize>,
dilations: Option<Span<usize>>,
kernel_shape: Span<usize>,
pads: Option<Span<usize>>,
storage_order: Option<usize>,
strides: Option<Span<usize>>,
output_len: usize,
) -> (Tensor<i8>, Option<Tensor<usize>>) {
panic(array!['not supported!'])
}
fn deform_conv(
X: @Tensor<i8>,
W: @Tensor<i8>,
offset: @Tensor<i8>,
B: Option<Span<i8>>,
mask: Option<Tensor<i8>>,
dilations: Option<Span<usize>>,
group: Option<usize>,
kernel_shape: Option<Span<usize>>,
offset_group: Option<usize>,
pads: Option<Span<usize>>,
strides: Option<Span<usize>>,
) -> Tensor<i8> {
panic(array!['not supported!'])
}
}
31 changes: 31 additions & 0 deletions src/operators/nn/implementations/nn_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use orion::operators::tensor::core::Tensor;
use orion::operators::nn::core::NNTrait;
use orion::operators::nn::functional;
use orion::operators::tensor::implementations::tensor_u32::{U32Tensor, U32TensorAdd};
use orion::operators::nn::AUTO_PAD;


impl U32NN of NNTrait<u32> {
fn relu(tensor: @Tensor<u32>) -> Tensor<u32> {
Expand Down Expand Up @@ -130,4 +132,33 @@ impl U32NN of NNTrait<u32> {
) -> Tensor<u32> {
functional::conv::conv(X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides)
}

fn max_pool(
X: @Tensor<u32>,
auto_pad: Option<AUTO_PAD>,
ceil_mode: Option<usize>,
dilations: Option<Span<usize>>,
kernel_shape: Span<usize>,
pads: Option<Span<usize>>,
storage_order: Option<usize>,
strides: Option<Span<usize>>,
output_len: usize,
) -> (Tensor<u32>, Option<Tensor<usize>>) {
panic(array!['not supported!'])
}
fn deform_conv(
X: @Tensor<u32>,
W: @Tensor<u32>,
offset: @Tensor<u32>,
B: Option<Span<u32>>,
mask: Option<Tensor<u32>>,
dilations: Option<Span<usize>>,
group: Option<usize>,
kernel_shape: Option<Span<usize>>,
offset_group: Option<usize>,
pads: Option<Span<usize>>,
strides: Option<Span<usize>>,
) -> Tensor<u32> {
panic(array!['not supported!'])
}
}

0 comments on commit a0ce46e

Please sign in to comment.