From 0699dbbe2b7f0c2e7277b3f79d2cf69ab68ba5a9 Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Thu, 27 Feb 2020 23:43:37 +0530 Subject: [PATCH 1/7] adding dirichlet distribution --- .../boost/math/distributions/dirichlet.hpp | 405 ++++++++++++++++++ .../compile_test/dist_dirichlet_incl_test.cpp | 25 ++ 2 files changed, 430 insertions(+) create mode 100644 include/boost/math/distributions/dirichlet.hpp create mode 100644 test/compile_test/dist_dirichlet_incl_test.cpp diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp new file mode 100644 index 0000000000..12707cf156 --- /dev/null +++ b/include/boost/math/distributions/dirichlet.hpp @@ -0,0 +1,405 @@ +// boost/math/distributions/dirichlet.hpp + +// Copyright Mrityunjay Tripathi 2020. + +// Use, modification and distribution are subject to the +// Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +// https://en.wikipedia.org/wiki/Dirichlet_distribution +// https://mast.queensu.ca/~communications/Papers/msc-jiayu-lin.pdf + +// The Dirichlet distribution is a family of continuous multivariate probability +// distributions parameterized by a vector 'alpha' of positive reals. +// It is a multivariate generalization of the dirichlet distribution, hence its +// alternative name of multivariate dirichlet distribution (MBD). +// Dirichlet distributions are commonly used as prior distributions in +// Bayesian statistics, and in fact the Dirichlet distribution is the +// conjugate prior of the categorical distribution and multinomial distribution. + +#ifndef BOOST_MATH_DIST_DIRICHLET_HPP +#define BOOST_MATH_DIST_DIRICHLET_HPP + +#include +#include +#include +#include +#include +#include + +#if defined(BOOST_MSVC) +#pragma warning(push) +#pragma warning(disable : 4702) // unreachable code +// in domain_error_imp in error_handling +#endif + +#include + +namespace boost +{ +namespace math +{ +namespace dirichlet_detail +{ +// Common error checking routines for dirichlet distribution function: +template +inline bool check_concentration(const char *function, + const VectorType &concentration, + RealType *result, + const Policy &pol) +{ + for (size_t i = 0; i < concentration.size(); ++i) + { + if (!(boost::math::isfinite)(i) || (i <= 0)) + { + *result = policies::raise_domain_error( + function, + "Concentration Parameter is %1%, but must be > 0 !", concentration, pol); + return false; + } + } + return true; +} // bool check_concentration + +template +inline bool check_prob(const char *function, + const VectorType &p, + RealType *result, + const Policy &pol) +{ + for (size_t i = 0; i < p.size(); ++i) + { + if ((i < 0) || (i > 1) || !(boost::math::isfinite)(i)) + { + *result = policies::raise_domain_error( + function, + "Probability argument is %1%, but must be >= 0 and <= 1 !", i, pol); + return false; + } + } + return true; +} // bool check_prob + +template +inline bool check_x(const char *function, + const VectorType &x, + RealType *result, + const Policy &pol) +{ + for (size_t i = 0; i < x.size(); ++i) + { + if (!(boost::math::isfinite)(x) || (x < 0) || (x > 1)) + { + *result = policies::raise_domain_error( + function, + "x argument is %1%, but must be >= 0 and <= 1 !", x, pol); + return false; + } + } + return true; +} // bool check_x + +template +inline bool check_dist(const char *function, + const VectorType &concentration, + RealType *result, + const Policy &pol) +{ + return check_concentration(function, concentration, result, pol); +} // bool check_dist + +template +inline bool check_dist_and_x(const char *function, + const VectorType &concentration, + const VectorType &x, + RealType *result, + const Policy &pol) +{ + return check_dist(function, concentration, result, pol) && check_x(function, x, result, pol); +} // bool check_dist_and_x + +template +inline bool check_dist_and_prob(const char *function, + const VectorType &concentration, + const VectorType &p, + RealType *result, + const Policy &pol) +{ + return check_dist(function, concentration, result, pol) && check_prob(function, p, result, pol); +} // bool check_dist_and_prob + +template +inline bool check_mean(const char *function, + const VectorType &mean, + RealType *result, + const Policy &pol) +{ + for (size_t i = 0; i < mean.size(); ++i) + { + if (!(boost::math::isfinite)(i) || (i <= 0)) + { + *result = policies::raise_domain_error( + function, + "mean argument is %1%, but must be > 0 !", i, pol); + return false; + } + } + return true; +} // bool check_mean + +template +inline bool check_variance(const char *function, + const VectorType &variance, + RealType *result, + const Policy &pol) +{ + for (size_t i = 0; i < variance.size(); ++i) + { + if (!(boost::math::isfinite)(i) || (i <= 0)) + { + *result = policies::raise_domain_error( + function, + "variance argument is %1%, but must be > 0 !", i, pol); + return false; + } + } + return true; +} // bool check_variance +} // namespace dirichlet_detail + +template , class RealType = double, class Policy = policies::policy<>> +class dirichlet_distribution +{ +public: + dirichlet_distribution(VectorType concentration) : concentration(concentration) + { + RealType result; + dirichlet_detail::check_dist( + "boost::math::dirichlet_distribution<%1%>::dirichlet_distribution", + concentration, + &result, Policy()); + sum_concentration = accumulate(concentration.begin(), concentration.end(), 0); + } // dirichlet_distribution constructor. + + // Accessor functions: + VectorType Concentration() const + { + return concentration; + } + + size_t Order() const + { + return concentration.size(); + } + + static VectorType find_concentration( + VectorType mean, // Expected value of mean. + VectorType variance) // Expected value of variance. + { + assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); + static const char *function = "boost::math::dirichlet_distribution<%1%>::find_concentration"; + RealType result = 0; // of error checks. + if (!(dirichlet_detail::check_mean(function, mean, &result, Policy()) && dirichlet_detail::check_variance(function, variance, &result, Policy()))) + { + return result; + } + VectorType c; + for (size_t i = 0; i < mean.size(); ++i) + { + c.push_back(mean[i] * (((mean[i] * (1 - mean[i])) / variance[i]) - 1)); + } + return c; + } // RealType find_concentration + + // TODO + // static VectorType find_concentration( + // VectorType x, // x. + // VectorType probability) // cdf + // { + // assert(("", x.size() == probability.size())); + // static const char *function = "boost::math::dirichlet_distribution<%1%>::find_conentration"; + // RealType result = 0; // of error checks. + // if (!(dirichlet_detail::check_prob(function, probability, &result, Policy()) && dirichlet_detail::check_x(function, x, &result, Policy()))) + // { + // return result; + // } + // return ; + // } // RealType find_concentration(x, probability) + +private: + VectorType concentration; // https://en.wikipedia.org/wiki/Concentration_parameter. + RealType sum_concentration; +}; // template class dirichlet_distribution + + +template +inline const std::pair range(const dirichlet_distribution & /* dist */) +{ // Range of permissible values for random variable x. + using boost::math::tools::max_value; + return std::pair(static_cast(0), static_cast(1)); +} + + +template +inline const std::pair support(const dirichlet_distribution & /* dist */) +{ // Range of supported values for random variable x. + // This is range where cdf rises from 0 to 1, and outside it, the pdf is zero. + return std::pair(static_cast(0), static_cast(1)); +} + + +template +inline VectorType mean(const dirichlet_distribution &dist) +{ // Mean of dirichlet distribution = c[i]/sum(c). + VectorType m; + for (size_t i = 0; i < dist.Order(); ++i) + { + m.push_back(dist.concentration[i] / dist.sum_concentration); + } + return m; + +} // mean + + +template +inline VectorType variance(const dirichlet_distribution &dist) +{ + VectorType v; + for (size_t i = 0; i < dist.Order(); ++i) + { + v.push_back(dist.concentration[i] / dist.sum_concentration * (1 - dist.concentration[i] / dist.sum_concentration) / (1 + dist.sum_concentration)); + } + return v; +} // variance + + +template +inline VectorType mode(const dirichlet_distribution &dist) +{ + static const char *function = "boost::math::mode(dirichlet_distribution<%1%> const&)"; + VectorType m; + for (size_t i = 0; i < dist.Order(); ++i) + { + if ((dist.concentration[i] <= 1)) + { + result = policies::raise_domain_error( + function, + "mode undefined for alpha = %1%, must be > 1!", dist.alpha(), Policy()); + return result; + } + else + { + m.push_back((dist.concentration[i] - 1) / (dist.sum_concentration - dist.Order())); + } + } + return m; +} // mode + + +template +inline RealType entropy(const dirichlet_distribution &dist) +{ + RealType t1 = 1; + for (size_t i = 0; i < dist.Order(); ++i) + { + t1 *= tgamma(dist.concentration[i]); + } + t1 = std::log(t1 / tgamma(dist.sum_concentration)); + RealType t2 = (dist.sum_concentration - dist.Order()) * digamma(dist.sum_concentration); + RealType t3 = 0; + for (size_t i = 0; i < dist.Order(); ++i) + { + t3 += (dist.concentration[i] - 1) * digamma(dist.concentration[i]); + } + return t1 + t2 - t3; +} + + +template +inline RealType pdf(const dirichlet_distribution &dist, const VectorType &x) +{ // Probability Density/Mass Function. + BOOST_FPU_EXCEPTION_GUARD + + static const char *function = "boost::math::pdf(dirichlet_distribution<%1%> const&, %1%)"; + + BOOST_MATH_STD_USING // for ADL of std functions + + // Argument checks: + RealType result = 0; + if (!dirichlet_detail::check_dist_and_x(function, x, &result, Policy())) + { + return result; + } + using boost::math::tgamma; + RealType f = 1; + for (size_t i = 0; i < dist.Order(); ++i) + { + f *= std::pow(x[i], dist.concentration[i] - 1); + } + f /= dist.normalizing_factor; + return f; +} // pdf + + +template +inline RealType cdf(const dirichlet_distribution &dist, const VectorType &x) +{ // Cumulative Distribution Function dirichlet. + BOOST_MATH_STD_USING // for ADL of std functions + + static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; + + // Argument checks: + RealType result = 0; + if (!dirichlet_detail::check_dist_and_x(function, dist.concentration, x, &result, Policy())) + { + return result; + } + RealType c = 1; + for (size_t i = 0; i < dist.Order(); ++i) + { + c *= std::pow(x[i], dist.concentration[i]) / tgamma(dist.concentration[i]) / dist.concentration[i]; + } + c *= tgamma(dist.sum_concentration); + return c; +} // dirichlet cdf + +template +inline RealType cdf(const complemented2_type, RealType> &c) +{ // Complemented Cumulative Distribution Function dirichlet. + + BOOST_MATH_STD_USING // for ADL of std functions + + static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; + + RealType const &x = c.param; + dirichlet_distribution const &dist = c.dist; + + // Argument checks: + RealType result = 0; + if (!dirichlet_detail::check_dist_and_x(function, x, &result, Policy())) + { + return result; + } + RealType cumm = 1; + for (size_t i = 0; i < dist.Order(); ++i) + { + cumm *= std::pow(x[i], dist.concentration[i]) / tgamma(dist.concentration[i]) / dist.concentration[i]; + } + cumm *= tgamma(dist.sum_concentration); + return cumm; +} // dirichlet cdf + +} // namespace math +} // namespace boost + +// This include must be at the end, *after* the accessors +// for this distribution have been defined, in order to +// keep compilers that support two-phase lookup happy. +#include + +#if defined(BOOST_MSVC) +#pragma warning(pop) +#endif + +#endif // BOOST_MATH_DIST_dirichlet_HPP diff --git a/test/compile_test/dist_dirichlet_incl_test.cpp b/test/compile_test/dist_dirichlet_incl_test.cpp new file mode 100644 index 0000000000..8f40e6bbb2 --- /dev/null +++ b/test/compile_test/dist_dirichlet_incl_test.cpp @@ -0,0 +1,25 @@ +// Copyright Mrityunjay Tripathi 2020. +// Use, modification and distribution are subject to the +// Boost Software License, Version 1.0. (See accompanying file +// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// Basic sanity check that header +// #includes all the files that it needs to. +// +#include +// +// Note this header includes no other headers, this is +// important if this test is to be meaningful: +// +#include "test_compile_result.hpp" + +void compile_and_link_test() +{ + TEST_DIST_FUNC(dirichlet) +} + +template class boost::math::dirichlet_distribution, double, boost::math::policies::policy<>>; +template class boost::math::dirichlet_distribution, float, boost::math::policies::policy<>>; +#ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS +template class boost::math::dirichlet_distribution, long double, boost::math::policies::policy<>>; +#endif From ddd2c1ad879643c61b1e64851532a372564e4c1a Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Sat, 29 Feb 2020 16:51:40 +0530 Subject: [PATCH 2/7] added skewness, kurtosis, kurtosis_excess functions and RandomAccessContainer type for vector type. --- .../boost/math/distributions/dirichlet.hpp | 391 ++++++++++-------- .../compile_test/dist_dirichlet_incl_test.cpp | 6 +- 2 files changed, 233 insertions(+), 164 deletions(-) diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp index 12707cf156..b286933257 100644 --- a/include/boost/math/distributions/dirichlet.hpp +++ b/include/boost/math/distributions/dirichlet.hpp @@ -12,19 +12,20 @@ // The Dirichlet distribution is a family of continuous multivariate probability // distributions parameterized by a vector 'alpha' of positive reals. -// It is a multivariate generalization of the dirichlet distribution, hence its -// alternative name of multivariate dirichlet distribution (MBD). +// It is a multivariate generalization of the beta distribution, hence its +// alternative name of multivariate beta distribution (MBD). // Dirichlet distributions are commonly used as prior distributions in // Bayesian statistics, and in fact the Dirichlet distribution is the // conjugate prior of the categorical distribution and multinomial distribution. -#ifndef BOOST_MATH_DIST_DIRICHLET_HPP -#define BOOST_MATH_DIST_DIRICHLET_HPP +#ifndef BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP +#define BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP #include +#include #include -#include #include +#include #include #include @@ -43,350 +44,418 @@ namespace math namespace dirichlet_detail { // Common error checking routines for dirichlet distribution function: -template -inline bool check_concentration(const char *function, - const VectorType &concentration, - RealType *result, - const Policy &pol) +template +inline bool check_alpha(const char *function, + const RandomAccessContainer &alpha, + typename RandomAccessContainer::value_type *result, + const Policy &pol) { - for (size_t i = 0; i < concentration.size(); ++i) + using RealType = RandomAccessContainer::value_type; + for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) { - if (!(boost::math::isfinite)(i) || (i <= 0)) + if (!(boost::math::isfinite)(alpha[i]) || (alpha[i] <= 0)) { *result = policies::raise_domain_error( function, - "Concentration Parameter is %1%, but must be > 0 !", concentration, pol); + "alpha Parameter is %1%, but must be > 0 !", alpha[i], pol); return false; } } return true; -} // bool check_concentration +} // bool check_alpha -template +template inline bool check_prob(const char *function, - const VectorType &p, - RealType *result, + const RandomAccessContainer &p, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - for (size_t i = 0; i < p.size(); ++i) + using RealType = RandomAccessContainer::value_type; + for (decltype(p.size()) i = 0; i < p.size(); ++i) { - if ((i < 0) || (i > 1) || !(boost::math::isfinite)(i)) + if ((p[i] < 0) || (p[i] > 1) || !(boost::math::isfinite)(p[i])) { *result = policies::raise_domain_error( function, - "Probability argument is %1%, but must be >= 0 and <= 1 !", i, pol); + "Probability argument is %1%, but must be >= 0 and <= 1 !", p[i], pol); return false; } } return true; } // bool check_prob -template +template inline bool check_x(const char *function, - const VectorType &x, - RealType *result, + const RandomAccessContainer &x, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - for (size_t i = 0; i < x.size(); ++i) + using RealType = typename RandomAccessContainer::value_type; + for (decltype(x.size()) i = 0; i < x.size(); ++i) { - if (!(boost::math::isfinite)(x) || (x < 0) || (x > 1)) + if (!(boost::math::isfinite)(x[i]) || (x[i] < 0) || (x[i] > 1)) { *result = policies::raise_domain_error( function, - "x argument is %1%, but must be >= 0 and <= 1 !", x, pol); + "x argument is %1%, but must be >= 0 and <= 1 !", x[i], pol); return false; } } return true; } // bool check_x -template +template inline bool check_dist(const char *function, - const VectorType &concentration, - RealType *result, + const RandomAccessContainer &alpha, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - return check_concentration(function, concentration, result, pol); + return check_alpha(function, alpha, result, pol); } // bool check_dist -template +template inline bool check_dist_and_x(const char *function, - const VectorType &concentration, - const VectorType &x, - RealType *result, + const RandomAccessContainer &alpha, + const RandomAccessContainer &x, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - return check_dist(function, concentration, result, pol) && check_x(function, x, result, pol); + return check_dist(function, alpha, result, pol) && check_x(function, x, result, pol); } // bool check_dist_and_x -template +template inline bool check_dist_and_prob(const char *function, - const VectorType &concentration, - const VectorType &p, - RealType *result, + const RandomAccessContainer &alpha, + const RandomAccessContainer &p, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - return check_dist(function, concentration, result, pol) && check_prob(function, p, result, pol); + return check_dist(function, alpha, result, pol) && check_prob(function, p, result, pol); } // bool check_dist_and_prob -template +template inline bool check_mean(const char *function, - const VectorType &mean, - RealType *result, + const RandomAccessContainer &mean, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - for (size_t i = 0; i < mean.size(); ++i) + using RealType = typename RandomAccessContainer::value_type; + for (decltype(mean.size()) i = 0; i < mean.size(); ++i) { - if (!(boost::math::isfinite)(i) || (i <= 0)) + if (!(boost::math::isfinite)(mean[i]) || (mean[i] <= 0)) { *result = policies::raise_domain_error( function, - "mean argument is %1%, but must be > 0 !", i, pol); + "mean argument is %1%, but must be > 0 !", mean[i], pol); return false; } } return true; } // bool check_mean -template +template inline bool check_variance(const char *function, - const VectorType &variance, - RealType *result, + const RandomAccessContainer &variance, + typename RandomAccessContainer::value_type *result, const Policy &pol) { - for (size_t i = 0; i < variance.size(); ++i) + using RealType = typename RandomAccessContainer::value_type; + for (decltype(variance.size()) i = 0; i < variance.size(); ++i) { - if (!(boost::math::isfinite)(i) || (i <= 0)) + if (!(boost::math::isfinite)(variance[i]) || (variance[i] <= 0)) { *result = policies::raise_domain_error( function, - "variance argument is %1%, but must be > 0 !", i, pol); + "variance argument is %1%, but must be > 0 !", variance[i], pol); return false; } } return true; } // bool check_variance + +template +inline bool check_mean_and_variance(const char *function, + const RandomAccessContainer &mean, + const RandomAccessContainer &variance, + typename RandomAccessContainer::value_type *result, + const Policy &pol) +{ + return check_mean(function, mean, result, pol) && check_variance(function, variance, result, pol); +} // bool check_mean_and_variance + +template +inline typename RandomAccessContainer::value_type mvar_beta( + const RandomAccessContainer &alpha, + const typename RandomAccessContainer::value_type &b) +{ + // B(a1,a2,...ak) = tgamma(a1+a2+...+ak)/(tgamma(a1)*tgamma(a2)...*tgamma(ak) + using RealType = typename RandomAccessContainer::value_type; + RealType mb; + RealType alpha_sum = accumulate(alpha.begin(), alpha.end(), b * alpha.size()); + for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) + { + mb *= tgamma(alpha[i] + b); + } + mb /= tgamma(alpha_sum); + return mb; +} // mvar_beta + +template +inline typename RandomAccessContainer::value_type alpha0(const RandomAccessContainer &alpha) +{ + return accumulate(alpha.begin(), alpha.end(), 0); +} } // namespace dirichlet_detail -template , class RealType = double, class Policy = policies::policy<>> +template , class Policy = policies::policy<>> class dirichlet_distribution { + using RealType = typename RandomAccessContainer::value_type; + public: - dirichlet_distribution(VectorType concentration) : concentration(concentration) + dirichlet_distribution(RandomAccessContainer &&alpha) : m_alpha(alpha) { RealType result; dirichlet_detail::check_dist( "boost::math::dirichlet_distribution<%1%>::dirichlet_distribution", - concentration, + alpha, &result, Policy()); - sum_concentration = accumulate(concentration.begin(), concentration.end(), 0); } // dirichlet_distribution constructor. - // Accessor functions: - VectorType Concentration() const - { - return concentration; - } + // Get the concentration parameters. + RandomAccessContainer &alpha() const { return m_alpha; } + // Set the concentration parameters. + RandomAccessContainer &alpha() { return m_alpha; } - size_t Order() const - { - return concentration.size(); - } + // Get the order of concentration parameters. + decltype(m_alpha.size()) &order() const { return m_alpha.size(); } - static VectorType find_concentration( - VectorType mean, // Expected value of mean. - VectorType variance) // Expected value of variance. + // Get alpha from mean and variance. + static void find_alpha( + RandomAccessContainer &&mean, // Expected value of mean. + RandomAccessContainer &&variance) // Expected value of variance. { assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); - static const char *function = "boost::math::dirichlet_distribution<%1%>::find_concentration"; + static const char *function = "boost::math::dirichlet_distribution<%1%>::alpha_from_mean_and_variance"; RealType result = 0; // of error checks. if (!(dirichlet_detail::check_mean(function, mean, &result, Policy()) && dirichlet_detail::check_variance(function, variance, &result, Policy()))) { return result; } - VectorType c; - for (size_t i = 0; i < mean.size(); ++i) + for (decltype(mean.size()) i = 0; i < mean.size(); ++i) { - c.push_back(mean[i] * (((mean[i] * (1 - mean[i])) / variance[i]) - 1)); + m_alpha[i] = mean[i] * (((mean[i] * (1 - mean[i])) / variance[i]) - 1); } - return c; - } // RealType find_concentration - - // TODO - // static VectorType find_concentration( - // VectorType x, // x. - // VectorType probability) // cdf - // { - // assert(("", x.size() == probability.size())); - // static const char *function = "boost::math::dirichlet_distribution<%1%>::find_conentration"; - // RealType result = 0; // of error checks. - // if (!(dirichlet_detail::check_prob(function, probability, &result, Policy()) && dirichlet_detail::check_x(function, x, &result, Policy()))) - // { - // return result; - // } - // return ; - // } // RealType find_concentration(x, probability) + } // void find_alpha private: - VectorType concentration; // https://en.wikipedia.org/wiki/Concentration_parameter. - RealType sum_concentration; + RandomAccessContainer m_alpha; // https://en.wikipedia.org/wiki/Concentration_parameter. }; // template class dirichlet_distribution - -template -inline const std::pair range(const dirichlet_distribution & /* dist */) +template +inline const std::pair< + typename RandomAccessContainer::value_type, + typename RandomAccessContainer::value_type> +range(const dirichlet_distribution & /* dist */) { // Range of permissible values for random variable x. using boost::math::tools::max_value; + using RealType = typename RandomAccessContainer::value_type; return std::pair(static_cast(0), static_cast(1)); } - -template -inline const std::pair support(const dirichlet_distribution & /* dist */) +template +inline const std::pair< + typename RandomAccessContainer::value_type, + typename RandomAccessContainer::value_type> +support(const dirichlet_distribution & /* dist */) { // Range of supported values for random variable x. // This is range where cdf rises from 0 to 1, and outside it, the pdf is zero. + using RealType = typename RandomAccessContainer::value_type; return std::pair(static_cast(0), static_cast(1)); } - -template -inline VectorType mean(const dirichlet_distribution &dist) +template +inline RandomAccessContainer mean(const dirichlet_distribution &dist) { // Mean of dirichlet distribution = c[i]/sum(c). - VectorType m; - for (size_t i = 0; i < dist.Order(); ++i) + // using RealType = typename RandomAccessContainer::value_type; + RandomAccessContainer m(dist.order()); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - m.push_back(dist.concentration[i] / dist.sum_concentration); + m[i] = dist.m_alpha[i] / dirichlet_detail::alpha0(dist.m_alpha); } return m; - } // mean - -template -inline VectorType variance(const dirichlet_distribution &dist) +template +inline RandomAccessContainer variance(const dirichlet_distribution &dist) { - VectorType v; - for (size_t i = 0; i < dist.Order(); ++i) + RandomAccessContainer v(dist.order()); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - v.push_back(dist.concentration[i] / dist.sum_concentration * (1 - dist.concentration[i] / dist.sum_concentration) / (1 + dist.sum_concentration)); + v[i] = dist.m_alpha[i] / dirichlet_detail::alpha0(dist.m_alpha) * (1 - dist.m_alpha[i] / dirichlet_detail::alpha0(dist.alpha)) / (1 + dirichlet_detail::alpha0(dist.alpha)); } return v; } // variance +template +inline RandomAccessContainer standard_deviation(const dirichlet_distribution &dist) +{ + RandomAccessContainer std = variance(dist); + for (decltype(dist.order()) i = 0; i < s.size(); ++i) + { + std[i] = std::sqrt(std[i]); + } + return std; +} // standard_deviation -template -inline VectorType mode(const dirichlet_distribution &dist) +template +inline RandomAccessContainer mode(const dirichlet_distribution &dist) { + using RealType = typename RandomAccessContainer::value_type; static const char *function = "boost::math::mode(dirichlet_distribution<%1%> const&)"; - VectorType m; - for (size_t i = 0; i < dist.Order(); ++i) + RandomAccessContainer m(dist.order()); + for (decltype(dist.order()) i = 0; i < m.size(); ++i) { - if ((dist.concentration[i] <= 1)) + if (dist.m_alpha[i] <= 1) { result = policies::raise_domain_error( function, - "mode undefined for alpha = %1%, must be > 1!", dist.alpha(), Policy()); + "mode undefined for alpha = %1%, must be > 1!", dist.m_alpha[i], Policy()); return result; } else { - m.push_back((dist.concentration[i] - 1) / (dist.sum_concentration - dist.Order())); + m[i] = (dist.m_alpha[i] - 1) / (dirichlet_detail::alpha0(dist.m_alpha) - dist.order()); } } return m; } // mode +template +inline typename RandomAccessContainer::value_type entropy(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + RealType ent = std::log(dirichlet_detail::mvar_beta(dist.m_alpha, 0)) + (dirichlet_detail::alpha0(dist.m_alpha) - dist.order()) * digamma(dirichlet_detail::alpha0(dist.m_alpha)); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + ent += (dist.m_alpha[i] - 1) * digamma(dist.m_alpha[i]); + } + return ent; +} -template -inline RealType entropy(const dirichlet_distribution &dist) +template +inline RandomAccessContainer skewness(const dirichlet_distribution &dist) { - RealType t1 = 1; - for (size_t i = 0; i < dist.Order(); ++i) + using RealType = typename RandomAccessContainer::value_type; + RandomAccessContainer s(dist.order()); + RealType A = dirichlet_detail::alpha0(dist.m_alpha); + RealType aj; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - t1 *= tgamma(dist.concentration[i]); + aj = dist.m_alpha[i]; + s[i] = std::sqrt(aj * (A + 1) / (A - aj)) * ((aj + 2) * (aj + 1) * A * A / (aj * (A + 2) * (A - aj)) - 3 - aj * (A + 1) / (A - aj)); } - t1 = std::log(t1 / tgamma(dist.sum_concentration)); - RealType t2 = (dist.sum_concentration - dist.Order()) * digamma(dist.sum_concentration); - RealType t3 = 0; - for (size_t i = 0; i < dist.Order(); ++i) + return s; +} + +template +inline RandomAccessContainer kurtosis(const dirichlet_distribution &dist) +{ + using RealType = typename RandomAccessContainer::value_type; + using std::pow; + RandomAccessContainer k(dist.order()); + RealType A = dirichlet_detail::alpha0(dist.m_alpha); + RealType aj; + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - t3 += (dist.concentration[i] - 1) * digamma(dist.concentration[i]); + aj = dist.m_alpha[i]; + k[i] = ((aj + 2) * (aj + 1) * ((aj + 3) * A / (A + 3) / aj - 4) + 6 * (aj + 1) * aj / (A + 1) / A - 3 * pow(aj / A, 2)) / std::pow((A - aj) / A / (A + 1), 2); } - return t1 + t2 - t3; + return k; } +template +inline RandomAccessContainer kurtosis_excess(const dirichlet_distribution &dist) +{ + RandomAccessContainer ke = kurtosis(dist); + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) + { + ke[i] = ke[i] - 3; + } + return ke; +} -template -inline RealType pdf(const dirichlet_distribution &dist, const VectorType &x) +template +inline typename RandomAccessContainer::value_type pdf( + const dirichlet_distribution &dist, + const RandomAccessContainer &x) { // Probability Density/Mass Function. + using RealType = typename RandomAccessContainer::value_type; + using std::pow; BOOST_FPU_EXCEPTION_GUARD - static const char *function = "boost::math::pdf(dirichlet_distribution<%1%> const&, %1%)"; - BOOST_MATH_STD_USING // for ADL of std functions - - // Argument checks: RealType result = 0; if (!dirichlet_detail::check_dist_and_x(function, x, &result, Policy())) { return result; } - using boost::math::tgamma; + RealType f = 1; - for (size_t i = 0; i < dist.Order(); ++i) + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - f *= std::pow(x[i], dist.concentration[i] - 1); + f *= pow(x[i], dist.m_alpha[i] - 1); } - f /= dist.normalizing_factor; + f /= dirichlet_detail::mvar_beta(dist.m_alpha, 0); return f; } // pdf - -template -inline RealType cdf(const dirichlet_distribution &dist, const VectorType &x) +template +inline typename RandomAccessContainer::value_type cdf( + const dirichlet_distribution &dist, + const RandomAccessContainer &x) { // Cumulative Distribution Function dirichlet. + using RealType = typename RandomAccessContainer::value_type; + using std::pow; BOOST_MATH_STD_USING // for ADL of std functions - static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; - - // Argument checks: - RealType result = 0; - if (!dirichlet_detail::check_dist_and_x(function, dist.concentration, x, &result, Policy())) + static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; + RealType result = 0; // Arguments check. + if (!dirichlet_detail::check_dist_and_x(function, dist.alpha, x, &result, Policy())) { return result; } RealType c = 1; - for (size_t i = 0; i < dist.Order(); ++i) + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - c *= std::pow(x[i], dist.concentration[i]) / tgamma(dist.concentration[i]) / dist.concentration[i]; + c *= pow(x[i], dist.m_alpha[i]) / tgamma(dist.m_alpha[i]) / dist.m_alpha[i]; } - c *= tgamma(dist.sum_concentration); + c *= tgamma(dirichlet_detail::alpha0(dist.m_alpha)); return c; } // dirichlet cdf -template -inline RealType cdf(const complemented2_type, RealType> &c) +template +inline typename RandomAccessContainer::value_type cdf( + const complemented2_type, + typename RandomAccessContainer::value_type> &c) { // Complemented Cumulative Distribution Function dirichlet. - + using RealType = typename RandomAccessContainer::value_type; + using std::pow; BOOST_MATH_STD_USING // for ADL of std functions - static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; - RealType const &x = c.param; dirichlet_distribution const &dist = c.dist; - - // Argument checks: - RealType result = 0; + RealType result = 0; // Argument checks. if (!dirichlet_detail::check_dist_and_x(function, x, &result, Policy())) { return result; } RealType cumm = 1; - for (size_t i = 0; i < dist.Order(); ++i) + for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - cumm *= std::pow(x[i], dist.concentration[i]) / tgamma(dist.concentration[i]) / dist.concentration[i]; + cumm *= pow(x[i], dist.m_alpha[i]) / tgamma(dist.m_alpha[i]) / dist.m_alpha[i]; } - cumm *= tgamma(dist.sum_concentration); + cumm *= tgamma(dirichlet_detail::alpha0(dist.m_alpha)); return cumm; } // dirichlet cdf @@ -402,4 +471,4 @@ inline RealType cdf(const complemented2_type, double, boost::math::policies::policy<>>; -template class boost::math::dirichlet_distribution, float, boost::math::policies::policy<>>; +template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; +template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; #ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS -template class boost::math::dirichlet_distribution, long double, boost::math::policies::policy<>>; +template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; #endif From a73ff6b70ba66911ad09d2305a9e7818d55a348c Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Mon, 2 Mar 2020 16:06:31 +0530 Subject: [PATCH 3/7] slight errors removed --- .../boost/math/distributions/dirichlet.hpp | 124 +++++++++--------- 1 file changed, 63 insertions(+), 61 deletions(-) diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp index b286933257..576e14b524 100644 --- a/include/boost/math/distributions/dirichlet.hpp +++ b/include/boost/math/distributions/dirichlet.hpp @@ -22,7 +22,6 @@ #define BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP #include -#include #include #include #include @@ -50,7 +49,7 @@ inline bool check_alpha(const char *function, typename RandomAccessContainer::value_type *result, const Policy &pol) { - using RealType = RandomAccessContainer::value_type; + using RealType = typename RandomAccessContainer::value_type; for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) { if (!(boost::math::isfinite)(alpha[i]) || (alpha[i] <= 0)) @@ -70,7 +69,7 @@ inline bool check_prob(const char *function, typename RandomAccessContainer::value_type *result, const Policy &pol) { - using RealType = RandomAccessContainer::value_type; + using RealType = typename RandomAccessContainer::value_type; for (decltype(p.size()) i = 0; i < p.size(); ++i) { if ((p[i] < 0) || (p[i] > 1) || !(boost::math::isfinite)(p[i])) @@ -182,29 +181,6 @@ inline bool check_mean_and_variance(const char *function, { return check_mean(function, mean, result, pol) && check_variance(function, variance, result, pol); } // bool check_mean_and_variance - -template -inline typename RandomAccessContainer::value_type mvar_beta( - const RandomAccessContainer &alpha, - const typename RandomAccessContainer::value_type &b) -{ - // B(a1,a2,...ak) = tgamma(a1+a2+...+ak)/(tgamma(a1)*tgamma(a2)...*tgamma(ak) - using RealType = typename RandomAccessContainer::value_type; - RealType mb; - RealType alpha_sum = accumulate(alpha.begin(), alpha.end(), b * alpha.size()); - for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) - { - mb *= tgamma(alpha[i] + b); - } - mb /= tgamma(alpha_sum); - return mb; -} // mvar_beta - -template -inline typename RandomAccessContainer::value_type alpha0(const RandomAccessContainer &alpha) -{ - return accumulate(alpha.begin(), alpha.end(), 0); -} } // namespace dirichlet_detail template , class Policy = policies::policy<>> @@ -223,20 +199,20 @@ class dirichlet_distribution } // dirichlet_distribution constructor. // Get the concentration parameters. - RandomAccessContainer &alpha() const { return m_alpha; } + const RandomAccessContainer &get_alpha() const { return m_alpha; } // Set the concentration parameters. - RandomAccessContainer &alpha() { return m_alpha; } + RandomAccessContainer &set_alpha() { return m_alpha; } // Get the order of concentration parameters. - decltype(m_alpha.size()) &order() const { return m_alpha.size(); } + size_t order() const { return m_alpha.size(); } // Get alpha from mean and variance. - static void find_alpha( + void find_alpha( RandomAccessContainer &&mean, // Expected value of mean. RandomAccessContainer &&variance) // Expected value of variance. { assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); - static const char *function = "boost::math::dirichlet_distribution<%1%>::alpha_from_mean_and_variance"; + static const char *function = "boost::math::dirichlet_distribution<%1%>::find_alpha"; RealType result = 0; // of error checks. if (!(dirichlet_detail::check_mean(function, mean, &result, Policy()) && dirichlet_detail::check_variance(function, variance, &result, Policy()))) { @@ -248,9 +224,27 @@ class dirichlet_distribution } } // void find_alpha + RealType normalizing_constant(RealType b = 0.0) const + { + // B(a1,a2,...ak) = tgamma(a1+a2+...+ak)/(tgamma(a1)*tgamma(a2)...*tgamma(ak) + RealType mb; + RealType alpha_sum = accumulate(m_alpha.begin(), m_alpha.end(), b * m_alpha.size()); + for (decltype(m_alpha.size()) i = 0; i < m_alpha.size(); ++i) + { + mb *= tgamma(m_alpha[i] + b); + } + mb /= tgamma(alpha_sum); + return mb; + } // normalizing_constant + + RealType sum_alpha() const + { + return accumulate(m_alpha.begin(), m_alpha.end(), 0); + } // sum_alpha + private: RandomAccessContainer m_alpha; // https://en.wikipedia.org/wiki/Concentration_parameter. -}; // template class dirichlet_distribution +}; // template class dirichlet_distribution template inline const std::pair< @@ -277,11 +271,12 @@ support(const dirichlet_distribution & /* dist */ template inline RandomAccessContainer mean(const dirichlet_distribution &dist) { // Mean of dirichlet distribution = c[i]/sum(c). - // using RealType = typename RandomAccessContainer::value_type; + using RealType = typename RandomAccessContainer::value_type; + RealType A = dist.sum_alpha(); RandomAccessContainer m(dist.order()); for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - m[i] = dist.m_alpha[i] / dirichlet_detail::alpha0(dist.m_alpha); + m[i] = dist.get_alpha()[i] / A; } return m; } // mean @@ -289,10 +284,12 @@ inline RandomAccessContainer mean(const dirichlet_distribution inline RandomAccessContainer variance(const dirichlet_distribution &dist) { + using RealType = typename RandomAccessContainer::value_type; + RealType A = dist.sum_alpha(); RandomAccessContainer v(dist.order()); for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - v[i] = dist.m_alpha[i] / dirichlet_detail::alpha0(dist.m_alpha) * (1 - dist.m_alpha[i] / dirichlet_detail::alpha0(dist.alpha)) / (1 + dirichlet_detail::alpha0(dist.alpha)); + v[i] = dist.get_alpha()[i] / A * (1 - dist.get_alpha()[i] / A) / (1 + A); } return v; } // variance @@ -301,7 +298,7 @@ template inline RandomAccessContainer standard_deviation(const dirichlet_distribution &dist) { RandomAccessContainer std = variance(dist); - for (decltype(dist.order()) i = 0; i < s.size(); ++i) + for (decltype(dist.order()) i = 0; i < std.size(); ++i) { std[i] = std::sqrt(std[i]); } @@ -313,19 +310,21 @@ inline RandomAccessContainer mode(const dirichlet_distribution( function, - "mode undefined for alpha = %1%, must be > 1!", dist.m_alpha[i], Policy()); + "mode undefined for alpha = %1%, must be > 1!", dist.get_alpha()[i], Policy()); return result; } else { - m[i] = (dist.m_alpha[i] - 1) / (dirichlet_detail::alpha0(dist.m_alpha) - dist.order()); + m[i] = (dist.get_alpha()[i] - 1) / (A - dist.order()); } } return m; @@ -335,10 +334,11 @@ template inline typename RandomAccessContainer::value_type entropy(const dirichlet_distribution &dist) { using RealType = typename RandomAccessContainer::value_type; - RealType ent = std::log(dirichlet_detail::mvar_beta(dist.m_alpha, 0)) + (dirichlet_detail::alpha0(dist.m_alpha) - dist.order()) * digamma(dirichlet_detail::alpha0(dist.m_alpha)); + using std::log; + RealType ent = log(dist.normalizing_constant()) + (dist.sum_alpha() - dist.order()) * digamma(dist.sum_alpha()); for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - ent += (dist.m_alpha[i] - 1) * digamma(dist.m_alpha[i]); + ent += (dist.get_alpha()[i] - 1) * digamma(dist.get_alpha()[i]); } return ent; } @@ -348,11 +348,11 @@ inline RandomAccessContainer skewness(const dirichlet_distribution inline typename RandomAccessContainer::value_type cdf( const complemented2_type, - typename RandomAccessContainer::value_type> &c) + typename RandomAccessContainer::value_type> &c) { // Complemented Cumulative Distribution Function dirichlet. using RealType = typename RandomAccessContainer::value_type; using std::pow; BOOST_MATH_STD_USING // for ADL of std functions - static const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; + const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; RealType const &x = c.param; dirichlet_distribution const &dist = c.dist; + RealType A = dist.sum_alpha(); RealType result = 0; // Argument checks. - if (!dirichlet_detail::check_dist_and_x(function, x, &result, Policy())) + if (!dirichlet_detail::check_dist_and_x(function, dist.get_alpha(), x, &result, Policy())) { return result; } RealType cumm = 1; for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - cumm *= pow(x[i], dist.m_alpha[i]) / tgamma(dist.m_alpha[i]) / dist.m_alpha[i]; + cumm *= (1 - pow(x[i], dist.get_alpha()[i])) / tgamma(dist.get_alpha()[i]) / dist.get_alpha()[i]; } - cumm *= tgamma(dirichlet_detail::alpha0(dist.m_alpha)); + cumm *= tgamma(A); return cumm; } // dirichlet cdf @@ -471,4 +473,4 @@ inline typename RandomAccessContainer::value_type cdf( #pragma warning(pop) #endif -#endif // BOOST_MATH_DIST_DIRICHLET_HPP +#endif // BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP From 658afe91aa6f027cac4e4af1a570d51fdbca8226 Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Tue, 3 Mar 2020 20:59:45 +0530 Subject: [PATCH 4/7] error correction --- .../boost/math/distributions/dirichlet.hpp | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp index 576e14b524..9a81b06fb2 100644 --- a/include/boost/math/distributions/dirichlet.hpp +++ b/include/boost/math/distributions/dirichlet.hpp @@ -100,7 +100,7 @@ inline bool check_x(const char *function, return false; } } - return true; + return std::accumulate(x.begin(), x.end(), 0.0) <= 1.0; } // bool check_x template @@ -200,16 +200,14 @@ class dirichlet_distribution // Get the concentration parameters. const RandomAccessContainer &get_alpha() const { return m_alpha; } - // Set the concentration parameters. - RandomAccessContainer &set_alpha() { return m_alpha; } // Get the order of concentration parameters. size_t order() const { return m_alpha.size(); } // Get alpha from mean and variance. void find_alpha( - RandomAccessContainer &&mean, // Expected value of mean. - RandomAccessContainer &&variance) // Expected value of variance. + RandomAccessContainer &mean, // Expected value of mean. + RandomAccessContainer &variance) const// Expected value of variance. { assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); static const char *function = "boost::math::dirichlet_distribution<%1%>::find_alpha"; @@ -226,8 +224,8 @@ class dirichlet_distribution RealType normalizing_constant(RealType b = 0.0) const { - // B(a1,a2,...ak) = tgamma(a1+a2+...+ak)/(tgamma(a1)*tgamma(a2)...*tgamma(ak) - RealType mb; + // B(a1,a2,...ak) = (tgamma(a1)*tgamma(a2)...*tgamma(ak)/tgamma(a1+a2+...+ak) + RealType mb = 1.0; RealType alpha_sum = accumulate(m_alpha.begin(), m_alpha.end(), b * m_alpha.size()); for (decltype(m_alpha.size()) i = 0; i < m_alpha.size(); ++i) { @@ -239,7 +237,8 @@ class dirichlet_distribution RealType sum_alpha() const { - return accumulate(m_alpha.begin(), m_alpha.end(), 0); + RealType init = 0.0; + return std::accumulate(m_alpha.begin(), m_alpha.end(), init); } // sum_alpha private: @@ -289,7 +288,7 @@ inline RandomAccessContainer variance(const dirichlet_distribution inline RandomAccessContainer standard_deviation(const dirichlet_distribution &dist) { + using std::sqrt; RandomAccessContainer std = variance(dist); for (decltype(dist.order()) i = 0; i < std.size(); ++i) { - std[i] = std::sqrt(std[i]); + std[i] = sqrt(std[i]); } return std; } // standard_deviation @@ -347,13 +347,14 @@ template inline RandomAccessContainer skewness(const dirichlet_distribution &dist) { using RealType = typename RandomAccessContainer::value_type; + using std::sqrt; RandomAccessContainer s(dist.order()); RealType A = dist.sum_alpha(); RealType aj; for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { aj = dist.get_alpha()[i]; - s[i] = std::sqrt(aj * (A + 1) / (A - aj)) * ((aj + 2) * (aj + 1) * A * A / (aj * (A + 2) * (A - aj)) - 3 - aj * (A + 1) / (A - aj)); + s[i] = sqrt(aj * (A + 1) / (A - aj)) * ((aj + 2) * (aj + 1) * A * A / (aj * (A + 2) * (A - aj)) - 3 - aj * (A + 1) / (A - aj)); } return s; } @@ -396,7 +397,7 @@ inline typename RandomAccessContainer::value_type pdf( const char *function = "boost::math::pdf(dirichlet_distribution<%1%> const&, %1%)"; BOOST_MATH_STD_USING // for ADL of std functions - RealType result = 0; + RealType result = 0; if (!dirichlet_detail::check_dist_and_x(function, dist.get_alpha(), x, &result, Policy())) { return result; From c773293fd43f93e7e136eef9b0bbb43263cdb5c9 Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Fri, 6 Mar 2020 14:22:59 +0530 Subject: [PATCH 5/7] adding tests dirichlet --- .../boost/math/distributions/dirichlet.hpp | 124 +++------ .../compile_test/dist_dirichlet_incl_test.cpp | 25 -- test/test_dirichlet_dist.cpp | 258 ++++++++++++++++++ 3 files changed, 299 insertions(+), 108 deletions(-) delete mode 100644 test/compile_test/dist_dirichlet_incl_test.cpp create mode 100644 test/test_dirichlet_dist.cpp diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp index 9a81b06fb2..d014cfda2a 100644 --- a/include/boost/math/distributions/dirichlet.hpp +++ b/include/boost/math/distributions/dirichlet.hpp @@ -50,6 +50,12 @@ inline bool check_alpha(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; + using std::invalid_argument; + if (alpha.size() < 1) + { + throw invalid_argument("Size of 'concentration parameters' must be greater than 0."); + return false; + } for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) { if (!(boost::math::isfinite)(alpha[i]) || (alpha[i] <= 0)) @@ -63,26 +69,6 @@ inline bool check_alpha(const char *function, return true; } // bool check_alpha -template -inline bool check_prob(const char *function, - const RandomAccessContainer &p, - typename RandomAccessContainer::value_type *result, - const Policy &pol) -{ - using RealType = typename RandomAccessContainer::value_type; - for (decltype(p.size()) i = 0; i < p.size(); ++i) - { - if ((p[i] < 0) || (p[i] > 1) || !(boost::math::isfinite)(p[i])) - { - *result = policies::raise_domain_error( - function, - "Probability argument is %1%, but must be >= 0 and <= 1 !", p[i], pol); - return false; - } - } - return true; -} // bool check_prob - template inline bool check_x(const char *function, const RandomAccessContainer &x, @@ -90,6 +76,12 @@ inline bool check_x(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; + using std::invalid_argument; + if (x.size() < 1) + { + throw invalid_argument("Size of 'quantiles' vector must be greater than 0."); + return false; + } for (decltype(x.size()) i = 0; i < x.size(); ++i) { if (!(boost::math::isfinite)(x[i]) || (x[i] < 0) || (x[i] > 1)) @@ -100,38 +92,20 @@ inline bool check_x(const char *function, return false; } } - return std::accumulate(x.begin(), x.end(), 0.0) <= 1.0; + return accumulate(x.begin(), x.end(), 0.0) <= 1.0; } // bool check_x -template -inline bool check_dist(const char *function, - const RandomAccessContainer &alpha, - typename RandomAccessContainer::value_type *result, - const Policy &pol) -{ - return check_alpha(function, alpha, result, pol); -} // bool check_dist template -inline bool check_dist_and_x(const char *function, +inline bool check_alpha_and_x(const char *function, const RandomAccessContainer &alpha, const RandomAccessContainer &x, typename RandomAccessContainer::value_type *result, const Policy &pol) { - return check_dist(function, alpha, result, pol) && check_x(function, x, result, pol); + return check_alpha(function, alpha, result, pol) && check_x(function, x, result, pol); } // bool check_dist_and_x -template -inline bool check_dist_and_prob(const char *function, - const RandomAccessContainer &alpha, - const RandomAccessContainer &p, - typename RandomAccessContainer::value_type *result, - const Policy &pol) -{ - return check_dist(function, alpha, result, pol) && check_prob(function, p, result, pol); -} // bool check_dist_and_prob - template inline bool check_mean(const char *function, const RandomAccessContainer &mean, @@ -139,6 +113,12 @@ inline bool check_mean(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; + using std::invalid_argument; + if (mean.size() < 1) + { + throw invalid_argument("Size of 'mean' vector must be greater than 0."); + return false; + } for (decltype(mean.size()) i = 0; i < mean.size(); ++i) { if (!(boost::math::isfinite)(mean[i]) || (mean[i] <= 0)) @@ -159,6 +139,12 @@ inline bool check_variance(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; + using std::invalid_argument; + if (variance.size() < 1) + { + throw invalid_argument("Size of 'variance' vector must be greater than 0."); + return false; + } for (decltype(variance.size()) i = 0; i < variance.size(); ++i) { if (!(boost::math::isfinite)(variance[i]) || (variance[i] <= 0)) @@ -191,28 +177,26 @@ class dirichlet_distribution public: dirichlet_distribution(RandomAccessContainer &&alpha) : m_alpha(alpha) { - RealType result; - dirichlet_detail::check_dist( - "boost::math::dirichlet_distribution<%1%>::dirichlet_distribution", - alpha, - &result, Policy()); + RealType result = 0; + const char *function = "boost::math::dirichlet_distribution<%1%>::dirichlet_distribution"; + dirichlet_detail::check_alpha(function, alpha, &result, Policy()); } // dirichlet_distribution constructor. // Get the concentration parameters. const RandomAccessContainer &get_alpha() const { return m_alpha; } // Get the order of concentration parameters. - size_t order() const { return m_alpha.size(); } + auto order() const { return m_alpha.size(); } // Get alpha from mean and variance. - void find_alpha( + auto find_alpha( RandomAccessContainer &mean, // Expected value of mean. - RandomAccessContainer &variance) const// Expected value of variance. + RandomAccessContainer &variance) // Expected value of variance. { assert(("Dimensions of mean and variance must be same!", mean.size() == variance.size())); static const char *function = "boost::math::dirichlet_distribution<%1%>::find_alpha"; RealType result = 0; // of error checks. - if (!(dirichlet_detail::check_mean(function, mean, &result, Policy()) && dirichlet_detail::check_variance(function, variance, &result, Policy()))) + if (!dirichlet_detail::check_mean_and_variance(function, mean, variance, &result, Policy())) { return result; } @@ -238,7 +222,7 @@ class dirichlet_distribution RealType sum_alpha() const { RealType init = 0.0; - return std::accumulate(m_alpha.begin(), m_alpha.end(), init); + return accumulate(m_alpha.begin(), m_alpha.end(), init); } // sum_alpha private: @@ -310,14 +294,14 @@ inline RandomAccessContainer mode(const dirichlet_distribution( + result[0] = policies::raise_domain_error( function, "mode undefined for alpha = %1%, must be > 1!", dist.get_alpha()[i], Policy()); return result; @@ -394,11 +378,11 @@ inline typename RandomAccessContainer::value_type pdf( using RealType = typename RandomAccessContainer::value_type; using std::pow; BOOST_FPU_EXCEPTION_GUARD + BOOST_MATH_STD_USING // for ADL of std functions const char *function = "boost::math::pdf(dirichlet_distribution<%1%> const&, %1%)"; - BOOST_MATH_STD_USING // for ADL of std functions RealType result = 0; - if (!dirichlet_detail::check_dist_and_x(function, dist.get_alpha(), x, &result, Policy())) + if (!dirichlet_detail::check_x(function, x, &result, Policy())) { return result; } @@ -420,10 +404,10 @@ inline typename RandomAccessContainer::value_type cdf( using RealType = typename RandomAccessContainer::value_type; using std::pow; BOOST_MATH_STD_USING // for ADL of std functions - RealType A = dist.sum_alpha(); + RealType A = dist.sum_alpha(); const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; RealType result = 0; // Arguments check. - if (!dirichlet_detail::check_dist_and_x(function, dist.get_alpha(), x, &result, Policy())) + if (!dirichlet_detail::check_x(function, x, &result, Policy())) { return result; } @@ -436,32 +420,6 @@ inline typename RandomAccessContainer::value_type cdf( return c; } // dirichlet cdf -template -inline typename RandomAccessContainer::value_type cdf( - const complemented2_type, - typename RandomAccessContainer::value_type> &c) -{ // Complemented Cumulative Distribution Function dirichlet. - using RealType = typename RandomAccessContainer::value_type; - using std::pow; - BOOST_MATH_STD_USING // for ADL of std functions - const char *function = "boost::math::cdf(dirichlet_distribution<%1%> const&, %1%)"; - RealType const &x = c.param; - dirichlet_distribution const &dist = c.dist; - RealType A = dist.sum_alpha(); - RealType result = 0; // Argument checks. - if (!dirichlet_detail::check_dist_and_x(function, dist.get_alpha(), x, &result, Policy())) - { - return result; - } - RealType cumm = 1; - for (decltype(dist.order()) i = 0; i < dist.order(); ++i) - { - cumm *= (1 - pow(x[i], dist.get_alpha()[i])) / tgamma(dist.get_alpha()[i]) / dist.get_alpha()[i]; - } - cumm *= tgamma(A); - return cumm; -} // dirichlet cdf - } // namespace math } // namespace boost diff --git a/test/compile_test/dist_dirichlet_incl_test.cpp b/test/compile_test/dist_dirichlet_incl_test.cpp deleted file mode 100644 index e8b6364c93..0000000000 --- a/test/compile_test/dist_dirichlet_incl_test.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright Mrityunjay Tripathi 2020. -// Use, modification and distribution are subject to the -// Boost Software License, Version 1.0. (See accompanying file -// LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) -// -// Basic sanity check that header -// #includes all the files that it needs to. -// -#include -// -// Note this header includes no other headers, this is -// important if this test is to be meaningful: -// -#include "test_compile_result.hpp" - -void compile_and_link_test() -{ - TEST_DIST_FUNC(dirichlet) -} - -template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; -template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; -#ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS -template class boost::math::dirichlet_distribution, boost::math::policies::policy<>>; -#endif diff --git a/test/test_dirichlet_dist.cpp b/test/test_dirichlet_dist.cpp new file mode 100644 index 0000000000..fa6234d494 --- /dev/null +++ b/test/test_dirichlet_dist.cpp @@ -0,0 +1,258 @@ +// test_dirichlet_dist.cpp + +// Copyright Mrityunjay Tripathi 2020. + +// Use, modification and distribution are subject to the +// Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt +// or copy at http://www.boost.org/LICENSE_1_0.txt) + +// Basic sanity tests for the Dirichlet Distribution. + +#ifdef _MSC_VER +#pragma warning(disable : 4127) // conditional expression is constant. +#pragma warning(disable : 4996) // POSIX name for this item is deprecated. +#pragma warning(disable : 4224) // nonstandard extension used : formal parameter 'arg' was previously defined as a type. +#endif + +#include +#include +#include +#include // for test_main +#include +#include // for real_concept +#include // for dirichlet_distribution +#include // for BOOST_CHECK_CLOSE_FRACTION +#include "test_out_of_range.hpp" +#include "math_unit_test.hpp" + +using boost::math::dirichlet_distribution; +using boost::math::concepts::real_concept; +using std::numeric_limits; +using std::domain_error; + +#define BOOST_TEST_MAIN +#define BOOST_MATH_CHECK_THROW + +template +void test_spot( + RandomAccessContainer &&alpha, // concentration parameters 'a' + RandomAccessContainer &&x, // quantiles 'x' + RandomAccessContainer &&mean, // mean + RandomAccessContainer &&mode, // mode + RandomAccessContainer &&var, // variance + RandomAccessContainer &&skewness, // skewness + RandomAccessContainer &&kurtosis, // kurtosis + typename RandomAccessContainer::value_type entropy, // entropy + typename RandomAccessContainer::value_type pdf, // pdf + typename RandomAccessContainer::value_type cdf, // cdf + typename RandomAccessContainer::value_type tol) // Test tolerance. +{ + // using RealType = typename RandomAccessContainer::value_type; + typedef RandomAccessContainer V; + boost::math::dirichlet_distribution diri(std::move(alpha)); + + V calc_mean = boost::math::mean(diri); + V calc_variance = boost::math::variance(diri); + V calc_mode = boost::math::mode(diri); + V calc_kurtosis = boost::math::kurtosis(diri); + V calc_skewness = boost::math::skewness(diri); + + BOOST_CHECK_CLOSE_FRACTION(boost::math::pdf(diri, x), pdf, tol); + BOOST_CHECK_CLOSE_FRACTION(boost::math::cdf(diri, x), cdf, tol); + BOOST_CHECK_CLOSE_FRACTION(boost::math::entropy(diri), entropy, tol); + + for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) + { + BOOST_CHECK_CLOSE_FRACTION(calc_mean[i], mean[i], tol); + BOOST_CHECK_CLOSE_FRACTION(calc_variance[i], var[i], tol); + BOOST_CHECK_CLOSE_FRACTION(calc_kurtosis[i], kurtosis[i], tol); + BOOST_CHECK_CLOSE_FRACTION(calc_skewness[i], skewness[i], tol); + } +} // template void test_spot + +template +void test_spots(RandomAccessContainer) +{ + typedef RandomAccessContainer V; + using RealType = typename V::value_type; + RealType tolerance = std::max(boost::math::tools::epsilon(), + static_cast(std::numeric_limits::epsilon())); + V alpha(2); + V x(2); + + // Error checks: + // Necessary conditions for instantiation: + // 1. alpha[i] > 0. + alpha[0] = 0.35; + alpha[1] = -1.72; // alpha[1] < 0. + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); + + // Domain test for pdf. Necessary conditions for pdf: + // 1. alpha[i] > 0 + // 2. 0 <= x[i] <=1 + // 3. sum(x) <= 1. + alpha[0] = -0.2; + alpha[1] = 1.7; // alpha[0] < 0. + x[0] = 0.5; + x[1] = 0.5; + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.36; + alpha[1] = 0.0; // alpha[1] = 0. + x[0] = 0.47; + x[1] = 0.53; + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.26; + alpha[1] = 2.99; + x[0] = 0.5; + x[1] = 0.75; // sum(x) > 1.0 + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.56; + alpha[1] = 4.00; + x[0] = 0.31; + x[1] = -0.03; // x[1] < 0. + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.56; + alpha[1] = 4.00; + x[0] = 0.31; + x[1] = 1.06; // x[1] > 1. + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + // Domain test for cdf. Necessary conditions for cdf: + // 1. alpha[i] > 0 + // 2. 0 <= x[i] <= 1 + // 3. sum(x) <= 1. + alpha[0] = 1.56; + alpha[1] = 4.00; + x[0] = 0.31; + x[1] = 1.06; // x[1] > 1. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 3.756; + alpha[1] = 4.91; + x[0] = 0.31; + x[1] = -1.06; // x[1] < 0. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.56; + alpha[1] = -4.00; // alpha[1] < 0 + x[0] = 0.31; + x[1] = 0.69; + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 0.0; + alpha[1] = 4.00; // alpha[0] = 0. + x[0] = 0.25; + x[1] = 0.75; + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + alpha[0] = 1.56; + alpha[1] = 4.00; + x[0] = 0.31; + x[1] = 0.71; // sum(x) > 1. + BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + + // Domain test for mode. Necessary conditions for mode: + // 1. alpha[i] > 1. + alpha[0] = 1.0; + alpha[1] = 1.4; // alpha[0] = 1. + BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); + + alpha[0] = 1.56; + alpha[1] = 0.92; // alpha[1] < 1. + BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); + + // Some exact values of pdf. + alpha[0] = 1.0, alpha[1] = 1.0; + x[0] = 0.5, x[1] = 0.5; + BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.0)); + + alpha[0] = 2.0, alpha[1] = 2.0; + x[0] = 0.5, x[1] = 0.5; + BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.5)); + + // Checking precalculated values on scipy. + // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html + alpha[0] = static_cast(5.778238829L); alpha[1] = static_cast(2.55821892973L); + x[0] = static_cast(0.23667289213L); x[1] = static_cast(1.0L - 0.23667289213L); + V mean = {static_cast(0.69312878L), static_cast(0.30687122L)}; + V mode = {static_cast(0.75408675L), static_cast(0.24591325L)}; + V var = {static_cast(0.0227818L), static_cast(0.0227818L)}; + V skewness = {static_cast(-0.49515513L), static_cast(0.49515513L)}; + V kurtosis = {static_cast(-139231.64518864L), static_cast(-6993.41057616L)}; + RealType entropy = static_cast(17.6747L); + RealType pdf = static_cast(0.05866154L); + RealType cdf = static_cast(0.00071693L); + tolerance *= 1E+07; + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(mode), + std::move(var), + std::move(skewness), + std::move(kurtosis), + entropy, pdf, cdf, tolerance); + + // No longer allow any parameter to be NaN or inf. + if (std::numeric_limits::has_quiet_NaN) + { + RealType not_a_num = std::numeric_limits::quiet_NaN(); + alpha[0] = not_a_num; alpha[1] = 0.37; +#ifndef BOOST_NO_EXCEPTIONS + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#else + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#endif + + // Non-finite parameters should throw. + alpha[0] = 1.67; alpha[1] = 3.8; + x[0] = not_a_num; x[1] = 0.5; + dirichlet_distribution w(std::move(alpha)); + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = NaN + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = NaN + } // has_quiet_NaN + + if (std::numeric_limits::has_infinity) + { + // Attempt to construct from non-finite should throw. + RealType infinite = std::numeric_limits::infinity(); + alpha[0] = infinite; + alpha[1] = 7.2; +#ifndef BOOST_NO_EXCEPTIONS + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#else + BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); +#endif + alpha[0] = 1.42; alpha[1] = 7.91; + x[0] = 0.25; x[1] = infinite; + dirichlet_distribution w(std::move(alpha)); + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = inf + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = inf + x[1] = -infinite; + BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = -inf + BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = -inf + } +} // test_spots() + + +BOOST_AUTO_TEST_CASE(test_main) +{ + BOOST_MATH_CONTROL_FP; + test_spots(std::vector(0.0L)); + + test_spots(std::vector(0.0)); + + test_spots(std::vector(0.0F)); + +// #ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS +// test_spots(); // Test long double. +// #if !BOOST_WORKAROUND(__BORLANDC__, BOOST_TESTED_AT(0x582)) +// test_spots(boost::math::concepts::real_concept(0.)); // Test real concept. +// #endif +} // BOOST_AUTO_TEST_CASE( test_main ) \ No newline at end of file From 6a9b8a5a656e8bbc358100741a8536ac77c83d39 Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Sun, 8 Mar 2020 22:17:21 +0530 Subject: [PATCH 6/7] entropy corrected [CI SKIP] --- .../boost/math/distributions/dirichlet.hpp | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/include/boost/math/distributions/dirichlet.hpp b/include/boost/math/distributions/dirichlet.hpp index d014cfda2a..b9daf21ff5 100644 --- a/include/boost/math/distributions/dirichlet.hpp +++ b/include/boost/math/distributions/dirichlet.hpp @@ -21,6 +21,7 @@ #ifndef BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP #define BOOST_MATH_DISTRIBUTIONS_DIRICHLET_HPP +#include #include #include #include @@ -50,10 +51,11 @@ inline bool check_alpha(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; - using std::invalid_argument; if (alpha.size() < 1) { - throw invalid_argument("Size of 'concentration parameters' must be greater than 0."); + *result = policies::raise_domain_error( + function, + "Size of alpha vector is %1%, but must be > 0 !", alpha.size(), pol); return false; } for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) @@ -76,10 +78,11 @@ inline bool check_x(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; - using std::invalid_argument; if (x.size() < 1) { - throw invalid_argument("Size of 'quantiles' vector must be greater than 0."); + *result = policies::raise_domain_error( + function, + "Size of x is %1%, but must be > 0 !", x.size(), pol); return false; } for (decltype(x.size()) i = 0; i < x.size(); ++i) @@ -92,7 +95,15 @@ inline bool check_x(const char *function, return false; } } - return accumulate(x.begin(), x.end(), 0.0) <= 1.0; + RealType s = accumulate(x.begin(), x.end(), RealType(0)); + if (s > static_cast(1.0)) + { + *result = policies::raise_domain_error( + function, + "Sum of quantiles is %1%, but must be <= 1 !", s, pol); + return false; + } + return true; } // bool check_x @@ -113,10 +124,11 @@ inline bool check_mean(const char *function, const Policy &pol) { using RealType = typename RandomAccessContainer::value_type; - using std::invalid_argument; if (mean.size() < 1) { - throw invalid_argument("Size of 'mean' vector must be greater than 0."); + *result = policies::raise_domain_error( + function, + "Size of mean vector is %1%, but must be > 0 !", mean.size(), pol); return false; } for (decltype(mean.size()) i = 0; i < mean.size(); ++i) @@ -142,7 +154,9 @@ inline bool check_variance(const char *function, using std::invalid_argument; if (variance.size() < 1) { - throw invalid_argument("Size of 'variance' vector must be greater than 0."); + *result = policies::raise_domain_error( + function, + "Size of variance vector is %1%, but must be > 0 !", variance.size(), pol); return false; } for (decltype(variance.size()) i = 0; i < variance.size(); ++i) @@ -314,6 +328,7 @@ inline RandomAccessContainer mode(const dirichlet_distribution inline typename RandomAccessContainer::value_type entropy(const dirichlet_distribution &dist) { @@ -322,7 +337,7 @@ inline typename RandomAccessContainer::value_type entropy(const dirichlet_distri RealType ent = log(dist.normalizing_constant()) + (dist.sum_alpha() - dist.order()) * digamma(dist.sum_alpha()); for (decltype(dist.order()) i = 0; i < dist.order(); ++i) { - ent += (dist.get_alpha()[i] - 1) * digamma(dist.get_alpha()[i]); + ent -= (dist.get_alpha()[i] - 1) * digamma(dist.get_alpha()[i]); } return ent; } From 366a9bedbfb22e485492f5a78c71921a787fb6c9 Mon Sep 17 00:00:00 2001 From: Mrityunjay Tripathi Date: Mon, 9 Mar 2020 23:23:36 +0530 Subject: [PATCH 7/7] adding more test samples [CI SKIP] --- test/test_dirichlet_dist.cpp | 324 +++++++++++++++++++++++------------ 1 file changed, 215 insertions(+), 109 deletions(-) diff --git a/test/test_dirichlet_dist.cpp b/test/test_dirichlet_dist.cpp index fa6234d494..eb9f8369e7 100644 --- a/test/test_dirichlet_dist.cpp +++ b/test/test_dirichlet_dist.cpp @@ -15,10 +15,14 @@ #pragma warning(disable : 4224) // nonstandard extension used : formal parameter 'arg' was previously defined as a type. #endif +#define BOOST_TEST_MAIN +#define BOOST_MATH_CHECK_THROW +#define BOOST_LIB_DIAGNOSTIC +#define BOOST_TEST_MODULE + #include #include -#include -#include // for test_main +#include // for test_main #include #include // for real_concept #include // for dirichlet_distribution @@ -28,24 +32,17 @@ using boost::math::dirichlet_distribution; using boost::math::concepts::real_concept; -using std::numeric_limits; using std::domain_error; - -#define BOOST_TEST_MAIN -#define BOOST_MATH_CHECK_THROW +using std::numeric_limits; template void test_spot( RandomAccessContainer &&alpha, // concentration parameters 'a' RandomAccessContainer &&x, // quantiles 'x' RandomAccessContainer &&mean, // mean - RandomAccessContainer &&mode, // mode RandomAccessContainer &&var, // variance - RandomAccessContainer &&skewness, // skewness - RandomAccessContainer &&kurtosis, // kurtosis typename RandomAccessContainer::value_type entropy, // entropy typename RandomAccessContainer::value_type pdf, // pdf - typename RandomAccessContainer::value_type cdf, // cdf typename RandomAccessContainer::value_type tol) // Test tolerance. { // using RealType = typename RandomAccessContainer::value_type; @@ -54,156 +51,262 @@ void test_spot( V calc_mean = boost::math::mean(diri); V calc_variance = boost::math::variance(diri); - V calc_mode = boost::math::mode(diri); - V calc_kurtosis = boost::math::kurtosis(diri); - V calc_skewness = boost::math::skewness(diri); BOOST_CHECK_CLOSE_FRACTION(boost::math::pdf(diri, x), pdf, tol); - BOOST_CHECK_CLOSE_FRACTION(boost::math::cdf(diri, x), cdf, tol); BOOST_CHECK_CLOSE_FRACTION(boost::math::entropy(diri), entropy, tol); for (decltype(alpha.size()) i = 0; i < alpha.size(); ++i) { BOOST_CHECK_CLOSE_FRACTION(calc_mean[i], mean[i], tol); BOOST_CHECK_CLOSE_FRACTION(calc_variance[i], var[i], tol); - BOOST_CHECK_CLOSE_FRACTION(calc_kurtosis[i], kurtosis[i], tol); - BOOST_CHECK_CLOSE_FRACTION(calc_skewness[i], skewness[i], tol); } } // template void test_spot template -void test_spots(RandomAccessContainer) +void test_spots() { typedef RandomAccessContainer V; using RealType = typename V::value_type; - RealType tolerance = std::max(boost::math::tools::epsilon(), - static_cast(std::numeric_limits::epsilon())); - V alpha(2); - V x(2); + RealType tolerance = 1e-8; // Error checks: // Necessary conditions for instantiation: - // 1. alpha[i] > 0. - alpha[0] = 0.35; - alpha[1] = -1.72; // alpha[1] < 0. + // 1. alpha.size() > 0. + // 2. alpha[i] > 0. + + V alpha; // alpha.size() == 0. + V x; // x.size() == 0. BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); + alpha.resize(2); + alpha[0] = static_cast(0.35); + alpha[1] = static_cast(-1.72); // alpha[1] < 0. BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); // Domain test for pdf. Necessary conditions for pdf: - // 1. alpha[i] > 0 - // 2. 0 <= x[i] <=1 - // 3. sum(x) <= 1. - alpha[0] = -0.2; - alpha[1] = 1.7; // alpha[0] < 0. - x[0] = 0.5; - x[1] = 0.5; + // 1. alpha[i] > 0. + // 2. x.size() > 0. + // 3. 0 <= x[i] <=1. + // 4. sum(x) <= 1. + alpha[0] = static_cast(0.2); + alpha[1] = static_cast(1.7); + BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); + x[0] = static_cast(0.5); + x[1] = static_cast(0.5); BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.36; - alpha[1] = 0.0; // alpha[1] = 0. - x[0] = 0.47; - x[1] = 0.53; + alpha[0] = static_cast(1.36); + alpha[1] = static_cast(0.0); // alpha[1] = 0. + x[0] = static_cast(0.47); + x[1] = static_cast(0.53); BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.26; - alpha[1] = 2.99; - x[0] = 0.5; - x[1] = 0.75; // sum(x) > 1.0 + alpha[0] = static_cast(1.26); + alpha[1] = static_cast(2.99); + x[0] = static_cast(0.5); + x[1] = static_cast(0.75); // sum(x) > 1.0 BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.56; - alpha[1] = 4.00; - x[0] = 0.31; - x[1] = -0.03; // x[1] < 0. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(-0.03); // x[1] < 0. BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.56; - alpha[1] = 4.00; - x[0] = 0.31; - x[1] = 1.06; // x[1] > 1. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(1.06); // x[1] > 1. BOOST_MATH_CHECK_THROW(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); // Domain test for cdf. Necessary conditions for cdf: // 1. alpha[i] > 0 // 2. 0 <= x[i] <= 1 // 3. sum(x) <= 1. - alpha[0] = 1.56; - alpha[1] = 4.00; - x[0] = 0.31; - x[1] = 1.06; // x[1] > 1. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(1.06); // x[1] > 1. BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 3.756; - alpha[1] = 4.91; - x[0] = 0.31; - x[1] = -1.06; // x[1] < 0. + alpha[0] = static_cast(3.756); + alpha[1] = static_cast(4.91); + x[0] = static_cast(0.31); + x[1] = static_cast(-1.06); // x[1] < 0. BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.56; - alpha[1] = -4.00; // alpha[1] < 0 - x[0] = 0.31; - x[1] = 0.69; + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(-4.00); // alpha[1] < 0 + x[0] = static_cast(0.31); + x[1] = static_cast(0.69); BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 0.0; - alpha[1] = 4.00; // alpha[0] = 0. - x[0] = 0.25; - x[1] = 0.75; + alpha[0] = static_cast(0.0); + alpha[1] = static_cast(4.00); // alpha[0] = 0. + x[0] = static_cast(0.25); + x[1] = static_cast(0.75); BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); - alpha[0] = 1.56; - alpha[1] = 4.00; - x[0] = 0.31; - x[1] = 0.71; // sum(x) > 1. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(4.00); + x[0] = static_cast(0.31); + x[1] = static_cast(0.71); // sum(x) > 1. BOOST_MATH_CHECK_THROW(boost::math::cdf(dirichlet_distribution(std::move(alpha)), x), std::domain_error); // Domain test for mode. Necessary conditions for mode: // 1. alpha[i] > 1. - alpha[0] = 1.0; - alpha[1] = 1.4; // alpha[0] = 1. + alpha[0] = static_cast(1.0); + alpha[1] = static_cast(1.4); // alpha[0] = 1. BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); - alpha[0] = 1.56; - alpha[1] = 0.92; // alpha[1] < 1. + alpha[0] = static_cast(1.56); + alpha[1] = static_cast(0.92); // alpha[1] < 1. BOOST_MATH_CHECK_THROW(boost::math::mode(dirichlet_distribution(std::move(alpha))), std::domain_error); // Some exact values of pdf. - alpha[0] = 1.0, alpha[1] = 1.0; - x[0] = 0.5, x[1] = 0.5; + alpha[0] = static_cast(1.0), alpha[1] = static_cast(1.0); + x[0] = static_cast(0.5), x[1] = static_cast(0.5); BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.0)); - alpha[0] = 2.0, alpha[1] = 2.0; - x[0] = 0.5, x[1] = 0.5; + alpha[0] = static_cast(2.0), alpha[1] = static_cast(2.0); + x[0] = static_cast(0.5), x[1] = static_cast(0.5); BOOST_CHECK_EQUAL(boost::math::pdf(dirichlet_distribution(std::move(alpha)), x), static_cast(1.5)); // Checking precalculated values on scipy. // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.dirichlet.html - alpha[0] = static_cast(5.778238829L); alpha[1] = static_cast(2.55821892973L); - x[0] = static_cast(0.23667289213L); x[1] = static_cast(1.0L - 0.23667289213L); - V mean = {static_cast(0.69312878L), static_cast(0.30687122L)}; - V mode = {static_cast(0.75408675L), static_cast(0.24591325L)}; - V var = {static_cast(0.0227818L), static_cast(0.0227818L)}; - V skewness = {static_cast(-0.49515513L), static_cast(0.49515513L)}; - V kurtosis = {static_cast(-139231.64518864L), static_cast(-6993.41057616L)}; - RealType entropy = static_cast(17.6747L); - RealType pdf = static_cast(0.05866154L); - RealType cdf = static_cast(0.00071693L); - tolerance *= 1E+07; + alpha[0] = static_cast(5.778238829); + alpha[1] = static_cast(2.55821892973); + x[0] = static_cast(0.23667289213); + x[1] = static_cast(0.76332710787); + V mean = {static_cast(0.693128783978901), static_cast(0.3068712160210989)}; + V var = {static_cast(0.022781795654775592), static_cast(0.022781795654775592)}; + RealType entropy = static_cast(-0.516646371355904); + RealType pdf = static_cast(0.05866153821852176); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(5.310948003052013); + alpha[1] = static_cast(8.003963132298916); + x[0] = static_cast(0.35042614416132284); + x[1] = static_cast(0.64957385583867716); + mean[0] = static_cast(0.398872207937724); + mean[1] = static_cast(0.601127792062276); + var[0] = static_cast(0.016749888798155716); + var[1] = static_cast(0.016749888798155716); + pdf = static_cast(2.870121181949622); + entropy = static_cast(-0.6347509574442718); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(8.764102220201394); + alpha[1] = static_cast(4.348446856921846); + x[0] = static_cast(0.6037585982123262); + x[1] = static_cast(0.39624140178767375); + mean[0] = static_cast(0.6683751701255137); + mean[1] = static_cast(0.33162482987448627); + var[0] = static_cast(0.015705865813037533); + var[1] = static_cast(0.015705865813037533); + pdf = static_cast(2.473329499915834); + entropy = static_cast(-0.6769547381491741); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha.resize(3); + x.resize(3); + mean.resize(3); + var.resize(3); + alpha[0] = static_cast(5.622313698848736); + alpha[1] = static_cast(0.3516907178071482); + alpha[2] = static_cast(9.15629985496498); + x[0] = static_cast(0.6571425803855344); + x[1] = static_cast(0.2972004956337586); + x[2] = static_cast(0.04565692398070697); + mean[0] = static_cast(0.37159290374577736); + mean[1] = static_cast(0.023244127249099442); + mean[2] = static_cast(0.6051629690051231); + var[0] = static_cast(0.014476578600094457); + var[1] = static_cast(0.0014075269390591417); + var[2] = static_cast(0.014813158259538361); + pdf = static_cast(4.97846312846897e-08); + entropy = static_cast(-4.047215462643532); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha.resize(4); + x.resize(4); + mean.resize(4); + var.resize(4); + alpha[0] = static_cast(5.958168192947443); + alpha[1] = static_cast(6.823198187239482); + alpha[2] = static_cast(6.297779996686504); + alpha[3] = static_cast(4.396226676824867); + x[0] = static_cast(0.15589020332495018); + x[1] = static_cast(0.3893497609653562); + x[2] = static_cast(0.060839680922786556); + x[3] = static_cast(0.393920354786907); + mean[0] = static_cast(0.2538050483508204); + mean[1] = static_cast(0.2906534508155371); + mean[2] = static_cast(0.26827177494818794); + mean[3] = static_cast(0.1872697258854546); + var[0] = static_cast(0.007737902313764369); + var[1] = static_cast(0.00842373359916587); + var[2] = static_cast(0.008020389690635378); + var[3] = static_cast(0.006218486448329886); + pdf = static_cast(0.2649374226055107); + entropy = static_cast(-3.4416182654031537); test_spot(std::move(alpha), - std::move(x), - std::move(mean), - std::move(mode), - std::move(var), - std::move(skewness), - std::move(kurtosis), - entropy, pdf, cdf, tolerance); + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); + + alpha[0] = static_cast(3.1779256968768976); + alpha[1] = static_cast(1.355989101047721); + alpha[2] = static_cast(5.594207813755373); + alpha[3] = static_cast(5.9897453525066355); + x[0] = static_cast(0.3388848203529338); + x[1] = static_cast(0.36731530174264704); + x[2] = static_cast(0.11166014002460622); + x[3] = static_cast(0.1821397378798129); + mean[0] = static_cast(0.19716787008915473); + mean[1] = static_cast(0.08412955758545015); + mean[2] = static_cast(0.34708112922785567); + mean[3] = static_cast(0.37162144309753953); + var[0] = static_cast(0.009247220589902615); + var[1] = static_cast(0.004501248361485874); + var[2] = static_cast(0.013238553973887957); + var[3] = static_cast(0.013641824239806118); + pdf = static_cast(0.06803159432725718); + entropy = static_cast(-3.398201562087422); + + test_spot(std::move(alpha), + std::move(x), + std::move(mean), + std::move(var), + entropy, pdf, tolerance); // No longer allow any parameter to be NaN or inf. if (std::numeric_limits::has_quiet_NaN) { RealType not_a_num = std::numeric_limits::quiet_NaN(); - alpha[0] = not_a_num; alpha[1] = 0.37; + alpha[0] = not_a_num; + alpha[1] = static_cast(0.37); #ifndef BOOST_NO_EXCEPTIONS BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); #else @@ -211,26 +314,30 @@ void test_spots(RandomAccessContainer) #endif // Non-finite parameters should throw. - alpha[0] = 1.67; alpha[1] = 3.8; - x[0] = not_a_num; x[1] = 0.5; + alpha[0] = static_cast(1.67); + alpha[1] = static_cast(3.8); + x[0] = not_a_num; + x[1] = static_cast(0.5); dirichlet_distribution w(std::move(alpha)); BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = NaN BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = NaN - } // has_quiet_NaN + } // has_quiet_NaN if (std::numeric_limits::has_infinity) { // Attempt to construct from non-finite should throw. RealType infinite = std::numeric_limits::infinity(); alpha[0] = infinite; - alpha[1] = 7.2; + alpha[1] = static_cast(7.2); #ifndef BOOST_NO_EXCEPTIONS BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); #else BOOST_MATH_CHECK_THROW(dirichlet_distribution(std::move(alpha)), std::domain_error); #endif - alpha[0] = 1.42; alpha[1] = 7.91; - x[0] = 0.25; x[1] = infinite; + alpha[0] = static_cast(1.42); + alpha[1] = static_cast(7.91); + x[0] = static_cast(0.25); + x[1] = infinite; dirichlet_distribution w(std::move(alpha)); BOOST_MATH_CHECK_THROW(boost::math::pdf(w, x), std::domain_error); // x = inf BOOST_MATH_CHECK_THROW(boost::math::cdf(w, x), std::domain_error); // x = inf @@ -240,19 +347,18 @@ void test_spots(RandomAccessContainer) } } // test_spots() - BOOST_AUTO_TEST_CASE(test_main) { BOOST_MATH_CONTROL_FP; - test_spots(std::vector(0.0L)); + test_spots>(); - test_spots(std::vector(0.0)); + test_spots>(); - test_spots(std::vector(0.0F)); + test_spots>(); -// #ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS -// test_spots(); // Test long double. -// #if !BOOST_WORKAROUND(__BORLANDC__, BOOST_TESTED_AT(0x582)) -// test_spots(boost::math::concepts::real_concept(0.)); // Test real concept. -// #endif + // #ifndef BOOST_MATH_NO_LONG_DOUBLE_MATH_FUNCTIONS + // test_spots(); // Test long double. + // #if !BOOST_WORKAROUND(__BORLANDC__, BOOST_TESTED_AT(0x582)) + // test_spots(boost::math::concepts::real_concept(0.)); // Test real concept. + // #endif } // BOOST_AUTO_TEST_CASE( test_main ) \ No newline at end of file