diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1ac782b884..ef7979d0aa 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -18,9 +18,10 @@ */ #include -#include #include #include +#include +#include #include @@ -62,6 +63,30 @@ namespace gtsam { return error(values.discrete()); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::multiply( + const DiscreteFactor::shared_ptr& f) const { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + // If f is a TableFactor, we convert `this` to a TableFactor since this + // conversion is cheaper than converting `f` to a DecisionTreeFactor. We + // then return a TableFactor. + result = std::make_shared((*tf) * TableFactor(*this)); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, simply call operator*. + result = std::make_shared(this->operator*(*dtf)); + + } else { + // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated. + result = std::make_shared(f->operator*(*this)); + } + return result; + } + /* ************************************************************************ */ double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 80ee10a7b2..907f29a45b 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,6 +147,23 @@ namespace gtsam { /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. Hence we specialize the code to return a TableFactor if + * `f` is a TableFactor, and DecisionTreeFactor otherwise. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + /// multiply two factors DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { return apply(f, Ring::mul); diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index a1fde0f864..c18eaae2fd 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// DecisionTreeFactor virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /** + * @brief Multiply in a DiscreteFactor and return the result as + * DiscreteFactor, both via shared pointers. + * + * @param df DiscreteFactor shared_ptr + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const = 0; + virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; /// @} diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 227bb4da37..a2b8962862 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -65,11 +65,18 @@ namespace gtsam { /* ************************************************************************ */ DecisionTreeFactor DiscreteFactorGraph::product() const { - DecisionTreeFactor result; - for (const sharedFactor& factor : *this) { - if (factor) result = (*factor) * result; + DiscreteFactor::shared_ptr result; + for (auto it = this->begin(); it != this->end(); ++it) { + if (*it) { + if (result) { + result = result->multiply(*it); + } else { + // Assign to the first non-null factor + result = *it; + } + } } - return result; + return result->toDecisionTreeFactor(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index a59095d406..fe901aac1c 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -254,6 +254,32 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { return toDecisionTreeFactor() * f; } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::multiply( + const DiscreteFactor::shared_ptr& f) const { + DiscreteFactor::shared_ptr result; + if (auto tf = std::dynamic_pointer_cast(f)) { + // If `f` is a TableFactor, we can simply call `operator*`. + result = std::make_shared(this->operator*(*tf)); + + } else if (auto dtf = std::dynamic_pointer_cast(f)) { + // If `f` is a DecisionTreeFactor, we convert to a TableFactor which is + // cheaper than converting `this` to a DecisionTreeFactor. + result = std::make_shared(this->operator*(TableFactor(*dtf))); + + } else { + // Simulate double dispatch in C++ + // Useful for other classes which inherit from DiscreteFactor and have + // only `operator*(DecisionTreeFactor)` defined. Thus, other classes don't + // need to be updated to know about TableFactor. + // Those classes can be specialized to use TableFactor + // if efficiency is a problem. + result = std::make_shared( + f->operator*(this->toDecisionTreeFactor())); + } + return result; +} + /* ************************************************************************ */ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DiscreteKeys dkeys = discreteKeys(); diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index a2fdb4d325..a2e89b3028 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -178,6 +178,23 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// multiply with DecisionTreeFactor DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /** + * @brief Multiply factors, DiscreteFactor::shared_ptr edition. + * + * This method accepts `DiscreteFactor::shared_ptr` and uses dynamic + * dispatch and specializations to perform the most efficient + * multiplication. + * + * While converting a DecisionTreeFactor to a TableFactor is efficient, the + * reverse is not. + * Hence we specialize the code to return a TableFactor always. + * + * @param f The factor to multiply with. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 3526a282d7..71ed7647a9 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -78,6 +78,14 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor { /// Partially apply known values, domain version virtual shared_ptr partiallyApply(const Domains&) const = 0; + + /// Multiply factors, DiscreteFactor::shared_ptr edition + DiscreteFactor::shared_ptr multiply( + const DiscreteFactor::shared_ptr& df) const override { + return std::make_shared( + this->operator*(df->toDecisionTreeFactor())); + } + /// @} /// @name Wrapper support /// @{