From 39240959fadbe7a7d1f6f132e35a425f6359c4c4 Mon Sep 17 00:00:00 2001 From: Davis King Date: Tue, 22 Oct 2024 22:26:14 -0400 Subject: [PATCH] Make event_correlation() work on fractional counts --- dlib/statistics/statistics.h | 24 ++++++++++++------------ dlib/statistics/statistics_abstract.h | 16 ++++++++-------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/dlib/statistics/statistics.h b/dlib/statistics/statistics.h index 492ddf8e7e..914f5482ff 100644 --- a/dlib/statistics/statistics.h +++ b/dlib/statistics/statistics.h @@ -1827,20 +1827,20 @@ namespace dlib // ---------------------------------------------------------------------------------------- inline double binomial_random_vars_are_different ( - uint64_t k1, - uint64_t n1, - uint64_t k2, - uint64_t n2 + double k1, + double n1, + double k2, + double n2 ) { DLIB_ASSERT(k1 <= n1, "k1: "<< k1 << " n1: "<< n1); DLIB_ASSERT(k2 <= n2, "k2: "<< k2 << " n2: "<< n2); - const double p1 = k1/(double)n1; - const double p2 = k2/(double)n2; - const double p = (k1+k2)/(double)(n1+n2); + const double p1 = n1 != 0 ? k1/n1 : 0; + const double p2 = n2 != 0 ? k2/n2 : 0; + const double p = (k1+k2)/(n1+n2); - auto ll = [](double p, uint64_t k, uint64_t n) { + auto ll = [](double p, double k, double n) { if (p == 0 || p == 1) return 0.0; return k*std::log(p) + (n-k)*std::log(1-p); @@ -1860,10 +1860,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- inline double event_correlation ( - uint64_t A_count, - uint64_t B_count, - uint64_t AB_count, - uint64_t total_num_observations + double A_count, + double B_count, + double AB_count, + double total_num_observations ) { DLIB_ASSERT(AB_count <= A_count && A_count <= total_num_observations, diff --git a/dlib/statistics/statistics_abstract.h b/dlib/statistics/statistics_abstract.h index b5738196d7..24432ded26 100644 --- a/dlib/statistics/statistics_abstract.h +++ b/dlib/statistics/statistics_abstract.h @@ -108,10 +108,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- double binomial_random_vars_are_different ( - uint64_t k1, - uint64_t n1, - uint64_t k2, - uint64_t n2 + double k1, + double n1, + double k2, + double n2 ); /*! requires @@ -138,10 +138,10 @@ namespace dlib // ---------------------------------------------------------------------------------------- double event_correlation ( - uint64_t A_count, - uint64_t B_count, - uint64_t AB_count, - uint64_t total_num_observations + double A_count, + double B_count, + double AB_count, + double total_num_observations ); /*! requires