Skip to content

Commit

Permalink
Feat: det
Browse files Browse the repository at this point in the history
  • Loading branch information
canacechan committed Mar 15, 2024
1 parent fb6f4a0 commit 070c6ee
Show file tree
Hide file tree
Showing 30 changed files with 581 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/framework/operators/tensor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ use orion::operators::tensor::TensorTrait;
| [`tensor.dynamic_quantize_linear`](tensor.dynamic\_quantize\_linear.md) | Computes the Scale, Zero Point and FP32->8Bit conversion of FP32 Input data. |
| [`tensor.scatter_nd`](tensor.scatter\_nd.md) | The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data |
| [`tensor.label_encoder`](tensor.label\_encoder.md) | Maps each element in the input tensor to another value. |
| [`tensor.det`](tensor.det.md) | Det calculates determinant of a square matrix or batches of square matrices. |

## Arithmetic Operations

Expand Down
56 changes: 56 additions & 0 deletions docs/framework/operators/tensor/tensor.det.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# TensorTrait::det

```rust
fn det(tensor: @Tensor<T>) -> T;
```

Det calculates determinant of a square matrix or batches of square matrices. Det takes one input tensor of shape [*, M, M], where * is zero or more batch dimensions, and the inner-most 2 dimensions form square matrices. The output is a tensor of shape [*], containing the determinants of all input submatrices.

## Args

* `tensor`(`@Tensor<T>`) - The input tensor of shape [*, M, M].

## Returns

* The output is a tensor of shape [*]

## Examples

```rust
use orion::operators::tensor::{I32Tensor, I32TensorAdd};
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::tensor::I32TensorPartialEq;

fn example() -> Tensor<i32> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(2);
shape.append(3);
shape.append(3);

let mut data = ArrayTrait::new();
data.append(1);
data.append(2);
data.append(3);
data.append(4);
data.append(5);
data.append(6);
data.append(7);
data.append(8);
data.append(9);
data.append(2);
data.append(2);
data.append(3);
data.append(4);
data.append(5);
data.append(6);
data.append(7);
data.append(8);
data.append(9);
let input_0 = TensorTrait::new(shape.span(), data.span());

return input_0.det();
}
>>> [0, -3]
```
65 changes: 65 additions & 0 deletions nodegen/node/det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
from nodegen.node import RunAll
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait, get_data_statement


class Det(RunAll):

@staticmethod
# We test here with fp8x23 implementation.
def fp8x23():
x = np.random.randint(-2, 2, (2, 4, 4)).astype(np.float64)
y = np.linalg.det(x)

x = Tensor(Dtype.FP8x23, x.shape, to_fp(
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))

name = "det_fp8x23"
make_test([x], y, f"input_0.det()", name)

@staticmethod
# We test here with fp16x16 implementation.
def fp16x16():
x = np.random.randint(-3, 3, (1, 2, 2)).astype(np.float64)
y = np.linalg.det(x)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "det_fp16x16"
make_test([x], y, f"input_0.det()", name)

@staticmethod
# We test here with i8 implementation.
def i8():
x = np.random.randint(0, 6, (3, 1, 1)).astype(np.int8)
y = np.linalg.det(x)

x = Tensor(Dtype.I8, x.shape, x.flatten())
y = Tensor(Dtype.I8, y.shape, y.flatten())

name = "det_i8"
make_test([x], y, f"input_0.det()", name)

@staticmethod
# We test here with i32 implementation.
def i32():
x = np.array([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[2, 2, 3],
[4, 5, 6],
[7, 8, 9]]
])
y = np.linalg.det(x)

x = Tensor(Dtype.I32, x.shape, x.flatten())
y = Tensor(Dtype.I32, y.shape, y.flatten())

name = "det_i32"
make_test([x], y, f"input_0.det()", name)

