diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 74eb3ddb38..2be8e077de 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -64,7 +64,7 @@ TableFactor::TableFactor(const DiscreteConditional& c) Eigen::SparseVector TableFactor::Convert( const std::vector& table) { Eigen::SparseVector sparse_table(table.size()); - // Count number of nonzero elements in table and reserving the space. + // Count number of nonzero elements in table and reserve the space. const uint64_t nnz = std::count_if(table.begin(), table.end(), [](uint64_t i) { return i != 0; }); sparse_table.reserve(nnz); @@ -218,6 +218,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; } +/* ************************************************************************ */ +TableFactor TableFactor::apply(Unary op) const { + // Initialize new factor. + uint64_t cardi = 1; + for (auto [key, c] : cardinalities_) cardi *= c; + Eigen::SparseVector sparse_table(cardi); + sparse_table.reserve(sparse_table_.nonZeros()); + + // Populate + for (SparseIt it(sparse_table_); it; ++it) { + sparse_table.coeffRef(it.index()) = op(it.value()); + } + + // Free unused memory and return. + sparse_table.pruned(); + sparse_table.data().squeeze(); + return TableFactor(discreteKeys(), sparse_table); +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(UnaryAssignment op) const { + // Initialize new factor. + uint64_t cardi = 1; + for (auto [key, c] : cardinalities_) cardi *= c; + Eigen::SparseVector sparse_table(cardi); + sparse_table.reserve(sparse_table_.nonZeros()); + + // Populate + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + sparse_table.coeffRef(it.index()) = op(assignment, it.value()); + } + + // Free unused memory and return. + sparse_table.pruned(); + sparse_table.data().squeeze(); + return TableFactor(discreteKeys(), sparse_table); +} + /* ************************************************************************ */ TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { if (keys_.empty() && sparse_table_.nonZeros() == 0) diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index bd637bb7d3..981e1507b4 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { typedef std::shared_ptr shared_ptr; typedef Eigen::SparseVector::InnerIterator SparseIt; typedef std::vector> AssignValList; + using Unary = std::function; + using UnaryAssignment = + std::function&, const double&)>; using Binary = std::function; public: @@ -218,6 +221,18 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { /// @name Advanced Interface /// @{ + /** + * Apply unary operator `op(*this)` where `op` accepts the discrete value. + * @param op a unary operator that operates on TableFactor + */ + TableFactor apply(Unary op) const; + /** + * Apply unary operator `op(*this)` where `op` accepts the discrete assignment + * and the value at that assignment. + * @param op a unary operator that operates on TableFactor + */ + TableFactor apply(UnaryAssignment op) const; + /** * Apply binary operator (*this) "op" f * @param f the second argument for op @@ -225,10 +240,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ TableFactor apply(const TableFactor& f, Binary op) const; - /// Return keys in contract mode. + /** + * Return keys in contract mode. + * + * Modes are each of the dimensions of a sparse tensor, + * and the contract modes represent which dimensions will + * be involved in contraction (aka tensor multiplication). + */ DiscreteKeys contractDkeys(const TableFactor& f) const; - /// Return keys in free mode. + /** + * @brief Return keys in free mode which are the dimensions + * not involved in the contraction operation. + */ DiscreteKeys freeDkeys(const TableFactor& f) const; /// Return union of DiscreteKeys in two factors. diff --git a/gtsam/discrete/tests/testTableFactor.cpp b/gtsam/discrete/tests/testTableFactor.cpp index b307d78f6a..e85e4254c8 100644 --- a/gtsam/discrete/tests/testTableFactor.cpp +++ b/gtsam/discrete/tests/testTableFactor.cpp @@ -93,8 +93,7 @@ void printTime(map> for (auto&& kv : measured_time) { cout << "dropout: " << kv.first << " | TableFactor time: " << kv.second.first.count() - << " | DecisionTreeFactor time: " << kv.second.second.count() << - endl; + << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; } } @@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(TableFactor, Unary) { + // Declare a bunch of keys + DiscreteKey X(0, 2), Y(1, 3); + + // Create factors + TableFactor f(X & Y, "2 5 3 6 2 7"); + auto op = [](const double x) { return 2 * x; }; + auto g = f.apply(op); + + TableFactor expected(X & Y, "4 10 6 12 4 14"); + EXPECT(assert_equal(g, expected)); + + auto sq_op = [](const double x) { return x * x; }; + auto g_sq = f.apply(sq_op); + TableFactor expected_sq(X & Y, "4 25 9 36 4 49"); + EXPECT(assert_equal(g_sq, expected_sq)); +} + +/* ************************************************************************* */ +TEST(TableFactor, UnaryAssignment) { + // Declare a bunch of keys + DiscreteKey X(0, 2), Y(1, 3); + + // Create factors + TableFactor f(X & Y, "2 5 3 6 2 7"); + auto op = [](const Assignment& key, const double x) { return 2 * x; }; + auto g = f.apply(op); + + TableFactor expected(X & Y, "4 10 6 12 4 14"); + EXPECT(assert_equal(g, expected)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 2855dfb5e0..c1d57715ec 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -176,6 +176,7 @@ class HybridGaussianFactorGraph { void push_back(const gtsam::HybridBayesTree& bayesTree); void push_back(const gtsam::GaussianMixtureFactor* gmm); void push_back(gtsam::DecisionTreeFactor* factor); + void push_back(gtsam::TableFactor* factor); void push_back(gtsam::JacobianFactor* factor); bool empty() const;