Skip to content

Commit

Permalink
move tensor arithmetic to trait
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Nov 30, 2023
1 parent 27229b2 commit 75cf8e9
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ trait TensorTrait<T> {
/// ```
///
fn min_in_tensor(self: @Tensor<T>) -> T;
fn add(lhs: Tensor<T>, rhs: Tensor<T>) -> Tensor<T>;
fn sub(lhs: Tensor<T>, rhs: Tensor<T>) -> Tensor<T>;
fn mul(lhs: Tensor<T>, rhs: Tensor<T>) -> Tensor<T>;
fn div(lhs: Tensor<T>, rhs: Tensor<T>) -> Tensor<T>;
/// # tensor.min
///
/// ```rust
Expand Down
17 changes: 17 additions & 0 deletions src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ impl BoolTensor of TensorTrait<bool> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<bool>, rhs: Tensor<bool>) -> Tensor<bool> {
panic(array!['not supported!'])
}

fn sub(lhs: Tensor<bool>, rhs: Tensor<bool>) -> Tensor<bool> {
panic(array!['not supported!'])
}

fn mul(lhs: Tensor<bool>, rhs: Tensor<bool>) -> Tensor<bool> {
panic(array!['not supported!'])
}

fn div(lhs: Tensor<bool>, rhs: Tensor<bool>) -> Tensor<bool> {
panic(array!['not supported!'])
}


fn min_in_tensor(self: @Tensor<bool>) -> bool {
panic(array!['not supported!'])
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
constant_of_shape(shape, value)
}

fn add(lhs: Tensor<FP16x16>, rhs: Tensor<FP16x16>) -> Tensor<FP16x16> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP16x16>, rhs: Tensor<FP16x16>) -> Tensor<FP16x16> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP16x16>, rhs: Tensor<FP16x16>) -> Tensor<FP16x16> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP16x16>, rhs: Tensor<FP16x16>) -> Tensor<FP16x16> {
math::arithmetic::div(@lhs, @rhs)
}

fn at(self: @Tensor<FP16x16>, indices: Span<usize>) -> FP16x16 {
*at_tensor(self, indices)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<FP16x16W>, rhs: Tensor<FP16x16W>) -> Tensor<FP16x16W> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP16x16W>, rhs: Tensor<FP16x16W>) -> Tensor<FP16x16W> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP16x16W>, rhs: Tensor<FP16x16W>) -> Tensor<FP16x16W> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP16x16W>, rhs: Tensor<FP16x16W>) -> Tensor<FP16x16W> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<FP16x16W>) -> FP16x16W {
math::min_in_tensor::min_in_tensor::<FP16x16W, u64>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
constant_of_shape(shape, value)
}

fn add(lhs: Tensor<FP32x32>, rhs: Tensor<FP32x32>) -> Tensor<FP32x32> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP32x32>, rhs: Tensor<FP32x32>) -> Tensor<FP32x32> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP32x32>, rhs: Tensor<FP32x32>) -> Tensor<FP32x32> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP32x32>, rhs: Tensor<FP32x32>) -> Tensor<FP32x32> {
math::arithmetic::div(@lhs, @rhs)
}

fn at(self: @Tensor<FP32x32>, indices: Span<usize>) -> FP32x32 {
*at_tensor(self, indices)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<FP64x64>, rhs: Tensor<FP64x64>) -> Tensor<FP64x64> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP64x64>, rhs: Tensor<FP64x64>) -> Tensor<FP64x64> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP64x64>, rhs: Tensor<FP64x64>) -> Tensor<FP64x64> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP64x64>, rhs: Tensor<FP64x64>) -> Tensor<FP64x64> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<FP64x64>) -> FP64x64 {
math::min_in_tensor::min_in_tensor::<FP64x64, u128>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<FP8x23>, rhs: Tensor<FP8x23>) -> Tensor<FP8x23> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP8x23>, rhs: Tensor<FP8x23>) -> Tensor<FP8x23> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP8x23>, rhs: Tensor<FP8x23>) -> Tensor<FP8x23> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP8x23>, rhs: Tensor<FP8x23>) -> Tensor<FP8x23> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<FP8x23>) -> FP8x23 {
math::min_in_tensor::min_in_tensor::<FP8x23, u32>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<FP8x23W>, rhs: Tensor<FP8x23W>) -> Tensor<FP8x23W> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<FP8x23W>, rhs: Tensor<FP8x23W>) -> Tensor<FP8x23W> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<FP8x23W>, rhs: Tensor<FP8x23W>) -> Tensor<FP8x23W> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<FP8x23W>, rhs: Tensor<FP8x23W>) -> Tensor<FP8x23W> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<FP8x23W>) -> FP8x23W {
math::min_in_tensor::min_in_tensor::<FP8x23W, u64>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ impl I32Tensor of TensorTrait<i32> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<i32>, rhs: Tensor<i32>) -> Tensor<i32> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<i32>, rhs: Tensor<i32>) -> Tensor<i32> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<i32>, rhs: Tensor<i32>) -> Tensor<i32> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<i32>, rhs: Tensor<i32>) -> Tensor<i32> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<i32>) -> i32 {
math::min_in_tensor::min_in_tensor::<i32, u32>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ impl I8Tensor of TensorTrait<i8> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<i8>, rhs: Tensor<i8>) -> Tensor<i8> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<i8>, rhs: Tensor<i8>) -> Tensor<i8> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<i8>, rhs: Tensor<i8>) -> Tensor<i8> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<i8>, rhs: Tensor<i8>) -> Tensor<i8> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<i8>) -> i8 {
math::min_in_tensor::min_in_tensor::<i8, u8>(*self.data)
}
Expand Down
16 changes: 16 additions & 0 deletions src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ impl U32Tensor of TensorTrait<u32> {
*at_tensor(self, indices)
}

fn add(lhs: Tensor<u32>, rhs: Tensor<u32>) -> Tensor<u32> {
math::arithmetic::add(@lhs, @rhs)
}

fn sub(lhs: Tensor<u32>, rhs: Tensor<u32>) -> Tensor<u32> {
math::arithmetic::sub(@lhs, @rhs)
}

fn mul(lhs: Tensor<u32>, rhs: Tensor<u32>) -> Tensor<u32> {
math::arithmetic::mul(@lhs, @rhs)
}

fn div(lhs: Tensor<u32>, rhs: Tensor<u32>) -> Tensor<u32> {
math::arithmetic::div(@lhs, @rhs)
}

fn min_in_tensor(self: @Tensor<u32>) -> u32 {
math::min_in_tensor::min_in_tensor::<u32, u32>(*self.data)
}
Expand Down

0 comments on commit 75cf8e9

Please sign in to comment.