diff --git a/src/operators/nn/functional/deform_conv.cairo b/src/operators/nn/functional/deform_conv.cairo index 1e43e36b7..c8ffc7c3a 100644 --- a/src/operators/nn/functional/deform_conv.cairo +++ b/src/operators/nn/functional/deform_conv.cairo @@ -47,15 +47,14 @@ fn deform_conv< strides: Option>, ) -> Tensor { assert((*X).shape.len() >= 3, 'X must have at least 3 dim'); + assert((*W).shape.len() >= 3, 'X must have at least 3 dim'); + let dilations = match dilations { Option::Some(dilations) => dilations, Option::None => { let mut dilations = ArrayTrait::new(); let mut i = 2; - loop { - if i >= (*X).shape.len() { - break; - } + while i != (*X).shape.len() { dilations.append(1); i += 1; }; @@ -67,10 +66,7 @@ fn deform_conv< Option::None => { let mut kernel_shape = ArrayTrait::new(); let mut i = 2; - loop { - if i >= (*W).shape.len() { - break; - } + while i != (*W).shape.len() { kernel_shape.append(*(*W).shape.at(i)); i += 1; }; @@ -82,10 +78,7 @@ fn deform_conv< Option::None => { let mut pads = ArrayTrait::new(); let mut i = 2; - loop { - if i >= (*X).shape.len() { - break; - } + while i != (*X).shape.len() { pads.append(0); pads.append(0); i += 1; @@ -98,10 +91,7 @@ fn deform_conv< Option::None => { let mut strides = ArrayTrait::new(); let mut i = 2; - loop { - if i >= (*X).shape.len() { - break; - } + while i != (*X).shape.len() { strides.append(1); i += 1; }; @@ -149,11 +139,9 @@ fn deform_conv< Option::Some(mask) => mask, Option::None => { let mut mask = ArrayTrait::::new(); + let mask_end = n * offset_group * prod(kernel_shape, 0) * prod(output_shape, 0); let mut i = 0; - loop { - if i == n * offset_group * prod(kernel_shape, 0) * prod(output_shape, 0) { - break; - } + while i != mask_end { mask.append(NumberTrait::::one()); i += 1; }; @@ -220,26 +208,14 @@ fn deform_conv< match B { Option::Some(B) => { let mut i = 0; - loop { - if i == n { - break; - } + while i != n { let mut j = 0; - loop { - if j == oc { - break; - } + while j != oc { let b_j = *B.at(j); let mut k = 0; - loop { - if k == oh { - break; - } + while k != oh { let mut l = 0; - loop { - if l == ow { - break; - } + while l != ow { res .set( i * *res_stride.at(0) @@ -277,35 +253,20 @@ fn deform_conv< let two: T = NumberTrait::one() + NumberTrait::one(); let mut batch_idx = 0; - loop { - if batch_idx == n { - break; - } + while batch_idx != n { let mut oc_idx = 0; - loop { - if oc_idx == oc { - break; - } + while oc_idx != oc { let mut ic_idx = 0; - loop { - if ic_idx == ic { - break; - } + while ic_idx != ic { if (ic_idx / ics_per_group) == (oc_idx / ocs_per_group) { let offset_group_idx = ic_idx / ics_per_offset_group; let mut i = 0; - loop { - if i == oh { - break; - } + while i != oh { let index = NumberTrait::new_unscaled(i.into(), false); let h_coord = bh + sth * index; let mut j = 0; - loop { - if j == ow { - break; - } + while j != ow { let jndex = NumberTrait::new_unscaled(j.into(), false); let w_coord = bw + stw * jndex; @@ -315,15 +276,9 @@ fn deform_conv< let mut offset_TEST = ArrayTrait::new(); let mut hi = 0; - loop { - if hi == ks0 { - break; - } + while hi != ks0 { let mut wi = 0; - loop { - if wi == ks1 { - break; - } + while wi != ks1 { let elem1 = h_coord + *offset .data @@ -461,10 +416,7 @@ fn deform_conv< let mut res_data = ArrayTrait::new(); let mut i = 0; - loop { - if i == res.len() { - break; - } + while i != res.len() { res_data.append(res.at(i)); i += 1; }; @@ -480,17 +432,11 @@ fn meshgrid(x: Span, y: Span) -> (Span, Span) { let mut yv = ArrayTrait::new(); let mut i = 0; - loop { - if i == y.len() { - break; - } + while i != y.len() { xv.append_span(x); let mut j = 0; - loop { - if j == x.len() { - break; - } + while j != x.len() { yv.append(*y.at(i)); j += 1; }; @@ -503,10 +449,7 @@ fn stack(x: Span, y: Span) -> Span { let mut stack = ArrayTrait::new(); let mut i = 0; - loop { - if i == x.len() { - break; - } + while i != x.len() { stack.append(*x.at(i)); stack.append(*y.at(i)); i += 1; @@ -521,10 +464,7 @@ fn flip_mod_2, impl TCopy: Copy, +NumberTrait Span { let mut i = 0; let mut res = ArrayTrait::new(); - loop { - if i == x.len / 2 { - break; - } + while i != x.len / 2 { res.append(x.at(i * 2 + 1)); res.append(x.at(i * 2)); i += 1; @@ -541,10 +481,7 @@ fn copy_to_vec< let mut res = NullableVecImpl::new(); let mut i = 0; - loop { - if i == x.len() { - break; - } + while i != x.len() { res.set(i, NumberTrait::new_unscaled((*x.at(i)).into(), false)); i += 1; }; @@ -556,10 +493,7 @@ fn copy_to_vec< fn arange(start: usize, end: usize, step: usize) -> Span { let mut arr = ArrayTrait::new(); let mut i = start; - loop { - if i >= end { - break; - } + while i != end { arr.append(i); i += step; }; @@ -568,29 +502,27 @@ fn arange(start: usize, end: usize, step: usize) -> Span { fn prod, +Copy, +NumberTrait, +TensorTrait, +Mul,>( - pA: Span, start: usize + a: Span, start: usize ) -> T { - let mut i = start; + assert(a.len() > start, 'wrong input dim'); let mut prod = NumberTrait::one(); - loop { - if i == pA.len() { - break; - } - prod = prod * (*pA.at(i)); + let mut i = start; + while i != a.len() { + prod = prod * (*a.at(i)); i += 1; }; return prod; } + + fn sum, +Copy, +NumberTrait, +TensorTrait, +AddEq,>( a: Span, start: usize ) -> T { - let mut i = start; + assert(a.len() > start, 'wrong input dim'); let mut sum = NumberTrait::zero(); - loop { - if i == a.len() { - break; - } + let mut i = start; + while i != a.len() { sum += (*a.at(i)); i += 1; };