Skip to content

Commit

Permalink
fixed lable count in train
Browse files Browse the repository at this point in the history
fixes  #450
  • Loading branch information
behrica committed Feb 5, 2025
1 parent 6629da8 commit 1704001
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
7 changes: 6 additions & 1 deletion src/scicloj/ml/smile/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -439,14 +439,19 @@
properties (smile-proto/options->properties entry-metadata dataset options)
ctor (:constructor entry-metadata)
model (ctor formula data properties)]
{:n-labels (-> label-ds (get target-colname) distinct count)
{:n-labels (-> label-ds (get target-colname)
vec ;; see https://github.com/techascent/tech.ml.dataset/issues/450
distinct
count)
:smile-df-used data
:smile-props-used properties
:smile-formula-used formula
:model-as-bytes
(model/model->byte-array model)}))




(defn- thaw
[model-data]
(model/byte-array->model (:model-as-bytes model-data)))
Expand Down
4 changes: 2 additions & 2 deletions test/scicloj/ml/smile/categorical_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
:label [:x :y :x :y :x :y :x :y]})

(ds/categorical->number [:col-1] [:a :b])
(ds/categorical->number [:label] [:x :y])
(ds/categorical->number [:label] [:x :y] :int32)
(ds-mod/set-inference-target :label)
(ml/train {:model-type :smile.classification/decision-tree}))

Expand All @@ -94,7 +94,7 @@
(ds/categorical->number [:col-1] [:a :b]))
model)]

(t/is (= (repeat 8 0.0)
(t/is (= (repeat 8 0)
(-> prediction :label)))))


Expand Down
46 changes: 34 additions & 12 deletions test/scicloj/ml/smile/smile_ml_test.clj
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
(ns scicloj.ml.smile.smile-ml-test
(:require [scicloj.metamorph.ml.verify :as verify]
[scicloj.metamorph.ml :as ml]
[scicloj.ml.smile.regression]
[scicloj.ml.smile.classification]
[tech.v3.dataset :as ds]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.dataset.utils :as ds-utils]
[tech.v3.dataset.column-filters :as cf]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.malli]
[scicloj.metamorph.ml.gridsearch :as ml-gs]
[malli.core :as m]))
(:require
[clojure.test :refer [deftest is]]
[malli.core :as m]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.gridsearch :as ml-gs]
[scicloj.metamorph.ml.malli]
[scicloj.metamorph.ml.verify :as verify]
[scicloj.ml.smile.classification]
[scicloj.ml.smile.regression]
[tech.v3.dataset :as ds]
[tech.v3.dataset.categorical :as ds-cat]
[tech.v3.dataset.column-filters :as cf]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.dataset.utils :as ds-utils]))


;;shut that shit up.
Expand Down Expand Up @@ -95,3 +97,23 @@
(is map?
(ml/train titanic {:model-type :smile.classification/random-forest
:trees 10}))))

(deftest test-labels []
(let [trained-model
(->
(ds/->dataset {:x1 [1 2 4 5 6 5 6 7]
:x2 [5 6 6 7 8 2 4 6]
:y [:a :b :b :a :a :a :b :b]})
(ds/categorical->number [:y] [] :float64)
(ds-mod/set-inference-target :y)
(ml/train {:model-type :smile.classification/knn}))]
(is (=
[{:a 0.3333333333333333, :b 0.6666666666666666, :y :b}
{:a 0.3333333333333333, :b 0.6666666666666666, :y :b}]
(->
(ml/predict (ds/->dataset {:x1 [1 2 4 5 6 5 6 7]
:x2 [5 6 6 7 8 2 4 6]}) trained-model)
(ds-cat/reverse-map-categorical-xforms)
(ds/head 2)
(ds/rows))))))

0 comments on commit 1704001

Please sign in to comment.