Skip to content

Commit

Permalink
fix: while loops
Browse files Browse the repository at this point in the history
  • Loading branch information
chachaleo committed Apr 21, 2024
1 parent 6494920 commit 5d3802a
Showing 1 changed file with 36 additions and 104 deletions.
140 changes: 36 additions & 104 deletions src/operators/nn/functional/deform_conv.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ fn deform_conv<
strides: Option<Span<usize>>,
) -> Tensor<T> {
assert((*X).shape.len() >= 3, 'X must have at least 3 dim');
assert((*W).shape.len() >= 3, 'X must have at least 3 dim');

let dilations = match dilations {
Option::Some(dilations) => dilations,
Option::None => {
let mut dilations = ArrayTrait::new();
let mut i = 2;
loop {
if i >= (*X).shape.len() {
break;
}
while i != (*X).shape.len() {
dilations.append(1);
i += 1;
};
Expand All @@ -67,10 +66,7 @@ fn deform_conv<
Option::None => {
let mut kernel_shape = ArrayTrait::new();
let mut i = 2;
loop {
if i >= (*W).shape.len() {
break;
}
while i != (*W).shape.len() {
kernel_shape.append(*(*W).shape.at(i));
i += 1;
};
Expand All @@ -82,10 +78,7 @@ fn deform_conv<
Option::None => {
let mut pads = ArrayTrait::new();
let mut i = 2;
loop {
if i >= (*X).shape.len() {
break;
}
while i != (*X).shape.len() {
pads.append(0);
pads.append(0);
i += 1;
Expand All @@ -98,10 +91,7 @@ fn deform_conv<
Option::None => {
let mut strides = ArrayTrait::new();
let mut i = 2;
loop {
if i >= (*X).shape.len() {
break;
}
while i != (*X).shape.len() {
strides.append(1);
i += 1;
};
Expand Down Expand Up @@ -149,11 +139,9 @@ fn deform_conv<
Option::Some(mask) => mask,
Option::None => {
let mut mask = ArrayTrait::<T>::new();
let mask_end = n * offset_group * prod(kernel_shape, 0) * prod(output_shape, 0);
let mut i = 0;
loop {
if i == n * offset_group * prod(kernel_shape, 0) * prod(output_shape, 0) {
break;
}
while i != mask_end {
mask.append(NumberTrait::<T>::one());
i += 1;
};
Expand Down Expand Up @@ -220,26 +208,14 @@ fn deform_conv<
match B {
Option::Some(B) => {
let mut i = 0;
loop {
if i == n {
break;
}
while i != n {
let mut j = 0;
loop {
if j == oc {
break;
}
while j != oc {
let b_j = *B.at(j);
let mut k = 0;
loop {
if k == oh {
break;
}
while k != oh {
let mut l = 0;
loop {
if l == ow {
break;
}
while l != ow {
res
.set(
i * *res_stride.at(0)
Expand Down Expand Up @@ -277,35 +253,20 @@ fn deform_conv<
let two: T = NumberTrait::one() + NumberTrait::one();

let mut batch_idx = 0;
loop {
if batch_idx == n {
break;
}
while batch_idx != n {
let mut oc_idx = 0;
loop {
if oc_idx == oc {
break;
}
while oc_idx != oc {
let mut ic_idx = 0;
loop {
if ic_idx == ic {
break;
}
while ic_idx != ic {
if (ic_idx / ics_per_group) == (oc_idx / ocs_per_group) {
let offset_group_idx = ic_idx / ics_per_offset_group;

let mut i = 0;
loop {
if i == oh {
break;
}
while i != oh {
let index = NumberTrait::new_unscaled(i.into(), false);
let h_coord = bh + sth * index;
let mut j = 0;
loop {
if j == ow {
break;
}
while j != ow {
let jndex = NumberTrait::new_unscaled(j.into(), false);
let w_coord = bw + stw * jndex;

Expand All @@ -315,15 +276,9 @@ fn deform_conv<
let mut offset_TEST = ArrayTrait::new();

let mut hi = 0;
loop {
if hi == ks0 {
break;
}
while hi != ks0 {
let mut wi = 0;
loop {
if wi == ks1 {
break;
}
while wi != ks1 {
let elem1 = h_coord
+ *offset
.data
Expand Down Expand Up @@ -461,10 +416,7 @@ fn deform_conv<

let mut res_data = ArrayTrait::new();
let mut i = 0;
loop {
if i == res.len() {
break;
}
while i != res.len() {
res_data.append(res.at(i));
i += 1;
};
Expand All @@ -480,17 +432,11 @@ fn meshgrid(x: Span<usize>, y: Span<usize>) -> (Span<usize>, Span<usize>) {
let mut yv = ArrayTrait::new();

let mut i = 0;
loop {
if i == y.len() {
break;
}
while i != y.len() {

xv.append_span(x);
let mut j = 0;
loop {
if j == x.len() {
break;
}
while j != x.len() {
yv.append(*y.at(i));
j += 1;
};
Expand All @@ -503,10 +449,7 @@ fn stack(x: Span<usize>, y: Span<usize>) -> Span<usize> {
let mut stack = ArrayTrait::new();

let mut i = 0;
loop {
if i == x.len() {
break;
}
while i != x.len() {
stack.append(*x.at(i));
stack.append(*y.at(i));
i += 1;
Expand All @@ -521,10 +464,7 @@ fn flip_mod_2<T, MAG, impl TDrop: Drop<T>, impl TCopy: Copy<T>, +NumberTrait<T,
) -> Span<T> {
let mut i = 0;
let mut res = ArrayTrait::new();
loop {
if i == x.len / 2 {
break;
}
while i != x.len / 2 {
res.append(x.at(i * 2 + 1));
res.append(x.at(i * 2));
i += 1;
Expand All @@ -541,10 +481,7 @@ fn copy_to_vec<
let mut res = NullableVecImpl::new();

let mut i = 0;
loop {
if i == x.len() {
break;
}
while i != x.len() {
res.set(i, NumberTrait::new_unscaled((*x.at(i)).into(), false));
i += 1;
};
Expand All @@ -556,10 +493,7 @@ fn copy_to_vec<
fn arange(start: usize, end: usize, step: usize) -> Span<usize> {
let mut arr = ArrayTrait::new();
let mut i = start;
loop {
if i >= end {
break;
}
while i != end {
arr.append(i);
i += step;
};
Expand All @@ -568,29 +502,27 @@ fn arange(start: usize, end: usize, step: usize) -> Span<usize> {


fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
a: Span<T>, start: usize
) -> T {
let mut i = start;
assert(a.len() > start, 'wrong input dim');
let mut prod = NumberTrait::one();
loop {
if i == pA.len() {
break;
}
prod = prod * (*pA.at(i));
let mut i = start;
while i != a.len() {
prod = prod * (*a.at(i));
i += 1;
};
return prod;
}



fn sum<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +AddEq<T>,>(
a: Span<T>, start: usize
) -> T {
let mut i = start;
assert(a.len() > start, 'wrong input dim');
let mut sum = NumberTrait::zero();
loop {
if i == a.len() {
break;
}
let mut i = start;
while i != a.len() {
sum += (*a.at(i));
i += 1;
};
Expand Down

0 comments on commit 5d3802a

Please sign in to comment.