diff --git a/Cargo.toml b/Cargo.toml index 2544018b..66be29a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tskit" -version = "0.12.0" +version = "0.13.0-alpha.0" authors = ["tskit developers "] build = "build.rs" edition = "2021" @@ -50,6 +50,7 @@ pkg-config = "0.3" [features] provenance = ["humantime"] derive = ["tskit-derive", "serde", "serde_json", "bincode"] +edgebuffer = [] [package.metadata.docs.rs] all-features = true @@ -58,3 +59,7 @@ rustdoc-args = ["--cfg", "doc_cfg"] # Not run during tests [[example]] name = "tree_traversals" + +[[example]] +name = "haploid_wright_fisher_edge_buffering" +required-features = ["edgebuffer"] diff --git a/examples/haploid_wright_fisher.rs b/examples/haploid_wright_fisher.rs index 1d10b7fa..27053a20 100644 --- a/examples/haploid_wright_fisher.rs +++ b/examples/haploid_wright_fisher.rs @@ -8,12 +8,30 @@ use proptest::prelude::*; use rand::distributions::Distribution; use rand::SeedableRng; +fn rotate_edges(bookmark: &tskit::types::Bookmark, tables: &mut tskit::TableCollection) { + let num_edges = tables.edges().num_rows().as_usize(); + let left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let right = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) }; + let parent = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) }; + let child = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) }; + let mid = bookmark.edges().as_usize(); + left.rotate_left(mid); + right.rotate_left(mid); + parent.rotate_left(mid); + child.rotate_left(mid); +} + // ANCHOR: haploid_wright_fisher fn simulate( seed: u64, popsize: usize, num_generations: i32, simplify_interval: i32, + update_bookmark: bool, ) -> Result { if popsize == 0 { return Err(anyhow::Error::msg("popsize must be > 0")); @@ -46,6 +64,7 @@ fn simulate( let parent_picker = rand::distributions::Uniform::new(0, popsize); let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut bookmark = tskit::types::Bookmark::new(); for birth_time in (0..num_generations).rev() { for c in children.iter_mut() { @@ -64,7 +83,10 @@ fn simulate( } if birth_time % simplify_interval == 0 { - tables.full_sort(tskit::TableSortOptions::default())?; + tables.sort(&bookmark, tskit::TableSortOptions::default())?; + if update_bookmark { + rotate_edges(&bookmark, &mut tables); + } if let Some(idmap) = tables.simplify(children, tskit::SimplificationOptions::default(), true)? { @@ -73,6 +95,9 @@ fn simulate( *o = idmap[usize::try_from(*o)?]; } } + if update_bookmark { + bookmark.set_edges(tables.edges().num_rows()); + } } std::mem::swap(&mut parents, &mut children); } @@ -91,6 +116,8 @@ struct SimParams { num_generations: i32, simplify_interval: i32, treefile: Option, + #[clap(short, long, help = "Use bookmark to avoid sorting entire edge table.")] + bookmark: bool, } fn main() -> Result<()> { @@ -100,6 +127,7 @@ fn main() -> Result<()> { params.popsize, params.num_generations, params.simplify_interval, + params.bookmark, )?; if let Some(treefile) = ¶ms.treefile { @@ -114,8 +142,9 @@ proptest! { #[test] fn test_simulate_proptest(seed in any::(), num_generations in 50..100i32, - simplify_interval in 1..100i32) { - let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap(); + simplify_interval in 1..100i32, + bookmark in proptest::bool::ANY) { + let ts = simulate(seed, 100, num_generations, simplify_interval, bookmark).unwrap(); // stress test the branch length fn b/c it is not a trivial // wrapper around the C API. diff --git a/examples/haploid_wright_fisher_edge_buffering.rs b/examples/haploid_wright_fisher_edge_buffering.rs new file mode 100644 index 00000000..8eb2ffb0 --- /dev/null +++ b/examples/haploid_wright_fisher_edge_buffering.rs @@ -0,0 +1,144 @@ +// This is a rust implementation of the example +// found in tskit-c + +use anyhow::Result; +use clap::Parser; +#[cfg(test)] +use proptest::prelude::*; +use rand::distributions::Distribution; +use rand::SeedableRng; + +// ANCHOR: haploid_wright_fisher_edge_buffering +fn simulate( + seed: u64, + popsize: usize, + num_generations: i32, + simplify_interval: i32, +) -> Result { + if popsize == 0 { + return Err(anyhow::Error::msg("popsize must be > 0")); + } + if num_generations == 0 { + return Err(anyhow::Error::msg("num_generations must be > 0")); + } + if simplify_interval == 0 { + return Err(anyhow::Error::msg("simplify_interval must be > 0")); + } + let mut tables = tskit::TableCollection::new(1.0)?; + + // create parental nodes + let mut parents_and_children = { + let mut temp = vec![]; + let parental_time = f64::from(num_generations); + for _ in 0..popsize { + let node = tables.add_node(0, parental_time, -1, -1)?; + temp.push(node); + } + temp + }; + + // allocate space for offspring nodes + parents_and_children.resize(2 * parents_and_children.len(), tskit::NodeId::NULL); + + // Construct non-overlapping mutable slices into our vector. + let (mut parents, mut children) = parents_and_children.split_at_mut(popsize); + + let parent_picker = rand::distributions::Uniform::new(0, popsize); + let breakpoint_generator = rand::distributions::Uniform::new(0.0, 1.0); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut buffer = tskit::EdgeBuffer::default(); + + for birth_time in (0..num_generations).rev() { + for c in children.iter_mut() { + let bt = f64::from(birth_time); + let child = tables.add_node(0, bt, -1, -1)?; + let left_parent = parents + .get(parent_picker.sample(&mut rng)) + .ok_or_else(|| anyhow::Error::msg("invalid left_parent index"))?; + let right_parent = parents + .get(parent_picker.sample(&mut rng)) + .ok_or_else(|| anyhow::Error::msg("invalid right_parent index"))?; + buffer.setup_births(&[*left_parent, *right_parent], &[child])?; + let breakpoint = breakpoint_generator.sample(&mut rng); + buffer.record_birth(*left_parent, child, 0., breakpoint)?; + buffer.record_birth(*right_parent, child, breakpoint, 1.0)?; + buffer.finalize_births(); + *c = child; + } + + if birth_time % simplify_interval == 0 { + buffer.pre_simplification(&mut tables)?; + //tables.full_sort(tskit::TableSortOptions::default())?; + if let Some(idmap) = + tables.simplify(children, tskit::SimplificationOptions::default(), true)? + { + // remap child nodes + for o in children.iter_mut() { + *o = idmap[usize::try_from(*o)?]; + } + } + buffer.post_simplification(children, &mut tables)?; + } + std::mem::swap(&mut parents, &mut children); + } + + tables.build_index()?; + let treeseq = tables.tree_sequence(tskit::TreeSequenceFlags::default())?; + + Ok(treeseq) +} +// ANCHOR_END: haploid_wright_fisher_edge_buffering + +#[derive(Clone, clap::Parser)] +struct SimParams { + seed: u64, + popsize: usize, + num_generations: i32, + simplify_interval: i32, + treefile: Option, +} + +fn main() -> Result<()> { + let params = SimParams::parse(); + let treeseq = simulate( + params.seed, + params.popsize, + params.num_generations, + params.simplify_interval, + )?; + + if let Some(treefile) = ¶ms.treefile { + treeseq.dump(treefile, 0)?; + } + + Ok(()) +} + +#[cfg(test)] +proptest! { +#[test] + fn test_simulate_proptest(seed in any::(), + num_generations in 50..100i32, + simplify_interval in 1..100i32) { + let ts = simulate(seed, 100, num_generations, simplify_interval).unwrap(); + + // stress test the branch length fn b/c it is not a trivial + // wrapper around the C API. + { + use streaming_iterator::StreamingIterator; + let mut x = f64::NAN; + if let Ok(mut tree_iter) = ts.tree_iterator(0) { + // We will only do the first tree to save time. + if let Some(tree) = tree_iter.next() { + let b = tree.total_branch_length(false).unwrap(); + let b2 = unsafe { + tskit::bindings::tsk_tree_get_total_branch_length(tree.as_ptr(), -1, &mut x) + }; + assert!(b2 >= 0, "{}", b2); + assert!(f64::from(b) - x <= 1e-8); + } + } + } + } +} + diff --git a/src/edgebuffer.rs b/src/edgebuffer.rs new file mode 100644 index 00000000..7abef80e --- /dev/null +++ b/src/edgebuffer.rs @@ -0,0 +1,504 @@ +use crate::NodeId; +use crate::Position; +use crate::TableCollection; +use crate::TskitError; + +// Design considerations: +// +// We should be able to do better than +// the fwdpp implementation by taking a +// time-sorted list of alive nodes and inserting +// their edges. +// After insertion, we can truncate the input +// edge table, eliminating all edges corresponding +// to the set of alive nodes. +// This procedure would only be done AFTER +// simplification, such that the copied +// edges are guaranteed correct. +// We'd need to hash the existence of these alive nodes. +// Then, when going over the edge buffer, we can ask +// if an edge parent is in the hashed set. +// We would also keep track of the smallest +// edge id, and that (maybe minus 1?) is our truncation point. + +fn swap_with_empty(vec: &mut Vec) { + let mut t = vec![]; + std::mem::swap(&mut t, vec); +} + +#[derive(Copy, Clone)] +struct AliveNodeTimes { + min: f64, + max: f64, +} + +impl AliveNodeTimes { + fn new(min: f64, max: f64) -> Self { + Self { min, max } + } + + fn non_overlapping(&self) -> bool { + self.min == self.max + } +} + +#[derive(Debug)] +struct PreExistingEdge { + first: usize, + last: usize, +} + +impl PreExistingEdge { + fn new(first: usize, last: usize) -> Self { + assert!(last > first); + Self { first, last } + } +} + +#[derive(Debug)] +struct Segment { + left: Position, + right: Position, +} + +type ChildSegments = std::collections::HashMap>; + +#[derive(Default, Debug)] +struct BufferedBirths { + children: Vec, + segments: std::collections::HashMap, +} + +impl BufferedBirths { + fn initialize(&mut self, parents: &[NodeId], children: &[NodeId]) -> Result<(), TskitError> { + self.children = children.to_vec(); + self.children.sort(); + self.segments.clear(); + // FIXME: don't do this work if the parent already exists + for p in parents { + let mut segments = ChildSegments::default(); + for c in &self.children { + if segments.insert(*c, vec![]).is_some() { + return Err(TskitError::LibraryError("redundant child ids".to_owned())); + } + } + self.segments.insert(*p, segments); + } + Ok(()) + } +} + +#[derive(Default, Debug)] +pub struct EdgeBuffer { + left: Vec, + right: Vec, + child: Vec, + // TODO: this should be + // an option so that we can use take. + buffered_births: BufferedBirths, + // NOTE: these vectors are wasteful: + // 1. usize is more than we need, + // but it is more convenient. + // 2. Worse, these vectors will + // contain N elements, where + // N is the total number of nodes, + // but likely many fewer nodes than that + // have actually had offspring. + // It is hard to fix this -- we cannot + // guarantee that parents are entered + // in any specific order. + // 3. Performance IMPROVES MEASURABLY + // if we use u32 here. But tsk_size_t + // is u64. + head: Vec, + tail: Vec, + next: Vec, +} + +impl EdgeBuffer { + fn insert_new_parent(&mut self, parent: usize, child: NodeId, left: Position, right: Position) { + self.left.push(left); + self.right.push(right); + self.child.push(child); + self.head[parent] = self.left.len() - 1; + self.tail[parent] = self.head[parent]; + self.next.push(usize::MAX); + } + + fn extend_parent(&mut self, parent: usize, child: NodeId, left: Position, right: Position) { + self.left.push(left); + self.right.push(right); + self.child.push(child); + let t = self.tail[parent]; + self.tail[parent] = self.left.len() - 1; + self.next[t] = self.left.len() - 1; + self.next.push(usize::MAX); + } + + fn clear(&mut self) { + self.left.clear(); + self.right.clear(); + self.child.clear(); + self.head.clear(); + self.tail.clear(); + self.next.clear(); + } + + fn release_memory(&mut self) { + swap_with_empty(&mut self.head); + swap_with_empty(&mut self.next); + swap_with_empty(&mut self.left); + swap_with_empty(&mut self.right); + swap_with_empty(&mut self.child); + swap_with_empty(&mut self.tail); + } + + fn extract_buffered_births(&mut self) -> BufferedBirths { + let mut b = BufferedBirths::default(); + std::mem::swap(&mut self.buffered_births, &mut b); + b + } + + // Should Err if prents/children not unique + pub fn setup_births( + &mut self, + parents: &[NodeId], + children: &[NodeId], + ) -> Result<(), TskitError> { + self.buffered_births.initialize(parents, children) + } + + pub fn finalize_births(&mut self) { + let buffered_births = self.extract_buffered_births(); + for (p, children) in buffered_births.segments.iter() { + for c in buffered_births.children.iter() { + if let Some(segs) = children.get(c) { + for s in segs { + self.buffer_birth(*p, *c, s.left, s.right).unwrap(); + } + } else { + // should be error + panic!(); + } + } + } + } + + pub fn record_birth( + &mut self, + parent: P, + child: C, + left: L, + right: R, + ) -> Result<(), TskitError> + where + P: Into, + C: Into, + L: Into, + R: Into, + { + let parent = parent.into(); + + let child = child.into(); + if let Some(parent_buffer) = self.buffered_births.segments.get_mut(&parent) { + if let Some(v) = parent_buffer.get_mut(&child) { + let left = left.into(); + let right = right.into(); + v.push(Segment { left, right }); + } else { + // should be an error + panic!(); + } + } else { + // should be an error + panic!(); + } + + Ok(()) + } + + // NOTE: tskit is overly strict during simplification, + // enforcing sorting requirements on the edge table + // that are not strictly necessary. + fn buffer_birth( + &mut self, + parent: P, + child: C, + left: L, + right: R, + ) -> Result<(), TskitError> + where + P: Into, + C: Into, + L: Into, + R: Into, + { + let parent = parent.into(); + if parent < 0 { + return Err(TskitError::IndexError); + } + + let parent = parent.as_usize(); + + if parent >= self.head.len() { + self.head.resize(parent + 1, usize::MAX); + self.tail.resize(parent + 1, usize::MAX); + } + + if self.head[parent] == usize::MAX { + self.insert_new_parent(parent, child.into(), left.into(), right.into()); + } else { + self.extend_parent(parent, child.into(), left.into(), right.into()); + } + Ok(()) + } + + // NOTE: we can probably have this function not error: + // the head array is populated by i32 converted to usize, + // so if things are getting out of range, we should be + // in trouble before this point. + // NOTE: we need a bitflags here for other options, like sorting the head + // contents based on birth time. + pub fn pre_simplification(&mut self, tables: &mut TableCollection) -> Result<(), TskitError> { + let num_input_edges = tables.edges().num_rows().as_usize(); + let mut head_index: Vec = self + .head + .iter() + .enumerate() + .filter(|(_, j)| **j != usize::MAX) + .map(|(i, _)| i) + .collect(); + + let node_time = tables.nodes().time_slice(); + head_index.sort_by(|a, b| node_time[*a].partial_cmp(&node_time[*b]).unwrap()); + //for (i, h) in self.head.iter().rev().enumerate() { + for h in head_index.into_iter() { + let parent = match i32::try_from(h) { + Ok(value) => value, + Err(_) => { + return Err(TskitError::RangeError( + "usize to i32 conversion failed".to_owned(), + )) + } + }; + tables.add_edge( + self.left[self.head[h]], + self.right[self.head[h]], + parent, + self.child[self.head[h]], + )?; + + let mut next = self.next[self.head[h]]; + while next != usize::MAX { + tables.add_edge(self.left[next], self.right[next], parent, self.child[next])?; + next = self.next[next]; + } + } + + self.release_memory(); + + // This assert is redundant b/c TableCollection + // works via MBox/NonNull. + assert!(!tables.as_ptr().is_null()); + // SAFETY: table collection pointer is not null and num_edges + // is the right length. + let num_edges = tables.edges().num_rows().as_usize(); + let edge_left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let edge_right = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) + }; + let edge_parent = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) + }; + let edge_child = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) + }; + edge_left.rotate_left(num_input_edges); + edge_right.rotate_left(num_input_edges); + edge_parent.rotate_left(num_input_edges); + edge_child.rotate_left(num_input_edges); + Ok(()) + } + + fn alive_node_times(&self, alive: &[NodeId], tables: &mut TableCollection) -> AliveNodeTimes { + let node_times = tables.nodes().time_slice_raw(); + let mut max_alive_node_time = 0.0; + let mut min_alive_node_time = f64::MAX; + + for a in alive { + let time = node_times[a.as_usize()]; + max_alive_node_time = if time > max_alive_node_time { + time + } else { + max_alive_node_time + }; + min_alive_node_time = if time < min_alive_node_time { + time + } else { + min_alive_node_time + }; + } + AliveNodeTimes::new(min_alive_node_time, max_alive_node_time) + } + + // The method here ends up creating a problem: + // we are buffering nodes with increasing node id + // that are also more ancient. This is the opposite + // order from what happens during a forward-time simulation. + fn buffer_existing_edges( + &mut self, + pre_existing_edges: Vec, + tables: &mut TableCollection, + ) -> Result { + let parent = tables.edges().parent_slice(); + let child = tables.edges().child_slice(); + let left = tables.edges().left_slice(); + let right = tables.edges().right_slice(); + let mut rv = 0; + for pre in pre_existing_edges.iter() { + self.setup_births(&[parent[pre.first]], &child[pre.first..pre.last])?; + for e in pre.first..pre.last { + assert_eq!(parent[e], parent[pre.first]); + self.record_birth(parent[e], child[e], left[e], right[e])?; + rv += 1; + } + self.finalize_births(); + } + + Ok(rv) + } + + // FIXME: clean up commented-out code + // if we decide we don't need it. + fn collect_pre_existing_edges( + &self, + alive_node_times: AliveNodeTimes, + tables: &mut TableCollection, + ) -> Vec { + let mut edges = vec![]; + let mut i = 0; + let parent = tables.edges().parent_slice(); + //let child = tables.edges().child_slice(); + let node_time = tables.nodes().time_slice(); + while i < parent.len() { + let p = parent[i]; + // let c = child[i]; + if node_time[p.as_usize()] <= alive_node_times.max + //|| (node_time[c.as_usize()] < alive_node_times.max + // && node_time[p.as_usize()] > alive_node_times.max) + { + let mut j = 0_usize; + while i + j < parent.len() && parent[i + j] == p { + j += 1; + } + edges.push(PreExistingEdge::new(i, i + j)); + i += j; + } else { + break; + } + } + edges + } + + // FIXME: + // + // 1. If min/max parent alive times are equal, return. + // DONE + // 2. Else, we need to do a rotation at min_edge + // before truncation. + // DONE + // 3. However, we also have to respect our API + // and process each parent carefully, + // setting up the birth/death epochs. + // We need to use setup_births and finalize_births + // to get this right. + // DONE + // 4. We are doing this in the wrong temporal order. + // We need to pre-process all existing edge intervals, + // cache them, then go backwards through them, + // so that we buffer them present-to-past + pub fn post_simplification( + &mut self, + alive: &[NodeId], + tables: &mut TableCollection, + ) -> Result<(), TskitError> { + self.clear(); + + let alive_node_times = self.alive_node_times(alive, tables); + if alive_node_times.non_overlapping() { + // There can be no overlap between current + // edges and births that are about to happen, + // so we get out. + return Ok(()); + } + + let pre_existing_edges = self.collect_pre_existing_edges(alive_node_times, tables); + let min_edge = self.buffer_existing_edges(pre_existing_edges, tables)?; + let num_edges = tables.edges().num_rows().as_usize(); + let edge_left = + unsafe { std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.left, num_edges) }; + let edge_right = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.right, num_edges) + }; + let edge_parent = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.parent, num_edges) + }; + let edge_child = unsafe { + std::slice::from_raw_parts_mut((*tables.as_mut_ptr()).edges.child, num_edges) + }; + edge_left.rotate_left(min_edge); + edge_right.rotate_left(min_edge); + edge_parent.rotate_left(min_edge); + edge_child.rotate_left(min_edge); + // SAFETY: ????? + let rv = unsafe { + crate::bindings::tsk_edge_table_truncate( + &mut (*tables.as_mut_ptr()).edges, + (num_edges - min_edge) as crate::bindings::tsk_size_t, + ) + }; + handle_tsk_return_value!(rv, ()) + } +} + +#[test] +fn test_pre_simplification() { + let mut tables = TableCollection::new(10.).unwrap(); + let mut buffer = EdgeBuffer::default(); + let p0 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let p1 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let c0 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let c1 = tables.add_node(0, 0.0, -1, -1).unwrap(); + buffer.setup_births(&[p0, p1], &[c0, c1]).unwrap(); + + // Record data in a way that intentionally + // breaks what tskit wants: + // * children are not sorted in increading order + // of id. + buffer.record_birth(0, 3, 5.0, 10.0).unwrap(); + buffer.record_birth(0, 2, 0.0, 5.0).unwrap(); + buffer.record_birth(1, 3, 0.0, 5.0).unwrap(); + buffer.record_birth(1, 2, 5.0, 10.0).unwrap(); + buffer.finalize_births(); + buffer.pre_simplification(&mut tables).unwrap(); + assert_eq!(tables.edges().num_rows(), 4); + tables.simplify(&[2, 3], 0, false).unwrap(); + assert_eq!(tables.edges().num_rows(), 0); +} + +#[test] +fn test_post_simplification() { + let mut tables = TableCollection::new(10.).unwrap(); + let p0 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let p1 = tables.add_node(0, 1.0, -1, -1).unwrap(); + let c0 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let c1 = tables.add_node(0, 0.0, -1, -1).unwrap(); + let _e0 = tables.add_edge(0.0, 10.0, p0, c0).unwrap(); + let _e1 = tables.add_edge(0.0, 10.0, p1, c1).unwrap(); + assert_eq!(tables.edges().num_rows(), 2); + let alive = vec![c0, c1]; // the children have replaced the parents + let mut buffer = EdgeBuffer::default(); + buffer.post_simplification(&alive, &mut tables).unwrap(); + assert_eq!(tables.edges().num_rows(), 2); +} diff --git a/src/lib.rs b/src/lib.rs index 81e5b122..31ae7757 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,6 +140,13 @@ pub use trees::{Tree, TreeSequence}; #[cfg_attr(doc_cfg, doc(cfg(feature = "provenance")))] pub mod provenance; +#[cfg(feature = "edgebuffer")] +mod edgebuffer; + +#[cfg(feature = "edgebuffer")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "edgebuffer")))] +pub use edgebuffer::EdgeBuffer; + /// Handles return codes from low-level tskit functions. /// /// When an error from the tskit C API is detected, diff --git a/tests/test_edge_buffer.rs b/tests/test_edge_buffer.rs new file mode 100644 index 00000000..249dfd29 --- /dev/null +++ b/tests/test_edge_buffer.rs @@ -0,0 +1,80 @@ +#![cfg(feature = "edgebuffer")] + +use proptest::prelude::*; +use rand::distributions::Distribution; +use rand::SeedableRng; + +use tskit::EdgeBuffer; +use tskit::TableCollection; +use tskit::TreeSequence; + +fn overlapping_generations(seed: u64, pdeath: f64, simplify: i32) -> TreeSequence { + let mut tables = TableCollection::new(1.0).unwrap(); + let mut buffer = EdgeBuffer::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + let popsize = 10; + + let mut parents = vec![]; + + for _ in 0..popsize { + let node = tables.add_node(0, 100.0, -1, -1).unwrap(); + parents.push(node); + } + + let death = rand::distributions::Uniform::new(0., 1.0); + let parent_picker = rand::distributions::Uniform::new(0, popsize); + + for birth_time in (0..10).rev() { + let mut replacements = vec![]; + for i in 0..parents.len() { + if death.sample(&mut rng) <= pdeath { + replacements.push(i); + } + } + let mut births = vec![]; + + for _ in 0..replacements.len() { + let parent_index = parent_picker.sample(&mut rng); + let parent = parents[parent_index]; + let child = tables.add_node(0, birth_time as f64, -1, -1).unwrap(); + births.push(child); + buffer.setup_births(&[parent], &[child]).unwrap(); + buffer.record_birth(parent, child, 0., 1.).unwrap(); + buffer.finalize_births(); + } + + for (r, b) in replacements.iter().zip(births.iter()) { + assert!(*r < parents.len()); + parents[*r] = *b; + } + if birth_time % simplify == 0 { + buffer.pre_simplification(&mut tables).unwrap(); + //tables.full_sort(tskit::TableSortOptions::default()).unwrap(); + if let Some(idmap) = tables + .simplify(&parents, tskit::SimplificationOptions::default(), true) + .unwrap() + { + // remap child nodes + for o in parents.iter_mut() { + *o = idmap[usize::try_from(*o).unwrap()]; + } + } + buffer.post_simplification(&parents, &mut tables).unwrap(); + } + } + + tables.build_index().unwrap(); + + tables.tree_sequence(0.into()).unwrap() +} + +#[cfg(test)] +proptest! { + #[test] + fn test_edge_buffer_overlapping_generations(seed in any::(), + pdeath in 0.05..1.0, + simplify_interval in 1..100i32) { + let _ = overlapping_generations(seed, pdeath, simplify_interval); + } +}