Skip to content

Commit

Permalink
test operator
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Mar 25, 2024
1 parent ca3b413 commit 5b17608
Show file tree
Hide file tree
Showing 19 changed files with 706 additions and 114 deletions.
60 changes: 43 additions & 17 deletions nodegen/node/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,59 @@ def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:


class Softmax(RunAll):

@staticmethod
def axis_0():
x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
y = softmax(x, axis=0)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "softmax_axis_0"
make_test([x], y, "NNTrait::softmax(@input_0, Option::Some(0))",
name, Trait.NN)

@staticmethod
def fp8x23():
x = np.random.uniform(-3, 3, (2, 2)).astype(np.float64)
y = softmax(x, 0)

x = Tensor(Dtype.FP8x23, x.shape, to_fp(
x.flatten(), FixedImpl.FP8x23))
y = Tensor(Dtype.FP8x23, y.shape, to_fp(
y.flatten(), FixedImpl.FP8x23))
def axis_1():
x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
y = softmax(x, axis=1)

name = "softmax_fp8x23"
make_test([x], y, "NNTrait::softmax(@input_0, 0)",
name, Trait.NN)
x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "softmax_axis_1"
make_test([x], y, "NNTrait::softmax(@input_0, Option::Some(1))",
name, Trait.NN)

@staticmethod
def fp16x16():
x = np.random.uniform(-3, 3, (2, 2)).astype(np.float64)
y = softmax(x, 1)
def axis_2():
x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
y = softmax(x, axis=2)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "softmax_fp16x16"
make_test([x], y, "NNTrait::softmax(@input_0, 1)",
name, Trait.NN)
name = "softmax_axis_2"
make_test([x], y, "NNTrait::softmax(@input_0, Option::Some(2))",
name, Trait.NN)

@staticmethod
def axis_minus_1():
x = np.abs(np.random.randn(3, 4, 5).astype(np.float32))
y = softmax(x, axis=-1)

x = Tensor(Dtype.FP16x16, x.shape, to_fp(
x.flatten(), FixedImpl.FP16x16))
y = Tensor(Dtype.FP16x16, y.shape, to_fp(
y.flatten(), FixedImpl.FP16x16))

name = "softmax_axis_minus_1"
make_test([x], y, "NNTrait::softmax(@input_0, Option::None)",
name, Trait.NN)
6 changes: 4 additions & 2 deletions tests/nodes.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,6 @@
// mod sin_fp8x23;
// mod sinh_fp16x16;
// mod sinh_fp8x23;
// mod softmax_fp16x16;
// mod softmax_fp8x23;
// mod softplus_fp8x23;
// mod softplus_fp16x16;
// mod softsign_fp8x23;
Expand Down Expand Up @@ -1020,3 +1018,7 @@ mod gather_elements_default;
mod gather_elements_axis1;
mod gather_elements_axis2;
mod gather_elements_negative_indices;
mod softmax_axis_0;
mod softmax_axis_1;
mod softmax_axis_2;
mod softmax_axis_minus_1;
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ mod input_0;
mod output_0;


use orion::operators::nn::NNTrait;
use orion::operators::tensor::FP16x16TensorPartialEq;
use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::nn::FP16x16NN;
use orion::operators::nn::NNTrait;
use orion::numbers::FixedTrait;
use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::tensor::FP16x16TensorPartialEq;

