From 2d7690dbb7c2468a86df5b63ae585a81c10d71cb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Jan 2025 20:32:49 -0500 Subject: [PATCH 01/17] update addConditionals to only use factor graph keys and remove an extra loop --- gtsam/hybrid/HybridSmoother.cpp | 33 ++++++++++++++------------------- gtsam/hybrid/HybridSmoother.h | 2 +- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index ca3e272521..3e06d98ae2 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -68,8 +68,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, } // Add the necessary conditionals from the previous timestep(s). - std::tie(graph, hybridBayesNet_) = - addConditionals(graph, hybridBayesNet_, ordering); + std::tie(graph, hybridBayesNet_) = addConditionals(graph, hybridBayesNet_); // Eliminate. HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering); @@ -88,10 +87,11 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, /* ************************************************************************* */ std::pair HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, - const HybridBayesNet &originalHybridBayesNet, - const Ordering &ordering) const { + const HybridBayesNet &hybridBayesNet) const { HybridGaussianFactorGraph graph(originalGraph); - HybridBayesNet hybridBayesNet(originalHybridBayesNet); + HybridBayesNet updatedHybridBayesNet(hybridBayesNet); + + KeySet factorKeys = graph.keys(); // If hybridBayesNet is not empty, // it means we have conditionals to add to the factor graph. @@ -99,10 +99,6 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, // We add all relevant hybrid conditionals on the last continuous variable // in the previous `hybridBayesNet` to the graph - // Conditionals to remove from the bayes net - // since the conditional will be updated. - std::vector conditionals_to_erase; - // New conditionals to add to the graph gtsam::HybridBayesNet newConditionals; @@ -112,25 +108,24 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, auto conditional = hybridBayesNet.at(i); for (auto &key : conditional->frontals()) { - if (std::find(ordering.begin(), ordering.end(), key) != - ordering.end()) { + if (std::find(factorKeys.begin(), factorKeys.end(), key) != + factorKeys.end()) { newConditionals.push_back(conditional); - conditionals_to_erase.push_back(conditional); + + // Remove the conditional from the updated Bayes net + auto it = find(updatedHybridBayesNet.begin(), + updatedHybridBayesNet.end(), conditional); + updatedHybridBayesNet.erase(it); break; } } } - // Remove conditionals at the end so we don't affect the order in the - // original bayes net. - for (auto &&conditional : conditionals_to_erase) { - auto it = find(hybridBayesNet.begin(), hybridBayesNet.end(), conditional); - hybridBayesNet.erase(it); - } graph.push_back(newConditionals); } - return {graph, hybridBayesNet}; + + return {graph, updatedHybridBayesNet}; } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 66edf86d6b..4669b1d8fd 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -66,7 +66,7 @@ class GTSAM_EXPORT HybridSmoother { */ std::pair addConditionals( const HybridGaussianFactorGraph& graph, - const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const; + const HybridBayesNet& hybridBayesNet) const; /** * @brief Get the hybrid Gaussian conditional from From b4020ed67b5aa70f189cbc762a05a35b254ce52b Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Jan 2025 21:04:12 -0500 Subject: [PATCH 02/17] fix up getOrdering and update to be more efficient --- gtsam/hybrid/HybridSmoother.cpp | 28 ++++++++++++++++------------ gtsam/hybrid/HybridSmoother.h | 28 ++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 3e06d98ae2..3f429198ef 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -24,17 +24,14 @@ namespace gtsam { /* ************************************************************************* */ -Ordering HybridSmoother::getOrdering( - const HybridGaussianFactorGraph &newFactors) { - HybridGaussianFactorGraph factors(hybridBayesNet()); - factors.push_back(newFactors); - +Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors, + const KeySet &newFactorKeys) { // Get all the discrete keys from the factors KeySet allDiscrete = factors.discreteKeySet(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; - const KeySet newFactorKeys = newFactors.keys(); + // Insert continuous keys first. for (auto &k : newFactorKeys) { if (!allDiscrete.exists(k)) { @@ -56,22 +53,29 @@ Ordering HybridSmoother::getOrdering( } /* ************************************************************************* */ -void HybridSmoother::update(HybridGaussianFactorGraph graph, +void HybridSmoother::update(const HybridGaussianFactorGraph &graph, std::optional maxNrLeaves, const std::optional given_ordering) { + HybridGaussianFactorGraph updatedGraph; + // Add the necessary conditionals from the previous timestep(s). + std::tie(updatedGraph, hybridBayesNet_) = + addConditionals(graph, hybridBayesNet_); + Ordering ordering; // If no ordering provided, then we compute one if (!given_ordering.has_value()) { - ordering = this->getOrdering(graph); + // Get the keys from the new factors + const KeySet newFactorKeys = graph.keys(); + + // Since updatedGraph now has all the connected conditionals, + // we can get the correct ordering. + ordering = this->getOrdering(updatedGraph, newFactorKeys); } else { ordering = *given_ordering; } - // Add the necessary conditionals from the previous timestep(s). - std::tie(graph, hybridBayesNet_) = addConditionals(graph, hybridBayesNet_); - // Eliminate. - HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering); + HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 4669b1d8fd..f941385bd5 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -49,11 +49,35 @@ class GTSAM_EXPORT HybridSmoother { * @param given_ordering The (optional) ordering for elimination, only * continuous variables are allowed */ - void update(HybridGaussianFactorGraph graph, + void update(const HybridGaussianFactorGraph& graph, std::optional maxNrLeaves = {}, const std::optional given_ordering = {}); - Ordering getOrdering(const HybridGaussianFactorGraph& newFactors); + /** + * @brief Get an elimination ordering which eliminates continuous and then + * discrete. + * + * Expects `newFactors` to already have the necessary conditionals connected + * to the + * + * @param factors + * @return Ordering + */ + + /** + * @brief Get an elimination ordering which eliminates continuous + * and then discrete. + * + * Expects `factors` to already have the necessary conditionals + * which were connected to the variables in the newly added factors. + * Those variables should be in `newFactorKeys`. + * + * @param factors All the new factors and connected conditionals. + * @param newFactorKeys The keys/variables in the newly added factors. + * @return Ordering + */ + Ordering getOrdering(const HybridGaussianFactorGraph& factors, + const KeySet& newFactorKeys); /** * @brief Add conditionals from previous timestep as part of liquefication. From 8ad3cb6ba167883abbbcb058bc9ddda941357fd7 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 23 Jan 2025 21:13:39 -0500 Subject: [PATCH 03/17] update HybridSmoother tests --- gtsam/hybrid/tests/testHybridSmoother.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 145f44d1e6..5493ce53e0 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -95,16 +95,15 @@ TEST(HybridSmoother, IncrementalSmoother) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); HybridGaussianFactorGraph linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, maxNrLeaves, ordering); + smoother.update(linearized, maxNrLeaves); // Clear all the factors from the graph graph.resize(0); } EXPECT_LONGS_EQUAL(11, - smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(3)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. @@ -150,16 +149,15 @@ TEST(HybridSmoother, ValidPruningError) { initial.insert(X(k), switching.linearizationPoint.at(X(k))); HybridGaussianFactorGraph linearized = *graph.linearize(initial); - Ordering ordering = smoother.getOrdering(linearized); - smoother.update(linearized, maxNrLeaves, ordering); + smoother.update(linearized, maxNrLeaves); // Clear all the factors from the graph graph.resize(0); } EXPECT_LONGS_EQUAL(14, - smoother.hybridBayesNet().at(0)->asDiscrete()->nrValues()); + smoother.hybridBayesNet().at(6)->asDiscrete()->nrValues()); // Get the continuous delta update as well as // the optimal discrete assignment. From 22fc8238ce2f7fccbf25aceb8bf8b3a4fa7df9db Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 18:14:38 -0500 Subject: [PATCH 04/17] dead mode removal flag and new constructor --- gtsam/hybrid/HybridSmoother.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index f941385bd5..355dd3b731 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -29,7 +29,23 @@ class GTSAM_EXPORT HybridSmoother { HybridBayesNet hybridBayesNet_; HybridGaussianFactorGraph remainingFactorGraph_; + /// Flag indicating that we should remove dead discrete modes. + bool removeDeadModes_; + /// The threshold above which we make a decision about a mode. + double deadModeThreshold_; + public: + /** + * @brief Constructor + * + * @param removeDeadModes Flag indicating whether to remove dead modes. + * @param deadModeThreshold The threshold above which a mode gets assigned a + * value and is considered "dead". + */ + HybridSmoother(bool removeDeadModes = false, double deadModeThreshold = 0.99) + : removeDeadModes_(removeDeadModes), + deadModeThreshold_(deadModeThreshold) {} + /** * Given new factors, perform an incremental update. * The relevant densities in the `hybridBayesNet` will be added to the input From 3ebeff149f759bfde1d5ba70d519f390902401cc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 18:16:04 -0500 Subject: [PATCH 05/17] update factorKeys with parents --- gtsam/hybrid/HybridSmoother.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 3f429198ef..b463c044e3 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -116,6 +116,14 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, factorKeys.end()) { newConditionals.push_back(conditional); + // Add the conditional parents to factorKeys + // so we add those conditionals too. + // NOTE: This assumes we have a structure where + // variables depend on those in the future. + for (auto &&parentKey : conditional->parents()) { + factorKeys.insert(parentKey); + } + // Remove the conditional from the updated Bayes net auto it = find(updatedHybridBayesNet.begin(), updatedHybridBayesNet.end(), conditional); From fe38776dc46cb11f6371e4fdb55be924ce0b1a3a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 18:16:31 -0500 Subject: [PATCH 06/17] use flag for dead mode removal --- gtsam/hybrid/HybridSmoother.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index b463c044e3..a50be28baa 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_); } // Add the partial bayes net to the posterior bayes net. From 938ae060317535677ed8d15f295651446eaae438 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 19:56:32 -0500 Subject: [PATCH 07/17] remove extra docstring --- gtsam/hybrid/HybridSmoother.h | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 355dd3b731..b348019226 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -69,17 +69,6 @@ class GTSAM_EXPORT HybridSmoother { std::optional maxNrLeaves = {}, const std::optional given_ordering = {}); - /** - * @brief Get an elimination ordering which eliminates continuous and then - * discrete. - * - * Expects `newFactors` to already have the necessary conditionals connected - * to the - * - * @param factors - * @return Ordering - */ - /** * @brief Get an elimination ordering which eliminates continuous * and then discrete. From 7ca7e4549e198580f59cc17701b44505fa8e96ee Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 19:57:47 -0500 Subject: [PATCH 08/17] improve dead mode removal by checking for empty discrete joints and adding the marginals for future factors --- gtsam/hybrid/HybridBayesNet.cpp | 34 ++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b6622980be..d27b1026e0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -58,15 +58,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, joint = joint * (*conditional); } - // Create the result starting with the pruned joint. + // Initialize the resulting HybridBayesNet. HybridBayesNet result; - result.emplace_shared(joint); - // Prune the joint. NOTE: imperative and, again, possibly quite expensive. - result.back()->asDiscrete()->prune(maxNrLeaves); - // Get pruned discrete probabilities so - // we can prune HybridGaussianConditionals. - DiscreteConditional pruned = *result.back()->asDiscrete(); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + DiscreteConditional pruned = joint; + joint.prune(maxNrLeaves); DiscreteValues deadModesValues; if (removeDeadModes) { @@ -88,8 +85,26 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, } // Remove the modes (imperative) - result.back()->asDiscrete()->removeDiscreteModes(deadModesValues); - pruned = *result.back()->asDiscrete(); + pruned.removeDiscreteModes(deadModesValues); + + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + for (auto &&[key, _] : deadModesValues) { + result.push_back( + std::dynamic_pointer_cast(marginals(key))); + } + + } else { + result.emplace_shared(pruned); } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -186,6 +201,7 @@ DiscreteValues HybridBayesNet::mpe() const { } } } + return discrete_fg.optimize(); } From 8725361fd2a778afb8a2f2cce1f94af225e411bd Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 19:59:54 -0500 Subject: [PATCH 09/17] new test for dead mode removal in smoother --- gtsam/hybrid/tests/testHybridSmoother.cpp | 55 +++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 5493ce53e0..815e325606 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -167,6 +167,61 @@ TEST(HybridSmoother, ValidPruningError) { EXPECT_DOUBLES_EQUAL(1e-8, errorTree(delta.discrete()), 1e-8); } +/****************************************************************************/ +// Test if dead mode removal works. +TEST(HybridSmoother, DeadModeRemoval) { + using namespace estimation_fixture; + + size_t K = 8; + + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + HybridNonlinearFactorGraph graph; + Values initial; + Switching switching = InitializeEstimationProblem( + K, 0.1, 0.1, measurements, "1/1 1/1", &graph, &initial); + + // Smoother with dead mode removal enabled. + HybridSmoother smoother(true); + + constexpr size_t maxNrLeaves = 3; + for (size_t k = 1; k < K; k++) { + if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain + graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model + graph.push_back(switching.unaryFactors.at(k)); // Measurement + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + HybridGaussianFactorGraph linearized = *graph.linearize(initial); + + // std::cout << "\n\n\nk" << std::endl; + // GTSAM_PRINT(linearized); + smoother.update(linearized, maxNrLeaves); + + // Clear all the factors from the graph + graph.resize(0); + } + + // Get the continuous delta update as well as + // the optimal discrete assignment. + HybridValues delta = smoother.hybridBayesNet().optimize(); + + // Check discrete assignment + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, delta.discrete())); + + // Update nonlinear solution and verify + Values result = initial.retract(delta.continuous()); + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /* ************************************************************************* */ int main() { TestResult tr; From 1b79e8800ff4a8e06635f6d3a4872e38c23845af Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 20:08:09 -0500 Subject: [PATCH 10/17] add deadModeThreshold argument to HybridBayesNet::prune --- gtsam/hybrid/HybridBayesNet.cpp | 4 ++-- gtsam/hybrid/HybridBayesNet.h | 6 +++++- gtsam/hybrid/HybridSmoother.cpp | 8 +++++++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index d27b1026e0..66661e845d 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, - bool removeDeadModes) const { +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, + double deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 5d3270f4cd..0546a74222 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -219,9 +219,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param maxNrLeaves Continuous values at which to compute the error. * @param removeDeadModes Flag to enable removal of modes which only have a * single possible assignment. + * @param deadModeThreshold The threshold to check the mode marginals against. + * If greater than this threshold, the mode gets assigned that value and is + * considered "dead" for hybrid elimination. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false) const; + HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, + double deadModeThreshold = 0.99) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index a50be28baa..34f28ff803 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -74,6 +74,11 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, ordering = *given_ordering; } + // graph.print("Original GRAPH"); + // GTSAM_PRINT(updatedGraph); + // GTSAM_PRINT(hybridBayesNet_); + // GTSAM_PRINT(ordering); + // Eliminate. HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); @@ -81,7 +86,8 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_, + deadModeThreshold_); } // Add the partial bayes net to the posterior bayes net. From ddb430cdebf1d9cb7d3068e07c5c4d04588ea8e6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 20:17:36 -0500 Subject: [PATCH 11/17] use deadModeThreshold --- gtsam/hybrid/HybridBayesNet.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 66661e845d..d21387312f 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -72,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > 0.99); + auto threshold = (probabilities.array() > deadModeThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { From 1d807db0a42c08ad7fe4be01f1f51ebbbb1dffcc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 20:17:58 -0500 Subject: [PATCH 12/17] remove surplus prints --- gtsam/hybrid/HybridSmoother.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 34f28ff803..918b20341c 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -74,11 +74,6 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, ordering = *given_ordering; } - // graph.print("Original GRAPH"); - // GTSAM_PRINT(updatedGraph); - // GTSAM_PRINT(hybridBayesNet_); - // GTSAM_PRINT(ordering); - // Eliminate. HybridBayesNet bayesNetFragment = *updatedGraph.eliminateSequential(ordering); From d2b9eb5df6c111ba4dd035254cba18a76094bbe0 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 20:19:59 -0500 Subject: [PATCH 13/17] fix which conditional is getting pruned --- gtsam/hybrid/HybridBayesNet.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index d21387312f..e9583b8458 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -63,7 +63,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, // Prune the joint. NOTE: imperative and, again, possibly quite expensive. DiscreteConditional pruned = joint; - joint.prune(maxNrLeaves); + pruned.prune(maxNrLeaves); DiscreteValues deadModesValues; if (removeDeadModes) { @@ -115,7 +115,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, */ // Go through all the Gaussian conditionals in the Bayes Net and prune them as - // per pruned Discrete joint. + // per pruned discrete joint. for (auto &&conditional : *this) { if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! From 1764b58e8cf47ac0f5c3b8acc0c0da1b1e13f957 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 23:11:42 -0500 Subject: [PATCH 14/17] use std::optional for specifying dead mode threshold --- gtsam/hybrid/HybridBayesNet.cpp | 10 +++++----- gtsam/hybrid/HybridBayesNet.h | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e9583b8458..68c248119a 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -47,8 +47,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // TODO(Frank): This can be quite expensive *unless* the factors have already // been pruned before. Another, possibly faster approach is branch and bound // search to find the K-best leaves and then create a single pruned conditional. -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, - double deadModeThreshold) const { +HybridBayesNet HybridBayesNet::prune( + size_t maxNrLeaves, const std::optional &deadModeThreshold) const { // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -66,13 +66,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, pruned.prune(maxNrLeaves); DiscreteValues deadModesValues; - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); int index = -1; - auto threshold = (probabilities.array() > deadModeThreshold); + auto threshold = (probabilities.array() > *deadModeThreshold); // If atleast 1 value is non-zero, then we can find the index // Else if all are zero, index would be set to 0 which is incorrect if (!threshold.isZero()) { @@ -121,7 +121,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves, bool removeDeadModes, // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); - if (removeDeadModes) { + if (deadModeThreshold.has_value()) { KeyVector deadKeys, conditionalDiscreteKeys; for (const auto &kv : deadModesValues) { deadKeys.push_back(kv.first); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0546a74222..86fc9527a3 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -217,15 +217,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves. * * @param maxNrLeaves Continuous values at which to compute the error. - * @param removeDeadModes Flag to enable removal of modes which only have a - * single possible assignment. * @param deadModeThreshold The threshold to check the mode marginals against. * If greater than this threshold, the mode gets assigned that value and is * considered "dead" for hybrid elimination. + * The mode can then be removed since it only has a single possible + * assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, bool removeDeadModes = false, - double deadModeThreshold = 0.99) const; + HybridBayesNet prune(size_t maxNrLeaves, + const std::optional &deadModeThreshold) const; /** * @brief Error method using HybridValues which returns specific error for From 4fcfe6493f46b854f3d75a8b3a12fcc32a17aab8 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 23:15:40 -0500 Subject: [PATCH 15/17] update HybridSmoother to use std::optional deadModeThreshold --- gtsam/hybrid/HybridSmoother.cpp | 3 +-- gtsam/hybrid/HybridSmoother.h | 11 ++++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 918b20341c..45320896a7 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -81,8 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph, if (maxNrLeaves) { // `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in // all the conditionals with the same keys in bayesNetFragment. - bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, removeDeadModes_, - deadModeThreshold_); + bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_); } // Add the partial bayes net to the posterior bayes net. diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index b348019226..96c8391b4e 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -29,10 +29,8 @@ class GTSAM_EXPORT HybridSmoother { HybridBayesNet hybridBayesNet_; HybridGaussianFactorGraph remainingFactorGraph_; - /// Flag indicating that we should remove dead discrete modes. - bool removeDeadModes_; /// The threshold above which we make a decision about a mode. - double deadModeThreshold_; + std::optional deadModeThreshold_; public: /** @@ -40,11 +38,10 @@ class GTSAM_EXPORT HybridSmoother { * * @param removeDeadModes Flag indicating whether to remove dead modes. * @param deadModeThreshold The threshold above which a mode gets assigned a - * value and is considered "dead". + * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(bool removeDeadModes = false, double deadModeThreshold = 0.99) - : removeDeadModes_(removeDeadModes), - deadModeThreshold_(deadModeThreshold) {} + HybridSmoother(const std::optional deadModeThreshold) + : deadModeThreshold_(deadModeThreshold) {} /** * Given new factors, perform an incremental update. From c91abb26441e0ed65cf3137a32c58c773f640234 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 23:21:32 -0500 Subject: [PATCH 16/17] default arg for HybridSmoother constructor and clean up tests --- gtsam/hybrid/HybridSmoother.h | 2 +- gtsam/hybrid/tests/testHybridSmoother.cpp | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridSmoother.h b/gtsam/hybrid/HybridSmoother.h index 96c8391b4e..c3f022c62b 100644 --- a/gtsam/hybrid/HybridSmoother.h +++ b/gtsam/hybrid/HybridSmoother.h @@ -40,7 +40,7 @@ class GTSAM_EXPORT HybridSmoother { * @param deadModeThreshold The threshold above which a mode gets assigned a * value and is considered "dead". 0.99 is a good starting value. */ - HybridSmoother(const std::optional deadModeThreshold) + HybridSmoother(const std::optional deadModeThreshold = {}) : deadModeThreshold_(deadModeThreshold) {} /** diff --git a/gtsam/hybrid/tests/testHybridSmoother.cpp b/gtsam/hybrid/tests/testHybridSmoother.cpp index 815e325606..3a0f376cc2 100644 --- a/gtsam/hybrid/tests/testHybridSmoother.cpp +++ b/gtsam/hybrid/tests/testHybridSmoother.cpp @@ -194,8 +194,6 @@ TEST(HybridSmoother, DeadModeRemoval) { HybridGaussianFactorGraph linearized = *graph.linearize(initial); - // std::cout << "\n\n\nk" << std::endl; - // GTSAM_PRINT(linearized); smoother.update(linearized, maxNrLeaves); // Clear all the factors from the graph From 2caf1173a4cec860c140335b640ea481b7485d54 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jan 2025 23:46:09 -0500 Subject: [PATCH 17/17] default argument and update tests --- gtsam/hybrid/HybridBayesNet.h | 5 +++-- gtsam/hybrid/tests/testHybridBayesNet.cpp | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 86fc9527a3..fb05e24076 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -224,8 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * assignment. * @return A pruned HybridBayesNet */ - HybridBayesNet prune(size_t maxNrLeaves, - const std::optional &deadModeThreshold) const; + HybridBayesNet prune( + size_t maxNrLeaves, + const std::optional &deadModeThreshold = {}) const; /** * @brief Error method using HybridValues which returns specific error for diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 56e93499b6..86dcd48e44 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -434,7 +434,7 @@ TEST(HybridBayesNet, RemoveDeadNodes) { HybridValues delta = posterior->optimize(); // Prune the Bayes net - const bool pruneDeadVariables = true; + const double pruneDeadVariables = 0.99; auto prunedBayesNet = posterior->prune(2, pruneDeadVariables); // Check that discrete joint only has M0 and not (M0, M1) @@ -445,11 +445,12 @@ TEST(HybridBayesNet, RemoveDeadNodes) { // Check that hybrid conditionals that only depend on M1 // are now Gaussian and not Hybrid EXPECT(prunedBayesNet.at(0)->isDiscrete()); - EXPECT(prunedBayesNet.at(1)->isHybrid()); + EXPECT(prunedBayesNet.at(1)->isDiscrete()); + EXPECT(prunedBayesNet.at(2)->isHybrid()); // Only P(X2 | X1, M1) depends on M1, // so it gets convert to a Gaussian P(X2 | X1) - EXPECT(prunedBayesNet.at(2)->isContinuous()); - EXPECT(prunedBayesNet.at(3)->isHybrid()); + EXPECT(prunedBayesNet.at(3)->isContinuous()); + EXPECT(prunedBayesNet.at(4)->isHybrid()); } /* ****************************************************************************/