Skip to content

Commit

Permalink
save random state across fossil extractions
Browse files Browse the repository at this point in the history
  • Loading branch information
vivaansinghvi07 committed Feb 6, 2025
1 parent 0ba900b commit b787652
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 33 deletions.
53 changes: 45 additions & 8 deletions examples/end2end_tree_reconstruction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ def sample_reference_and_reconstruction(
differentia_bitwidth: int,
surface_size: int,
fossil_interval: typing.Optional[int],
) -> typing.Tuple[pd.DataFrame, pd.DataFrame]:
) -> typing.Dict[
typing.Union[
typing.Literal["true"],
typing.Literal["reconst"],
typing.Literal["true_dropped_fossils"],
typing.Literal["reconst_dropped_fossils"],
],
pd.DataFrame,
]:
"""Sample a reference phylogeny and corresponding reconstruction."""
paths = subprocess.run(
[
Expand Down Expand Up @@ -91,7 +99,26 @@ def sample_reference_and_reconstruction(
true_phylo_df
) == alifestd_count_leaf_nodes(reconst_phylo_df)

return true_phylo_df, reconst_phylo_df
reconst_phylo_df_extant = reconst_phylo_df.copy()
reconst_phylo_df_extant["extant"] = reconst_phylo_df["is_fossil"] == False
reconst_phylo_df_no_fossils = alifestd_prune_extinct_lineages_asexual(
reconst_phylo_df_extant
)

true_phylo_df_no_fossils = alifestd_prune_extinct_lineages_asexual(
true_phylo_df.set_index("taxon_label")
.drop(
reconst_phylo_df["taxon_label"][reconst_phylo_df["is_fossil"] == True] # type: ignore
)
.reset_index()
)

return {
"true": true_phylo_df,
"reconst": reconst_phylo_df,
"true_dropped_fossils": true_phylo_df_no_fossils,
"reconst_dropped_fossils": reconst_phylo_df_no_fossils,
}


def plot_colorclade_comparison(
Expand Down Expand Up @@ -167,31 +194,41 @@ def test_reconstruct_one(
print(f"differentia_bitwidth: {differentia_bitwidth}")
print(f"fossil_interval: {fossil_interval}")

true_phylo_df, reconst_phylo_df = sample_reference_and_reconstruction(
frames = sample_reference_and_reconstruction(
differentia_bitwidth,
surface_size,
fossil_interval,
)

visualize_reconstruction(
true_phylo_df,
reconst_phylo_df,
frames["true_dropped_fossils"],
frames["reconst_dropped_fossils"],
differentia_bitwidth=differentia_bitwidth,
surface_size=surface_size,
fossil_interval=fossil_interval,
visualize=visualize,
)
reconstruction_error = alifestd_calc_triplet_distance_asexual(
alifestd_collapse_unifurcations(true_phylo_df), reconst_phylo_df
alifestd_collapse_unifurcations(frames["true"]), frames["reconst"]
)

reconstruction_error_dropped_fossils = (
alifestd_calc_triplet_distance_asexual(
alifestd_collapse_unifurcations(frames["true_dropped_fossils"]),
frames["reconst_dropped_fossils"],
)
)

print(f"{reconstruction_error=}")
assert reconstruction_error < alifestd_count_leaf_nodes(true_phylo_df)
print(f"{reconstruction_error_dropped_fossils=}")
assert reconstruction_error < alifestd_count_leaf_nodes(frames["true"])

return {
"differentia_bitwidth": differentia_bitwidth,
"surface_size": surface_size,
"fossil_interval": fossil_interval,
"error": reconstruction_error,
"error_dropped_fossils": reconstruction_error_dropped_fossils,
}


Expand All @@ -216,7 +253,7 @@ def _parse_args():
fossil_interval,
surface_size,
differentia_bitwidth,
) in product((None, 50, 200), (256, 64, 16), (64, 8, 1))
) in product((None, 50, 200), (16,), (64, 8, 1))
]
)

Expand Down
58 changes: 33 additions & 25 deletions examples/evolve_dstream_surf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import functools
import gc
import os
from pathlib import Path
import random
import types
import typing
Expand All @@ -25,15 +24,38 @@
raise e


class SaveRandomState:

def __enter__(self):
self.st = random.getstate()
self.np_st = np.random.get_state()

def __exit__(self, *args):
random.setstate(self.st)
np.random.set_state(self.np_st)


def make_uuid4_fast() -> str:
"""Fast UUID4 generator, using lower-quality randomness."""
return str(uuid.UUID(int=random.getrandbits(128), version=4))


def extract_fossils(
pop: typing.List,
fossil_sample_percentage: float = 0.1,
) -> typing.List:
return [
parent.CreateOffspring(fossil=True)
for parent in random.sample(
pop,
k=int(len(pop) * fossil_sample_percentage),
)
]


def evolve_drift(
population: typing.List,
fossil_interval: typing.Optional[int] = None,
fossil_sample_percentage: float = 0.1,
) -> typing.List:
"""
Simple asexual evolutionary algorithm under drift conditions.
Expand All @@ -49,21 +71,14 @@ def evolve_drift(
for generation in tqdm(range(500)):
population = [
parent.CreateOffspring()
for parent in selector.choices(population, k=len(population))
for parent in random.choices(population, k=len(population))
]
if fossil_interval and generation % fossil_interval == 0:
# note: we extract CreateOffspring() instead of the parent itself,
# beause parents with surviving children are not treated as leaf
# nodes by phylotrackpy; simplifies true/reconst phylo comparison
fossils.extend(
[
parent.CreateOffspring(fossil=True)
for parent in selector.sample(
population,
k=int(len(population) * fossil_sample_percentage),
)
]
)
with SaveRandomState():
# note: we extract CreateOffspring() instead of the parent itself,
# beause parents with surviving children are not treated as leaf
# nodes by phylotrackpy; simplifies true/reconst phylo comparison
fossils.extend(extract_fossils(population))

# asyncrhonous generations
nsplit = len(population) // 2
Expand All @@ -73,16 +88,9 @@ def evolve_drift(
for parent in selector.choices(population[:nsplit], k=nsplit)
]
if fossil_interval and generation % fossil_interval == 0:
# see above
fossils.extend(
[
parent.CreateOffspring(fossil=True)
for parent in selector.sample(
population,
k=int(len(population) * fossil_sample_percentage),
)
]
)
with SaveRandomState():
# see above
fossils.extend(extract_fossils(population))
selector.shuffle(population)

return [*fossils, *population]
Expand Down

0 comments on commit b787652

Please sign in to comment.