Skip to content

Commit

Permalink
fix testGaussianMixture
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jan 2, 2025
1 parent e620729 commit d18569b
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions gtsam/hybrid/tests/testGaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
Expand Down Expand Up @@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double midway = mu1 - mu0;
auto eliminationResult =
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
auto pMid = eliminationResult->at(0)->asDiscrete();
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), *pMid));
auto pMid = eliminationResult->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT(assert_equal(DiscreteTableConditional(m, "60/40"), *pMid));

// Everywhere else, the result should be a sigmoid.
for (const double shift : {-4, -2, 0, 2, 4}) {
Expand All @@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
// Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);

// Workflow 2: directly specify HFG and solve
Expand All @@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
hfg1.push_back(mixing);
auto eliminationResult2 = hfg1.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
Expand Down Expand Up @@ -138,8 +141,9 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT(assert_equal(expectedDiscretePosterior,
eliminationResultMax->discretePosterior(vv)));

auto pMax = *eliminationResultMax->at(0)->asDiscrete();
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
auto pMax =
*eliminationResultMax->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT(assert_equal(DiscreteTableConditional(m, "42/58"), pMax, 1e-4));

// Everywhere else, the result should be a bell curve like function.
for (const double shift : {-4, -2, 0, 2, 4}) {
Expand All @@ -149,7 +153,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
// Workflow 1: convert HBN to HFG and solve
auto eliminationResult1 =
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
auto posterior1 =
*eliminationResult1->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);

// Workflow 2: directly specify HFG and solve
Expand All @@ -158,7 +163,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
hfg.push_back(mixing);
auto eliminationResult2 = hfg.eliminateSequential();
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
auto posterior2 =
*eliminationResult2->at(0)->asDiscrete<DiscreteTableConditional>();
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
Expand Down

0 comments on commit d18569b

Please sign in to comment.