Skip to content

Commit

Permalink
Merge branch 'develop' into direct-hybrid-fg
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Sep 19, 2024
2 parents f875b86 + 63c4e33 commit 2937533
Show file tree
Hide file tree
Showing 29 changed files with 198 additions and 243 deletions.
16 changes: 8 additions & 8 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the Gaussian mixture.
// and the hybrid Gaussian conditional.
std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
Expand All @@ -63,7 +63,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(

// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the Gaussian mixture has the same
// Case where the hybrid Gaussian conditional has the same
// discrete keys as the decision tree.
if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDiscreteProbs(values) == 0) {
Expand Down Expand Up @@ -180,8 +180,8 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
if (auto gm = conditional->asHybrid()) {
// Make a copy of the hybrid Gaussian conditional and prune it!
auto prunedHybridGaussianConditional =
std::make_shared<HybridGaussianConditional>(*gm);
prunedHybridGaussianConditional->prune(
Expand All @@ -204,7 +204,7 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment.
gbn.push_back((*gm)(assignment));
} else if (auto gc = conditional->asGaussian()) {
Expand Down Expand Up @@ -291,7 +291,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, compute error for all assignments.
result = result + gm->errorTree(continuousValues);

Expand Down Expand Up @@ -321,7 +321,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
result = result + gm->logProbability(continuousValues);
Expand Down Expand Up @@ -369,7 +369,7 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
if (conditional->frontalsIn(measurements)) {
if (auto gc = conditional->asGaussian()) {
fg.push_back(gc->likelihood(measurements));
} else if (auto gm = conditional->asMixture()) {
} else if (auto gm = conditional->asHybrid()) {
fg.push_back(gm->likelihood(measurements));
} else {
throw std::runtime_error("Unknown conditional type");
Expand Down
3 changes: 2 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace gtsam {

/**
* A hybrid Bayes net is a collection of HybridConditionals, which can have
* discrete conditionals, Gaussian mixtures, or pure Gaussian conditionals.
* discrete conditionals, hybrid Gaussian conditionals,
* or pure Gaussian conditionals.
*
* @ingroup hybrid
*/
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct HybridAssignmentData {

GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
conditional = (*hybrid_conditional->asHybrid())(parentData.assignment_);
} else if (hybrid_conditional->isContinuous()) {
conditional = hybrid_conditional->asGaussian();
} else {
Expand Down Expand Up @@ -205,9 +205,9 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {

// If conditional is hybrid, we prune it.
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
auto hybridGaussianCond = conditional->asHybrid();

gaussianMixture->prune(parentData.prunedDiscreteProbs);
hybridGaussianCond->prune(parentData.prunedDiscreteProbs);
}
return parentData;
}
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
* @param assignment The discrete values assignment to select the Gaussian
* mixtures.
* @param assignment The discrete values assignment to select
* the hybrid conditional.
* @return VectorValues
*/
VectorValues optimize(const DiscreteValues& assignment) const;
Expand Down Expand Up @@ -170,7 +170,7 @@ class BayesTreeOrphanWrapper<HybridBayesTreeClique> : public HybridConditional {
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
clique->print(s + "stored clique", formatter);
clique->print(s + " stored clique ", formatter);
}
};

Expand Down
20 changes: 10 additions & 10 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ HybridConditional::HybridConditional(

/* ************************************************************************ */
HybridConditional::HybridConditional(
const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
: BaseFactor(gaussianMixture->continuousKeys(),
gaussianMixture->discreteKeys()),
BaseConditional(gaussianMixture->nrFrontals()) {
inner_ = gaussianMixture;
const std::shared_ptr<HybridGaussianConditional> &hybridGaussianCond)
: BaseFactor(hybridGaussianCond->continuousKeys(),
hybridGaussianCond->discreteKeys()),
BaseConditional(hybridGaussianCond->nrFrontals()) {
inner_ = hybridGaussianCond;
}

/* ************************************************************************ */
Expand Down Expand Up @@ -97,8 +97,8 @@ void HybridConditional::print(const std::string &s,
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
if (e == nullptr) return false;
if (auto gm = asMixture()) {
auto other = e->asMixture();
if (auto gm = asHybrid()) {
auto other = e->asHybrid();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
Expand All @@ -119,7 +119,7 @@ double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto gm = asMixture()) {
if (auto gm = asHybrid()) {
return gm->error(values);
}
if (auto dc = asDiscrete()) {
Expand All @@ -134,7 +134,7 @@ double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->logProbability(values.continuous());
}
if (auto gm = asMixture()) {
if (auto gm = asHybrid()) {
return gm->logProbability(values);
}
if (auto dc = asDiscrete()) {
Expand All @@ -149,7 +149,7 @@ double HybridConditional::logNormalizationConstant() const {
if (auto gc = asGaussian()) {
return gc->logNormalizationConstant();
}
if (auto gm = asMixture()) {
if (auto gm = asHybrid()) {
return gm->logNormalizationConstant(); // 0.0!
}
if (auto dc = asDiscrete()) {
Expand Down
8 changes: 4 additions & 4 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class GTSAM_EXPORT HybridConditional
/**
* @brief Construct a new Hybrid Conditional object
*
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* @param hybridGaussianCond Hybrid Gaussian Conditional used to create the
* HybridConditional.
*/
HybridConditional(
const std::shared_ptr<HybridGaussianConditional>& gaussianMixture);
const std::shared_ptr<HybridGaussianConditional>& hybridGaussianCond);

/// @}
/// @name Testable
Expand All @@ -148,10 +148,10 @@ class GTSAM_EXPORT HybridConditional

/**
* @brief Return HybridConditional as a HybridGaussianConditional
* @return nullptr if not a mixture
* @return nullptr if not a conditional
* @return HybridGaussianConditional::shared_ptr otherwise
*/
HybridGaussianConditional::shared_ptr asMixture() const {
HybridGaussianConditional::shared_ptr asHybrid() const {
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
}

Expand Down
12 changes: 6 additions & 6 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,20 @@ std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
// and the hybrid gaussian conditional.
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());

auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
const DiscreteValues values(choices);

// Case where the gaussian mixture has the same
// Case where the hybrid gaussian conditional has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
if (discreteProbs(values) == 0.0) {
// empty aka null pointer
std::shared_ptr<GaussianConditional> null;
Expand All @@ -273,7 +273,7 @@ HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
std::back_inserter(set_diff));

const std::vector<DiscreteValues> assignments =
Expand Down
16 changes: 8 additions & 8 deletions gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace gtsam {
class HybridValues;

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a
* @brief A conditional of gaussian conditionals indexed by discrete variables,
* as part of a Bayes Network. This is the result of the elimination of a
* continuous variable in a hybrid scheme, such that the remaining variables are
* discrete+continuous.
*
Expand Down Expand Up @@ -107,7 +107,7 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);

/**
* @brief Make a Gaussian Mixture from a vector of Gaussian conditionals.
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
Expand Down Expand Up @@ -152,8 +152,8 @@ class GTSAM_EXPORT HybridGaussianConditional
double logNormalizationConstant() const override { return logConstant_; }

/**
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
* on the parents.
* Create a likelihood factor for a hybrid Gaussian conditional,
* return a hybrid Gaussian factor on the parents.
*/
std::shared_ptr<HybridGaussianFactor> likelihood(
const VectorValues &given) const;
Expand All @@ -172,9 +172,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture.
* @brief Compute the error of this hybrid Gaussian conditional.
*
* This requires some care, as different mixture components may have
* This requires some care, as different components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
Expand Down Expand Up @@ -209,7 +209,7 @@ class GTSAM_EXPORT HybridGaussianConditional
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
* @brief Compute the logProbability of this hybrid Gaussian conditional.
*
* @param values Continuous values and discrete assignment.
* @return double
Expand Down
12 changes: 6 additions & 6 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class VectorValues;
using GaussianFactorValuePair = std::pair<GaussianFactor::shared_ptr, double>;

/**
* @brief Implementation of a discrete conditional mixture factor.
* @brief Implementation of a discrete-conditioned hybrid factor.
* Implements a joint discrete-continuous factor where the discrete variable
* serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
* serves to "select" a component corresponding to a GaussianFactor.
*
* Represents the underlying Gaussian mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution.
* Represents the underlying hybrid Gaussian components as a Decision Tree,
* where the set of discrete variables indexes to
* the continuous gaussian distribution.
*
* @ingroup hybrid
*/
Expand Down Expand Up @@ -80,7 +80,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
HybridGaussianFactor() = default;

/**
* @brief Construct a new Gaussian mixture factor.
* @brief Construct a new hybrid Gaussian factor.
*
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
Expand Down
13 changes: 7 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void HybridGaussianFactorGraph::printErrors(
} else {
// Is hybrid
auto conditionalComponent =
hc->asMixture()->operator()(values.discrete());
hc->asHybrid()->operator()(values.discrete());
conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
Expand Down Expand Up @@ -184,7 +184,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
if (auto gm = hc->asHybrid()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
Expand Down Expand Up @@ -437,8 +437,8 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only
// a few cases:
// 1. continuous variable, make a Gaussian Mixture if there are hybrid
// factors;
// 1. continuous variable, make a hybrid Gaussian conditional if there are
// hybrid factors;
// 2. continuous variable, we make a Gaussian Factor if there are no hybrid
// factors;
// 3. discrete variable, no continuous factor is allowed
Expand Down Expand Up @@ -550,9 +550,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
f = hc->inner();
}

if (auto gaussianMixture = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
*
* For example, if there are two mixture factors, one with a discrete key A
* For example, if there are two hybrid factors, one with a discrete key A
* and one with a discrete key B, then the decision tree will have two levels,
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
Expand Down
Loading

0 comments on commit 2937533

Please sign in to comment.