Skip to content

Commit

Permalink
refactor operator
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 25, 2024
1 parent dc86183 commit e7438fa
Show file tree
Hide file tree
Showing 13 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ trait TensorTrait<T> {
/// >>> [1,1,1,0,0,0,1,1,1]
/// ```
///
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
fn less_equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<i32>;
/// #tensor.abs
///
/// ```rust
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl BoolTensor of TensorTrait<bool> {
panic(array!['not supported!'])
}

fn less_equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<usize> {
fn less_equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<i32> {
panic(array!['not supported!'])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl Complex64Tensor of TensorTrait<complex64> {
panic(array!['not supported!'])
}

fn less_equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<usize> {
fn less_equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<i32> {
panic(array!['not supported!'])
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<usize> {
fn less_equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl I32Tensor of TensorTrait<i32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl I8Tensor of TensorTrait<i8> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<usize> {
fn less_equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl U32Tensor of TensorTrait<u32> {
math::less::less(self, other)
}

fn less_equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<usize> {
fn less_equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<i32> {
math::less_equal::less_equal(self, other)
}

Expand Down
7 changes: 3 additions & 4 deletions src/operators/tensor/math/less_equal.cairo
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
use orion::operators::tensor::core::{Tensor, TensorTrait, unravel_index};
use orion::operators::tensor::{core::{Tensor, TensorTrait, unravel_index}, I32Tensor};
use orion::operators::tensor::helpers::{
broadcast_shape, broadcast_index_mapping, len_from_shape, check_compatibility
};

/// Cf: TensorTrait::less_equal docstring
fn less_equal<
T,
impl UsizeFTensor: TensorTrait<usize>,
impl TPartialOrd: PartialOrd<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>
>(
y: @Tensor<T>, z: @Tensor<T>
) -> Tensor<usize> {
) -> Tensor<i32> {
let broadcasted_shape = broadcast_shape(*y.shape, *z.shape);
let mut result: Array<usize> = array![];
let mut result: Array<i32> = array![];

let num_elements = len_from_shape(broadcasted_shape);

Expand Down

0 comments on commit e7438fa

Please sign in to comment.