From 7c9d04fb65b3d595a8d1b08a54d023d72cc87576 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 27 Dec 2024 12:02:21 -0500 Subject: [PATCH] conditional switch for hybrid timing --- gtsam/config.h.in | 4 +++ gtsam/discrete/DiscreteFactorGraph.cpp | 30 ++++++++++++---- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 41 ++++++++++++++++++++-- gtsam/hybrid/HybridGaussianISAM.cpp | 6 ++++ 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/gtsam/config.h.in b/gtsam/config.h.in index 8b4903d3af..db6dd2b34e 100644 --- a/gtsam/config.h.in +++ b/gtsam/config.h.in @@ -42,6 +42,10 @@ // Whether to enable merging of equal leaf nodes in the Discrete Decision Tree. #cmakedefine GTSAM_DT_MERGING +// Whether to enable timing in hybrid factor graph machinery +// #cmakedefine01 GTSAM_HYBRID_TIMING +#define GTSAM_HYBRID_TIMING + // Whether we are using TBB (if TBB was found and GTSAM_WITH_TBB is enabled in CMake) #cmakedefine GTSAM_USE_TBB diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 169259a36d..2037dd9514 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -121,15 +121,25 @@ namespace gtsam { static DecisionTreeFactor ProductAndNormalize( const DiscreteFactorGraph& factors) { // PRODUCT: multiply all factors - gttic(product); +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteProduct); +#endif DecisionTreeFactor product = factors.product(); - gttoc(product); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteProduct); +#endif // Max over all the potentials by pretending all keys are frontal: auto normalizer = product.max(product.size()); +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteNormalize); +#endif // Normalize the product factor to prevent underflow. product = product / (*normalizer); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteNormalize); +#endif return product; } @@ -220,9 +230,13 @@ namespace gtsam { DecisionTreeFactor product = ProductAndNormalize(factors); // sum out frontals, this is the factor on the separator - gttic(sum); +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteSum); +#endif DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); - gttoc(sum); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteSum); +#endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; @@ -232,10 +246,14 @@ namespace gtsam { sum->keys().end()); // now divide product/sum to get conditional - gttic(divide); +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscreteToDiscreteConditional); +#endif auto conditional = std::make_shared(product, *sum, orderedKeys); - gttoc(divide); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscreteToDiscreteConditional); +#endif return {conditional, sum}; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index b9051554a4..703684c788 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -282,14 +282,28 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } else if (auto hc = dynamic_pointer_cast(f)) { auto dc = hc->asDiscrete(); if (!dc) throwRuntimeError("discreteElimination", dc); - dfg.push_back(dc); +#if GTSAM_HYBRID_TIMING + gttic_(ConvertConditionalToTableFactor); +#endif + // Convert DiscreteConditional to TableFactor + auto tdc = std::make_shared(*dc); +#if GTSAM_HYBRID_TIMING + gttoc_(ConvertConditionalToTableFactor); +#endif + dfg.push_back(tdc); } else { throwRuntimeError("discreteElimination", f); } } +#if GTSAM_HYBRID_TIMING + gttic_(EliminateDiscrete); +#endif // NOTE: This does sum-product. For max-product, use EliminateForMPE. auto result = EliminateDiscrete(dfg, frontalKeys); +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscrete); +#endif return {std::make_shared(result.first), result.second}; } @@ -319,8 +333,19 @@ static std::shared_ptr createDiscreteFactor( } }; +#if GTSAM_HYBRID_TIMING + gttic_(DiscreteBoundaryErrors); +#endif AlgebraicDecisionTree errors(eliminationResults, calculateError); - return DiscreteFactorFromErrors(discreteSeparator, errors); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteBoundaryErrors); + gttic_(DiscreteBoundaryResult); +#endif + auto result = DiscreteFactorFromErrors(discreteSeparator, errors); +#if GTSAM_HYBRID_TIMING + gttoc_(DiscreteBoundaryResult); +#endif + return result; } /* *******************************************************************************/ @@ -360,12 +385,18 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // the discrete separator will be *all* the discrete keys. DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); +#if GTSAM_HYBRID_TIMING + gttic_(HybridCollectProductFactor); +#endif // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. Just like any hybrid // factor, every assignment also has a scalar error, in this case the sum of // all errors in the graph. This error is assignment-specific and accounts for // any difference in noise models used. HybridGaussianProductFactor productFactor = collectProductFactor(); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridCollectProductFactor); +#endif // Check if a factor is null auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; @@ -393,8 +424,14 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { return {conditional, conditional->negLogConstant(), factor, scalar}; }; +#if GTSAM_HYBRID_TIMING + gttic_(HybridEliminate); +#endif // Perform elimination! const ResultTree eliminationResults(productFactor, eliminate); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridEliminate); +#endif // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a HybridGaussianFactor diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 28116df45d..f99d95c018 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -104,7 +104,13 @@ void HybridGaussianISAM::updateInternal( elimination_ordering, function, std::cref(index)); if (maxNrLeaves) { +#if GTSAM_HYBRID_TIMING + gttic_(HybridBayesTreePrune); +#endif bayesTree->prune(*maxNrLeaves); +#if GTSAM_HYBRID_TIMING + gttoc_(HybridBayesTreePrune); +#endif } // Re-add into Bayes tree data structures