diff --git a/docgen/src/main.rs b/docgen/src/main.rs index 8d1f90f4b..fb7dba8f1 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -59,6 +59,14 @@ fn main() { doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + // TREE ENSEMBLE DOC + let trait_path = "src/operators/ml/tree_ensemble/tree_ensemble.cairo"; + let doc_path = "docs/framework/operators/machine-learning/tree-ensemble"; + let label = "tree_ensemble"; + let trait_name: &str = "TreeEnsembleTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); + // LINEAR REGRESSOR DOC let trait_path = "src/operators/ml/linear/linear_regressor.cairo"; let doc_path = "docs/framework/operators/machine-learning/linear-regressor"; diff --git a/docs/framework/operators/machine-learning/tree-ensemble/README.md b/docs/framework/operators/machine-learning/tree-ensemble/README.md new file mode 100644 index 000000000..26fcfb205 --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-ensemble/README.md @@ -0,0 +1,22 @@ +# Tree Ensemble + +`TreeEnsembleTrait` provides a trait definition for tree ensemble problem. + +```rust +use orion::operators::ml::TreeEnsembleTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `TreeEnsembleTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `TreeEnsembleTrait` | + + +*** + +| function | description | +| --- | --- | +| [`tree_ensemble.predict`](tree_ensemble.predict.md) | Returns the regressed values for each input in a batch. | \ No newline at end of file diff --git a/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md new file mode 100644 index 000000000..a7f97e96d --- /dev/null +++ b/docs/framework/operators/machine-learning/tree-ensemble/tree_ensemble.predict.md @@ -0,0 +1,139 @@ +# TreeEnsemble::predict + +```rust + fn predict(X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix::; +``` + +Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute. + +## Args + +* `X`: Input 2D tensor. +* `nodes_splits`: Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'. +* `nodes_featureids`: Feature id for each node. +* `nodes_modes`: The comparison operation performed by the node. This is encoded as an enumeration of 'NODE_MODE::LEQ', 'NODE_MODE::LT', 'NODE_MODE::GTE', 'NODE_MODE::GT', 'NODE_MODE::EQ', 'NODE_MODE::NEQ', and 'NODE_MODE::MEMBER' +* `nodes_truenodeids`: If `nodes_trueleafs` is 0 (false) at an entry, this represents the position of the true branch node. +* `nodes_falsenodeids`: If `nodes_falseleafs` is 0 (false) at an entry, this represents the position of the false branch node. +* `nodes_trueleafs`: 1 if true branch is leaf for each node and 0 an interior node. +* `nodes_falseleafs`: 1 if true branch is leaf for each node and 0 an interior node. +* `leaf_targetids`: The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`). +* `leaf_weights`: The weight for each leaf. +* `tree_roots`: Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each node. +* `post_transform`: Indicates the transform to apply to the score.One of 'POST_TRANSFORM::NONE', 'POST_TRANSFORM::SOFTMAX', 'POST_TRANSFORM::LOGISTIC', 'POST_TRANSFORM::SOFTMAX_ZERO' or 'POST_TRANSFORM::PROBIT' , +* `aggregate_function`: Defines how to aggregate leaf values within a target. One of 'AGGREGATE_FUNCTION::AVERAGE', 'AGGREGATE_FUNCTION::SUM', 'AGGREGATE_FUNCTION::MIN', 'AGGREGATE_FUNCTION::MAX` defaults to 'AGGREGATE_FUNCTION::SUM' +* `nodes_hitrates`: Popularity of each node, used for performance and may be omitted. +* `nodes_missing_value_tracks_true`: For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and the default value is false (0) for all nodes. +* `membership_values`: Members to test membership of for each set membership node. List all of the members to test again in the order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes. +* `n_targets`: The total number of targets. + + +## Returns + +* Output of shape [Batch Size, Number of targets] + +## Type Constraints + +`TreeEnsembleClassifier` and `X` must be fixed points + +## Examples + +```rust +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::operators::ml::{TreeEnsembleTrait,POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE}; +use orion::operators::matrix::{MutMatrix, MutMatrixImpl}; +use orion::numbers::NumberTrait; + +fn example_tree_ensemble_one_tree() -> MutMatrix:: { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(FP16x16 { mag: 108789, sign: false }); + data.append(FP16x16 { mag: 271319, sign: false }); + data.append(FP16x16 { mag: 115998, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 342753, sign: false }); + data.append(FP16x16 { mag: 794296, sign: false }); + data.append(FP16x16 { mag: 801505, sign: true }); + data.append(FP16x16 { mag: 472514, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 205783, sign: false }); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 275251, sign: false }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let membership_values = Option::None; + + let n_targets = 2; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![MODE::LEQ, MODE::LEQ, MODE::LEQ].span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![0, 1, 1].span(); + let leaf_targetids: Span = array![0, 1, 0, 1].span(); + + return TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); +} + +>>> [[ 5.23 0. ] + [ 5.23 0. ] + [ 0. 12.12]] +``` diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index fe2995096..96fa266e6 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -132,6 +132,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.dynamic_quantize_linear`](tensor.dynamic\_quantize\_linear.md) | Computes the Scale, Zero Point and FP32->8Bit conversion of FP32 Input data. | | [`tensor.scatter_nd`](tensor.scatter\_nd.md) | The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data | | [`tensor.label_encoder`](tensor.label\_encoder.md) | Maps each element in the input tensor to another value. | +| [`tensor.mean`](tensor.mean.md) | Element-wise mean of each of the input tensors. | ## Arithmetic Operations diff --git a/docs/framework/operators/tensor/tensor.mean.md b/docs/framework/operators/tensor/tensor.mean.md new file mode 100644 index 000000000..e5fa8e940 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.mean.md @@ -0,0 +1,52 @@ +# tensor.mean + +```rust + fn mean(args: Span>) -> Tensor; +``` + +Element-wise mean of each of the input tensors. + + +* `args`(`Span>`) - List of tensors for mean. + +## Returns + +Output tensor. + +## Examples + +```rust +use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{FixedTrait, FP8x23}; + + +fn example() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 16777216, sign: false }); + let tensor1 = TensorTrait::new(shape.span(), data.span()); + + let mut shape2 = ArrayTrait::::new(); + shape2.append(2); + shape2.append(2); + + let mut data2 = ArrayTrait::new(); + data2.append(FP8x23 { mag: 8388608, sign: false }); + data2.append(FP8x23 { mag: 0, sign: false }); + data2.append(FP8x23 { mag: 0, sign: false }); + data2.append(FP8x23 { mag: 8388608, sign: false }); + let tensor2 = TensorTrait::new(shape2.span(), data2.span()); + return TensorTrait::mean(array![tensor1, tensor2].span()); +} +>>> [FP8x23 { mag: 4194304, sign: false }, FP8x23 { mag: 8388608, sign: true }, FP8x23 { mag: 8388608, sign: false }, FP8x23 { mag: 12582912, sign: true }] +``` diff --git a/nodegen/node/mean.py b/nodegen/node/mean.py new file mode 100644 index 000000000..ad001b7ae --- /dev/null +++ b/nodegen/node/mean.py @@ -0,0 +1,94 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait, get_data_statement + +def mean(*args) -> np.ndarray: # type: ignore + res = args[0].copy() + for m in args[1:]: + res += m + return (res / len(args)).astype(args[0].dtype) + +class Mean(RunAll): + + @staticmethod + # We test here with fp8x23 implementation. + def fp8x23(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + y = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + z = mean(x, y) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + z = Tensor(Dtype.FP8x23, z.shape, to_fp( + z.flatten(), FixedImpl.FP8x23)) + + name = "mean_fp8x23" + make_test([x, y], z, "TensorTrait::mean(array![input_0, input_1].span())", name) + + @staticmethod + # We test here with fp16x16 implementation. + def fp16x16(): + x = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + y = np.random.randint(-3, 3, (2, 2)).astype(np.float64) + z = mean(x, y) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + z = Tensor(Dtype.FP16x16, z.shape, to_fp( + z.flatten(), FixedImpl.FP16x16)) + + name = "mean_fp16x16" + make_test([x, y], z, "TensorTrait::mean(array![input_0, input_1].span())", name) + + @staticmethod + # We test here with i8 implementation. + def i8(): + x = np.random.randint(0, 6, (2, 2)).astype(np.int8) + y = np.random.randint(0, 6, (2, 2)).astype(np.int8) + z = np.random.randint(0, 6, (2, 2)).astype(np.int8) + m = mean(x, y, z) + + x = Tensor(Dtype.I8, x.shape, x.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + z = Tensor(Dtype.I8, z.shape, z.flatten()) + m = Tensor(Dtype.I8, m.shape, m.flatten()) + + name = "mean_i8" + make_test([x, y, z], m, "TensorTrait::mean(array![input_0, input_1, input_2].span())", name) + + @staticmethod + # We test here with i32 implementation. + def i32(): + x = np.random.randint(0, 6, (2, 2)).astype(np.int32) + y = np.random.randint(0, 6, (2, 2)).astype(np.int32) + z = np.random.randint(0, 6, (2, 2)).astype(np.int32) + m = mean(x, y, z) + + x = Tensor(Dtype.I32, x.shape, x.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + z = Tensor(Dtype.I32, z.shape, z.flatten()) + m = Tensor(Dtype.I32, m.shape, m.flatten()) + + name = "mean_i32" + make_test([x, y, z], m, "TensorTrait::mean(array![input_0, input_1, input_2].span())", name) + + @staticmethod + # We test here with u32 implementation. + def u32(): + x = np.random.randint(0, 6, (2, 2)).astype(np.uint32) + y = np.random.randint(0, 6, (2, 2)).astype(np.uint32) + z = np.random.randint(0, 6, (2, 2)).astype(np.uint32) + m = mean(x, y, z) + + x = Tensor(Dtype.U32, x.shape, x.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + z = Tensor(Dtype.U32, z.shape, z.flatten()) + m = Tensor(Dtype.U32, m.shape, m.flatten()) + + name = "mean_u32" + make_test([x, y, z], m, "TensorTrait::mean(array![input_0, input_1, input_2].span())", name) + \ No newline at end of file diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 08e9e40fb..93a0394a6 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -3,6 +3,8 @@ mod linear; mod svm; mod normalizer; +use orion::operators::ml::tree_ensemble::tree_ensemble::{TreeEnsembleTrait}; + use orion::operators::ml::tree_ensemble::core::{ TreeEnsemble, TreeEnsembleAttributes, TreeEnsembleImpl, NODE_MODES }; @@ -32,3 +34,4 @@ enum POST_TRANSFORM { SOFTMAXZERO, PROBIT, } + diff --git a/src/operators/ml/tree_ensemble.cairo b/src/operators/ml/tree_ensemble.cairo index 32c96c0bd..925c1ea7e 100644 --- a/src/operators/ml/tree_ensemble.cairo +++ b/src/operators/ml/tree_ensemble.cairo @@ -1,3 +1,4 @@ mod core; mod tree_ensemble_classifier; mod tree_ensemble_regressor; +mod tree_ensemble; diff --git a/src/operators/ml/tree_ensemble/tree_ensemble.cairo b/src/operators/ml/tree_ensemble/tree_ensemble.cairo new file mode 100644 index 000000000..51e3f9ec2 --- /dev/null +++ b/src/operators/ml/tree_ensemble/tree_ensemble.cairo @@ -0,0 +1,602 @@ +use orion::operators::tensor::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; + +use orion::operators::matrix::{MutMatrix, MutMatrixImpl, MutMatrixTrait}; + +#[derive(Copy, Drop)] +enum AGGREGATE_FUNCTION { + AVERAGE, + SUM, + MIN, + MAX, +} + +#[derive(Copy, Drop)] +enum POST_TRANSFORM { + NONE, + SOFTMAX, + LOGISTIC, + SOFTMAX_ZERO, + PROBIT, +} + +#[derive(Copy, Drop)] +enum NODE_MODE { + LEQ, + LT, + GTE, + GT, + EQ, + NEQ, + MEMBER, +} + +/// Trait +/// +/// predict - Returns the regressed values for each input in a batch. +trait TreeEnsembleTrait { + /// # TreeEnsemble::predict + /// + /// ```rust + /// fn predict(X: @Tensor, + /// nodes_splits: Tensor, + /// nodes_featureids: Span, + /// nodes_modes: Span, + /// nodes_truenodeids: Span, + /// nodes_falsenodeids: Span, + /// nodes_trueleafs: Span, + /// nodes_falseleafs: Span, + /// leaf_targetids: Span, + /// leaf_weights: Tensor, + /// tree_roots: Span, + /// post_transform: POST_TRANSFORM, + /// aggregate_function: AGGREGATE_FUNCTION, + /// nodes_hitrates: Option>, + /// nodes_missing_value_tracks_true: Option>, + /// membership_values: Option>, + /// n_targets: usize + /// ) -> MutMatrix::; + /// ``` + /// + /// Tree Ensemble operator. Returns the regressed values for each input in a batch. Inputs have dimensions [N, F] where N is the input batch size and F is the number of input features. Outputs have dimensions [N, num_targets] where N is the batch size and num_targets is the number of targets, which is a configurable attribute. + /// + /// ## Args + /// + /// * `X`: Input 2D tensor. + /// * `nodes_splits`: Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'. + /// * `nodes_featureids`: Feature id for each node. + /// * `nodes_modes`: The comparison operation performed by the node. This is encoded as an enumeration of 'NODE_MODE::LEQ', 'NODE_MODE::LT', 'NODE_MODE::GTE', 'NODE_MODE::GT', 'NODE_MODE::EQ', 'NODE_MODE::NEQ', and 'NODE_MODE::MEMBER' + /// * `nodes_truenodeids`: If `nodes_trueleafs` is 0 (false) at an entry, this represents the position of the true branch node. + /// * `nodes_falsenodeids`: If `nodes_falseleafs` is 0 (false) at an entry, this represents the position of the false branch node. + /// * `nodes_trueleafs`: 1 if true branch is leaf for each node and 0 an interior node. + /// * `nodes_falseleafs`: 1 if true branch is leaf for each node and 0 an interior node. + /// * `leaf_targetids`: The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`). + /// * `leaf_weights`: The weight for each leaf. + /// * `tree_roots`: Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each node. + /// * `post_transform`: Indicates the transform to apply to the score.One of 'POST_TRANSFORM::NONE', 'POST_TRANSFORM::SOFTMAX', 'POST_TRANSFORM::LOGISTIC', 'POST_TRANSFORM::SOFTMAX_ZERO' or 'POST_TRANSFORM::PROBIT' , + /// * `aggregate_function`: Defines how to aggregate leaf values within a target. One of 'AGGREGATE_FUNCTION::AVERAGE', 'AGGREGATE_FUNCTION::SUM', 'AGGREGATE_FUNCTION::MIN', 'AGGREGATE_FUNCTION::MAX` defaults to 'AGGREGATE_FUNCTION::SUM' + /// * `nodes_hitrates`: Popularity of each node, used for performance and may be omitted. + /// * `nodes_missing_value_tracks_true`: For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and the default value is false (0) for all nodes. + /// * `membership_values`: Members to test membership of for each set membership node. List all of the members to test again in the order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same number of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any 'BRANCH_MEMBER' nodes. + /// * `n_targets`: The total number of targets. + + /// + /// ## Returns + /// + /// * Output of shape [Batch Size, Number of targets] + /// + /// ## Type Constraints + /// + /// `T` must be fixed point + /// + /// ## Examples + /// + /// ```rust + /// use orion::numbers::FP16x16; + /// use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + /// use orion::operators::ml::{TreeEnsembleTrait,POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE}; + /// use orion::operators::matrix::{MutMatrix, MutMatrixImpl}; + /// use orion::numbers::NumberTrait; + /// + /// fn example_tree_ensemble_one_tree() -> MutMatrix:: { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// shape.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 78643, sign: false }); + /// data.append(FP16x16 { mag: 222822, sign: false }); + /// data.append(FP16x16 { mag: 7864, sign: true }); + /// data.append(FP16x16 { mag: 108789, sign: false }); + /// data.append(FP16x16 { mag: 271319, sign: false }); + /// data.append(FP16x16 { mag: 115998, sign: false }); + /// let mut X = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(4); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 342753, sign: false }); + /// data.append(FP16x16 { mag: 794296, sign: false }); + /// data.append(FP16x16 { mag: 801505, sign: true }); + /// data.append(FP16x16 { mag: 472514, sign: false }); + /// let leaf_weights = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 205783, sign: false }); + /// data.append(FP16x16 { mag: 78643, sign: false }); + /// data.append(FP16x16 { mag: 275251, sign: false }); + /// let nodes_splits = TensorTrait::new(shape.span(), data.span()); + /// + /// let membership_values = Option::None; + /// + /// let n_targets = 2; + /// let aggregate_function = AGGREGATE_FUNCTION::SUM; + /// let nodes_missing_value_tracks_true = Option::None; + /// let nodes_hitrates = Option::None; + /// let post_transform = POST_TRANSFORM::NONE; + /// + /// let tree_roots: Span = array![0].span(); + /// let nodes_modes: Span = array![MODE::LEQ, MODE::LEQ, MODE::LEQ].span(); + /// + /// let nodes_featureids: Span = array![0, 0, 0].span(); + /// let nodes_truenodeids: Span = array![1, 0, 1].span(); + /// let nodes_trueleafs: Span = array![0, 1, 1].span(); + /// let nodes_falsenodeids: Span = array![2, 2, 3].span(); + /// let nodes_falseleafs: Span = array![0, 1, 1].span(); + /// let leaf_targetids: Span = array![0, 1, 0, 1].span(); + /// + /// return TreeEnsembleTrait::predict( + /// @X, + /// nodes_splits, + /// nodes_featureids, + /// nodes_modes, + /// nodes_truenodeids, + /// nodes_falsenodeids, + /// nodes_trueleafs, + /// nodes_falseleafs, + /// leaf_targetids, + /// leaf_weights, + /// tree_roots, + /// post_transform, + /// aggregate_function, + /// nodes_hitrates, + /// nodes_missing_value_tracks_true, + /// membership_values, + /// n_targets + /// ); + /// } + /// + /// >>> [[ 5.23 0. ] + /// [ 5.23 0. ] + /// [ 0. 12.12]] + /// ``` + /// + fn predict( + X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix::; +} + + +impl TreeEnsembleImpl< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, + +Add, + +Div, + +Mul, + +Into, + +AddEq, +> of TreeEnsembleTrait { + fn predict( + X: @Tensor, + nodes_splits: Tensor, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + tree_roots: Span, + post_transform: POST_TRANSFORM, + aggregate_function: AGGREGATE_FUNCTION, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Option>, + membership_values: Option>, + n_targets: usize + ) -> MutMatrix:: { + let batch_size = *(*X).shape.at(0); + let n_features = *(*X).shape.at(1); + let n_trees = tree_roots.len(); + + let mut set_membership_iter = array![].span(); + let mut map_member_to_nodeid = Default::default(); + + let mut res: MutMatrix = MutMatrixImpl::new(batch_size, n_targets); + + let (nodes_missing_value_tracks_true, nodes_missing_value_tracks_true_flag) = + match nodes_missing_value_tracks_true { + Option::Some(nodes_missing_value_tracks_true) => { + (nodes_missing_value_tracks_true, true) + }, + Option::None => { (array![].span(), false) } + }; + + match membership_values { + Option::Some(membership_values) => { + set_membership_iter = membership_values.data.clone(); + + let mut tree_roots_iter = tree_roots.clone(); + loop { + match tree_roots_iter.pop_front() { + Option::Some(root_index) => { + let root_index = *root_index; + let is_leaf = (*nodes_trueleafs.at(root_index) == 1 + && *nodes_falseleafs.at(root_index) == 1 + && *nodes_truenodeids + .at(root_index) == *nodes_falsenodeids + .at(root_index)); + map_members_to_nodeids( + root_index, + is_leaf, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); + }, + Option::None => { break; } + } + }; + }, + Option::None => {} + }; + + match aggregate_function { + AGGREGATE_FUNCTION::AVERAGE => { res.set(batch_size, n_targets, NumberTrait::zero()); }, + AGGREGATE_FUNCTION::SUM => { res.set(batch_size, n_targets, NumberTrait::zero()); }, + AGGREGATE_FUNCTION::MIN => { + let mut i = 0; + while i != batch_size { + let mut j = 0; + while j != n_targets { + res.set(i, j, NumberTrait::min_value()); + j += 1; + }; + i += 1; + }; + }, + AGGREGATE_FUNCTION::MAX => { + let mut i = 0; + while i != batch_size { + let mut j = 0; + while j != n_targets { + res.set(i, j, NumberTrait::max_value()); + j += 1; + }; + i += 1; + }; + }, + } + + let mut weights = ArrayTrait::new(); + let mut target_ids = ArrayTrait::new(); + + let mut tree_roots_iter = tree_roots.clone(); + loop { + match tree_roots_iter.pop_front() { + Option::Some(root_index) => { + let root_index = *root_index; + let is_leaf = (*nodes_trueleafs.at(root_index) == 1 + && *nodes_falseleafs.at(root_index) == 1 + && *nodes_truenodeids.at(root_index) == *nodes_falsenodeids.at(root_index)); + + let mut batch_num = 0; + while batch_num != batch_size { + let x_batch = SpanTrait::slice( + (*X).data, batch_num * n_features, n_features + ); + + let (weight, target) = iterate_node( + x_batch, + root_index, + is_leaf, + nodes_splits.data, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + nodes_hitrates, + nodes_missing_value_tracks_true, + nodes_missing_value_tracks_true_flag, + ref map_member_to_nodeid, + ); + weights.append(weight); + target_ids.append(target); + batch_num += 1; + }; + }, + Option::None => { break; } + } + }; + + let weights = weights.span(); + let target_ids = target_ids.span(); + + let mut batch_num = 0; + while batch_num != batch_size { + match aggregate_function { + AGGREGATE_FUNCTION::AVERAGE => { + let mut i = 0; + while i != n_trees { + let index = i * batch_size + batch_num; + res + .set( + batch_num, + *target_ids.at(index), + res.at(batch_num, *target_ids.at(index)) + + *weights.at(index) + / NumberTrait::new_unscaled(n_trees.into(), false) + ); + i += 1; + }; + }, + AGGREGATE_FUNCTION::SUM => { + let mut i = 0; + while i != n_trees { + let index = i * batch_size + batch_num; + res + .set( + batch_num, + *target_ids.at(index), + res.at(batch_num, *target_ids.at(index)) + *weights.at(index) + ); + i += 1; + }; + }, + AGGREGATE_FUNCTION::MIN => { + let mut i = 0; + while i != n_targets { + let val = NumberTrait::min( + res.at(batch_num, *target_ids.at(batch_num)), *weights.at(batch_num) + ); + res.set(batch_num, *target_ids.at(batch_num), val); + i += 1; + }; + }, + AGGREGATE_FUNCTION::MAX => { + let mut i = 0; + while i != n_targets { + let val = NumberTrait::max( + res.at(batch_num, *target_ids.at(batch_num)), *weights.at(batch_num) + ); + res.set(batch_num, *target_ids.at(batch_num), val); + i += 1; + }; + } + } + + batch_num += 1; + }; + + // Post Transform + let mut res = match post_transform { + POST_TRANSFORM::NONE => res, + POST_TRANSFORM::SOFTMAX => res.softmax(1), + POST_TRANSFORM::LOGISTIC => res.sigmoid(), + POST_TRANSFORM::SOFTMAX_ZERO => res.softmax_zero(1), + POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), + }; + + return res; + } +} +fn iterate_node< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, +>( + X: Span, + current_node_index: usize, + is_leaf: bool, + nodes_splits: Span, + nodes_featureids: Span, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + leaf_targetids: Span, + leaf_weights: Tensor, + nodes_hitrates: Option>, + nodes_missing_value_tracks_true: Span, + nodes_missing_value_tracks_true_flag: bool, + ref map_member_to_nodeid: Felt252Dict>>, +) -> (T, usize) { + let mut current_node_index = current_node_index; + let mut is_leaf = is_leaf; + + while !is_leaf { + let nmvtt_flag = if nodes_missing_value_tracks_true_flag { + *nodes_missing_value_tracks_true.at(current_node_index) == 1 + } else { + nodes_missing_value_tracks_true_flag + }; + if compare( + *X.at(*nodes_featureids.at(current_node_index)), + current_node_index, + *nodes_splits.at(current_node_index), + *nodes_modes.at(current_node_index), + ref map_member_to_nodeid, + nmvtt_flag + ) { + is_leaf = *nodes_trueleafs.at(current_node_index) == 1; + current_node_index = *nodes_truenodeids.at(current_node_index); + } else { + is_leaf = *nodes_falseleafs.at(current_node_index) == 1; + current_node_index = *nodes_falsenodeids.at(current_node_index); + }; + }; + + return (*leaf_weights.data.at(current_node_index), *leaf_targetids.at(current_node_index)); +} + +fn map_members_to_nodeids< + T, + MAG, + +TensorTrait, + +NumberTrait, + +Copy, + +Drop, + +PartialOrd, + +PartialEq, +>( + current_node_index: usize, + is_leaf: bool, + nodes_modes: Span, + nodes_truenodeids: Span, + nodes_falsenodeids: Span, + nodes_trueleafs: Span, + nodes_falseleafs: Span, + ref set_membership_iter: Span, + ref map_member_to_nodeid: Felt252Dict>>, +) { + let mut current_node_index = current_node_index; + let mut is_leaf = is_leaf; + + if is_leaf { + return; + } + + match *nodes_modes.at(current_node_index) { + NODE_MODE::LEQ => {}, + NODE_MODE::LT => {}, + NODE_MODE::GTE => {}, + NODE_MODE::GT => {}, + NODE_MODE::EQ => {}, + NODE_MODE::NEQ => {}, + NODE_MODE::MEMBER => { + let mut subset_members = ArrayTrait::new(); + loop { + match set_membership_iter.pop_front() { + Option::Some(v) => { + if *v == NumberTrait::NaN() { + break; + } + subset_members.append(*v) + }, + Option::None => { break; } + } + }; + map_member_to_nodeid + .insert(current_node_index.into(), NullableTrait::new(subset_members.span())); + }, + } + // true branch + map_members_to_nodeids( + *nodes_truenodeids.at(current_node_index), + *nodes_trueleafs.at(current_node_index) == 1, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); + + // false branch + map_members_to_nodeids( + *nodes_falsenodeids.at(current_node_index), + *nodes_falseleafs.at(current_node_index) == 1, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + ref set_membership_iter, + ref map_member_to_nodeid, + ); +} + + +fn compare< + T, MAG, +TensorTrait, +NumberTrait, +Copy, +Drop, +PartialOrd, +PartialEq +>( + x_feat: T, + current_node_index: usize, + value: T, + mode: NODE_MODE, + ref map_members_to_nodeids: Felt252Dict>>, + nodes_missing_value_tracks_true_flag: bool, +) -> bool { + match mode { + NODE_MODE::LEQ => { + (x_feat <= value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::LT => { + (x_feat < value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::GTE => { + (x_feat >= value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::GT => { + (x_feat > value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::EQ => { + (x_feat == value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::NEQ => { + (x_feat != value && !x_feat.is_nan()) || nodes_missing_value_tracks_true_flag + }, + NODE_MODE::MEMBER => { + let mut set_members = map_members_to_nodeids.get(current_node_index.into()).deref(); + loop { + match set_members.pop_front() { + Option::Some(v) => { if x_feat == *v { + break true; + } }, + Option::None => { break false; } + } + } + }, + } +} diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 02f9cc6e4..bccd350f0 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -131,6 +131,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde8Bit conversion of FP32 Input data. /// scatter_nd - The output of the operation is produced by creating a copy of the input data, and then updating its value to values specified by updates at specific index positions specified by indices. Its output shape is the same as the shape of data /// label_encoder - Maps each element in the input tensor to another value. +/// mean - Element-wise mean of each of the input tensors. trait TensorTrait { /// # tensor.new /// @@ -5858,6 +5859,60 @@ trait TensorTrait { values: Option>, values_tensor: Option> ) -> Tensor; + /// # tensor.mean + /// + /// ```rust + /// fn mean(args: Span>) -> Tensor; + /// ``` + /// + /// Element-wise mean of each of the input tensors. + /// + /// + /// * `args`(`Span>`) - List of tensors for mean. + /// + /// ## Returns + /// + /// Output tensor. + /// + /// ## Examples + /// + /// ```rust + /// use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; + /// use core::array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::{TensorTrait, Tensor}; + /// use orion::utils::{assert_eq, assert_seq_eq}; + /// use orion::operators::tensor::FP8x23TensorPartialEq; + /// use orion::numbers::{FixedTrait, FP8x23}; + /// + /// + /// fn example() -> Tensor { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(2); + /// shape.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP8x23 { mag: 16777216, sign: true }); + /// data.append(FP8x23 { mag: 16777216, sign: false }); + /// data.append(FP8x23 { mag: 16777216, sign: true }); + /// data.append(FP8x23 { mag: 16777216, sign: false }); + /// let tensor1 = TensorTrait::new(shape.span(), data.span()); + /// + /// let mut shape2 = ArrayTrait::::new(); + /// shape2.append(2); + /// shape2.append(2); + /// + /// let mut data2 = ArrayTrait::new(); + /// data2.append(FP8x23 { mag: 8388608, sign: false }); + /// data2.append(FP8x23 { mag: 0, sign: false }); + /// data2.append(FP8x23 { mag: 0, sign: false }); + /// data2.append(FP8x23 { mag: 8388608, sign: false }); + /// let tensor2 = TensorTrait::new(shape2.span(), data2.span()); + /// return TensorTrait::mean(array![tensor1, tensor2].span()); + /// } + /// >>> [FP8x23 { mag: 4194304, sign: false }, FP8x23 { mag: 8388608, sign: true }, FP8x23 { mag: 8388608, sign: false }, FP8x23 { mag: 12582912, sign: true }] + /// ``` + /// + fn mean(args: Span>) -> Tensor; } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index e8ca7e2d8..87eae7c14 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -552,6 +552,10 @@ impl BoolTensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn mean(args: Span>) -> Tensor { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 8acb0891e..4411a7306 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -592,6 +592,10 @@ impl Complex64Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } + + fn mean(args: Span>) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 27f853df5..4de8936c9 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -644,6 +644,10 @@ impl FP16x16Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 61485bae6..17307a3f3 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -604,6 +604,10 @@ impl FP16x16WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 6ea3c7d94..c721c35ca 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -640,6 +640,10 @@ impl FP32x32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index af955fff1..c3228ccc2 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -640,6 +640,10 @@ impl FP64x64Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 19681e641..900f75092 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -636,6 +636,10 @@ impl FP8x23Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index ef65871d4..57bba428e 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -581,6 +581,10 @@ impl FP8x23WTensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 924a6b1fd..43d950762 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -604,6 +604,10 @@ impl I32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index f523c47b2..2c67bc611 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -607,6 +607,10 @@ impl I8Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 7aa2ade26..353c60a18 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -551,6 +551,10 @@ impl U32Tensor of TensorTrait { self, default_list, default_tensor, keys, keys_tensor, values, values_tensor ) } + + fn mean(args: Span>) -> Tensor { + math::mean::mean(args) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index b73f6d102..03f846e07 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -68,3 +68,4 @@ mod hann_window; mod hamming_window; mod blackman_window; mod scatter_nd; +mod mean; diff --git a/src/operators/tensor/math/gather_elements.cairo b/src/operators/tensor/math/gather_elements.cairo index cc8b9ae20..e4b624e42 100644 --- a/src/operators/tensor/math/gather_elements.cairo +++ b/src/operators/tensor/math/gather_elements.cairo @@ -1,7 +1,9 @@ +use core::option::OptionTrait; +use core::traits::TryInto; use alexandria_data_structures::array_ext::SpanTraitExt; use orion::numbers::NumberTrait; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{{TensorTrait, Tensor}, core::{unravel_index, stride}}; /// Cf: TensorTrait::gather_elements docstring fn gather_elements, impl TCopy: Copy, impl TDrop: Drop,>( @@ -19,71 +21,45 @@ fn gather_elements, impl TCopy: Copy, im }; assert(axis < (*self.shape).len(), 'axis out of dimensions'); - let axis_shape = *(*self.shape).at(axis); - - // Adjust indices that are negative - let mut adjusted_indices = array![]; - let mut indices_data = indices.data.clone(); - loop { - match indices_data.pop_front() { - Option::Some(index) => { - let adjusted_index: usize = if *index < 0 { - let val: u32 = (axis_shape.try_into().unwrap() + *index).try_into().unwrap(); - val - } else { - let val: u32 = (*index).try_into().unwrap(); - val - }; - assert(adjusted_index >= 0 && adjusted_index < axis_shape, 'Index out of bounds'); - adjusted_indices.append(adjusted_index); - }, - Option::None => { break; } - }; - }; + let data_strides = stride(*self.shape); let mut output_data = array![]; - let mut data_shape_clone = (*self.shape).clone(); - let mut multiplier = 1; - let mut looper = 1; - let mut ind = 0; - loop { - match data_shape_clone.pop_front() { - Option::Some(val) => { - if ind >= axis { - multiplier *= *val; - } - if ind > axis { - looper *= *val; - } - ind += 1; - }, - Option::None => { break; } - }; - }; + let mut i: usize = 0; + while i < indices + .data + .len() { + let indice = *indices.data.at(i); + let adjusted_indice: u32 = if indice < 0 { + ((*(*self.shape).at(axis)).try_into().unwrap() + indice).try_into().unwrap() + } else { + indice.try_into().unwrap() + }; - let inner_loop = multiplier / axis_shape; - let mut adjusted_indices_iter = adjusted_indices.clone(); + assert(adjusted_indice < (*(*self.shape).at(axis)), 'Index out of bounds'); - let mut i: usize = 0; - loop { - match adjusted_indices_iter.pop_front() { - Option::Some(indice) => { - let value = if axis == 0 { - indice * inner_loop + (i % inner_loop) - } else if axis == (*self.shape).len() - 1 { - indice + axis_shape * (i / axis_shape) - } else { - indice * looper - + (i % looper) - + (multiplier / axis_shape) * (i / (multiplier / axis_shape)) + let multidim_index = unravel_index(i, indices.shape); + let mut flat_index_for_data = 0; + + let mut j: usize = 0; + while j < multidim_index + .len() { + let dim_index = *multidim_index.at(j); + if j == axis { + flat_index_for_data += adjusted_indice * (*data_strides.at(j)); + } else { + flat_index_for_data += (dim_index * *data_strides.at(j)) + } + j += 1; }; - output_data.append(*self.data[value]); - i += 1; - }, - Option::None => { break; } + assert( + flat_index_for_data < (*self.data).len().try_into().unwrap(), + 'Flat index out of bounds' + ); + + output_data.append(*(*self.data).at(flat_index_for_data)); + i += 1; }; - }; TensorTrait::::new(indices.shape, output_data.span()) } diff --git a/src/operators/tensor/math/mean.cairo b/src/operators/tensor/math/mean.cairo new file mode 100644 index 000000000..f4e56ed53 --- /dev/null +++ b/src/operators/tensor/math/mean.cairo @@ -0,0 +1,44 @@ +use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::NumberTrait; +use orion::operators::tensor::core::{Tensor, TensorTrait}; + + +fn mean< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumber: NumberTrait, + impl TAdd: Add, + impl TSub: Sub, + impl TMul: Mul, + impl TDiv: Div, + impl TTensorAdd: Add>, + impl TPartialOrd: PartialOrd, + impl TAddEq: AddEq, + impl TCopy: Copy, + impl TDrop: Drop, +>( + args: Span> +) -> Tensor { + let len = args.len(); + let mut i: usize = 1; + let mut t = *args.at(0); + let mut len_t: T = NumberTrait::one(); + while i != len { + let v = *args.at(i); + t = t + v; + len_t += NumberTrait::one(); + i += 1; + }; + + let mut arr: Array = array![]; + let count = (t.data).len(); + i = 0; + while i != count { + let v = *(t.data).at(i); + let r = v / len_t; + arr.append(r); + i += 1; + }; + TensorTrait::::new(t.shape, arr.span()) +} diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,3 +5,4 @@ mod nodes; mod ml; mod operators; + diff --git a/tests/ml.cairo b/tests/ml.cairo index 4e3e0781e..b92dbcd83 100644 --- a/tests/ml.cairo +++ b/tests/ml.cairo @@ -5,3 +5,4 @@ mod linear_classifier_test; mod svm_regressor_test; mod svm_classifier_test; mod normalizer_test; +mod tree_ensemble_test; diff --git a/tests/ml/tree_ensemble_test.cairo b/tests/ml/tree_ensemble_test.cairo new file mode 100644 index 000000000..59a5592f6 --- /dev/null +++ b/tests/ml/tree_ensemble_test.cairo @@ -0,0 +1,300 @@ +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::operators::ml::tree_ensemble::tree_ensemble::{ + TreeEnsembleTrait, POST_TRANSFORM, AGGREGATE_FUNCTION, NODE_MODE +}; +use orion::operators::matrix::{MutMatrix, MutMatrixImpl, MutMatrixTrait}; +use orion::numbers::NumberTrait; + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_two_trees() { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 32768, sign: true }); + data.append(FP16x16 { mag: 26214, sign: true }); + data.append(FP16x16 { mag: 19660, sign: true }); + data.append(FP16x16 { mag: 13107, sign: true }); + data.append(FP16x16 { mag: 6553, sign: true }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 6553, sign: false }); + data.append(FP16x16 { mag: 13107, sign: false }); + data.append(FP16x16 { mag: 19660, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 5041, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 18724, sign: false }); + data.append(FP16x16 { mag: 32768, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 17462, sign: false }); + data.append(FP16x16 { mag: 40726, sign: false }); + data.append(FP16x16 { mag: 36652, sign: true }); + data.append(FP16x16 { mag: 47240, sign: true }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let n_targets = 1; + let aggregate_function = AGGREGATE_FUNCTION::AVERAGE; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0, 2].span(); + let nodes_modes: Span = array![ + NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ + ] + .span(); + + let nodes_featureids: Span = array![0, 2, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 3, 4].span(); + let nodes_trueleafs: Span = array![0, 1, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 1, 3, 5].span(); + let nodes_falseleafs: Span = array![1, 1, 0, 1].span(); + let leaf_targetids: Span = array![0, 0, 0, 0, 0, 0].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + Option::None, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(1, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(2, 0) == FP16x16 { mag: 18904, sign: false }, 'scores.at(2, 0)'); +} + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_one_tree() { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(FP16x16 { mag: 108789, sign: false }); + data.append(FP16x16 { mag: 271319, sign: false }); + data.append(FP16x16 { mag: 115998, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 342753, sign: false }); + data.append(FP16x16 { mag: 794296, sign: false }); + data.append(FP16x16 { mag: 801505, sign: true }); + data.append(FP16x16 { mag: 472514, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 205783, sign: false }); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 275251, sign: false }); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let membership_values = Option::None; + + let n_targets = 2; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![NODE_MODE::LEQ, NODE_MODE::LEQ, NODE_MODE::LEQ] + .span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![0, 1, 1].span(); + let leaf_targetids: Span = array![0, 1, 0, 1].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 342753, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(0, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 1)'); + + assert(scores.at(1, 0) == FP16x16 { mag: 342753, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(1, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 1)'); + + assert(scores.at(2, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 0)'); + assert(scores.at(2, 1) == FP16x16 { mag: 794296, sign: false }, 'scores.at(2, 1)'); +} + + +#[test] +#[available_gas(200000000000)] +fn export_tree_ensemble_set_membership() { + let mut shape = ArrayTrait::::new(); + shape.append(6); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 222822, sign: false }); + data.append(FP16x16 { mag: 7864, sign: true }); + data.append(NumberTrait::::NaN()); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + let mut X = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 65536000, sign: false }); + data.append(FP16x16 { mag: 6553600, sign: false }); + let leaf_weights = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 1522663424, sign: false }); + data.append(NumberTrait::::NaN()); + let nodes_splits = TensorTrait::new(shape.span(), data.span()); + + let mut shape = ArrayTrait::::new(); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 242483, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(NumberTrait::::NaN()); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(NumberTrait::::NaN()); + let membership_values = Option::Some(TensorTrait::new(shape.span(), data.span())); + + let n_targets = 4; + let aggregate_function = AGGREGATE_FUNCTION::SUM; + let nodes_missing_value_tracks_true = Option::None; + let nodes_hitrates = Option::None; + let post_transform = POST_TRANSFORM::NONE; + + let tree_roots: Span = array![0].span(); + let nodes_modes: Span = array![NODE_MODE::LEQ, NODE_MODE::MEMBER, NODE_MODE::MEMBER] + .span(); + + let nodes_featureids: Span = array![0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 0, 1].span(); + let nodes_trueleafs: Span = array![0, 1, 1].span(); + let nodes_falsenodeids: Span = array![2, 2, 3].span(); + let nodes_falseleafs: Span = array![1, 0, 1].span(); + let leaf_targetids: Span = array![0, 1, 2, 3].span(); + + let mut scores = TreeEnsembleTrait::predict( + @X, + nodes_splits, + nodes_featureids, + nodes_modes, + nodes_truenodeids, + nodes_falsenodeids, + nodes_trueleafs, + nodes_falseleafs, + leaf_targetids, + leaf_weights, + tree_roots, + post_transform, + aggregate_function, + nodes_hitrates, + nodes_missing_value_tracks_true, + membership_values, + n_targets + ); + + // ASSERT SCOREs + assert(scores.at(0, 0) == FP16x16 { mag: 65536, sign: false }, 'scores.at(0, 0)'); + assert(scores.at(0, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 1)'); + assert(scores.at(0, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 2)'); + assert(scores.at(0, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(0, 3)'); + + assert(scores.at(1, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 0)'); + assert(scores.at(1, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 1)'); + assert(scores.at(1, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(1, 2)'); + assert(scores.at(1, 3) == FP16x16 { mag: 6553600, sign: false }, 'scores.at(1, 3)'); + + assert(scores.at(2, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 0)'); + assert(scores.at(2, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 1)'); + assert(scores.at(2, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(2, 2)'); + assert(scores.at(2, 3) == FP16x16 { mag: 6553600, sign: false }, 'scores.at(2, 3)'); + + assert(scores.at(3, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 0)'); + assert(scores.at(3, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 1)'); + assert(scores.at(3, 2) == FP16x16 { mag: 65536000, sign: false }, 'scores.at(3, 2)'); + assert(scores.at(3, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(3, 3)'); + + assert(scores.at(4, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 0)'); + assert(scores.at(4, 1) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 1)'); + assert(scores.at(4, 2) == FP16x16 { mag: 65536000, sign: false }, 'scores.at(4, 2)'); + assert(scores.at(4, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(4, 3)'); + + assert(scores.at(5, 0) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 0)'); + assert(scores.at(5, 1) == FP16x16 { mag: 655360, sign: false }, 'scores.at(5, 1)'); + assert(scores.at(5, 2) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 2)'); + assert(scores.at(5, 3) == FP16x16 { mag: 0, sign: false }, 'scores.at(5, 3)'); +} + diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 244d8b0c9..d672d4577 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -985,3 +985,8 @@ mod argmax_negative_axis_keepdims; mod argmax_negative_axis_keepdims_select_last_index; mod argmax_no_keepdims; mod argmax_no_keepdims_select_last_index; +mod mean_fp16x16; +mod mean_fp8x23; +mod mean_i32; +mod mean_i8; +mod mean_u32; diff --git a/tests/nodes/mean_fp16x16.cairo b/tests/nodes/mean_fp16x16.cairo new file mode 100644 index 000000000..f7d2f86bf --- /dev/null +++ b/tests/nodes/mean_fp16x16.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_mean_fp16x16() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = TensorTrait::mean(array![input_0, input_1].span()); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/mean_fp16x16/input_0.cairo b/tests/nodes/mean_fp16x16/input_0.cairo new file mode 100644 index 000000000..7349384a7 --- /dev/null +++ b/tests/nodes/mean_fp16x16/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 196608, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_fp16x16/input_1.cairo b/tests/nodes/mean_fp16x16/input_1.cairo new file mode 100644 index 000000000..7ef6601b6 --- /dev/null +++ b/tests/nodes/mean_fp16x16/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 131072, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_fp16x16/output_0.cairo b/tests/nodes/mean_fp16x16/output_0.cairo new file mode 100644 index 000000000..330a52626 --- /dev/null +++ b/tests/nodes/mean_fp16x16/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 163840, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_fp8x23.cairo b/tests/nodes/mean_fp8x23.cairo new file mode 100644 index 000000000..5f8aeef9e --- /dev/null +++ b/tests/nodes/mean_fp8x23.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; + +#[test] +#[available_gas(2000000000)] +fn test_mean_fp8x23() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = TensorTrait::mean(array![input_0, input_1].span()); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/mean_fp8x23/input_0.cairo b/tests/nodes/mean_fp8x23/input_0.cairo new file mode 100644 index 000000000..1b61c0605 --- /dev/null +++ b/tests/nodes/mean_fp8x23/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 0, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_fp8x23/input_1.cairo b/tests/nodes/mean_fp8x23/input_1.cairo new file mode 100644 index 000000000..0b92e2c22 --- /dev/null +++ b/tests/nodes/mean_fp8x23/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_fp8x23/output_0.cairo b/tests/nodes/mean_fp8x23/output_0.cairo new file mode 100644 index 000000000..7d92be5b2 --- /dev/null +++ b/tests/nodes/mean_fp8x23/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{FP8x23Tensor, FP8x23TensorAdd}; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 12582912, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i32.cairo b/tests/nodes/mean_i32.cairo new file mode 100644 index 000000000..cbe69f8b9 --- /dev/null +++ b/tests/nodes/mean_i32.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{I32Tensor, I32TensorAdd}; +use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_mean_i32() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let y_0 = TensorTrait::mean(array![input_0, input_1, input_2].span()); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/mean_i32/input_0.cairo b/tests/nodes/mean_i32/input_0.cairo new file mode 100644 index 000000000..d010fec3b --- /dev/null +++ b/tests/nodes/mean_i32/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I32Tensor, I32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(0); + data.append(2); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i32/input_1.cairo b/tests/nodes/mean_i32/input_1.cairo new file mode 100644 index 000000000..db7edcd9a --- /dev/null +++ b/tests/nodes/mean_i32/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I32Tensor, I32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(5); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i32/input_2.cairo b/tests/nodes/mean_i32/input_2.cairo new file mode 100644 index 000000000..7dd95c697 --- /dev/null +++ b/tests/nodes/mean_i32/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I32Tensor, I32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(5); + data.append(0); + data.append(0); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i32/output_0.cairo b/tests/nodes/mean_i32/output_0.cairo new file mode 100644 index 000000000..6856c7ed4 --- /dev/null +++ b/tests/nodes/mean_i32/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I32Tensor, I32TensorAdd}; +use orion::numbers::NumberTrait; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(0); + data.append(2); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i8.cairo b/tests/nodes/mean_i8.cairo new file mode 100644 index 000000000..58a4f0066 --- /dev/null +++ b/tests/nodes/mean_i8.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::{I8Tensor, I8TensorAdd}; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::I8TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_mean_i8() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let y_0 = TensorTrait::mean(array![input_0, input_1, input_2].span()); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/mean_i8/input_0.cairo b/tests/nodes/mean_i8/input_0.cairo new file mode 100644 index 000000000..67b69b1d1 --- /dev/null +++ b/tests/nodes/mean_i8/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I8Tensor, I8TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(5); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i8/input_1.cairo b/tests/nodes/mean_i8/input_1.cairo new file mode 100644 index 000000000..4b13414e2 --- /dev/null +++ b/tests/nodes/mean_i8/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I8Tensor, I8TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(3); + data.append(0); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i8/input_2.cairo b/tests/nodes/mean_i8/input_2.cairo new file mode 100644 index 000000000..5c2d34ecd --- /dev/null +++ b/tests/nodes/mean_i8/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I8Tensor, I8TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(5); + data.append(2); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_i8/output_0.cairo b/tests/nodes/mean_i8/output_0.cairo new file mode 100644 index 000000000..841de50bf --- /dev/null +++ b/tests/nodes/mean_i8/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{I8Tensor, I8TensorAdd}; +use orion::numbers::NumberTrait; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(3); + data.append(2); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_u32.cairo b/tests/nodes/mean_u32.cairo new file mode 100644 index 000000000..a05eebede --- /dev/null +++ b/tests/nodes/mean_u32.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::{U32Tensor, U32TensorAdd}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::utils::{assert_eq, assert_seq_eq}; + +#[test] +#[available_gas(2000000000)] +fn test_mean_u32() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let y_0 = TensorTrait::mean(array![input_0, input_1, input_2].span()); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/mean_u32/input_0.cairo b/tests/nodes/mean_u32/input_0.cairo new file mode 100644 index 000000000..49c5bc040 --- /dev/null +++ b/tests/nodes/mean_u32/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{U32Tensor, U32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(5); + data.append(4); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_u32/input_1.cairo b/tests/nodes/mean_u32/input_1.cairo new file mode 100644 index 000000000..e512234c4 --- /dev/null +++ b/tests/nodes/mean_u32/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{U32Tensor, U32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(4); + data.append(5); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_u32/input_2.cairo b/tests/nodes/mean_u32/input_2.cairo new file mode 100644 index 000000000..3c94a962c --- /dev/null +++ b/tests/nodes/mean_u32/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{U32Tensor, U32TensorAdd}; +use orion::numbers::NumberTrait; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(5); + data.append(5); + data.append(4); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/mean_u32/output_0.cairo b/tests/nodes/mean_u32/output_0.cairo new file mode 100644 index 000000000..70afa26b0 --- /dev/null +++ b/tests/nodes/mean_u32/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::{U32Tensor, U32TensorAdd}; +use orion::numbers::NumberTrait; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(4); + data.append(4); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +}