Skip to content

Commit

Permalink
use real tribuo RandomForest model
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Dec 27, 2024
1 parent 24b42f8 commit 5aa64ec
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions notebooks/noj_book/automl.clj
Original file line number Diff line number Diff line change
Expand Up @@ -334,22 +334,34 @@ logistic-regression-specs
;; The list of the model types we want to try:
(def models-specs
(concat logistic-regression-specs
[{:model-type :xgboost/classification :round 10}
[{:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "cart"
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
:properties {:maxDepth "8"
:useRandomSplitPoints "false"
:fractionFeaturesInSplit "0.5"}}
{:name "combiner"
:type "org.tribuo.classification.ensemble.VotingCombiner"}

{:name "random-forest"
:type "org.tribuo.common.tree.RandomForestTrainer"
:properties {:innerTrainer "cart"
:combiner "combiner"
:seed "1234"
:numMembers "500"}}]
:tribuo-trainer-name "random-forest"}

{:model-type :xgboost/classification :round 10}
{:model-type :sklearn.classification/decision-tree-classifier}
{:model-type :sklearn.classification/logistic-regression}
{:model-type :sklearn.classification/random-forest-classifier}
{:model-type :metamorph.ml/dummy-classifier}
{:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "logistic"
:type "org.tribuo.classification.sgd.linear.LinearSGDTrainer"}]
:type "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer"}]
:tribuo-trainer-name "logistic"}
{:model-type :scicloj.ml.tribuo/classification
:tribuo-components [{:name "random-forest"
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
:properties {:maxDepth "8"
:useRandomSplitPoints "false"
:fractionFeaturesInSplit "0.5"}}]
:tribuo-trainer-name "random-forest"}]))

]))


;; This uses models from Smile, Tribuo and sklearn but could be any
Expand All @@ -374,6 +386,7 @@ logistic-regression-specs
;; Execute all pipelines for all splits in the cross-validations
;; and return best model by `classification-accuracy`

(add-tap println)
(def evaluation-results
(ml/evaluate-pipelines
pipe-fns
Expand Down

0 comments on commit 5aa64ec

Please sign in to comment.