#[test]
#[available_gas(2000000000)]
fn test_softmax_fp16x16() {
fn test_softmax_axis_0() {
let input_0 = input_0::input_0();
let z = output_0::output_0();
let z_0 = output_0::output_0();

let y = NNTrait::softmax(@input_0, 1);
let y_0 = NNTrait::softmax(@input_0, Option::Some(0));

assert_eq(y, z);
assert_eq(y_0, z_0);
}
74 changes: 74 additions & 0 deletions tests/nodes/softmax_axis_0/input_0.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::numbers::{FixedTrait, FP16x16};

fn input_0() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(3);
shape.append(4);
shape.append(5);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 77748, sign: false });
data.append(FP16x16 { mag: 20181, sign: false });
data.append(FP16x16 { mag: 66586, sign: false });
data.append(FP16x16 { mag: 39564, sign: false });
data.append(FP16x16 { mag: 55469, sign: false });
data.append(FP16x16 { mag: 15765, sign: false });
data.append(FP16x16 { mag: 31745, sign: false });
data.append(FP16x16 { mag: 64291, sign: false });
data.append(FP16x16 { mag: 64704, sign: false });
data.append(FP16x16 { mag: 95806, sign: false });
data.append(FP16x16 { mag: 42434, sign: false });
data.append(FP16x16 { mag: 107711, sign: false });
data.append(FP16x16 { mag: 63051, sign: false });
data.append(FP16x16 { mag: 93445, sign: false });
data.append(FP16x16 { mag: 241, sign: false });
data.append(FP16x16 { mag: 131759, sign: false });
data.append(FP16x16 { mag: 74671, sign: false });
data.append(FP16x16 { mag: 44973, sign: false });
data.append(FP16x16 { mag: 92338, sign: false });
data.append(FP16x16 { mag: 36204, sign: false });
data.append(FP16x16 { mag: 12200, sign: false });
data.append(FP16x16 { mag: 73821, sign: false });
data.append(FP16x16 { mag: 13038, sign: false });
data.append(FP16x16 { mag: 21598, sign: false });
data.append(FP16x16 { mag: 75353, sign: false });
data.append(FP16x16 { mag: 41470, sign: false });
data.append(FP16x16 { mag: 11370, sign: false });
data.append(FP16x16 { mag: 62793, sign: false });
data.append(FP16x16 { mag: 19117, sign: false });
data.append(FP16x16 { mag: 95800, sign: false });
data.append(FP16x16 { mag: 40696, sign: false });
data.append(FP16x16 { mag: 95240, sign: false });
data.append(FP16x16 { mag: 103492, sign: false });
data.append(FP16x16 { mag: 36412, sign: false });
data.append(FP16x16 { mag: 22269, sign: false });
data.append(FP16x16 { mag: 201968, sign: false });
data.append(FP16x16 { mag: 40874, sign: false });
data.append(FP16x16 { mag: 14038, sign: false });
data.append(FP16x16 { mag: 55733, sign: false });
data.append(FP16x16 { mag: 65120, sign: false });
data.append(FP16x16 { mag: 128415, sign: false });
data.append(FP16x16 { mag: 86247, sign: false });
data.append(FP16x16 { mag: 47611, sign: false });
data.append(FP16x16 { mag: 34746, sign: false });
data.append(FP16x16 { mag: 23589, sign: false });
data.append(FP16x16 { mag: 51498, sign: false });
data.append(FP16x16 { mag: 6664, sign: false });
data.append(FP16x16 { mag: 32348, sign: false });
data.append(FP16x16 { mag: 31728, sign: false });
data.append(FP16x16 { mag: 43457, sign: false });
data.append(FP16x16 { mag: 41874, sign: false });
data.append(FP16x16 { mag: 17514, sign: false });
data.append(FP16x16 { mag: 42083, sign: false });
data.append(FP16x16 { mag: 30365, sign: false });
data.append(FP16x16 { mag: 133274, sign: false });
data.append(FP16x16 { mag: 54633, sign: false });
data.append(FP16x16 { mag: 168600, sign: false });
data.append(FP16x16 { mag: 15559, sign: false });
data.append(FP16x16 { mag: 50448, sign: false });
data.append(FP16x16 { mag: 70775, sign: false });
TensorTrait::new(shape.span(), data.span())
}
74 changes: 74 additions & 0 deletions tests/nodes/softmax_axis_0/output_0.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::numbers::{FixedTrait, FP16x16};

