Skip to content

Commit

Permalink
Merge pull request tskit-dev#145 from jeromekelleher/improve-mutation…
Browse files Browse the repository at this point in the history
…-pop-perf

Improve mutation pop perf
  • Loading branch information
jeromekelleher authored Apr 4, 2024
2 parents a1d6c70 + 6277d76 commit 3256e3f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 14 deletions.
33 changes: 33 additions & 0 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,39 @@ def test_simulated_mutations(self, seed):
assert ts.num_mutations > 0
self.check_ts(ts)

def test_no_metadata_schema(self):
ts = msprime.sim_mutations(self.example_ts(), rate=1e-6, random_seed=43)
assert ts.num_mutations > 0
tables = ts.dump_tables()
tables.populations.metadata_schema = tskit.MetadataSchema(None)
self.check_ts(tables.tree_sequence())

def test_no_populations(self):
tables = single_tree_example_ts().dump_tables()
tables.populations.add_row(b"{}")
tsm = model.TSModel(tables.tree_sequence())
with pytest.raises(ValueError, match="must be assigned to populations"):
tsm.mutations_df


class TestNodeIsSample:
def test_simple_example(self):
ts = single_tree_example_ts()
is_sample = model.node_is_sample(ts)
for node in ts.nodes():
assert node.is_sample() == is_sample[node.id]

@pytest.mark.parametrize("bit", [1, 2, 17, 31])
def test_sample_and_other_flags(self, bit):
tables = single_tree_example_ts().dump_tables()
flags = tables.nodes.flags
tables.nodes.flags = flags | (1 << bit)
ts = tables.tree_sequence()
is_sample = model.node_is_sample(ts)
for node in ts.nodes():
assert node.is_sample() == is_sample[node.id]
assert (node.flags & (1 << bit)) != 0


class TestTreesDataTable:
def test_single_tree_example(self):
Expand Down
40 changes: 26 additions & 14 deletions tsqc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ class MutationCounts:


def compute_mutation_counts(ts):
logger.info("Computing mutation inheritance counts")
tree_pos = alloc_tree_position(ts)
mutations_position = ts.sites_position[ts.mutations_site].astype(int)
num_descendants, num_inheritors = _compute_mutation_inheritance_counts(
Expand All @@ -266,17 +267,22 @@ def _compute_population_mutation_counts(
num_populations,
edges_parent,
edges_child,
num_pop_samples,
nodes_is_sample,
nodes_population,
mutations_position,
mutations_node,
mutations_parent,
):
num_pop_samples = np.zeros((num_nodes, num_populations), dtype=np.int32)

pop_mutation_count = np.zeros((num_populations, num_mutations), dtype=np.int32)
parent = np.zeros(num_nodes, dtype=np.int32) - 1

mut_id = 0
for u in range(num_nodes):
if nodes_is_sample[u]:
num_pop_samples[u, nodes_population[u]] = 1

mut_id = 0
while tree_pos.next():
for j in range(tree_pos.out_range[0], tree_pos.out_range[1]):
e = tree_pos.edge_removal_order[j]
Expand All @@ -285,7 +291,8 @@ def _compute_population_mutation_counts(
parent[c] = -1
u = p
while u != -1:
num_pop_samples[u] -= num_pop_samples[c]
for k in range(num_populations):
num_pop_samples[u, k] -= num_pop_samples[c, k]
u = parent[u]

for j in range(tree_pos.in_range[0], tree_pos.in_range[1]):
Expand All @@ -295,7 +302,8 @@ def _compute_population_mutation_counts(
parent[c] = p
u = p
while u != -1:
num_pop_samples[u] += num_pop_samples[c]
for k in range(num_populations):
num_pop_samples[u, k] += num_pop_samples[c, k]
u = parent[u]

left, right = tree_pos.interval
Expand All @@ -309,18 +317,23 @@ def _compute_population_mutation_counts(
return pop_mutation_count


def node_is_sample(ts):
sample_flag = np.full_like(ts.nodes_flags, tskit.NODE_IS_SAMPLE)
return np.bitwise_and(ts.nodes_flags, sample_flag) != 0


def compute_population_mutation_counts(ts):
"""
Return a dataframe that gives the frequency of each mutation
in each of the populations in the specified tree sequence.
Return a (num_populations, num_mutations) array that gives the frequency
of each mutation in each of the populations in the specified tree sequence.
"""
logger.info(
f"Computing mutation frequencies within {ts.num_populations} populations"
)
mutations_position = ts.sites_position[ts.mutations_site].astype(int)
num_pop_samples = np.zeros((ts.num_nodes, ts.num_populations), dtype=np.int32)
for pop in range(ts.num_populations):
samples = np.logical_and(
ts.nodes_population == pop, ts.nodes_flags == 1 # Not quite right!
)
num_pop_samples[samples, pop] = 1

if np.any(ts.nodes_population[ts.samples()] == -1):
raise ValueError("Sample nodes must be assigned to populations")

return _compute_population_mutation_counts(
alloc_tree_position(ts),
Expand All @@ -329,7 +342,7 @@ def compute_population_mutation_counts(ts):
ts.num_populations,
ts.edges_parent,
ts.edges_child,
num_pop_samples,
node_is_sample(ts),
ts.nodes_population,
mutations_position,
ts.mutations_node,
Expand Down Expand Up @@ -409,7 +422,6 @@ def mutations_df(self):
unknown = tskit.is_unknown_time(mutations_time)
mutations_time[unknown] = self.ts.nodes_time[mutations_node[unknown]]

# node_flag = ts.nodes_flags[mutations_node]
position = ts.sites_position[ts.mutations_site]

tables = self.ts.tables
Expand Down

0 comments on commit 3256e3f

Please sign in to comment.