Skip to content

Commit

Permalink
Uniform sampling (#31)
Browse files Browse the repository at this point in the history
Sampling uniformly
  • Loading branch information
ognian- authored Oct 19, 2022
1 parent 5bac38f commit 9cfefad
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 154 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ add_library(larch
src/mutation_annotated_dag.cpp
src/node_label.cpp
src/node_storage.cpp
src/node.cpp
src/post_order_iterator.cpp
src/pre_order_iterator.cpp)
larch_compile_opts(larch)
Expand Down
80 changes: 80 additions & 0 deletions include/impl/node_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,86 @@
// Functions defined here are documented where declared in `include/node.hpp`
#include <range/v3/view/join.hpp>

template <typename T>
NodeView<T>::NodeView(T dag, NodeId id) : dag_{dag}, id_{id} {
static_assert(std::is_same_v<T, DAG&> or std::is_same_v<T, const DAG&>);
Assert(id.value != NoId);
Assert(id.value < dag_.nodes_.size());
}

template <typename T>
NodeView<T>::operator Node() const {
return {dag_, id_};
}

template <typename T>
NodeView<T>::operator NodeId() const {
return id_;
}

template <typename T>
T NodeView<T>::GetDAG() const {
return dag_;
}

template <typename T>
NodeId NodeView<T>::GetId() const {
return id_;
}

template <typename T>
typename NodeView<T>::EdgeType NodeView<T>::GetSingleParent() const {
Assert(GetParents().size() == 1);
return *GetParents().begin();
}

template <typename T>
bool NodeView<T>::IsRoot() const {
return GetStorage().GetParents().empty();
}

template <typename T>
bool NodeView<T>::IsLeaf() const {
if (GetClades().empty()) {
return true;
}
auto children = GetChildren();
return children.begin() == children.end();
}

template <typename T>
void NodeView<T>::AddParentEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().AddEdge(edge.GetClade(), edge.GetId(), false);
}
}

template <typename T>
void NodeView<T>::AddChildEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().AddEdge(edge.GetClade(), edge.GetId(), true);
}
}

template <typename T>
void NodeView<T>::RemoveParentEdge(Edge edge) const {
if constexpr (is_mutable) {
GetStorage().RemoveEdge(edge, false);
}
}

template <typename T>
const std::optional<std::string>& NodeView<T>::GetSampleId() const {
return GetStorage().GetSampleId();
}

template <typename T>
void NodeView<T>::SetSampleId(std::optional<std::string>&& sample_id) {
if constexpr (is_mutable) {
GetStorage().SetSampleId(std::forward<std::optional<std::string>>(sample_id));
}
}

template <typename T>
auto NodeView<T>::GetParents() const {
return GetStorage().GetParents() | Transform::ToEdges(dag_);
Expand Down
67 changes: 51 additions & 16 deletions include/impl/subtree_weight_impl.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <algorithm>
#include <type_traits>

template <typename WeightOps>
SubtreeWeight<WeightOps>::SubtreeWeight(const MADAG& dag)
Expand Down Expand Up @@ -54,20 +55,35 @@ MADAG SubtreeWeight<WeightOps>::TrimToMinWeight(WeightOps&& weight_ops) {
template <typename WeightOps>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::SampleTree(
WeightOps&& weight_ops) {
MADAG result{dag_.GetReferenceSequence()};
std::vector<NodeId> result_dag_ids;

ExtractTree(
dag_, dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops),
[this](Node node, CladeIdx clade_idx) {
auto clade = node.GetClade(clade_idx);
Assert(not clade.empty());
std::uniform_int_distribution<size_t> distribuition{0, clade.size() - 1};
return clade.at(distribuition(random_generator_));
},
result, result_dag_ids);
return SampleTreeImpl(std::forward<WeightOps>(weight_ops), [](auto clade) {
return std::uniform_int_distribution<size_t>{0, clade.size() - 1};
});
}

return {std::move(result), std::move(result_dag_ids)};
struct TreeCount;
template <typename WeightOps>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::UniformSampleTree(
WeightOps&& weight_ops) {
static_assert(std::is_same_v<std::decay_t<WeightOps>, TreeCount>,
"UniformSampleTree needs TreeCount");
// Ensure cache is filled
ComputeWeightBelow(dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops));
return SampleTreeImpl(
std::forward<WeightOps>(weight_ops), [this, &weight_ops](auto clade) {
std::vector<double> probabilities;
typename WeightOps::Weight sum{};
for (NodeId child : clade | Transform::GetChild()) {
sum += cached_weights_.at(child.value).value();
}
if (sum > 0) {
for (NodeId child : clade | Transform::GetChild()) {
probabilities.push_back(
static_cast<double>(cached_weights_.at(child.value).value() / sum));
}
}
return std::discrete_distribution<size_t>{probabilities.begin(),
probabilities.end()};
});
}

