From 9be3f41ca2164d77fe5d8b9cb46a580569cc2058 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 1 Nov 2024 19:58:23 -0400 Subject: [PATCH 1/5] Correct the second term in the pruner value so that the minNegLogConstant term is set correctly --- gtsam/hybrid/HybridGaussianConditional.cpp | 7 +++++-- gtsam/hybrid/tests/testHybridGaussianConditional.cpp | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index ac03bd3a3e..1bec428107 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -322,8 +322,11 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { if (max->evaluate(choices) == 0.0) return {nullptr, std::numeric_limits::infinity()}; - else - return pair; + else { + // Add negLogConstant_ back so that the minimum negLogConstant in the + // HybridGaussianConditional is set correctly. + return {pair.first, pair.second + negLogConstant_}; + } }; FactorValuePairs prunedConditionals = factors().apply(pruner); diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index e29c485afd..350bc91848 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -275,6 +275,11 @@ TEST(HybridGaussianConditional, Prune) { // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); + + // Check that the minimum negLogConstant is set correctly + EXPECT_DOUBLES_EQUAL( + hgc.conditionals()({{M(1), 0}, {M(2), 1}})->negLogConstant(), + pruned->negLogConstant(), 1e-9); } { const std::vector potentials{0.2, 0, 0.3, 0, // @@ -285,6 +290,9 @@ TEST(HybridGaussianConditional, Prune) { // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); + + // Check that the minimum negLogConstant is correct + EXPECT_DOUBLES_EQUAL(hgc.negLogConstant(), pruned->negLogConstant(), 1e-9); } } From e52970aa9269c79e1b508cf6cb64b31fe50b13e0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 1 Nov 2024 20:23:04 -0400 Subject: [PATCH 2/5] negLogConstant methods for HybridBayesNet --- gtsam/hybrid/HybridBayesNet.cpp | 29 +++++++++++++++++++++++++++++ gtsam/hybrid/HybridBayesNet.h | 17 +++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f57cda28d9..e5748366c3 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -197,6 +197,35 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( return result; } +/* ************************************************************************* */ +double HybridBayesNet::negLogConstant() const { + double negLogNormConst = 0.0; + // Iterate over each conditional. + for (auto &&conditional : *this) { + negLogNormConst += conditional->negLogConstant(); + } + return negLogNormConst; +} + +/* ************************************************************************* */ +double HybridBayesNet::negLogConstant(const DiscreteValues &discrete) const { + double negLogNormConst = 0.0; + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (auto gm = conditional->asHybrid()) { + negLogNormConst += gm->choose(discrete)->negLogConstant(); + } else if (auto gc = conditional->asGaussian()) { + negLogNormConst += gc->negLogConstant(); + } else if (auto dc = conditional->asDiscrete()) { + negLogNormConst += dc->choose(discrete)->negLogConstant(); + } else { + throw std::runtime_error( + "Unknown conditional type when computing negLogConstant"); + } + } + return negLogNormConst; +} + /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::discretePosterior( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index bba301be2f..451f7f6757 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -237,6 +237,23 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using BayesNet::logProbability; // expose HybridValues version + /** + * @brief Get the negative log of the normalization constant corresponding + * to the joint density represented by this Bayes net. + * + * @return double + */ + double negLogConstant() const; + + /** + * @brief Get the negative log of the normalization constant + * corresponding to the joint Gaussian density represented by + * this Bayes net indexed by `discrete`. + * + * @return double + */ + double negLogConstant(const DiscreteValues &discrete) const; + /** * @brief Compute normalized posterior P(M|X=x) and return as a tree. * From 44e848536033c951a809a9bb4cabc3bbf7f17fa5 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 1 Nov 2024 20:23:32 -0400 Subject: [PATCH 3/5] get failing tests in testHybridBayesNet to pass --- gtsam/hybrid/tests/testHybridBayesNet.cpp | 42 ++++++++++++++++++----- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 16d0ae1a12..135da5dc73 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -363,10 +363,6 @@ TEST(HybridBayesNet, Pruning) { AlgebraicDecisionTree expected(s.modes, leaves); EXPECT(assert_equal(expected, discretePosterior, 1e-6)); - // Prune and get probabilities - auto prunedBayesNet = posterior->prune(2); - auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); - // Verify logProbability computation and check specific logProbability value const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; const HybridValues hybridValues{delta.continuous(), discrete_values}; @@ -381,10 +377,21 @@ TEST(HybridBayesNet, Pruning) { EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); + double negLogConstant = posterior->negLogConstant(discrete_values); + + // The sum of all the mode densities + double normalizer = + AlgebraicDecisionTree(posterior->errorTree(delta.continuous()), + [](double error) { return exp(-error); }) + .sum(); + // Check agreement with discrete posterior - // double density = exp(logProbability); - // FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), - // 1e-6); + double density = exp(logProbability + negLogConstant) / normalizer; + EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values), 1e-6); + + // Prune and get probabilities + auto prunedBayesNet = posterior->prune(2); + auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous()); // Regression test on pruned logProbability tree std::vector pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578}; @@ -392,7 +399,26 @@ TEST(HybridBayesNet, Pruning) { EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); // Regression - // FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); + double pruned_logProbability = 0; + pruned_logProbability += + prunedBayesNet.at(0)->asDiscrete()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(1)->asHybrid()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(2)->asHybrid()->logProbability(hybridValues); + pruned_logProbability += + prunedBayesNet.at(3)->asHybrid()->logProbability(hybridValues); + + double pruned_negLogConstant = prunedBayesNet.negLogConstant(discrete_values); + + // The sum of all the mode densities + double pruned_normalizer = + AlgebraicDecisionTree(prunedBayesNet.errorTree(delta.continuous()), + [](double error) { return exp(-error); }) + .sum(); + double pruned_density = + exp(pruned_logProbability + pruned_negLogConstant) / pruned_normalizer; + EXPECT_DOUBLES_EQUAL(pruned_density, prunedTree(discrete_values), 1e-9); } /* ****************************************************************************/ From 8aacfa95f32ab6459d7888146b406820a25bade0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 1 Nov 2024 20:24:36 -0400 Subject: [PATCH 4/5] add docstrings for elimination results --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index ceabe0871a..9ca7a3938e 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -59,10 +59,11 @@ using OrphanWrapper = BayesTreeOrphanWrapper; /// Result from elimination. struct Result { + // Gaussian conditional resulting from elimination. GaussianConditional::shared_ptr conditional; - double negLogK; - GaussianFactor::shared_ptr factor; - double scalar; + double negLogK; // Negative log of the normalization constant K. + GaussianFactor::shared_ptr factor; // Leftover factor 𝜏. + double scalar; // Scalar value associated with factor 𝜏. bool operator==(const Result &other) const { return conditional == other.conditional && negLogK == other.negLogK && From 5c63ac833c56d52b1558d0be876065939394d2c6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 3 Nov 2024 15:32:21 -0500 Subject: [PATCH 5/5] use optional DiscreteValues --- gtsam/hybrid/HybridBayesNet.cpp | 33 ++++++++++++++------------------- gtsam/hybrid/HybridBayesNet.h | 15 ++++----------- 2 files changed, 18 insertions(+), 30 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e5748366c3..623b82eea7 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -198,29 +198,24 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( } /* ************************************************************************* */ -double HybridBayesNet::negLogConstant() const { +double HybridBayesNet::negLogConstant( + const std::optional &discrete) const { double negLogNormConst = 0.0; // Iterate over each conditional. for (auto &&conditional : *this) { - negLogNormConst += conditional->negLogConstant(); - } - return negLogNormConst; -} - -/* ************************************************************************* */ -double HybridBayesNet::negLogConstant(const DiscreteValues &discrete) const { - double negLogNormConst = 0.0; - // Iterate over each conditional. - for (auto &&conditional : *this) { - if (auto gm = conditional->asHybrid()) { - negLogNormConst += gm->choose(discrete)->negLogConstant(); - } else if (auto gc = conditional->asGaussian()) { - negLogNormConst += gc->negLogConstant(); - } else if (auto dc = conditional->asDiscrete()) { - negLogNormConst += dc->choose(discrete)->negLogConstant(); + if (discrete.has_value()) { + if (auto gm = conditional->asHybrid()) { + negLogNormConst += gm->choose(*discrete)->negLogConstant(); + } else if (auto gc = conditional->asGaussian()) { + negLogNormConst += gc->negLogConstant(); + } else if (auto dc = conditional->asDiscrete()) { + negLogNormConst += dc->choose(*discrete)->negLogConstant(); + } else { + throw std::runtime_error( + "Unknown conditional type when computing negLogConstant"); + } } else { - throw std::runtime_error( - "Unknown conditional type when computing negLogConstant"); + negLogNormConst += conditional->negLogConstant(); } } return negLogNormConst; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 451f7f6757..96afb87d6d 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -237,22 +237,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using BayesNet::logProbability; // expose HybridValues version - /** - * @brief Get the negative log of the normalization constant corresponding - * to the joint density represented by this Bayes net. - * - * @return double - */ - double negLogConstant() const; - /** * @brief Get the negative log of the normalization constant - * corresponding to the joint Gaussian density represented by - * this Bayes net indexed by `discrete`. + * corresponding to the joint density represented by this Bayes net. + * Optionally index by `discrete`. * + * @param discrete Optional DiscreteValues * @return double */ - double negLogConstant(const DiscreteValues &discrete) const; + double negLogConstant(const std::optional &discrete) const; /** * @brief Compute normalized posterior P(M|X=x) and return as a tree.