From b2fbbbb26fc994b04b8cc5bcf692cdfe2dedee81 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 22 Jan 2025 12:05:27 -0500 Subject: [PATCH 1/8] New DiscreteBayesTree tests --- .../discrete/tests/testDiscreteBayesTree.cpp | 35 +++++++++++++++++++ python/gtsam/tests/test_DiscreteBayesTree.py | 31 ++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 617eb7c9d5..95d6f03701 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -369,6 +369,41 @@ TEST(DiscreteBayesTree, Lookup) { EXPECT_DOUBLES_EQUAL(1.0, (*lookup_a2_x3)({{X(2),2},{A(2),1},{X(3),2}}), 1e-9); } +/* ************************************************************************* */ +// Test creating a Bayes tree directly from cliques +TEST(DiscreteBayesTree, DirectFromCliques) { + // Create a BayesNet + DiscreteBayesNet bayesNet; + DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2); + bayesNet.add(key0 % "1/3"); + bayesNet.add(key1 | key0 = "1/3 3/1"); + bayesNet.add(key2 | key1 = "3/1 3/1"); + + // Create cliques directly + auto clique2 = std::make_shared( + std::make_shared(key2 | key1 = "3/1 3/1")); + auto clique1 = std::make_shared( + std::make_shared(key1 | key0 = "1/3 3/1")); + auto clique0 = std::make_shared( + std::make_shared(key0 % "1/3")); + + // Create a BayesTree + DiscreteBayesTree bayesTree; + bayesTree.insertRoot(clique2); + bayesTree.addClique(clique1, clique2); + bayesTree.addClique(clique0, clique1); + + // Check that the BayesTree is correct + DiscreteValues values; + values[0] = 1; + values[1] = 1; + values[2] = 1; + + double expected = bayesNet.evaluate(values); + double actual = bayesTree.evaluate(values); + DOUBLES_EQUAL(expected, actual, 1e-9); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index e08491faba..d327b5a1b2 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -156,5 +156,36 @@ def test_discrete_bayes_tree_lookup(self): values[X(3)] = 2 self.assertAlmostEqual(lookup_a2_x3(values), 1.0) # not 10... + def test_direct_from_cliques(self): + """Test creating a Bayes tree directly from cliques.""" + # Create a BayesNet + bayesNet = DiscreteBayesNet() + key0, key1, key2 = (0, 2), (1, 2), (2, 2) + bayesNet.add(key0, "1/3") + bayesNet.add(key1, [key0], "1/3 3/1") + bayesNet.add(key2, [key1], "3/1 3/1") + + # Create cliques directly + clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1")) + clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1")) + clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3")) + + # Create a BayesTree + bayesTree = gtsam.DiscreteBayesTree() + bayesTree.insertRoot(clique2) + bayesTree.addClique(clique1, clique2) + bayesTree.addClique(clique0, clique1) + + # Check that the BayesTree is correct + values = DiscreteValues() + values[0] = 1 + values[1] = 1 + values[2] = 1 + + expected = bayesNet.evaluate(values) + actual = bayesTree.evaluate(values) + self.assertAlmostEqual(expected, actual, places=9) + + if __name__ == "__main__": unittest.main() From abac726c3560e42a7b4447335614d03ea3e8bb19 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 22 Jan 2025 23:11:16 -0500 Subject: [PATCH 2/8] Fix some docs and visiblity --- gtsam/inference/BayesTree.h | 22 ++++++++++++---------- gtsam/inference/BayesTreeCliqueBase-inst.h | 16 +++++++++------- gtsam/inference/BayesTreeCliqueBase.h | 4 ++-- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 03f79c8cf1..4a2ae7560f 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -119,13 +119,14 @@ namespace gtsam { /** Assignment operator */ This& operator=(const This& other); + public: + /// @name Testable /// @{ /** check equality */ bool equals(const This& other, double tol = 1e-9) const; - public: /** print */ void print(const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; @@ -185,18 +186,19 @@ namespace gtsam { */ sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const; - /// @name Graph Display - /// @{ + /// @} + /// @name Graph Display + /// @{ - /// Output to graphviz format, stream version. - void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format, stream version. + void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /// Output to graphviz format string. - std::string dot( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /// Output to graphviz format string. + std::string dot( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; - /// output to file with graphviz format. - void saveGraph(const std::string& filename, + /// output to file with graphviz format. + void saveGraph(const std::string& filename, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; /// @} diff --git a/gtsam/inference/BayesTreeCliqueBase-inst.h b/gtsam/inference/BayesTreeCliqueBase-inst.h index a91fa4f78b..d335e4b5e0 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inst.h +++ b/gtsam/inference/BayesTreeCliqueBase-inst.h @@ -104,14 +104,16 @@ namespace gtsam { } /* ************************************************************************* */ - // The shortcut density is a conditional P(S|R) of the separator of this - // clique on the root. We can compute it recursively from the parent shortcut - // P(Sp|R) as \int P(Fp|Sp) P(Sp|R), where Fp are the frontal nodes in p - /* ************************************************************************* */ - template + // The shortcut density is a conditional P(S|B) of the separator of this + // clique on the root or common ancestor B. We can compute it recursively from + // the parent shortcut P(Sp|B) as \int P(Fp|Sp) P(Sp|B), where Fp are the + // frontal nodes in p + /* ************************************************************************* + */ + template typename BayesTreeCliqueBase::BayesNetType - BayesTreeCliqueBase::shortcut(const derived_ptr& B, Eliminate function) const - { + BayesTreeCliqueBase::shortcut( + const derived_ptr& B, Eliminate function) const { gttic(BayesTreeCliqueBase_shortcut); // We only calculate the shortcut when this clique is not B // and when the S\B is not empty diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index 0ccb04e908..c674fb13a5 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -190,11 +190,11 @@ namespace gtsam { friend class BayesTree; - protected: - /// Calculate set \f$ S \setminus B \f$ for shortcut calculations KeyVector separator_setminus_B(const derived_ptr& B) const; + protected: + /** Determine variable indices to keep in recursive separator shortcut calculation The factor * graph p_Cp_B has keys from the parent clique Cp and from B. But we only keep the variables * not in S union B. */ From 52e3faa250ed161fdc7f46e19a50c65543985d16 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 22 Jan 2025 23:22:17 -0500 Subject: [PATCH 3/8] Refactor joint marginal --- gtsam/inference/BayesTree-inst.h | 155 +++++++++------------ gtsam/inference/BayesTreeCliqueBase-inst.h | 23 +-- 2 files changed, 72 insertions(+), 106 deletions(-) diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 0648a90f64..8058127274 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -28,6 +28,8 @@ #include #include #include +#include + namespace gtsam { /* ************************************************************************* */ @@ -335,112 +337,85 @@ namespace gtsam { } /* ************************************************************************* */ - template - typename BayesTree::sharedBayesNet - BayesTree::jointBayesNet(Key j1, Key j2, const Eliminate& function) const - { + // Find the lowest common ancestor of two cliques + template + static std::shared_ptr findLowestCommonAncestor( + const std::shared_ptr& C1, const std::shared_ptr& C2) { + // Collect all ancestors of C1 + std::unordered_set> ancestors; + for (auto p = C1; p; p = p->parent()) { + ancestors.insert(p); + } + + // Find the first common ancestor in C2's lineage + std::shared_ptr B; + for (auto p = C2; p; p = p->parent()) { + if (ancestors.count(p)) { + return p; // Return the common ancestor when found + } + } + + return nullptr; // Return nullptr if no common ancestor is found + } + + /* ************************************************************************* */ + // Given the clique P(F:S) and the ancestor clique B + // Return the Bayes tree P(S\B | S \cap B) + template + static auto factorInto( + const std::shared_ptr& p_F_S, const std::shared_ptr& B, + const typename CLIQUE::FactorGraphType::Eliminate& eliminate) { + gttic(Full_root_factoring); + + // Get the shortcut P(S|B) + auto p_S_B = p_F_S->shortcut(B, eliminate); + + // Compute S\B + KeyVector S_setminus_B = p_F_S->separator_setminus_B(B); + + // Factor P(S|B) into P(S\B|S \cap B) and P(S \cap B) + auto [bayesTree, fg] = + typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal( + Ordering(S_setminus_B), eliminate); + return bayesTree; + }; + + /* ************************************************************************* */ + template + typename BayesTree::sharedBayesNet BayesTree::jointBayesNet( + Key j1, Key j2, const Eliminate& eliminate) const { gttic(BayesTree_jointBayesNet); // get clique C1 and C2 sharedClique C1 = (*this)[j1], C2 = (*this)[j2]; - gttic(Lowest_common_ancestor); - // Find lowest common ancestor clique - sharedClique B; { - // Build two paths to the root - FastList path1, path2; { - sharedClique p = C1; - while(p) { - path1.push_front(p); - p = p->parent(); - } - } { - sharedClique p = C2; - while(p) { - path2.push_front(p); - p = p->parent(); - } - } - // Find the path intersection - typename FastList::const_iterator p1 = path1.begin(), p2 = path2.begin(); - if(*p1 == *p2) - B = *p1; - while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) { - B = *p1; - ++p1; - ++p2; - } - } - gttoc(Lowest_common_ancestor); + // Find the lowest common ancestor clique + auto B = findLowestCommonAncestor(C1, C2); // Build joint on all involved variables FactorGraphType p_BC1C2; - if(B) - { + if (B) { // Compute marginal on lowest common ancestor clique - gttic(LCA_marginal); - FactorGraphType p_B = B->marginal2(function); - gttoc(LCA_marginal); - - // Compute shortcuts of the requested cliques given the lowest common ancestor - gttic(Clique_shortcuts); - BayesNetType p_C1_Bred = C1->shortcut(B, function); - BayesNetType p_C2_Bred = C2->shortcut(B, function); - gttoc(Clique_shortcuts); - - // Factor the shortcuts to be conditioned on the full root - // Get the set of variables to eliminate, which is C1\B. - gttic(Full_root_factoring); - std::shared_ptr p_C1_B; { - KeyVector C1_minus_B; { - KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C1_minus_B_set.erase(j); } - C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end()); - } - // Factor into C1\B | B. - p_C1_B = - FactorGraphType(p_C1_Bred) - .eliminatePartialMultifrontal(Ordering(C1_minus_B), function) - .first; - } - std::shared_ptr p_C2_B; { - KeyVector C2_minus_B; { - KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents()); - for(const Key j: *B->conditional()) { - C2_minus_B_set.erase(j); } - C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end()); - } - // Factor into C2\B | B. - p_C2_B = - FactorGraphType(p_C2_Bred) - .eliminatePartialMultifrontal(Ordering(C2_minus_B), function) - .first; - } - gttoc(Full_root_factoring); + FactorGraphType p_B = B->marginal2(eliminate); + + // Factor the shortcuts to be conditioned on lowest common ancestor + auto p_C1_B = factorInto(C1, B, eliminate); + auto p_C2_B = factorInto(C2, B, eliminate); - gttic(Variable_joint); p_BC1C2.push_back(p_B); p_BC1C2.push_back(*p_C1_B); p_BC1C2.push_back(*p_C2_B); - if(C1 != B) - p_BC1C2.push_back(C1->conditional()); - if(C2 != B) - p_BC1C2.push_back(C2->conditional()); - gttoc(Variable_joint); - } - else - { - // The nodes have no common ancestor, they're in different trees, so they're joint is just the - // product of their marginals. - gttic(Disjoint_marginals); - p_BC1C2.push_back(C1->marginal2(function)); - p_BC1C2.push_back(C2->marginal2(function)); - gttoc(Disjoint_marginals); + if (C1 != B) p_BC1C2.push_back(C1->conditional()); + if (C2 != B) p_BC1C2.push_back(C2->conditional()); + } else { + // The nodes have no common ancestor, they're in different trees, so + // they're joint is just the product of their marginals. + p_BC1C2.push_back(C1->marginal2(eliminate)); + p_BC1C2.push_back(C2->marginal2(eliminate)); } // now, marginalize out everything that is not variable j1 or j2 - return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function); + return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase-inst.h b/gtsam/inference/BayesTreeCliqueBase-inst.h index d335e4b5e0..9e687be6b6 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inst.h +++ b/gtsam/inference/BayesTreeCliqueBase-inst.h @@ -122,12 +122,10 @@ namespace gtsam { { // Obtain P(Cp||B) = P(Fp|Sp) * P(Sp||B) as a factor graph derived_ptr parent(parent_.lock()); - gttoc(BayesTreeCliqueBase_shortcut); FactorGraphType p_Cp_B(parent->shortcut(B, function)); // P(Sp||B) - gttic(BayesTreeCliqueBase_shortcut); p_Cp_B.push_back(parent->conditional_); // P(Fp|Sp) - // Determine the variables we want to keepSet, S union B + // Determine the variables we want to keep, S union B KeyVector keep = shortcut_indices(B, p_Cp_B); // Marginalize out everything except S union B @@ -141,8 +139,9 @@ namespace gtsam { } /* *********************************************************************** */ - // separator marginal, uses separator marginal of parent recursively - // P(C) = P(F|S) P(S) + // Separator marginal, uses separator marginal of parent recursively + // Calculates P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) + // if P(Sp) is not cached, it will call separatorMarginal on the parent /* *********************************************************************** */ template typename BayesTreeCliqueBase::FactorGraphType @@ -152,30 +151,22 @@ namespace gtsam { gttic(BayesTreeCliqueBase_separatorMarginal); // Check if the Separator marginal was already calculated if (!cachedSeparatorMarginal_) { - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); - // If this is the root, there is no separator if (parent_.expired() /*(if we're the root)*/) { // we are root, return empty FactorGraphType empty; cachedSeparatorMarginal_ = empty; } else { - // Flatten recursion in timing outline - gttoc(BayesTreeCliqueBase_separatorMarginal_cachemiss); - gttoc(BayesTreeCliqueBase_separatorMarginal); - // Obtain P(S) = \int P(Cp) = \int P(Fp|Sp) P(Sp) // initialize P(Cp) with the parent separator marginal derived_ptr parent(parent_.lock()); - FactorGraphType p_Cp(parent->separatorMarginal(function)); // P(Sp) - - gttic(BayesTreeCliqueBase_separatorMarginal); - gttic(BayesTreeCliqueBase_separatorMarginal_cachemiss); + FactorGraphType p_Cp( + parent->separatorMarginal(function)); // recursive P(Sp) // now add the parent conditional p_Cp.push_back(parent->conditional_); // P(Fp|Sp) - // The variables we want to keepSet are exactly the ones in S + // The variables we want to keep are exactly the ones in S KeyVector indicesS(this->conditional()->beginParents(), this->conditional()->endParents()); auto separatorMarginal = From 9532feadc01a7be8a695dc744cc7de732f62d8a1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 22 Jan 2025 23:43:33 -0500 Subject: [PATCH 4/8] Add missing BayesTree methods in wrapper --- gtsam/discrete/discrete.i | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 40f1822cf2..b84ac69a0d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -268,6 +268,10 @@ class DiscreteBayesTreeClique { class DiscreteBayesTree { DiscreteBayesTree(); + void insertRoot(const gtsam::DiscreteBayesTreeClique* subtree); + void addClique(const gtsam::DiscreteBayesTreeClique* clique); + void addClique(const gtsam::DiscreteBayesTreeClique* clique, const gtsam::DiscreteBayesTreeClique* parent_clique); + void print(string s = "DiscreteBayesTree\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -276,6 +280,12 @@ class DiscreteBayesTree { size_t size() const; bool empty() const; const DiscreteBayesTreeClique* operator[](size_t j) const; + const DiscreteBayesTreeClique* clique(size_t j) const; + size_t numCachedSeparatorMarginals() const; + + gtsam::DiscreteConditional marginalFactor(size_t key) const; + gtsam::DiscreteFactorGraph* joint(size_t j1, size_t j2) const; + gtsam::DiscreteBayesNet* jointBayesNet(size_t j1, size_t j2) const; double evaluate(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const; @@ -285,7 +295,6 @@ class DiscreteBayesTree { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - double operator()(const gtsam::DiscreteValues& values) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; From 1c2bc4513b835096ee96a266249f2282320641f4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 23 Jan 2025 08:55:12 -0500 Subject: [PATCH 5/8] Remove extra ; --- gtsam/inference/BayesTree-inst.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gtsam/inference/BayesTree-inst.h b/gtsam/inference/BayesTree-inst.h index 8058127274..c65e2ddc26 100644 --- a/gtsam/inference/BayesTree-inst.h +++ b/gtsam/inference/BayesTree-inst.h @@ -378,7 +378,7 @@ namespace gtsam { typename CLIQUE::FactorGraphType(p_S_B).eliminatePartialMultifrontal( Ordering(S_setminus_B), eliminate); return bayesTree; - }; + } /* ************************************************************************* */ template From fff14ab0b7cde1963059d2d573617a202c1a8a27 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 23 Jan 2025 15:15:27 -0500 Subject: [PATCH 6/8] Isolate Asia network --- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 49a360cbb6..d2033909c5 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -36,6 +36,25 @@ static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), using ADT = AlgebraicDecisionTree; +// Function to construct the Asia example +DiscreteBayesNet constructAsiaExample() { + DiscreteBayesNet asia; + + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version + + asia.add(Tuberculosis | Asia = "99/1 95/5"); + asia.add(LungCancer | Smoking = "99/1 90/10"); + asia.add(Bronchitis | Smoking = "70/30 40/60"); + + asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); + + asia.add(XRay | Either = "95/5 2/98"); + asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + + return asia; +} + /* ************************************************************************* */ TEST(DiscreteBayesNet, bayesNet) { DiscreteBayesNet bayesNet; @@ -67,19 +86,7 @@ TEST(DiscreteBayesNet, bayesNet) { /* ************************************************************************* */ TEST(DiscreteBayesNet, Asia) { - DiscreteBayesNet asia; - - asia.add(Asia, "99/1"); - asia.add(Smoking % "50/50"); // Signature version - - asia.add(Tuberculosis | Asia = "99/1 95/5"); - asia.add(LungCancer | Smoking = "99/1 90/10"); - asia.add(Bronchitis | Smoking = "70/30 40/60"); - - asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); - - asia.add(XRay | Either = "95/5 2/98"); - asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); + DiscreteBayesNet asia = constructAsiaExample(); // Convert to factor graph DiscreteFactorGraph fg(asia); From 2b1f51f098ecdaadf3a3c048b9bc3405793ba244 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 24 Jan 2025 09:14:30 -0500 Subject: [PATCH 7/8] Fix tests --- .../discrete/tests/testDiscreteBayesTree.cpp | 41 ++++++++++--------- python/gtsam/tests/test_DiscreteBayesTree.py | 20 ++++----- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index 95d6f03701..bc205c96cf 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -28,7 +28,7 @@ #include using namespace gtsam; -static constexpr bool debug = false; +static constexpr bool debug = true; /* ************************************************************************* */ struct TestFixture { @@ -186,11 +186,11 @@ TEST(DiscreteBayesTree, Shortcuts) { shortcut = clique->shortcut(R, EliminateDiscrete); DOUBLES_EQUAL(joint_8_11_12_13 / joint_11_13, shortcut.evaluate(all1), 1e-9); - // calculate all shortcuts to root - DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); - for (auto clique : cliques) { - DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); - if (debug) { + if (debug) { + // print all shortcuts to root + DiscreteBayesTree::Nodes cliques = self.bayesTree->nodes(); + for (auto clique : cliques) { + DiscreteBayesNet shortcut = clique.second->shortcut(R, EliminateDiscrete); clique.second->conditional_->printSignature(); shortcut.print("shortcut:"); } @@ -202,6 +202,7 @@ TEST(DiscreteBayesTree, Shortcuts) { TEST(DiscreteBayesTree, MarginalFactors) { TestFixture self; + // Caclulate marginals with brute force enumeration. Vector marginals = Vector::Zero(15); for (size_t i = 0; i < self.assignments.size(); ++i) { DiscreteValues& x = self.assignments[i]; @@ -287,6 +288,8 @@ TEST(DiscreteBayesTree, Joints) { TEST(DiscreteBayesTree, Dot) { TestFixture self; std::string actual = self.bayesTree->dot(); + // print actual: + if (debug) std::cout << actual << std::endl; EXPECT(actual == "digraph G{\n" "0[label=\"13, 11, 6, 7\"];\n" @@ -374,18 +377,18 @@ TEST(DiscreteBayesTree, Lookup) { TEST(DiscreteBayesTree, DirectFromCliques) { // Create a BayesNet DiscreteBayesNet bayesNet; - DiscreteKey key0(0, 2), key1(1, 2), key2(2, 2); - bayesNet.add(key0 % "1/3"); - bayesNet.add(key1 | key0 = "1/3 3/1"); - bayesNet.add(key2 | key1 = "3/1 3/1"); + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + bayesNet.add(A % "1/3"); + bayesNet.add(B | A = "1/3 3/1"); + bayesNet.add(C | B = "3/1 3/1"); // Create cliques directly auto clique2 = std::make_shared( - std::make_shared(key2 | key1 = "3/1 3/1")); + std::make_shared(C | B = "3/1 3/1")); auto clique1 = std::make_shared( - std::make_shared(key1 | key0 = "1/3 3/1")); + std::make_shared(B | A = "1/3 3/1")); auto clique0 = std::make_shared( - std::make_shared(key0 % "1/3")); + std::make_shared(A % "1/3")); // Create a BayesTree DiscreteBayesTree bayesTree; @@ -395,13 +398,13 @@ TEST(DiscreteBayesTree, DirectFromCliques) { // Check that the BayesTree is correct DiscreteValues values; - values[0] = 1; - values[1] = 1; - values[2] = 1; + values[A.first] = 1; + values[A.first] = 1; + values[A.first] = 1; - double expected = bayesNet.evaluate(values); - double actual = bayesTree.evaluate(values); - DOUBLES_EQUAL(expected, actual, 1e-9); + // Regression + double expected = .046875; + DOUBLES_EQUAL(expected, bayesTree.evaluate(values), 1e-9); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index d327b5a1b2..e8943fc803 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -160,15 +160,15 @@ def test_direct_from_cliques(self): """Test creating a Bayes tree directly from cliques.""" # Create a BayesNet bayesNet = DiscreteBayesNet() - key0, key1, key2 = (0, 2), (1, 2), (2, 2) - bayesNet.add(key0, "1/3") - bayesNet.add(key1, [key0], "1/3 3/1") - bayesNet.add(key2, [key1], "3/1 3/1") + A, B, C = (0, 2), (1, 2), (2, 2) + bayesNet.add(A, "1/3") + bayesNet.add(B, [A], "1/3 3/1") + bayesNet.add(C, [B], "3/1 3/1") # Create cliques directly - clique2 = DiscreteBayesTreeClique(DiscreteConditional(key2, [key1], "3/1 3/1")) - clique1 = DiscreteBayesTreeClique(DiscreteConditional(key1, [key0], "1/3 3/1")) - clique0 = DiscreteBayesTreeClique(DiscreteConditional(key0, "1/3")) + clique2 = DiscreteBayesTreeClique(DiscreteConditional(C, [B], "3/1 3/1")) + clique1 = DiscreteBayesTreeClique(DiscreteConditional(B, [A], "1/3 3/1")) + clique0 = DiscreteBayesTreeClique(DiscreteConditional(A, "1/3")) # Create a BayesTree bayesTree = gtsam.DiscreteBayesTree() @@ -182,9 +182,9 @@ def test_direct_from_cliques(self): values[1] = 1 values[2] = 1 - expected = bayesNet.evaluate(values) - actual = bayesTree.evaluate(values) - self.assertAlmostEqual(expected, actual, places=9) + # regression + expected = .046875 + self.assertAlmostEqual(expected, bayesNet.evaluate(values)) if __name__ == "__main__": From 21cb31e51999a480e8a46295f2ed032e4f311c67 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 24 Jan 2025 09:58:00 -0500 Subject: [PATCH 8/8] Fix test --- gtsam/discrete/tests/testDiscreteBayesTree.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteBayesTree.cpp b/gtsam/discrete/tests/testDiscreteBayesTree.cpp index bc205c96cf..e0402969dd 100644 --- a/gtsam/discrete/tests/testDiscreteBayesTree.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesTree.cpp @@ -28,7 +28,7 @@ #include using namespace gtsam; -static constexpr bool debug = true; +static constexpr bool debug = false; /* ************************************************************************* */ struct TestFixture { @@ -399,8 +399,8 @@ TEST(DiscreteBayesTree, DirectFromCliques) { // Check that the BayesTree is correct DiscreteValues values; values[A.first] = 1; - values[A.first] = 1; - values[A.first] = 1; + values[B.first] = 1; + values[C.first] = 1; // Regression double expected = .046875;