template <typename WeightOps>
Expand All @@ -94,6 +110,25 @@ typename WeightOps::Weight SubtreeWeight<WeightOps>::CladeWeight(
return clade_result.first;
}

template <typename WeightOps>
template <typename DistributionMaker>
std::pair<MADAG, std::vector<NodeId>> SubtreeWeight<WeightOps>::SampleTreeImpl(
WeightOps&& weight_ops, DistributionMaker&& distribution_maker) {
MADAG result{dag_.GetReferenceSequence()};
std::vector<NodeId> result_dag_ids;

ExtractTree(
dag_, dag_.GetDAG().GetRoot(), std::forward<WeightOps>(weight_ops),
[this, &distribution_maker](Node node, CladeIdx clade_idx) {
auto clade = node.GetClade(clade_idx);
Assert(not clade.empty());
return clade.at(distribution_maker(clade)(random_generator_));
},
result, result_dag_ids);

return {std::move(result), std::move(result_dag_ids)};
}

template <typename WeightOps>
template <typename EdgeSelector>
void SubtreeWeight<WeightOps>::ExtractTree(const MADAG& input_dag, Node node,
Expand Down Expand Up @@ -136,10 +171,10 @@ void SubtreeWeight<WeightOps>::ExtractTree(const MADAG& input_dag, Node node,

for (auto node : result.GetDAG().GetNodes()) {
size_t idx = node.GetId().value;
std::optional<std::string> old_sample_id =
const std::optional<std::string>& old_sample_id =
input_dag.GetDAG().GetNodes().at(idx).GetSampleId();
if (node.IsLeaf() and (bool) old_sample_id) {
node.SetSampleId(old_sample_id);
if (node.IsLeaf() and old_sample_id.has_value()) {
node.SetSampleId(std::optional<std::string>{old_sample_id});
}
}
}
29 changes: 24 additions & 5 deletions include/leaf_set.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include "common.hpp"
#include "compact_genome.hpp"

class CompactGenome;
class NodeLabel;

/**
Expand All @@ -27,9 +27,9 @@ class LeafSet {
LeafSet(Node node, const std::vector<NodeLabel>& labels,
std::vector<LeafSet>& computed_leafsets);

LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades);
inline LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades);

bool operator==(const LeafSet& rhs) const noexcept;
inline bool operator==(const LeafSet& rhs) const noexcept;

[[nodiscard]] size_t Hash() const noexcept;

Expand All @@ -43,8 +43,8 @@ class LeafSet {
const std::vector<std::vector<const CompactGenome*>>& GetClades() const;

private:
static size_t ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades);
inline static size_t ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) noexcept;
};

template <>
Expand All @@ -58,3 +58,22 @@ struct std::equal_to<LeafSet> {
return lhs == rhs;
}
};

bool LeafSet::operator==(const LeafSet& rhs) const noexcept {
return clades_ == rhs.clades_;
}

LeafSet::LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades)
: clades_{std::forward<std::vector<std::vector<const CompactGenome*>>>(clades)},
hash_{ComputeHash(clades_)} {}

size_t LeafSet::ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) noexcept {
size_t hash = 0;
for (auto& clade : clades) {
for (auto leaf : clade) {
hash = HashCombine(hash, leaf->Hash());
}
}
return hash;
}
28 changes: 14 additions & 14 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ class NodeView {
constexpr static const bool is_mutable = std::is_same_v<T, DAG&>;
using NodeType = std::conditional_t<is_mutable, MutableNode, Node>;
using EdgeType = std::conditional_t<is_mutable, MutableEdge, Edge>;
NodeView(T dag, NodeId id);
operator Node() const;
operator NodeId() const;
inline NodeView(T dag, NodeId id);
inline operator Node() const;
inline operator NodeId() const;
/**
* Return DAG-like object containing this node
*/
T GetDAG() const;
NodeId GetId() const;
inline T GetDAG() const;
inline NodeId GetId() const;
/**
* Return a range containing parent Edge objects
*/
Expand All @@ -50,29 +50,29 @@ class NodeView {
/**
* Return the count of child clades
*/
size_t GetCladesCount() const;
inline size_t GetCladesCount() const;
/**
* Return a range containing child Edges
*/
auto GetChildren() const;
/**
* Return a single parent edge of this node
*/
EdgeType GetSingleParent() const;
inline EdgeType GetSingleParent() const;
/**
* Checks if node has no parents
*/
bool IsRoot() const;
inline bool IsRoot() const;
/**
* Checks if node has no children
*/
bool IsLeaf() const;
void AddParentEdge(Edge edge) const;
void AddChildEdge(Edge edge) const;
void RemoveParentEdge(Edge edge) const;
inline bool IsLeaf() const;
inline void AddParentEdge(Edge edge) const;
inline void AddChildEdge(Edge edge) const;
inline void RemoveParentEdge(Edge edge) const;

const std::optional<std::string> GetSampleId() const;
void SetSampleId(std::optional<std::string> sample_id);
inline const std::optional<std::string>& GetSampleId() const;
inline void SetSampleId(std::optional<std::string>&& sample_id);

private:
auto& GetStorage() const;
Expand Down
4 changes: 2 additions & 2 deletions include/node_storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class NodeStorage {
*/
const std::vector<std::vector<EdgeId>>& GetClades() const;

const std::optional<std::string> GetSampleId() const;
void SetSampleId(std::optional<std::string> sample_id);
const std::optional<std::string>& GetSampleId() const;
void SetSampleId(std::optional<std::string>&& sample_id);

/**
* Remove all parent and child edges
Expand Down
7 changes: 7 additions & 0 deletions include/subtree_weight.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ class SubtreeWeight {
[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> SampleTree(
WeightOps&& weight_ops);

[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> UniformSampleTree(
WeightOps&& weight_ops);

private:
template <typename CladeRange>
typename WeightOps::Weight CladeWeight(CladeRange&& clade, WeightOps&& weight_ops);

template <typename DistributionMaker>
[[nodiscard]] std::pair<MADAG, std::vector<NodeId>> SampleTreeImpl(
WeightOps&& weight_ops, DistributionMaker&& distribution_maker);

template <typename EdgeSelector>
void ExtractTree(const MADAG& input_dag, Node node, WeightOps&& weight_ops,
EdgeSelector&& edge_selector, MADAG& result,
Expand Down
2 changes: 1 addition & 1 deletion src/dag_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ MADAG LoadTreeFromProtobuf(std::string_view path, std::string_view reference_seq
result.BuildConnections();
for (auto node : result.GetDAG().GetNodes()) {
if (node.IsLeaf()) {
node.SetSampleId(seq_ids[node.GetId().value]);
node.SetSampleId(std::move(seq_ids[node.GetId().value]));
}
}

Expand Down
19 changes: 0 additions & 19 deletions src/leaf_set.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <range/v3/range/conversion.hpp>

#include "dag.hpp"
#include "compact_genome.hpp"
#include "node_label.hpp"

const LeafSet* LeafSet::Empty() {
Expand Down Expand Up @@ -42,13 +41,6 @@ LeafSet::LeafSet(Node node, const std::vector<NodeLabel>& labels,
}()},
hash_{ComputeHash(clades_)} {}

LeafSet::LeafSet(std::vector<std::vector<const CompactGenome*>>&& clades)
: clades_{clades}, hash_{ComputeHash(clades_)} {}

bool LeafSet::operator==(const LeafSet& rhs) const noexcept {
return clades_ == rhs.clades_;
}

size_t LeafSet::Hash() const noexcept { return hash_; }

auto LeafSet::begin() const -> decltype(clades_.begin()) { return clades_.begin(); }
Expand All @@ -69,14 +61,3 @@ std::vector<const CompactGenome*> LeafSet::ToParentClade() const {
const std::vector<std::vector<const CompactGenome*>>& LeafSet::GetClades() const {
return clades_;
}

size_t LeafSet::ComputeHash(
const std::vector<std::vector<const CompactGenome*>>& clades) {
size_t hash = 0;
for (auto& clade : clades) {
for (auto leaf : clade) {
hash = HashCombine(hash, leaf->Hash());
}
}
return hash;
}
Loading

0 comments on commit 9cfefad

Please sign in to comment.