-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb6f4a0
commit 070c6ee
Showing
30 changed files
with
581 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# TensorTrait::det | ||
|
||
```rust | ||
fn det(tensor: @Tensor<T>) -> T; | ||
``` | ||
|
||
Det calculates determinant of a square matrix or batches of square matrices. Det takes one input tensor of shape [*, M, M], where * is zero or more batch dimensions, and the inner-most 2 dimensions form square matrices. The output is a tensor of shape [*], containing the determinants of all input submatrices. | ||
|
||
## Args | ||
|
||
* `tensor`(`@Tensor<T>`) - The input tensor of shape [*, M, M]. | ||
|
||
## Returns | ||
|
||
* The output is a tensor of shape [*] | ||
|
||
## Examples | ||
|
||
```rust | ||
use orion::operators::tensor::{I32Tensor, I32TensorAdd}; | ||
use core::array::{ArrayTrait, SpanTrait}; | ||
use orion::operators::tensor::{TensorTrait, Tensor}; | ||
use orion::utils::{assert_eq, assert_seq_eq}; | ||
use orion::operators::tensor::I32TensorPartialEq; | ||
|
||
fn example() -> Tensor<i32> { | ||
let mut shape = ArrayTrait::<usize>::new(); | ||
shape.append(2); | ||
shape.append(3); | ||
shape.append(3); | ||
|
||
let mut data = ArrayTrait::new(); | ||
data.append(1); | ||
data.append(2); | ||
data.append(3); | ||
data.append(4); | ||
data.append(5); | ||
data.append(6); | ||
data.append(7); | ||
data.append(8); | ||
data.append(9); | ||
data.append(2); | ||
data.append(2); | ||
data.append(3); | ||
data.append(4); | ||
data.append(5); | ||
data.append(6); | ||
data.append(7); | ||
data.append(8); | ||
data.append(9); | ||
let input_0 = TensorTrait::new(shape.span(), data.span()); | ||
|
||
return input_0.det(); | ||
} | ||
>>> [0, -3] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
from nodegen.node import RunAll | ||
from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait, get_data_statement | ||
|
||
|
||
class Det(RunAll): | ||
|
||
@staticmethod | ||
# We test here with fp8x23 implementation. | ||
def fp8x23(): | ||
x = np.random.randint(-2, 2, (2, 4, 4)).astype(np.float64) | ||
y = np.linalg.det(x) | ||
|
||
x = Tensor(Dtype.FP8x23, x.shape, to_fp( | ||
x.flatten(), FixedImpl.FP8x23)) | ||
y = Tensor(Dtype.FP8x23, y.shape, to_fp( | ||
y.flatten(), FixedImpl.FP8x23)) | ||
|
||
name = "det_fp8x23" | ||
make_test([x], y, f"input_0.det()", name) | ||
|
||
@staticmethod | ||
# We test here with fp16x16 implementation. | ||
def fp16x16(): | ||
x = np.random.randint(-3, 3, (1, 2, 2)).astype(np.float64) | ||
y = np.linalg.det(x) | ||
|
||
x = Tensor(Dtype.FP16x16, x.shape, to_fp( | ||
x.flatten(), FixedImpl.FP16x16)) | ||
y = Tensor(Dtype.FP16x16, y.shape, to_fp( | ||
y.flatten(), FixedImpl.FP16x16)) | ||
|
||
name = "det_fp16x16" | ||
make_test([x], y, f"input_0.det()", name) | ||
|
||
@staticmethod | ||
# We test here with i8 implementation. | ||
def i8(): | ||
x = np.random.randint(0, 6, (3, 1, 1)).astype(np.int8) | ||
y = np.linalg.det(x) | ||
|
||
x = Tensor(Dtype.I8, x.shape, x.flatten()) | ||
y = Tensor(Dtype.I8, y.shape, y.flatten()) | ||
|
||
name = "det_i8" | ||
make_test([x], y, f"input_0.det()", name) | ||
|
||
@staticmethod | ||
# We test here with i32 implementation. | ||
def i32(): | ||
x = np.array([[[1, 2, 3], | ||
[4, 5, 6], | ||
[7, 8, 9]], | ||
[[2, 2, 3], | ||
[4, 5, 6], | ||
[7, 8, 9]] | ||
]) | ||
y = np.linalg.det(x) | ||
|
||
x = Tensor(Dtype.I32, x.shape, x.flatten()) | ||
y = Tensor(Dtype.I32, y.shape, y.flatten()) | ||
|
||
name = "det_i32" | ||
make_test([x], y, f"input_0.det()", name) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,3 +68,4 @@ mod hann_window; | |
mod hamming_window; | ||
mod blackman_window; | ||
mod scatter_nd; | ||
mod det; |
Oops, something went wrong.