Skip to content

Commit

Permalink
Prototype API for calculating ancestry proportion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
molpopgen committed Nov 19, 2024
1 parent a7cc066 commit 8284905
Showing 1 changed file with 264 additions and 3 deletions.
267 changes: 264 additions & 3 deletions demes/src/specification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3853,7 +3853,7 @@ impl Graph {
buffer.iter_mut().for_each(|v| *v *= 1. - sum_migrates);
for i in input_migrations {
let source = self.deme_map[i.source()];
buffer[source] += f64::from(i.rate());
buffer[source] += f64::from(i.rate())
}
assert!(
buffer
Expand All @@ -3863,6 +3863,100 @@ impl Graph {
);
}
}

#[allow(missing_docs)]
pub fn ancestry_proportions_matrix(&self, at: Time) -> Result<Box<[f64]>, DemesError> {
let mut buffer = vec![0.; self.num_demes() * self.num_demes()];
self.fill_ancestry_proportions_matrix(at, &mut buffer)
.map(|_| buffer.into_boxed_slice())
}

#[allow(missing_docs)]
pub fn fill_ancestry_proportions_matrix(
&self,
at: Time,
buffer: &mut [f64],
) -> Result<(), DemesError> {
if at == 0.0 {
return Err(DemesError::ValueError(format!(
"time must be > 0.0, got {at:?}"
)));
}
buffer.fill_with(|| 0.);
for (deme_index, deme) in self.demes().iter().enumerate() {
if at == deme.start_time() {
for (a, p) in deme
.ancestor_indexes()
.iter()
.cloned()
.zip(deme.proportions().iter().cloned())
{
buffer[deme_index * self.num_demes() + a] += f64::from(p);
}
} else if at > deme.end_time() && at <= deme.start_time() {
buffer[deme_index * self.num_demes() + deme_index] = 1.0;
}
}
let mut temp = vec![0.0; self.num_demes() * self.num_demes()];
let pulses = self
.pulses()
.iter()
.filter(|&p| p.time() == at)
.collect::<Vec<_>>();
if !pulses.is_empty() {
for (deme_index, deme) in self.demes().iter().enumerate() {
for &pulse in pulses.iter().filter(|p| p.dest() == deme.name()) {
for (a, p) in pulse
.sources()
.iter()
.zip(pulse.proportions().iter().cloned())
{
let source_index = self.deme_index(a).unwrap();
temp[deme_index * self.num_demes() + source_index] = p.into();
}
}
}
for (i, t) in temp.chunks_exact(self.num_demes()).enumerate() {
let sum = t.iter().sum::<f64>();
for (j, tt) in buffer
.iter_mut()
.skip(i * self.num_demes())
.take(self.num_demes())
.zip(t.iter().cloned())
{
*j *= 1. - sum;
*j += tt;
}
}
}
let input_migrations = self
.migrations()
.iter()
.filter(|m| at <= m.start_time() && at > m.end_time())
.collect::<Vec<_>>();
if !input_migrations.is_empty() {
temp.fill_with(|| 0.);
for (deme_index, deme) in self.demes().iter().enumerate() {
for &m in input_migrations.iter().filter(|p| p.dest() == deme.name()) {
let source_index = self.deme_index(m.source()).unwrap();
temp[deme_index * self.num_demes() + source_index] = m.rate().into();
}
}
for (i, t) in temp.chunks_exact(self.num_demes()).enumerate() {
let sum = t.iter().sum::<f64>();
for (j, tt) in buffer
.iter_mut()
.skip(i * self.num_demes())
.take(self.num_demes())
.zip(t.iter().cloned())
{
*j *= 1. - sum;
*j += tt;
}
}
}
Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -4824,7 +4918,7 @@ mod deme_equality {

#[cfg(test)]
mod test_rescaling {
static SIMPLE_TEST_GRAPH_0: &str = "
pub static SIMPLE_TEST_GRAPH_0: &str = "
time_units: generations
demes:
- name: ancestor1
Expand Down Expand Up @@ -5163,7 +5257,174 @@ mod test_forward_ancestry_proportions {
fn test_time_0() {
let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
assert!(graph
.ancestry_proportions(usize::MAX, 0.0.try_into().unwrap())
.ancestry_proportions(0, 0.0.try_into().unwrap())
.is_none());
let mut buffer = vec![9.; graph.num_demes()];
assert!(graph
.fill_ancestry_proportions(0, 0.0.try_into().unwrap(), &mut buffer)
.is_none())
}
}

#[cfg(test)]
mod test_ancestry_proportion_matrix {
use std::ops::Deref;
#[test]
fn test_simpler_graph() {
let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_0).unwrap();
let proportions = graph
.ancestry_proportions_matrix(21.0.try_into().unwrap())
.unwrap();
for i in [0, 1] {
let mut e = vec![0.; graph.num_demes()];
e[i] = 1.;
assert_eq!(
proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
e
);
}
let e = vec![0.; graph.num_demes()];
assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);

let proportions = graph
.ancestry_proportions_matrix(20.0.try_into().unwrap())
.unwrap();
for i in [0, 1] {
let e = vec![0.; graph.num_demes()];
assert_eq!(
proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
e
);
}
let e = vec![0.5, 0.5, 0.];
assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);

let proportions = graph
.ancestry_proportions_matrix(19.0.try_into().unwrap())
.unwrap();
for i in [0, 1] {
let e = vec![0.; graph.num_demes()];
assert_eq!(
proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
e
);
}
let e = vec![0., 0., 1.];
assert_eq!(proportions[2 * graph.num_demes()..3 * graph.num_demes()], e);
}

#[test]
fn test_simple_graph() {
let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_2).unwrap();
let proportions = graph
.ancestry_proportions_matrix(21.0.try_into().unwrap())
.unwrap();
assert_eq!(proportions.len(), 16);
assert_eq!(&proportions[..graph.num_demes()], &[1., 0., 0., 0.]);
assert_eq!(
&proportions[graph.num_demes()..2 * graph.num_demes()],
&[0., 1., 0., 0.]
);
for i in [2, 3] {
assert_eq!(
&proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
&[0., 0., 0., 0.],
"{i}"
);
}
for (i, ap) in proportions.chunks(4).enumerate() {
if let Some(dp) = graph.ancestry_proportions(i, 21.0.try_into().unwrap()) {
assert_eq!(ap, dp.deref());
}
}

let proportions = graph
.ancestry_proportions_matrix(20.0.try_into().unwrap())
.unwrap();
for i in [0, 1] {
assert_eq!(
&proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
&[0., 0., 0., 0.],
"{i}"
);
}
assert_eq!(
&proportions[2 * graph.num_demes()..3 * graph.num_demes()],
&[1., 0., 0., 0.],
"{proportions:?}"
);
assert_eq!(
&proportions[3 * graph.num_demes()..4 * graph.num_demes()],
&[0., 1., 0., 0.],
);
for i in [0, 1] {
assert_eq!(
&proportions[i * graph.num_demes()..(i + 1) * graph.num_demes()],
&[0., 0., 0., 0.],
);
}
for (i, ap) in proportions.chunks(4).enumerate() {
if let Some(dp) = graph.ancestry_proportions(i, 20.0.try_into().unwrap()) {
assert_eq!(ap, dp.deref());
}
}

let mut proportions = proportions;
graph
.fill_ancestry_proportions_matrix(19.0.try_into().unwrap(), &mut proportions)
.unwrap();
for (i, ap) in proportions.chunks(4).enumerate() {
if let Some(dp) = graph.ancestry_proportions(i, 19.0.try_into().unwrap()) {
assert_eq!(ap, dp.deref());
}
}
}

