Skip to content

Commit

Permalink
test and refactor equal
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 21, 2024
1 parent 2dcb46c commit 4ae5a38
Show file tree
Hide file tree
Showing 97 changed files with 1,093 additions and 1,061 deletions.
8 changes: 4 additions & 4 deletions docs/framework/operators/tensor/tensor.equal.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#tensor.equal

```rust
fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<bool>;
```

Check if two tensors are equal element-wise.
Expand All @@ -20,7 +20,7 @@ The input tensors must have either:

## Returns

A new `Tensor<usize>` of booleans (1 if equal, 0 otherwise) with the same shape as the broadcasted inputs.
A new `Tensor<bool>` of booleans (1 if equal, 0 otherwise) with the same shape as the broadcasted inputs.

## Examples

Expand All @@ -43,7 +43,7 @@ fn eq_example() -> Tensor<usize> {
// We can call `equal` function as follows.
return tensor_1.equal(@tensor_2);
}
>>> [1,1,1,1,1,0,0,0]
>>> [true,true,true,true,true,false,false,false]
```

Case 2: Compare tensors with different shapes
Expand All @@ -63,5 +63,5 @@ fn eq_example() -> Tensor<usize> {
// We can call `equal` function as follows.
return tensor_1.equal(@tensor_2);
}
>>> [1,1,1,0,0,0,0,0,0]
>>> [true,true,true,false,false,false,false,false,false]
```
4 changes: 2 additions & 2 deletions docs/framework/operators/tensor/tensor.less.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use core::array::{ArrayTrait, SpanTrait};

use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};

