Skip to content

Commit

Permalink
fixed confusion-map->ds
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Mar 19, 2024
1 parent 30332db commit 36493ec
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
unreleased
- fixed and dcoumented confusion-map->ds

0.7.8
- fixed default colors of error bands
Expand Down
74 changes: 33 additions & 41 deletions src/scicloj/metamorph/ml/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
1))

(defn confusion-map
"Creates a confusion-matrix in map form. Can be either as raw counts or normalized.
`normalized` when :all (default) it is normalized
:none otherwise
"
([predicted-labels labels normalize]

(let [answer-counts (frequencies labels)]
(->> (map vector predicted-labels labels)
(reduce (fn [total-map [pred actual]]
Expand All @@ -38,47 +43,34 @@


(defn confusion-map->ds
([conf-matrix-map normalize]
(let [all-labels (->> (keys conf-matrix-map)
sort)
header-column (merge {:column-name "column-name"}
(->> all-labels
(map #(vector % %))
(into {})))
column-names (concat [:column-name]
all-labels)]
(->> all-labels
(map (fn [label-name]
(let [entry (get conf-matrix-map label-name)]
(merge {:column-name label-name}
(->> all-labels
(map (fn [entry-name]
[entry-name (dtype-pp/format-object
(get entry entry-name
(case normalize
:none 0
:all 0.0)))]))
(into {}))))))
(concat [header-column])
(ds/->>dataset)
;;Ensure order is consistent
(#(ds/select-columns % column-names)))))
([conf-matrix-map]
(confusion-map->ds conf-matrix-map :all)))





#_(defn confusion-ds
[model test-ds]
(let [predictions (ml/predict model test-ds)
answers (ds/labels test-ds)]
(-> (probability-distributions->labels predictions)
(confusion-map (ds/labels test-ds))
(confusion-map->ds))))
(comment
(confusion-map [:a :b :c :a] [:a :c :c :a] :all))
"Converts teh confusion-matrix map obtained via `confusion-mao` into a dataset representation"
[conf-matrix-map]
(let [
conf-matrix-map conf-matrix-map
all-counts (flatten (map vals (vals conf-matrix-map)))
_ (assert (or
(every? float? all-counts)
(every? int? all-counts))
(str "All counts need to be either int? or float?, but are: " all-counts))
is-integer (integer? (first all-counts))
all-labels (->> (keys conf-matrix-map)
sort)
column-names (concat [:column-name]
all-labels)]
(->> all-labels
(map (fn [label-name]
(let [entry (get conf-matrix-map label-name)]
(merge {:column-name label-name}
(->> all-labels
(map (fn [entry-name]
[entry-name (dtype-pp/format-object
(get entry entry-name
(if is-integer
0
0.0)))]))
(into {}))))))
(ds/->>dataset)
(#(ds/select-columns % column-names)))))


(defn- get-majority-class [target-ds]
Expand Down
20 changes: 19 additions & 1 deletion test/scicloj/metamorph/classification_test.clj
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(ns scicloj.metamorph.classification-test
(:require [scicloj.metamorph.ml.classification :refer [confusion-map]]
(:require [scicloj.metamorph.ml.classification :refer [confusion-map confusion-map->ds]]
[clojure.test :refer :all]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.loss :as loss]
Expand All @@ -24,6 +24,24 @@
:c {:b 0.5 :c 0.5}})))


(deftest test-confusion-map->ds
(is (=
[{:column-name :a, :a "2", :c "0"}
{:column-name :c, :a "0", :c "1"}]
(-> (confusion-map [:a :b :c :a] [:a :c :c :a] :none)
(confusion-map->ds)
(tc/rows :as-maps))))


(is (=
[{:column-name :a, :a "1.000", :c "0.000"}
{:column-name :c, :a "0.000", :c "0.5000"}]
(-> (confusion-map [:a :b :c :a] [:a :c :c :a] :all)
(confusion-map->ds)
(tc/rows :as-maps)))))



(deftest dummy-classification-fixed-label []
(let [ds (toydata/iris-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
Expand Down

0 comments on commit 36493ec

Please sign in to comment.