fn output_0() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(3);
shape.append(4);
shape.append(5);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 18542, sign: false });
data.append(FP16x16 { mag: 10909, sign: false });
data.append(FP16x16 { mag: 29920, sign: false });
data.append(FP16x16 { mag: 24368, sign: false });
data.append(FP16x16 { mag: 22071, sign: false });
data.append(FP16x16 { mag: 15584, sign: false });
data.append(FP16x16 { mag: 27139, sign: false });
data.append(FP16x16 { mag: 25287, sign: false });
data.append(FP16x16 { mag: 31157, sign: false });
data.append(FP16x16 { mag: 26751, sign: false });
data.append(FP16x16 { mag: 22100, sign: false });
data.append(FP16x16 { mag: 31519, sign: false });
data.append(FP16x16 { mag: 18307, sign: false });
data.append(FP16x16 { mag: 36393, sign: false });
data.append(FP16x16 { mag: 6545, sign: false });
data.append(FP16x16 { mag: 15502, sign: false });
data.append(FP16x16 { mag: 11319, sign: false });
data.append(FP16x16 { mag: 28971, sign: false });
data.append(FP16x16 { mag: 31211, sign: false });
data.append(FP16x16 { mag: 15422, sign: false });
data.append(FP16x16 { mag: 6820, sign: false });
data.append(FP16x16 { mag: 24731, sign: false });
data.append(FP16x16 { mag: 13216, sign: false });
data.append(FP16x16 { mag: 18525, sign: false });
data.append(FP16x16 { mag: 29894, sign: false });
data.append(FP16x16 { mag: 23068, sign: false });
data.append(FP16x16 { mag: 19887, sign: false });
data.append(FP16x16 { mag: 24716, sign: false });
data.append(FP16x16 { mag: 15540, sign: false });
data.append(FP16x16 { mag: 26749, sign: false });
data.append(FP16x16 { mag: 21522, sign: false });
data.append(FP16x16 { mag: 26057, sign: false });
data.append(FP16x16 { mag: 33933, sign: false });
data.append(FP16x16 { mag: 15242, sign: false });
data.append(FP16x16 { mag: 9159, sign: false });
data.append(FP16x16 { mag: 45254, sign: false });
data.append(FP16x16 { mag: 6759, sign: false });
data.append(FP16x16 { mag: 18070, sign: false });
data.append(FP16x16 { mag: 17854, sign: false });
data.append(FP16x16 { mag: 23976, sign: false });
data.append(FP16x16 { mag: 40173, sign: false });
data.append(FP16x16 { mag: 29894, sign: false });
data.append(FP16x16 { mag: 22398, sign: false });
data.append(FP16x16 { mag: 22641, sign: false });
data.append(FP16x16 { mag: 13569, sign: false });
data.append(FP16x16 { mag: 26883, sign: false });
data.append(FP16x16 { mag: 18509, sign: false });
data.append(FP16x16 { mag: 15532, sign: false });
data.append(FP16x16 { mag: 18838, sign: false });
data.append(FP16x16 { mag: 12034, sign: false });
data.append(FP16x16 { mag: 21912, sign: false });
data.append(FP16x16 { mag: 7959, sign: false });
data.append(FP16x16 { mag: 13294, sign: false });
data.append(FP16x16 { mag: 13899, sign: false });
data.append(FP16x16 { mag: 49831, sign: false });
data.append(FP16x16 { mag: 4778, sign: false });
data.append(FP16x16 { mag: 47456, sign: false });
data.append(FP16x16 { mag: 18494, sign: false });
data.append(FP16x16 { mag: 16470, sign: false });
data.append(FP16x16 { mag: 26136, sign: false });
TensorTrait::new(shape.span(), data.span())
}
20 changes: 20 additions & 0 deletions tests/nodes/softmax_axis_1.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
mod input_0;
mod output_0;


use orion::utils::{assert_eq, assert_seq_eq};
use orion::operators::nn::FP16x16NN;
use orion::operators::nn::NNTrait;
use orion::numbers::FixedTrait;
use orion::operators::tensor::FP16x16TensorPartialEq;

#[test]
#[available_gas(2000000000)]
fn test_softmax_axis_1() {
let input_0 = input_0::input_0();
let z_0 = output_0::output_0();

let y_0 = NNTrait::softmax(@input_0, Option::Some(1));

assert_eq(y_0, z_0);
}
74 changes: 74 additions & 0 deletions tests/nodes/softmax_axis_1/input_0.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use core::array::{ArrayTrait, SpanTrait};
use orion::operators::tensor::{TensorTrait, Tensor};
use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd};
use orion::numbers::{FixedTrait, FP16x16};

