Skip to content

Commit

Permalink
Merge pull request #477 from bilgin-kocak/feat-bitwise-xor
Browse files Browse the repository at this point in the history
Feat bitwise xor
  • Loading branch information
raphaelDkhn authored Nov 30, 2023
2 parents c1ca3f4 + bf30699 commit 51faa2b
Show file tree
Hide file tree
Showing 29 changed files with 314 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
* [tensor.and](framework/operators/tensor/tensor.and.md)
* [tensor.where](framework/operators/tensor/tensor.where.md)
* [tensor.bitwise_and](framework/operators/tensor/tensor.bitwise_and.md)
* [tensor.bitwise_xor](framework/operators/tensor/tensor.bitwise_xor.md)
* [tensor.bitwise_or](framework/operators/tensor/tensor.bitwise_or.md)
* [tensor.round](framework/operators/tensor/tensor.round.md)
* [tensor.scatter](framework/operators/tensor/tensor.scatter.md)
Expand Down
3 changes: 2 additions & 1 deletion docs/framework/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ You can see below the list of current supported ONNX Operators:
| [Where](operators/tensor/tensor.where.md) | :white\_check\_mark: |
| [BitwiseAnd](operators/tensor/tensor.bitwise_and.md) | :white\_check\_mark: |
| [BitwiseOr](operators/tensor/tensor.bitwise_or.md) | :white\_check\_mark: |
| [BitwiseXor](operators/tensor/tensor.bitwise_xor.md) | :white\_check\_mark: |
| [Round](operators/tensor/tensor.round.md) | :white\_check\_mark: |
| [MaxInTensor](operators/tensor/tensor.max\_in\_tensor.md) | :white\_check\_mark: |
| [Max](operators/tensor/tensor.max.md) | :white\_check\_mark: |
Expand All @@ -99,4 +100,4 @@ You can see below the list of current supported ONNX Operators:
| [IsNaN](operators/tensor/tensor.is\_nan.md) | :white\_check\_mark: |
| [IsInf](operators/tensor/tensor.is\_inf.md) | :white\_check\_mark: |

