diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c56818448a..d1b68f4bfb 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -147,14 +147,14 @@ namespace gtsam { size_t i; ADT result(*this); for (i = 0; i < nrFrontals; i++) { - Key j = keys()[i]; + Key j = keys_[i]; result = result.combine(j, cardinality(j), op); } - // create new factor, note we start keys after nrFrontals + // Create new factor, note we start with keys after nrFrontals: DiscreteKeys dkeys; - for (; i < keys().size(); i++) { - Key j = keys()[i]; + for (; i < keys_.size(); i++) { + Key j = keys_[i]; dkeys.push_back(DiscreteKey(j, cardinality(j))); } return std::make_shared(dkeys, result); @@ -179,24 +179,22 @@ namespace gtsam { result = result.combine(j, cardinality(j), op); } - // create new factor, note we collect keys that are not in frontalKeys /* - Due to branch merging, the labels in `result` may be missing some keys + Create new factor, note we collect keys that are not in frontalKeys. + + Due to branch merging, the labels in `result` may be missing some keys. E.g. After branch merging, we may get a ADT like: Leaf [2] 1.0204082 - This is missing the key values used for branching. + Hence, code below traverses the original keys and omits those in + frontalKeys. We loop over cardinalities, which is O(n) even for a map, and + then "contains" is a binary search on a small vector. */ - KeyVector difference, frontalKeys_(frontalKeys), keys_(keys()); - // Get the difference of the frontalKeys and the factor keys using set_difference - std::sort(keys_.begin(), keys_.end()); - std::sort(frontalKeys_.begin(), frontalKeys_.end()); - std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(), - frontalKeys_.end(), back_inserter(difference)); - DiscreteKeys dkeys; - for (Key key : difference) { - dkeys.push_back(DiscreteKey(key, cardinality(key))); + for (auto&& [key, cardinality] : cardinalities_) { + if (!frontalKeys.contains(key)) { + dkeys.push_back(DiscreteKey(key, cardinality)); + } } return std::make_shared(dkeys, result); }