Skip to content

Commit

Permalink
made predict datatype symetric to train
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Feb 4, 2025
1 parent 0bbb4db commit 6629da8
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 31 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# ConstantChangeLog

# unreleased
- made predicion target datatype symetric to train

# 7.5.0
- imported code for TMS<->smile dataframe conversions from tech.v3.libs.smile
- options as malli
Expand Down
6 changes: 4 additions & 2 deletions src/scicloj/ml/smile/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,10 @@
(defn- predict
[feature-ds thawed-model {:keys [target-columns
target-categorical-maps
target-datatypes
options
model-data]}]
model-data
]}]
;; (errors/when-not-error target-categorical-maps "target-categorical-maps not found. Target column need to be categorical.")
(let [n-labels (model-data :n-labels)
entry-metadata (model-type->classification-model
Expand All @@ -473,7 +475,7 @@
n-labels
target-categorical-maps))
mapped-predictions
(-> (ds-mod/probability-distributions->label-column finalised-predictions target-colname)
(-> (ds-mod/probability-distributions->label-column finalised-predictions target-colname (get target-datatypes target-colname))
(ds/update-column target-colname
#(vary-meta % assoc :column-type :prediction)))]
mapped-predictions))
Expand Down
2 changes: 1 addition & 1 deletion test/scicloj/ml/smile/categorical_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@
(ds/->dataset {:col-1 [:a :a :b :b :a :a :b :b]})
(ds/categorical->number [:col-1] [:a :b]))
model)]
(t/is (= (repeat 8 0.0)
(t/is (= (repeat 8 0)
(-> prediction :label)))))
52 changes: 24 additions & 28 deletions test/scicloj/ml/smile/logistic_regression_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -89,53 +89,47 @@



(t/deftest allow-numeric-int-target
(let [iris-ds-traget-is-non-categorical
(-> (datasets/iris-ds)
(ds/assoc-metadata [:species]
:categorical-map nil
:categorical? nil)
(ds/column-cast :species :int)
)


model (ml/train
iris-ds-traget-is-non-categorical
{:model-type :smile.classification/logistic-regression})
prediction (ml/predict iris-ds-traget-is-non-categorical model)]


(t/is (= [:model-data :options :train-input-hash :id :feature-columns :target-columns :target-datatypes]
(keys model)))

;; TODO https://github.com/scicloj/scicloj.ml.smile/issues/16
(t/is (= :float64
(-> prediction
:species
meta
:datatype)))))

(t/deftest allow-numeric-float-target
(defn validate-numeric-target [datatype]
(let [iris-ds-traget-is-non-categorical
(-> (datasets/iris-ds)
(ds/assoc-metadata [:species]
:categorical-map nil
:categorical? nil)
(ds/column-cast :species :float))
(ds/column-cast :species datatype))
model (ml/train
iris-ds-traget-is-non-categorical
{:model-type :smile.classification/logistic-regression})
prediction (ml/predict iris-ds-traget-is-non-categorical model)]

(t/is (= [:model-data :options :train-input-hash :id :feature-columns :target-columns :target-datatypes]
(keys model)))
(t/is (= :float64
(t/is (= datatype
(-> prediction
:species
meta
:datatype)))))



(t/deftest allow-numeric-target


;(validate-numeric-target :int)

(validate-numeric-target :int8)
(validate-numeric-target :int16)
(validate-numeric-target :int32)
(validate-numeric-target :int64)

(validate-numeric-target :float32)
(validate-numeric-target :float64)



)


(t/deftest fail-on-string-target
(let [iris-ds-traget-is-string
(-> (datasets/iris-ds)
Expand All @@ -150,3 +144,5 @@





0 comments on commit 6629da8

Please sign in to comment.