fn input_0() -> Tensor<FP16x16> {
let mut shape = ArrayTrait::<usize>::new();
shape.append(3);
shape.append(4);
shape.append(5);

let mut data = ArrayTrait::new();
data.append(FP16x16 { mag: 55504, sign: false });
data.append(FP16x16 { mag: 131012, sign: false });
data.append(FP16x16 { mag: 66466, sign: false });
data.append(FP16x16 { mag: 15137, sign: false });
data.append(FP16x16 { mag: 184134, sign: false });
data.append(FP16x16 { mag: 45919, sign: false });
data.append(FP16x16 { mag: 61072, sign: false });
data.append(FP16x16 { mag: 18808, sign: false });
data.append(FP16x16 { mag: 10438, sign: false });
data.append(FP16x16 { mag: 28335, sign: false });
data.append(FP16x16 { mag: 19320, sign: false });
data.append(FP16x16 { mag: 18945, sign: false });
data.append(FP16x16 { mag: 51241, sign: false });
data.append(FP16x16 { mag: 29903, sign: false });
data.append(FP16x16 { mag: 9030, sign: false });
data.append(FP16x16 { mag: 112806, sign: false });
data.append(FP16x16 { mag: 28939, sign: false });
data.append(FP16x16 { mag: 112572, sign: false });
data.append(FP16x16 { mag: 89990, sign: false });
data.append(FP16x16 { mag: 87594, sign: false });
data.append(FP16x16 { mag: 56996, sign: false });
data.append(FP16x16 { mag: 31238, sign: false });
data.append(FP16x16 { mag: 66896, sign: false });
data.append(FP16x16 { mag: 37962, sign: false });
data.append(FP16x16 { mag: 26194, sign: false });
data.append(FP16x16 { mag: 59208, sign: false });
data.append(FP16x16 { mag: 6005, sign: false });
data.append(FP16x16 { mag: 16581, sign: false });
data.append(FP16x16 { mag: 27378, sign: false });
data.append(FP16x16 { mag: 59336, sign: false });
data.append(FP16x16 { mag: 11513, sign: false });
data.append(FP16x16 { mag: 12294, sign: false });
data.append(FP16x16 { mag: 4336, sign: false });
data.append(FP16x16 { mag: 111725, sign: false });
data.append(FP16x16 { mag: 45307, sign: false });
data.append(FP16x16 { mag: 145057, sign: false });
data.append(FP16x16 { mag: 44365, sign: false });
data.append(FP16x16 { mag: 80274, sign: false });
data.append(FP16x16 { mag: 50643, sign: false });
data.append(FP16x16 { mag: 39432, sign: false });
data.append(FP16x16 { mag: 53176, sign: false });
data.append(FP16x16 { mag: 202691, sign: false });
data.append(FP16x16 { mag: 54389, sign: false });
data.append(FP16x16 { mag: 125453, sign: false });
data.append(FP16x16 { mag: 101533, sign: false });
data.append(FP16x16 { mag: 2658, sign: false });
data.append(FP16x16 { mag: 31411, sign: false });
data.append(FP16x16 { mag: 44406, sign: false });
data.append(FP16x16 { mag: 82774, sign: false });
data.append(FP16x16 { mag: 36316, sign: false });
data.append(FP16x16 { mag: 37737, sign: false });
data.append(FP16x16 { mag: 5076, sign: false });
data.append(FP16x16 { mag: 48499, sign: false });
data.append(FP16x16 { mag: 3099, sign: false });
data.append(FP16x16 { mag: 168018, sign: false });
data.append(FP16x16 { mag: 18863, sign: false });
data.append(FP16x16 { mag: 16555, sign: false });
data.append(FP16x16 { mag: 4096, sign: false });
data.append(FP16x16 { mag: 227, sign: false });
data.append(FP16x16 { mag: 35060, sign: false });
TensorTrait::new(shape.span(), data.span())
}
Loading

0 comments on commit 5b17608

Please sign in to comment.