58 changes: 58 additions & 0 deletions src/operators/tensor/core.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ impl TensorSerde<T, impl TSerde: Serde<T>, impl TDrop: Drop<T>> of Serde<Tensor<
/// dynamic_quantize_linear - Computes the Scale, Zero Point and FP32->8Bit conversion of FP32 Input data.
/// scatter_nd - The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data
/// label_encoder - Maps each element in the input tensor to another value.
/// det - Det calculates determinant of a square matrix or batches of square matrices.
trait TensorTrait<T> {
/// # tensor.new
///
Expand Down Expand Up @@ -5850,6 +5851,63 @@ trait TensorTrait<T> {
values: Option<Span<T>>,
values_tensor: Option<Tensor<T>>
) -> Tensor<T>;
/// # TensorTrait::det
///
/// ```rust
/// fn det(tensor: @Tensor<T>) -> T;
/// ```
///
/// Det calculates determinant of a square matrix or batches of square matrices. Det takes one input tensor of shape [*, M, M], where * is zero or more batch dimensions, and the inner-most 2 dimensions form square matrices. The output is a tensor of shape [*], containing the determinants of all input submatrices.
///
/// ## Args
///
/// * `tensor`(`@Tensor<T>`) - The input tensor of shape [*, M, M].
///
/// ## Returns
///
/// * The output is a tensor of shape [*]
///
/// ## Examples
///
/// ```rust
/// use orion::operators::tensor::{I32Tensor, I32TensorAdd};
/// use core::array::{ArrayTrait, SpanTrait};
/// use orion::operators::tensor::{TensorTrait, Tensor};
/// use orion::utils::{assert_eq, assert_seq_eq};
/// use orion::operators::tensor::I32TensorPartialEq;
///
/// fn example() -> Tensor<i32> {
/// let mut shape = ArrayTrait::<usize>::new();
/// shape.append(2);
/// shape.append(3);
/// shape.append(3);
///
/// let mut data = ArrayTrait::new();
/// data.append(1);
/// data.append(2);
/// data.append(3);
/// data.append(4);
/// data.append(5);
/// data.append(6);
/// data.append(7);
/// data.append(8);
/// data.append(9);
/// data.append(2);
/// data.append(2);
/// data.append(3);
/// data.append(4);
/// data.append(5);
/// data.append(6);
/// data.append(7);
/// data.append(8);
/// data.append(9);
/// let input_0 = TensorTrait::new(shape.span(), data.span());
///
/// return input_0.det();
/// }
/// >>> [0, -3]
/// ```
fn det(self: @Tensor<T>) -> Tensor<T>;
}

/// Cf: TensorTrait::new docstring
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_bool.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ impl BoolTensor of TensorTrait<bool> {
) -> Tensor<bool> {
panic(array!['not supported!'])
}

fn det(self: @Tensor<bool>) -> Tensor<bool> {
panic(array!['not supported!'])
}
}

/// Implements partial equal for two `Tensor<bool>` using the `PartialEq` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_complex64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,10 @@ impl Complex64Tensor of TensorTrait<complex64> {
) -> Tensor<complex64> {
panic(array!['not supported!'])
}

fn det(self: @Tensor<complex64>) -> Tensor<complex64> {
panic(array!['not supported!'])
}
}

/// Implements addition for `Tensor<complex64>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,10 @@ impl FP16x16Tensor of TensorTrait<FP16x16> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP16x16>) -> Tensor<FP16x16> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP16x16>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp16x16wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ impl FP16x16WTensor of TensorTrait<FP16x16W> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP16x16W>) -> Tensor<FP16x16W> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP16x16W>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp32x32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ impl FP32x32Tensor of TensorTrait<FP32x32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP32x32>) -> Tensor<FP32x32> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP32x32>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp64x64.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ impl FP64x64Tensor of TensorTrait<FP64x64> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP64x64>) -> Tensor<FP64x64> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP64x64>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,10 @@ impl FP8x23Tensor of TensorTrait<FP8x23> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP8x23>) -> Tensor<FP8x23> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP8x23>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_fp8x23wide.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ impl FP8x23WTensor of TensorTrait<FP8x23W> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<FP8x23W>) -> Tensor<FP8x23W> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<FP8x23W>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_i32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ impl I32Tensor of TensorTrait<i32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<i32>) -> Tensor<i32> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<i32>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_i8.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,10 @@ impl I8Tensor of TensorTrait<i8> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<i8>) -> Tensor<i8> {
math::det::det(*self)
}
}

/// Implements addition for `Tensor<i8>` using the `Add` trait.
Expand Down
4 changes: 4 additions & 0 deletions src/operators/tensor/implementations/tensor_u32.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,10 @@ impl U32Tensor of TensorTrait<u32> {
self, default_list, default_tensor, keys, keys_tensor, values, values_tensor
)
}

fn det(self: @Tensor<u32>) -> Tensor<u32> {
panic(array!['not supported!'])
}
}

/// Implements addition for `Tensor<u32>` using the `Add` trait.
Expand Down
1 change: 1 addition & 0 deletions src/operators/tensor/math.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ mod hann_window;
mod hamming_window;
mod blackman_window;
mod scatter_nd;
mod det;
Loading

0 comments on commit 070c6ee

Please sign in to comment.