#[test]
fn test_simple_graph_with_migrations() {
let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_1).unwrap();
for time in [21.0, 20.0, 19.0, 18.0] {
let at = crate::Time::try_from(time).unwrap();
let proportions = graph.ancestry_proportions_matrix(at).unwrap();
for deme_index in 0..graph.num_demes() {
let pslice = &proportions
[deme_index * graph.num_demes()..(deme_index + 1) * graph.num_demes()];
if let Some(prop) = graph.ancestry_proportions(deme_index, at) {
assert_eq!(pslice, prop.deref())
} else {
assert!(pslice.iter().all(|&p| p == 0.))
}
}
}

// Again, now with a buffer
let mut buffer = vec![666.0; graph.num_demes() * graph.num_demes()];
for time in [21.0, 20.0, 19.0, 18.0] {
let at = crate::Time::try_from(time).unwrap();
graph
.fill_ancestry_proportions_matrix(at, &mut buffer)
.unwrap();
for deme_index in 0..graph.num_demes() {
let pslice =
&buffer[deme_index * graph.num_demes()..(deme_index + 1) * graph.num_demes()];
if let Some(prop) = graph.ancestry_proportions(deme_index, at) {
assert_eq!(pslice, prop.deref())
} else {
assert!(pslice.iter().all(|&p| p == 0.))
}
}
}
}

#[test]
fn test_time_0() {
let graph = crate::loads(super::test_rescaling::SIMPLE_TEST_GRAPH_1).unwrap();
assert!(graph
.ancestry_proportions_matrix(0.0.try_into().unwrap())
.is_err());
let mut buffer = vec![f64::NAN; graph.num_demes() * graph.num_demes()];
assert!(graph
.fill_ancestry_proportions_matrix(0.0.try_into().unwrap(), &mut buffer)
.is_err());
}
}

0 comments on commit 8284905

Please sign in to comment.