Skip to content

Commit

Permalink
replace drop with destruct
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelDkhn committed Nov 6, 2023
1 parent f1c109d commit c41189e
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 11 deletions.
12 changes: 5 additions & 7 deletions src/operators/ml/tree_ensemble/core.cairo
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use core::array::ArrayTrait;
use alexandria_data_structures::array_ext::SpanTraitExt;
use orion::numbers::NumberTrait;
use orion::operators::tensor::{Tensor, TensorTrait, U32Tensor};
use orion::utils::get_row;

use alexandria_data_structures::merkle_tree::{pedersen::PedersenHasherImpl};
use alexandria_data_structures::array_ext::ArrayTraitExt;

impl UsizeDictCopy of Copy<Felt252Dict<usize>>;
impl UsizeDictDrop of Drop<Felt252Dict<usize>>;


#[derive(Copy, Drop)]
#[derive(Copy, Drop, Destruct)]
struct TreeEnsembleAttributes<T> {
base_values: Option<Span<T>>,
nodes_falsenodeids: Span<usize>,
nodes_featureids: Span<usize>,
nodes_hitrates: Span<T>,
nodes_missing_value_tracks_true: Span<usize>,
nodes_modes: Span<NODE_MODES>,
nodes_nodeids: Span<usize>,
Expand All @@ -23,7 +20,7 @@ struct TreeEnsembleAttributes<T> {
nodes_values: Span<T>,
}

#[derive(Copy, Drop)]
#[derive(Destruct)]
struct TreeEnsemble<T> {
atts: TreeEnsembleAttributes<T>,
tree_ids: Span<usize>,
Expand All @@ -42,6 +39,7 @@ enum NODE_MODES {
LEAF
}


#[generate_trait]
impl TreeEnsembleImpl<
T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +PartialOrd<T>, +PartialEq<T>
Expand Down
8 changes: 4 additions & 4 deletions src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use orion::numbers::NumberTrait;

use alexandria_data_structures::merkle_tree::{pedersen::PedersenHasherImpl};

#[derive(Copy, Drop)]
#[derive(Destruct)]
struct TreeEnsembleClassifier<T> {
ensemble: TreeEnsemble<T>,
class_ids: Span<usize>,
Expand Down Expand Up @@ -43,7 +43,7 @@ impl TreeEnsembleClassifierImpl<
+Copy<Felt252Dict<Nullable<Array<(usize, T)>>>>,
+Copy<Nullable<Array<(usize, T)>>>,
+Copy<Felt252Dict<Nullable<T>>>,
+Drop<(Span::<u32>, Felt252Dict::<Nullable<T>>)>
+Drop<(Span::<u32>, Felt252Dict::<Nullable<T>>)>,
> of TreeEnsembleClassifierTrait<T> {
fn predict(ref self: TreeEnsembleClassifier<T>, X: Tensor<T>) -> (Tensor<usize>, Tensor<T>) {
let leaf_indices = self.ensemble.leave_index_tree(X);
Expand All @@ -65,7 +65,7 @@ fn compute_scores<
+Add<T>,
+Copy<Felt252Dict<Nullable<Array<(usize, T)>>>>,
+Copy<Nullable<Array<(usize, T)>>>,
+Copy<Felt252Dict<Nullable<T>>>
+Copy<Felt252Dict<Nullable<T>>>,
>(
ref self: TreeEnsembleClassifier<T>, leaf_indices: Tensor<usize>
) -> (Span<usize>, Felt252Dict::<Nullable<T>>) {
Expand Down Expand Up @@ -201,7 +201,7 @@ fn classify<
+Copy<Felt252Dict<Nullable<Array<(usize, T)>>>>,
+Copy<Nullable<Array<(usize, T)>>>,
+Copy<Felt252Dict<Nullable<T>>>,
+Drop<(Span::<u32>, Felt252Dict::<Nullable<T>>)>
+Drop<(Span::<u32>, Felt252Dict::<Nullable<T>>)>,
>(
ref self: TreeEnsembleClassifier<T>, scores: (Span<usize>, Felt252Dict::<Nullable<T>>)
) -> (Tensor<usize>, Tensor<T>) {
Expand Down
1 change: 1 addition & 0 deletions tests/ml.cairo
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod tree_regressor;
mod tree_ensemble_classifier;
134 changes: 134 additions & 0 deletions tests/ml/tree_ensemble_classifier.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use core::dict::Felt252DictTrait;
use orion::numbers::FP16x16;
use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor};
use orion::operators::ml::tree_ensemble::core::{
NODE_MODES, TreeEnsembleAttributes, TreeEnsemble, TreeEnsembleImpl
//TreeEnsembleHelperTrait
};
// use orion::operators::ml::tree_ensemble::implementations::{FP16x16TreeEnsembleHelper};

#[test]
#[available_gas(2000000000)]
fn test_tree_ensemble_classifier_multi() {
let class_ids: Span<usize> = array![0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
.span();

let class_nodeids: Span<usize> = array![2, 2, 2, 3, 3, 3, 4, 4, 4, 1, 1, 1, 3, 3, 3, 4, 4, 4]
.span();

let class_treeids: Span<usize> = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
.span();

let class_weights: Span<FP16x16> = array![
FP16x16 { mag: 30583, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 2185, sign: false },
FP16x16 { mag: 13107, sign: false },
FP16x16 { mag: 15729, sign: false },
FP16x16 { mag: 3932, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 32768, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 32768, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 29491, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 3277, sign: false },
FP16x16 { mag: 6746, sign: false },
FP16x16 { mag: 12529, sign: false },
FP16x16 { mag: 13493, sign: false },
]
.span();

let class_labels: Span<usize> = array![0, 1, 2].span();

let nodes_falsenodeids: Span<usize> = array![4, 3, 0, 0, 0, 2, 0, 4, 0, 0].span();

let nodes_featureids: Span<usize> = array![1, 0, 0, 0, 0, 1, 0, 0, 0, 0].span();

let nodes_missing_value_tracks_true: Span<usize> = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span();

let nodes_modes: Span<NODE_MODES> = array![
NODE_MODES::BRANCH_LEQ,
NODE_MODES::BRANCH_LEQ,
NODE_MODES::LEAF,
NODE_MODES::LEAF,
NODE_MODES::LEAF,
NODE_MODES::BRANCH_LEQ,
NODE_MODES::LEAF,
NODE_MODES::BRANCH_LEQ,
NODE_MODES::LEAF,
NODE_MODES::LEAF,
]
.span();

let nodes_nodeids: Span<usize> = array![0, 1, 2, 3, 4, 0, 1, 2, 3, 4].span();

let nodes_treeids: Span<usize> = array![0, 0, 0, 0, 0, 1, 1, 1, 1, 1].span();

let nodes_truenodeids: Span<usize> = array![1, 2, 0, 0, 0, 1, 0, 3, 0, 0].span();

let nodes_values: Span<FP16x16> = array![
FP16x16 { mag: 81892, sign: false },
FP16x16 { mag: 19992, sign: true },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 110300, sign: true },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 44245, sign: true },
FP16x16 { mag: 0, sign: false },
FP16x16 { mag: 0, sign: false },
]
.span();

let tree_ids: Span<usize> = array![0, 1].span();

let mut root_index: Felt252Dict<usize> = Default::default();
root_index.insert(0, 0);
root_index.insert(1, 5);

let mut node_index: Felt252Dict<usize> = Default::default();
node_index
.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0);
node_index
.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1);
node_index
.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2);
node_index
.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3);
node_index
.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4);
node_index
.insert(1089549915800264549621536909767699778745926517555586332772759280702396009108, 5);
node_index
.insert(1321142004022994845681377299801403567378503530250467610343381590909832171180, 6);
node_index
.insert(2592987851775965742543459319508348457290966253241455514226127639100457844774, 7);
node_index
.insert(2492755623019086109032247218615964389726368532160653497039005814484393419348, 8);
node_index
.insert(1323616023845704258113538348000047149470450086307731200728039607710316625916, 9);

let atts = TreeEnsembleAttributes {
nodes_falsenodeids,
nodes_featureids,
nodes_missing_value_tracks_true,
nodes_modes,
nodes_nodeids,
nodes_treeids,
nodes_truenodeids,
nodes_values
};

let mut ensemble: TreeEnsemble<FP16x16> = TreeEnsemble {
atts: atts, tree_ids: tree_ids, root_index: root_index, node_index: node_index
};

TreeEnsembleImpl::leave_index_tree(
ref ensemble,
TensorTrait::new(array![1].span(), array![FP16x16 { mag: 0, sign: false }].span())
);
}

0 comments on commit c41189e

Please sign in to comment.