Skip to content

Commit

Permalink
Merge pull request #1947 from borglab/discrete-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Jan 2, 2025
2 parents 46f6cf0 + ebd523e commit 6c516cc
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 48 deletions.
16 changes: 7 additions & 9 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,23 @@ namespace gtsam {
// }

/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
* @brief Multiply all the `factors`.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor
*/
static DecisionTreeFactor ProductAndNormalize(
static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product = factors.product();
gttoc(product);

// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
auto denominator = product.max(product.size());

// Normalize the product factor to prevent underflow.
product = product / (*normalization);
product = product / (*denominator);

return product;
}
Expand All @@ -139,7 +138,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = DiscreteProduct(factors);

// max out frontals, this is the factor on the separator
gttic(max);
Expand Down Expand Up @@ -207,8 +206,7 @@ namespace gtsam {
return dag.argmax();
}

DiscreteValues DiscreteFactorGraph::optimize(
const Ordering& ordering) const {
DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax();
Expand All @@ -218,7 +216,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors);
DecisionTreeFactor product = DiscreteProduct(factors);

// sum out frontals, this is the factor on the separator
gttic(sum);
Expand Down
37 changes: 4 additions & 33 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();

// Record key assignment and value pairs in pair_table.
// The assignments are stored in descending order of keys so that the order of
// the values matches what is expected by a DecisionTree.
// This is why we reverse the keys and then
// query for the key value/assignment.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) {
std::stringstream ss;
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k;
ss >> k;
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
}

// Sort the pair_table (of assignment-value pairs) based on assignment so we
// get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first < b.first; });

// Create the table vector by extracting the values from pair_table.
// The pair_table has already been sorted in the desired order,
// so the values will be in descending key order.
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));
}

AlgebraicDecisionTree<Key> tree(rdkeys, table);
AlgebraicDecisionTree<Key> tree(dkeys, table);
DecisionTreeFactor f(dkeys, tree);
return f;
}
Expand Down
10 changes: 5 additions & 5 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);

// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());
auto normalizer = newFactor.max(newFactor.size());

newFactor = newFactor / *normalization;
newFactor = newFactor / *normalizer;

// Check Conditional
CHECK(conditional);
Expand All @@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization;
normalizer = expectedFactor.max(expectedFactor.size());
// Ensure normalizer is correct.
expectedFactor = expectedFactor / *normalizer;
EXPECT(assert_equal(expectedFactor, newFactor));

// Test using elimination tree
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(hc->asDiscrete());
dfg.push_back(dc);
} else {
throwRuntimeError("discreteElimination", f);
}
Expand Down

0 comments on commit 6c516cc

Please sign in to comment.