Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiscreteFactor multiply method #1963

Merged
merged 14 commits into from
Jan 6, 2025
Merged
12 changes: 12 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
}
return result;
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/TableFactor.h>
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
#include <gtsam/inference/Ordering.h>

#include <algorithm>
Expand Down Expand Up @@ -147,6 +148,10 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteFactor::shared_ptr& df) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
Expand Down
8 changes: 4 additions & 4 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
DiscreteFactor::shared_ptr result = *this->begin();
for (auto it = this->begin() + 1; it != this->end(); ++it) {
if (*it) result = result->multiply(*it);
}
return result;
return result->toDecisionTreeFactor();
dellaert marked this conversation as resolved.
Show resolved Hide resolved
}

/* ************************************************************************ */
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(*tf));
dellaert marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
}
return result;
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
Expand Down
4 changes: 4 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
Expand Down
Loading