Current Operators support: **88/156 (56%)**
Current Operators support: **83/156 (53%)**
4 changes: 1 addition & 3 deletions docs/framework/operators/tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ use orion::operators::tensor::TensorTrait;
| [`tensor.identity`](tensor.identity.md) | Return a Tensor with the same shape and contents as input. |
| [`tensor.where`](tensor.where.md) | Return elements chosen from x or y depending on condition. |
| [`tensor.bitwise_and`](tensor.bitwise\_and.md) | Computes the bitwise AND of two tensors element-wise. |
| [`tensor.bitwise_xor`](tensor.bitwise\_xor.md) | Computes the bitwise XOR of two tensors element-wise. |
| [`tensor.bitwise_or`](tensor.bitwise\_or.md) | Computes the bitwise OR of two tensors element-wise. |
| [`tensor.round`](tensor.round.md) | Computes the round value of all elements in the input tensor. |
| [`tensor.reduce_l1`](tensor.reduce\_l1.md) | Computes the L1 norm of the input tensor's elements along the provided axes. |
Expand All @@ -106,15 +107,12 @@ use orion::operators::tensor::TensorTrait;
| [`tensor.reduce_min`](tensor.reduce\_min.md) | Computes the min of the input tensor's elements along the provided axes. |
| [`tensor.sequence_construct`](tensor.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. |
| [`tensor.sequence_length`](tensor.sequence\_length.md) | Returns the length of the input sequence. |
| [`tensor.shrink`](tensor.shrink.md) | Shrinks the input tensor element-wise to the output tensor with the same datatype and shape based on a defined formula. |
| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. |
| [`tensor.sequence_insert`](tensor.sequence\_insert.md) | Insert a tensor into a sequence. |
| [`tensor.sequence_at`](tensor.sequence\_at.md) | Outputs the tensor at the specified position in the input sequence. |
| [`tensor.sequence_construct`](tensor.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. |
| [`tensor.shrink`](tensor.shrink.md) | Shrinks the input tensor element-wise to the output tensor. |
| [`tensor.reduce_mean`](tensor.reduce\_mean.md) | Computes the mean of the input tensor's elements along the provided axes. |
| [`tensor.pow`](tensor.pow.md) | Pow takes input data (Tensor) and exponent Tensor, and produces one output data (Tensor) where the function f(x) = x^exponent, is applied to the data tensor elementwise. |
| [`tensor.sequence_erase`](tensor.sequence\_erase.md) | Outputs the tensor sequence with the erased tensor at the specified position. |
| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. |
| [`tensor.binarizer`](tensor.binarizer.md) | Maps the values of a tensor element-wise to 0 or 1 based on the comparison against a threshold value. |
| [`tensor.array_feature_extractor`](tensor.array\_feature\_extractor.md) | Selects elements of the input tensor based on the indices passed applied to the last tensor axis. |
Expand Down
44 changes: 44 additions & 0 deletions docs/framework/operators/tensor/tensor.bitwise_xor.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#tensor.bitwise_xor

```rust
fn bitwise_xor(self: @Tensor<T>, other: @Tensor<T>) -> Tensor<usize>;
```

Computes the bitwise XOR of two tensors element-wise.
The input tensors must have either:
* Exactly the same shape
* The same number of dimensions and the length of each dimension is either a common length or 1.

## Args

* `self`(`@Tensor<T>`) - The first tensor to be compared
* `other`(`@Tensor<T>`) - The second tensor to be compared

## Panics

* Panics if the shapes are not equal or broadcastable

## Returns

A new `Tensor<T>` with the same shape as the broadcasted inputs.

## Example

```rust
use array::{ArrayTrait, SpanTrait};

use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor};

fn xor_example() -> Tensor<usize> {
let tensor_1 = TensorTrait::<u32>::new(
shape: array![3, 3].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7, 8].span(),
);

let tensor_2 = TensorTrait::<u32>::new(
shape: array![3, 3].span(), data: array![0, 1, 2, 0, 4, 5, 0, 6, 2].span(),
);

return tensor_1.bitwise_xor(@tensor_2);
}
>>> [0,0,0,3,0,0,6,1,10]
```
49 changes: 49 additions & 0 deletions src/numbers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ trait NumberTrait<T, MAG> {
fn is_pos_inf(self: T) -> bool;
fn is_neg_inf(self: T) -> bool;
fn bitwise_and(lhs: T, rhs: T) -> T;
fn bitwise_xor(lhs: T, rhs: T) -> T;
fn bitwise_or(lhs: T, rhs: T) -> T;
fn add(lhs: T, rhs: T) -> T;
fn sub(lhs: T, rhs: T) -> T;
Expand Down Expand Up @@ -273,6 +274,10 @@ impl FP8x23Number of NumberTrait<FP8x23, u32> {
comp_fp8x23::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP8x23, rhs: FP8x23) -> FP8x23 {
comp_fp8x23::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP8x23, rhs: FP8x23) -> FP8x23 {
comp_fp8x23::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -496,6 +501,10 @@ impl FP8x23WNumber of NumberTrait<FP8x23W, u64> {
comp_fp8x23wide::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W {
comp_fp8x23wide::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP8x23W, rhs: FP8x23W) -> FP8x23W {
comp_fp8x23wide::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -719,6 +728,10 @@ impl FP16x16Number of NumberTrait<FP16x16, u32> {
comp_fp16x16::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP16x16, rhs: FP16x16) -> FP16x16 {
comp_fp16x16::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP16x16, rhs: FP16x16) -> FP16x16 {
comp_fp16x16::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -942,6 +955,10 @@ impl FP16x16WNumber of NumberTrait<FP16x16W, u64> {
comp_fp16x16wide::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W {
comp_fp16x16wide::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP16x16W, rhs: FP16x16W) -> FP16x16W {
comp_fp16x16wide::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -1166,6 +1183,10 @@ impl FP64x64Number of NumberTrait<FP64x64, u128> {
comp_fp64x64::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP64x64, rhs: FP64x64) -> FP64x64 {
comp_fp64x64::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP64x64, rhs: FP64x64) -> FP64x64 {
comp_fp64x64::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -1390,6 +1411,10 @@ impl FP32x32Number of NumberTrait<FP32x32, u64> {
comp_fp32x32::bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: FP32x32, rhs: FP32x32) -> FP32x32 {
comp_fp32x32::bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: FP32x32, rhs: FP32x32) -> FP32x32 {
comp_fp32x32::bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -1627,6 +1652,10 @@ impl I8Number of NumberTrait<i8, u8> {
i8_core::i8_bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: i8, rhs: i8) -> i8 {
i8_core::i8_bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: i8, rhs: i8) -> i8 {
i8_core::i8_bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -1864,6 +1893,10 @@ impl i16Number of NumberTrait<i16, u16> {
i16_core::i16_bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: i16, rhs: i16) -> i16 {
i16_core::i16_bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: i16, rhs: i16) -> i16 {
i16_core::i16_bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -2101,6 +2134,10 @@ impl i32Number of NumberTrait<i32, u32> {
i32_core::i32_bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: i32, rhs: i32) -> i32 {
i32_core::i32_bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: i32, rhs: i32) -> i32 {
i32_core::i32_bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -2338,6 +2375,10 @@ impl i64Number of NumberTrait<i64, u64> {
i64_core::i64_bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: i64, rhs: i64) -> i64 {
i64_core::i64_bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: i64, rhs: i64) -> i64 {
i64_core::i64_bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -2576,6 +2617,10 @@ impl i128Number of NumberTrait<i128, u128> {
i128_core::i128_bitwise_and(lhs, rhs)
}

fn bitwise_xor(lhs: i128, rhs: i128) -> i128 {
i128_core::i128_bitwise_xor(lhs, rhs)
}

fn bitwise_or(lhs: i128, rhs: i128) -> i128 {
i128_core::i128_bitwise_or(lhs, rhs)
}
Expand Down Expand Up @@ -2818,6 +2863,10 @@ impl u32Number of NumberTrait<u32, u32> {
lhs & rhs
}

fn bitwise_xor(lhs: u32, rhs: u32) -> u32 {
lhs ^ rhs
}

fn bitwise_or(lhs: u32, rhs: u32) -> u32 {
lhs | rhs
}
Expand Down
14 changes: 13 additions & 1 deletion src/numbers/fixed_point/implementations/fp16x16/math/comp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ fn bitwise_and(a: FP16x16, b: FP16x16) -> FP16x16 {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP16x16, b: FP16x16) -> FP16x16 {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP16x16, b: FP16x16) -> FP16x16 {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
Expand All @@ -64,7 +68,7 @@ fn bitwise_or(a: FP16x16, b: FP16x16) -> FP16x16 {

#[cfg(test)]
mod tests {
use super::{FixedTrait, max, min, bitwise_and, bitwise_or};
use super::{FixedTrait, max, min, bitwise_and, bitwise_xor, bitwise_or};


#[test]
Expand Down Expand Up @@ -115,6 +119,14 @@ mod tests {
}

#[test]
fn test_bitwise_xor() {
let a = FixedTrait::new(225280, false); // 3.4375
let b = FixedTrait::new(4160843776, true); // -2046.5625
let c = FixedTrait::new(4160880640, true);

assert(bitwise_xor(a, b) == c, 'bitwise_xor(a,b)')
}

fn test_bitwise_or() {
let a = FixedTrait::new(225280, false); // 3.4375
let b = FixedTrait::new(4160843776, true); // -2046.5625
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ fn bitwise_and(a: FP16x16W, b: FP16x16W) -> FP16x16W {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP16x16W, b: FP16x16W) -> FP16x16W {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP16x16W, b: FP16x16W) -> FP16x16W {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
Expand All @@ -64,7 +68,7 @@ fn bitwise_or(a: FP16x16W, b: FP16x16W) -> FP16x16W {

#[cfg(test)]
mod tests {
use super::{FixedTrait, max, min, bitwise_and, bitwise_or};
use super::{FixedTrait, max, min, bitwise_and, bitwise_xor, bitwise_or};


#[test]
Expand Down Expand Up @@ -115,6 +119,14 @@ mod tests {
}

#[test]
fn test_bitwise_xor() {
let a = FixedTrait::new(225280, false); // 3.4375
let b = FixedTrait::new(4160843776, true); // -2046.5625
let c = FixedTrait::new(4160880640, true);

assert(bitwise_xor(a, b) == c, 'bitwise_xor(a,b)')
}

fn test_bitwise_or() {
let a = FixedTrait::new(225280, false); // 3.4375
let b = FixedTrait::new(4160843776, true); // -2046.5625
Expand Down
6 changes: 5 additions & 1 deletion src/numbers/fixed_point/implementations/fp32x32/comp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ fn bitwise_and(a: FP32x32, b: FP32x32) -> FP32x32 {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP32x32, b: FP32x32) -> FP32x32 {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP32x32, b: FP32x32) -> FP32x32 {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
}
4 changes: 4 additions & 0 deletions src/numbers/fixed_point/implementations/fp64x64/comp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ fn bitwise_and(a: FP64x64, b: FP64x64) -> FP64x64 {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP64x64, b: FP64x64) -> FP64x64 {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP64x64, b: FP64x64) -> FP64x64 {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
14 changes: 13 additions & 1 deletion src/numbers/fixed_point/implementations/fp8x23/math/comp.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ fn bitwise_and(a: FP8x23, b: FP8x23) -> FP8x23 {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP8x23, b: FP8x23) -> FP8x23 {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP8x23, b: FP8x23) -> FP8x23 {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
Expand All @@ -64,7 +68,7 @@ fn bitwise_or(a: FP8x23, b: FP8x23) -> FP8x23 {

#[cfg(test)]
mod tests {
use super::{FixedTrait, max, min, bitwise_and, bitwise_or};
use super::{FixedTrait, max, min, bitwise_and, bitwise_xor, bitwise_or};

#[test]
fn test_max() {
Expand Down Expand Up @@ -112,6 +116,14 @@ mod tests {
}

#[test]
fn test_bitwise_xor() {
let a = FixedTrait::new(28835840, false); // 3.4375
let b = FixedTrait::new(1639448576, true); // -60.5625
let c = FixedTrait::new(1610612736, true);

assert(bitwise_xor(a, b) == c, 'bitwise_xor(a,b)')
}

fn test_bitwise_or() {
let a = FixedTrait::new(28835840, false); // 3.4375
let b = FixedTrait::new(1639448576, true); // -60.5625
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ fn bitwise_and(a: FP8x23W, b: FP8x23W) -> FP8x23W {
return FixedTrait::new(a.mag & b.mag, a.sign & b.sign);
}

fn bitwise_xor(a: FP8x23W, b: FP8x23W) -> FP8x23W {
return FixedTrait::new(a.mag ^ b.mag, a.sign ^ b.sign);
}

fn bitwise_or(a: FP8x23W, b: FP8x23W) -> FP8x23W {
return FixedTrait::new(a.mag | b.mag, a.sign | b.sign);
}
Expand All @@ -64,7 +68,7 @@ fn bitwise_or(a: FP8x23W, b: FP8x23W) -> FP8x23W {

#[cfg(test)]
mod tests {
use super::{FixedTrait, max, min, bitwise_and, bitwise_or};
use super::{FixedTrait, max, min, bitwise_and, bitwise_xor, bitwise_or};


#[test]
Expand Down Expand Up @@ -114,6 +118,14 @@ mod tests {
}

#[test]
fn test_bitwise_xor() {
let a = FixedTrait::new(28835840, false); // 3.4375
let b = FixedTrait::new(1639448576, true); // -60.5625
let c = FixedTrait::new(1610612736, true);

assert(bitwise_xor(a, b) == c, 'bitwise_xor(a,b)')
}

fn test_bitwise_or() {
let a = FixedTrait::new(28835840, false); // 3.4375
let b = FixedTrait::new(1639448576, true); // -60.5625
Expand Down
Loading

0 comments on commit 51faa2b

Please sign in to comment.