fn less_example() -> Tensor<usize> {
fn less_example() -> Tensor<bool> {
let tensor_1 = TensorTrait::<u32>::new(
shape: array![3, 3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
);
Expand All @@ -63,5 +63,5 @@ fn less_example() -> Tensor<usize> {
// We can call `less` function as follows.
return tensor_1.less(@tensor_2);
}
>>> [0,0,0,0,0,0,0,1,1]
>>> [false,false,false,false,false,false,false,true,true]
```
20 changes: 10 additions & 10 deletions nodegen/node/equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def default():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_u32"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -25,7 +25,7 @@ def broadcast():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_u32_broadcast"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -42,7 +42,7 @@ def default():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_i32"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -54,7 +54,7 @@ def broadcast():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_i32_broadcast"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -71,7 +71,7 @@ def default():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_i8"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -83,7 +83,7 @@ def broadcast():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_i8_broadcast"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -102,7 +102,7 @@ def default():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_fp8x23"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -116,7 +116,7 @@ def broadcast():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_fp8x23_broadcast"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -135,7 +135,7 @@ def default():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_fp16x16"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand All @@ -149,7 +149,7 @@ def broadcast():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "equal_fp16x16_broadcast"
make_test([x, y], z, "input_0.equal(@input_1)", name)
Expand Down
20 changes: 10 additions & 10 deletions nodegen/node/less.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def default():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_u32"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -25,7 +25,7 @@ def broadcast():

x = Tensor(Dtype.U32, x.shape, x.flatten())
y = Tensor(Dtype.U32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_u32_broadcast"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -42,7 +42,7 @@ def default():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_i32"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -54,7 +54,7 @@ def broadcast():

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_i32_broadcast"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -71,7 +71,7 @@ def default():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_i8"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -83,7 +83,7 @@ def broadcast():

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_i8_broadcast"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -102,7 +102,7 @@ def default():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_fp8x23"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -116,7 +116,7 @@ def broadcast():
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_fp8x23_broadcast"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -135,7 +135,7 @@ def default():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_fp16x16"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand All @@ -149,7 +149,7 @@ def broadcast():
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))
z = Tensor(Dtype.U32, z.shape, z.flatten())
z = Tensor(Dtype.BOOL, z.shape, z.flatten())

name = "less_fp16x16_broadcast"
make_test([x, y], z, "input_0.less(@input_1)", name)
Expand Down
10 changes: 5 additions & 5 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ trait TensorTrait<T> {
/// #tensor.equal
///
/// ```rust
/// fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
/// fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<bool>;
/// ```
///
/// Check if two tensors are equal element-wise.
Expand All @@ -1104,7 +1104,7 @@ trait TensorTrait<T> {
///
/// ## Returns
///
/// A new `Tensor<usize>` of booleans (1 if equal, 0 otherwise) with the same shape as the broadcasted inputs.
/// A new `Tensor<bool>` of booleans (1 if equal, 0 otherwise) with the same shape as the broadcasted inputs.
///
/// ## Examples
///
Expand All @@ -1127,7 +1127,7 @@ trait TensorTrait<T> {
/// // We can call `equal` function as follows.
/// return tensor_1.equal(@tensor_2);
/// }
/// >>> [1,1,1,1,1,0,0,0]
/// >>> [true,true,true,true,true,false,false,false]
/// ```
///
/// Case 2: Compare tensors with different shapes
Expand All @@ -1147,10 +1147,10 @@ trait TensorTrait<T> {
/// // We can call `equal` function as follows.
/// return tensor_1.equal(@tensor_2);
/// }
/// >>> [1,1,1,0,0,0,0,0,0]
/// >>> [true,true,true,false,false,false,false,false,false]
/// ```
///
fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
fn equal(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<bool>;
/// #tensor.greater
///
/// ```rust
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl BoolTensor of TensorTrait<bool> {
panic(array!['not supported!'])
}

fn equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<usize> {
fn equal(self: @Tensor<bool>, other: @Tensor<bool>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl Complex64Tensor of TensorTrait<complex64> {
math::log::log(*self)
}

fn equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<usize> {
fn equal(self: @Tensor<complex64>, other: @Tensor<complex64>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<usize> {
fn equal(self: @Tensor<FP16x16>, other: @Tensor<FP16x16>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<usize> {
fn equal(self: @Tensor<FP16x16W>, other: @Tensor<FP16x16W>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<usize> {
fn equal(self: @Tensor<FP32x32>, other: @Tensor<FP32x32>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<usize> {
fn equal(self: @Tensor<FP64x64>, other: @Tensor<FP64x64>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<usize> {
fn equal(self: @Tensor<FP8x23>, other: @Tensor<FP8x23>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
math::log::log(*self)
}

fn equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<usize> {
fn equal(self: @Tensor<FP8x23W>, other: @Tensor<FP8x23W>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl I32Tensor of TensorTrait<i32> {
panic(array!['not supported!'])
}

fn equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<usize> {
fn equal(self: @Tensor<i32>, other: @Tensor<i32>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl I8Tensor of TensorTrait<i8> {
panic(array!['not supported!'])
}

fn equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<usize> {
fn equal(self: @Tensor<i8>, other: @Tensor<i8>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl U32Tensor of TensorTrait<u32> {
panic(array!['not supported!'])
}

fn equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<usize> {
fn equal(self: @Tensor<u32>, other: @Tensor<u32>) -> Tensor<bool> {
math::equal::equal(self, other)
}

Expand Down
10 changes: 5 additions & 5 deletions src/operators/tensor/math/equal.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use orion::operators::tensor::helpers::{
/// Cf: TensorTrait::equal docstring
fn equal<
T,
impl UsizeFTensor: TensorTrait<usize>,
impl BoolTensor: TensorTrait<bool>,
impl TPartialEq: PartialEq<T>,
impl TCopy: Copy<T>,
impl TDrop: Drop<T>
>(
y: @Tensor<T>, z: @Tensor<T>
) -> Tensor<usize> {
) -> Tensor<bool> {
let broadcasted_shape = broadcast_shape(*y.shape, *z.shape);
let mut result: Array<usize> = array![];
let mut result: Array<bool> = array![];

let num_elements = len_from_shape(broadcasted_shape);

Expand All @@ -26,9 +26,9 @@ fn equal<
let indices_other = broadcast_index_mapping(*z.shape, indices_broadcasted);

if *(*y.data)[indices_self] == *(*z.data)[indices_other] {
result.append(1);
result.append(true);
} else {
result.append(0);
result.append(false);
}

n += 1;
Expand Down
Loading

0 comments on commit 4ae5a38

Please sign in to comment.