Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DiscreteFactor multiply method #1963

Merged
merged 14 commits into from
Jan 6, 2025
Merged
15 changes: 15 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ namespace gtsam {
return error(values.discrete());
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>((*tf) * TableFactor(*this));
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<DecisionTreeFactor>(this->operator*(*dtf));
} else {
// Simulate double dispatch in C++
dellaert marked this conversation as resolved.
Show resolved Hide resolved
result = std::make_shared<DecisionTreeFactor>(f->operator*(*this));
}
return result;
}

/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
Expand Down
5 changes: 5 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h>
#include <gtsam/discrete/TableFactor.h>
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
#include <gtsam/inference/Ordering.h>

#include <algorithm>
Expand Down Expand Up @@ -147,6 +148,10 @@ namespace gtsam {
/// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

/// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
return apply(f, Ring::mul);
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,16 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

/**
* @brief Multiply in a DiscreteFactor and return the result as
* DiscreteFactor, both via shared pointers.
*
* @param df DiscreteFactor shared_ptr
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr multiply(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteFactor::shared_ptr& df) const = 0;

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
Expand Down
15 changes: 11 additions & 4 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,18 @@ namespace gtsam {

/* ************************************************************************ */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
for (const sharedFactor& factor : *this) {
if (factor) result = (*factor) * result;
DiscreteFactor::shared_ptr result;
for (auto it = this->begin(); it != this->end(); ++it) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
if (*it) {
if (result) {
result = result->multiply(*it);
} else {
// Assign to the first non-null factor
result = *it;
}
}
}
return result;
return result->toDecisionTreeFactor();
dellaert marked this conversation as resolved.
Show resolved Hide resolved
}

/* ************************************************************************ */
Expand Down
16 changes: 16 additions & 0 deletions gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
return toDecisionTreeFactor() * f;
}

/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::multiply(
const DiscreteFactor::shared_ptr& f) const {
DiscreteFactor::shared_ptr result;
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(*tf));
dellaert marked this conversation as resolved.
Show resolved Hide resolved
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
result = std::make_shared<TableFactor>(this->operator*(TableFactor(*dtf)));
} else {
// Simulate double dispatch in C++
result = std::make_shared<DecisionTreeFactor>(
f->operator*(this->toDecisionTreeFactor()));
}
return result;
}

/* ************************************************************************ */
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
Expand Down
4 changes: 4 additions & 0 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
virtual DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& f) const override;

static double safe_div(const double& a, const double& b);

/// divide by factor f (safely)
Expand Down
7 changes: 7 additions & 0 deletions gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
varunagrawal marked this conversation as resolved.
Show resolved Hide resolved
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}

/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
Expand Down
7 changes: 7 additions & 0 deletions gtsam_unstable/discrete/BinaryAllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ class BinaryAllDiff : public Constraint {
return toDecisionTreeFactor() * f;
}

/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}

/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
Expand Down
7 changes: 7 additions & 0 deletions gtsam_unstable/discrete/Domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}

/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked
Expand Down
7 changes: 7 additions & 0 deletions gtsam_unstable/discrete/SingleValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;

/// Multiply factors, DiscreteFactor::shared_ptr edition
DiscreteFactor::shared_ptr multiply(
const DiscreteFactor::shared_ptr& df) const override {
return std::make_shared<DecisionTreeFactor>(
this->operator*(df->toDecisionTreeFactor()));
}

/*
* Ensure Arc-consistency: just sets domain[j] to {value_}.
* @param j domain to be checked
Expand Down
Loading