From 3a7a0b84fe73a1c20845b84a78dd7d9984a3fc81 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 15 Sep 2024 17:26:33 -0400 Subject: [PATCH 1/5] use enum to categorize HybridFactor --- gtsam/hybrid/HybridFactor.cpp | 39 ++++++++++++++++------- gtsam/hybrid/HybridFactor.h | 18 +++++------ gtsam/hybrid/tests/testHybridBayesNet.cpp | 10 +++--- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index b25e97f051..89b1943cdf 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,31 +50,37 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), isContinuous_(true), continuousKeys_(keys) {} + : Base(keys), + category_(HybridCategory::Continuous), + continuousKeys_(keys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(CollectKeys(continuousKeys, discreteKeys)), - isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), - isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), - isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), discreteKeys_(discreteKeys), - continuousKeys_(continuousKeys) {} + continuousKeys_(continuousKeys) { + if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { + category_ = HybridCategory::Discrete; + } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { + category_ = HybridCategory::Continuous; + } else { + category_ = HybridCategory::Hybrid; + } +} /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), - isDiscrete_(true), + category_(HybridCategory::Discrete), discreteKeys_(discreteKeys), continuousKeys_({}) {} /* ************************************************************************ */ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - return e != nullptr && Base::equals(*e, tol) && - isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && - isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ && + return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ && + continuousKeys_ == e->continuousKeys_ && discreteKeys_ == e->discreteKeys_; } @@ -82,9 +88,18 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << (s.empty() ? "" : s + "\n"); - if (isContinuous_) std::cout << "Continuous "; - if (isDiscrete_) std::cout << "Discrete "; - if (isHybrid_) std::cout << "Hybrid "; + switch (category_) { + case HybridCategory::Continuous: + std::cout << "Continuous "; + break; + case HybridCategory::Discrete: + std::cout << "Discrete "; + break; + case HybridCategory::Hybrid: + std::cout << "Hybrid "; + break; + } + std::cout << "["; for (size_t c = 0; c < continuousKeys_.size(); c++) { std::cout << formatter(continuousKeys_.at(c)); diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index c661165125..2cc7453f41 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2); +/// Enum to help with categorizing hybrid factors. +enum class HybridCategory { Discrete, Continuous, Hybrid }; + /** * Base class for *truly* hybrid probabilistic factors * @@ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, */ class GTSAM_EXPORT HybridFactor : public Factor { private: - bool isDiscrete_ = false; - bool isContinuous_ = false; - bool isHybrid_ = false; + /// Record what category of HybridFactor this is. + HybridCategory category_; protected: // Set of DiscreteKeys for this factor. @@ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @{ /// True if this is a factor of discrete variables only. - bool isDiscrete() const { return isDiscrete_; } + bool isDiscrete() const { return category_ == HybridCategory::Discrete; } /// True if this is a factor of continuous variables only. - bool isContinuous() const { return isContinuous_; } + bool isContinuous() const { return category_ == HybridCategory::Continuous; } /// True is this is a Discrete-Continuous factor. - bool isHybrid() const { return isHybrid_; } + bool isHybrid() const { return category_ == HybridCategory::Hybrid; } /// Return the number of continuous variables in this factor. size_t nrContinuous() const { return continuousKeys_.size(); } @@ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { template void serialize(ARCHIVE &ar, const unsigned int /*version*/) { ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); - ar &BOOST_SERIALIZATION_NVP(isDiscrete_); - ar &BOOST_SERIALIZATION_NVP(isContinuous_); - ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(category_); ar &BOOST_SERIALIZATION_NVP(discreteKeys_); ar &BOOST_SERIALIZATION_NVP(continuousKeys_); } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index cf4231dba2..99487a84ab 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) { std::make_shared>(X(0), X(1), 0, noise_model); auto one_motion = std::make_shared>(X(0), X(1), 1, noise_model); - std::vector factors = {{zero_motion, 0.0}, - {one_motion, 0.0}}; + + DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)}; + HybridNonlinearFactor::Factors factors( + discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}}); nfg.emplace_shared>(X(0), 0.0, noise_model); - nfg.emplace_shared( - KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); + nfg.emplace_shared(KeyVector{X(0), X(1)}, discreteKeys, + factors); DiscreteKey mode(M(0), 2); nfg.emplace_shared(mode, "1/1"); From 97eb6bc8b9cb758cb782c1c1eb5011f9845a33f6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Sep 2024 13:28:29 -0400 Subject: [PATCH 2/5] renaming --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 74091bf956..3621507451 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -114,10 +114,11 @@ void HybridGaussianFactorGraph::printErrors( << "\n"; } else { // Is hybrid - auto mixtureComponent = + auto conditionalComponent = hc->asMixture()->operator()(values.discrete()); - mixtureComponent->print(ss.str(), keyFormatter); - std::cout << "error = " << mixtureComponent->error(values) << "\n"; + conditionalComponent->print(ss.str(), keyFormatter); + std::cout << "error = " << conditionalComponent->error(values) + << "\n"; } } } else if (auto gf = std::dynamic_pointer_cast(factor)) { @@ -411,10 +412,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Create the HybridGaussianConditional from the conditionals HybridGaussianConditional::Conditionals conditionals( eliminationResults, [](const Result &pair) { return pair.first; }); - auto gaussianMixture = std::make_shared( + auto hybridGaussian = std::make_shared( frontalKeys, continuousSeparator, discreteSeparator, conditionals); - return {std::make_shared(gaussianMixture), newFactor}; + return {std::make_shared(hybridGaussian), newFactor}; } /* ************************************************************************ @@ -465,7 +466,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Now we will need to know how to retrieve the corresponding continuous // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO // defined order!). We also need to consider when there is pruning. Two - // mixture factors could have different pruning patterns - one could have + // hybrid factors could have different pruning patterns - one could have // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this // creates a big problem in how to identify the intersection of non-pruned // branches. From 4302ee33c96d5179f00394903869f80409246713 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Sep 2024 13:30:07 -0400 Subject: [PATCH 3/5] make None the default HybridCategory --- gtsam/hybrid/HybridFactor.cpp | 27 +++++++++++++++++++-------- gtsam/hybrid/HybridFactor.h | 4 ++-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 89b1943cdf..5582166a3e 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -55,20 +55,28 @@ HybridFactor::HybridFactor(const KeyVector &keys) continuousKeys_(keys) {} /* ************************************************************************ */ -HybridFactor::HybridFactor(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys) - : Base(CollectKeys(continuousKeys, discreteKeys)), - discreteKeys_(discreteKeys), - continuousKeys_(continuousKeys) { +HybridCategory GetCategory(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) { if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { - category_ = HybridCategory::Discrete; + return HybridCategory::Discrete; } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { - category_ = HybridCategory::Continuous; + return HybridCategory::Continuous; + } else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) { + return HybridCategory::Hybrid; } else { - category_ = HybridCategory::Hybrid; + // Case where we have no keys. Should never happen. + return HybridCategory::None; } } +/* ************************************************************************ */ +HybridFactor::HybridFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) + : Base(CollectKeys(continuousKeys, discreteKeys)), + category_(GetCategory(continuousKeys, discreteKeys)), + discreteKeys_(discreteKeys), + continuousKeys_(continuousKeys) {} + /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), @@ -98,6 +106,9 @@ void HybridFactor::print(const std::string &s, case HybridCategory::Hybrid: std::cout << "Hybrid "; break; + case HybridCategory::None: + std::cout << "None "; + break; } std::cout << "["; diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 2cc7453f41..d0b9bbabe8 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -42,7 +42,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2); /// Enum to help with categorizing hybrid factors. -enum class HybridCategory { Discrete, Continuous, Hybrid }; +enum class HybridCategory { None, Discrete, Continuous, Hybrid }; /** * Base class for *truly* hybrid probabilistic factors @@ -57,7 +57,7 @@ enum class HybridCategory { Discrete, Continuous, Hybrid }; class GTSAM_EXPORT HybridFactor : public Factor { private: /// Record what category of HybridFactor this is. - HybridCategory category_; + HybridCategory category_ = HybridCategory::None; protected: // Set of DiscreteKeys for this factor. From 8cb95d5b5a7aad89fbad616cefb69d7f90b7a1e6 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 16 Sep 2024 13:31:03 -0400 Subject: [PATCH 4/5] remove redundancy from HybridConditional constructors --- gtsam/hybrid/HybridConditional.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8a8511aef9..ed8125c2b7 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -28,14 +28,9 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, const DiscreteKeys &discreteFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents) - : HybridConditional( - CollectKeys( - {continuousFrontals.begin(), continuousFrontals.end()}, - KeyVector{continuousParents.begin(), continuousParents.end()}), - CollectDiscreteKeys( - {discreteFrontals.begin(), discreteFrontals.end()}, - {discreteParents.begin(), discreteParents.end()}), - continuousFrontals.size() + discreteFrontals.size()) {} + : HybridConditional(CollectKeys(continuousFrontals, continuousParents), + CollectDiscreteKeys(discreteFrontals, discreteParents), + continuousFrontals.size() + discreteFrontals.size()) {} /* ************************************************************************ */ HybridConditional::HybridConditional( @@ -56,9 +51,7 @@ HybridConditional::HybridConditional( /* ************************************************************************ */ HybridConditional::HybridConditional( const std::shared_ptr &gaussianMixture) - : BaseFactor(KeyVector(gaussianMixture->keys().begin(), - gaussianMixture->keys().begin() + - gaussianMixture->nrContinuous()), + : BaseFactor(gaussianMixture->continuousKeys(), gaussianMixture->discreteKeys()), BaseConditional(gaussianMixture->nrFrontals()) { inner_ = gaussianMixture; From 4feec4ddaf070e0f7f0974ca458a1ad50f4ed386 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 18 Sep 2024 04:18:23 -0400 Subject: [PATCH 5/5] rename to Category and put inside HybridFactor class --- gtsam/hybrid/HybridFactor.cpp | 26 ++++++++++++-------------- gtsam/hybrid/HybridFactor.h | 15 ++++++++------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 5582166a3e..3338951bf0 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,22 +50,20 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), - category_(HybridCategory::Continuous), - continuousKeys_(keys) {} + : Base(keys), category_(Category::Continuous), continuousKeys_(keys) {} /* ************************************************************************ */ -HybridCategory GetCategory(const KeyVector &continuousKeys, - const DiscreteKeys &discreteKeys) { +HybridFactor::Category GetCategory(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) { if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) { - return HybridCategory::Discrete; + return HybridFactor::Category::Discrete; } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) { - return HybridCategory::Continuous; + return HybridFactor::Category::Continuous; } else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) { - return HybridCategory::Hybrid; + return HybridFactor::Category::Hybrid; } else { // Case where we have no keys. Should never happen. - return HybridCategory::None; + return HybridFactor::Category::None; } } @@ -80,7 +78,7 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys, /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), - category_(HybridCategory::Discrete), + category_(Category::Discrete), discreteKeys_(discreteKeys), continuousKeys_({}) {} @@ -97,16 +95,16 @@ void HybridFactor::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << (s.empty() ? "" : s + "\n"); switch (category_) { - case HybridCategory::Continuous: + case Category::Continuous: std::cout << "Continuous "; break; - case HybridCategory::Discrete: + case Category::Discrete: std::cout << "Discrete "; break; - case HybridCategory::Hybrid: + case Category::Hybrid: std::cout << "Hybrid "; break; - case HybridCategory::None: + case Category::None: std::cout << "None "; break; } diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index d0b9bbabe8..ad29dfdca9 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -41,9 +41,6 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, const DiscreteKeys &key2); -/// Enum to help with categorizing hybrid factors. -enum class HybridCategory { None, Discrete, Continuous, Hybrid }; - /** * Base class for *truly* hybrid probabilistic factors * @@ -55,9 +52,13 @@ enum class HybridCategory { None, Discrete, Continuous, Hybrid }; * @ingroup hybrid */ class GTSAM_EXPORT HybridFactor : public Factor { + public: + /// Enum to help with categorizing hybrid factors. + enum class Category { None, Discrete, Continuous, Hybrid }; + private: /// Record what category of HybridFactor this is. - HybridCategory category_ = HybridCategory::None; + Category category_ = Category::None; protected: // Set of DiscreteKeys for this factor. @@ -118,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// @{ /// True if this is a factor of discrete variables only. - bool isDiscrete() const { return category_ == HybridCategory::Discrete; } + bool isDiscrete() const { return category_ == Category::Discrete; } /// True if this is a factor of continuous variables only. - bool isContinuous() const { return category_ == HybridCategory::Continuous; } + bool isContinuous() const { return category_ == Category::Continuous; } /// True is this is a Discrete-Continuous factor. - bool isHybrid() const { return category_ == HybridCategory::Hybrid; } + bool isHybrid() const { return category_ == Category::Hybrid; } /// Return the number of continuous variables in this factor. size_t nrContinuous() const { return continuousKeys_.size(); }