diff --git a/CHANGELOG.md b/CHANGELOG.md index 46e28a6..bd42bf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ unreleased +- fixed and dcoumented confusion-map->ds 0.7.8 - fixed default colors of error bands diff --git a/src/scicloj/metamorph/ml/classification.clj b/src/scicloj/metamorph/ml/classification.clj index 295e666..3c72204 100644 --- a/src/scicloj/metamorph/ml/classification.clj +++ b/src/scicloj/metamorph/ml/classification.clj @@ -13,7 +13,12 @@ 1)) (defn confusion-map + "Creates a confusion-matrix in map form. Can be either as raw counts or normalized. + `normalized` when :all (default) it is normalized + :none otherwise + " ([predicted-labels labels normalize] + (let [answer-counts (frequencies labels)] (->> (map vector predicted-labels labels) (reduce (fn [total-map [pred actual]] @@ -38,47 +43,34 @@ (defn confusion-map->ds - ([conf-matrix-map normalize] - (let [all-labels (->> (keys conf-matrix-map) - sort) - header-column (merge {:column-name "column-name"} - (->> all-labels - (map #(vector % %)) - (into {}))) - column-names (concat [:column-name] - all-labels)] - (->> all-labels - (map (fn [label-name] - (let [entry (get conf-matrix-map label-name)] - (merge {:column-name label-name} - (->> all-labels - (map (fn [entry-name] - [entry-name (dtype-pp/format-object - (get entry entry-name - (case normalize - :none 0 - :all 0.0)))])) - (into {})))))) - (concat [header-column]) - (ds/->>dataset) - ;;Ensure order is consistent - (#(ds/select-columns % column-names))))) - ([conf-matrix-map] - (confusion-map->ds conf-matrix-map :all))) - - - - - -#_(defn confusion-ds - [model test-ds] - (let [predictions (ml/predict model test-ds) - answers (ds/labels test-ds)] - (-> (probability-distributions->labels predictions) - (confusion-map (ds/labels test-ds)) - (confusion-map->ds)))) -(comment - (confusion-map [:a :b :c :a] [:a :c :c :a] :all)) + "Converts teh confusion-matrix map obtained via `confusion-mao` into a dataset representation" + [conf-matrix-map] + (let [ + conf-matrix-map conf-matrix-map + all-counts (flatten (map vals (vals conf-matrix-map))) + _ (assert (or + (every? float? all-counts) + (every? int? all-counts)) + (str "All counts need to be either int? or float?, but are: " all-counts)) + is-integer (integer? (first all-counts)) + all-labels (->> (keys conf-matrix-map) + sort) + column-names (concat [:column-name] + all-labels)] + (->> all-labels + (map (fn [label-name] + (let [entry (get conf-matrix-map label-name)] + (merge {:column-name label-name} + (->> all-labels + (map (fn [entry-name] + [entry-name (dtype-pp/format-object + (get entry entry-name + (if is-integer + 0 + 0.0)))])) + (into {})))))) + (ds/->>dataset) + (#(ds/select-columns % column-names))))) (defn- get-majority-class [target-ds] diff --git a/test/scicloj/metamorph/classification_test.clj b/test/scicloj/metamorph/classification_test.clj index f8cd722..1cc5aed 100644 --- a/test/scicloj/metamorph/classification_test.clj +++ b/test/scicloj/metamorph/classification_test.clj @@ -1,5 +1,5 @@ (ns scicloj.metamorph.classification-test - (:require [scicloj.metamorph.ml.classification :refer [confusion-map]] + (:require [scicloj.metamorph.ml.classification :refer [confusion-map confusion-map->ds]] [clojure.test :refer :all] [scicloj.metamorph.ml :as ml] [scicloj.metamorph.ml.loss :as loss] @@ -24,6 +24,24 @@ :c {:b 0.5 :c 0.5}}))) +(deftest test-confusion-map->ds + (is (= + [{:column-name :a, :a "2", :c "0"} + {:column-name :c, :a "0", :c "1"}] + (-> (confusion-map [:a :b :c :a] [:a :c :c :a] :none) + (confusion-map->ds) + (tc/rows :as-maps)))) + + + (is (= + [{:column-name :a, :a "1.000", :c "0.000"} + {:column-name :c, :a "0.000", :c "0.5000"}] + (-> (confusion-map [:a :b :c :a] [:a :c :c :a] :all) + (confusion-map->ds) + (tc/rows :as-maps))))) + + + (deftest dummy-classification-fixed-label [] (let [ds (toydata/iris-ds) model (ml/train ds {:model-type :metamorph.ml/dummy-classifier