From b7876524c637747dc1016ce73a82ecff480fa941 Mon Sep 17 00:00:00 2001 From: vivaansinghvi07 Date: Thu, 6 Feb 2025 15:17:05 -0500 Subject: [PATCH] save random state across fossil extractions --- examples/end2end_tree_reconstruction_test.py | 53 +++++++++++++++--- examples/evolve_dstream_surf.py | 58 +++++++++++--------- 2 files changed, 78 insertions(+), 33 deletions(-) diff --git a/examples/end2end_tree_reconstruction_test.py b/examples/end2end_tree_reconstruction_test.py index 1bf64822..5e035ba6 100755 --- a/examples/end2end_tree_reconstruction_test.py +++ b/examples/end2end_tree_reconstruction_test.py @@ -60,7 +60,15 @@ def sample_reference_and_reconstruction( differentia_bitwidth: int, surface_size: int, fossil_interval: typing.Optional[int], -) -> typing.Tuple[pd.DataFrame, pd.DataFrame]: +) -> typing.Dict[ + typing.Union[ + typing.Literal["true"], + typing.Literal["reconst"], + typing.Literal["true_dropped_fossils"], + typing.Literal["reconst_dropped_fossils"], + ], + pd.DataFrame, +]: """Sample a reference phylogeny and corresponding reconstruction.""" paths = subprocess.run( [ @@ -91,7 +99,26 @@ def sample_reference_and_reconstruction( true_phylo_df ) == alifestd_count_leaf_nodes(reconst_phylo_df) - return true_phylo_df, reconst_phylo_df + reconst_phylo_df_extant = reconst_phylo_df.copy() + reconst_phylo_df_extant["extant"] = reconst_phylo_df["is_fossil"] == False + reconst_phylo_df_no_fossils = alifestd_prune_extinct_lineages_asexual( + reconst_phylo_df_extant + ) + + true_phylo_df_no_fossils = alifestd_prune_extinct_lineages_asexual( + true_phylo_df.set_index("taxon_label") + .drop( + reconst_phylo_df["taxon_label"][reconst_phylo_df["is_fossil"] == True] # type: ignore + ) + .reset_index() + ) + + return { + "true": true_phylo_df, + "reconst": reconst_phylo_df, + "true_dropped_fossils": true_phylo_df_no_fossils, + "reconst_dropped_fossils": reconst_phylo_df_no_fossils, + } def plot_colorclade_comparison( @@ -167,31 +194,41 @@ def test_reconstruct_one( print(f"differentia_bitwidth: {differentia_bitwidth}") print(f"fossil_interval: {fossil_interval}") - true_phylo_df, reconst_phylo_df = sample_reference_and_reconstruction( + frames = sample_reference_and_reconstruction( differentia_bitwidth, surface_size, fossil_interval, ) visualize_reconstruction( - true_phylo_df, - reconst_phylo_df, + frames["true_dropped_fossils"], + frames["reconst_dropped_fossils"], differentia_bitwidth=differentia_bitwidth, surface_size=surface_size, fossil_interval=fossil_interval, visualize=visualize, ) reconstruction_error = alifestd_calc_triplet_distance_asexual( - alifestd_collapse_unifurcations(true_phylo_df), reconst_phylo_df + alifestd_collapse_unifurcations(frames["true"]), frames["reconst"] ) + + reconstruction_error_dropped_fossils = ( + alifestd_calc_triplet_distance_asexual( + alifestd_collapse_unifurcations(frames["true_dropped_fossils"]), + frames["reconst_dropped_fossils"], + ) + ) + print(f"{reconstruction_error=}") - assert reconstruction_error < alifestd_count_leaf_nodes(true_phylo_df) + print(f"{reconstruction_error_dropped_fossils=}") + assert reconstruction_error < alifestd_count_leaf_nodes(frames["true"]) return { "differentia_bitwidth": differentia_bitwidth, "surface_size": surface_size, "fossil_interval": fossil_interval, "error": reconstruction_error, + "error_dropped_fossils": reconstruction_error_dropped_fossils, } @@ -216,7 +253,7 @@ def _parse_args(): fossil_interval, surface_size, differentia_bitwidth, - ) in product((None, 50, 200), (256, 64, 16), (64, 8, 1)) + ) in product((None, 50, 200), (16,), (64, 8, 1)) ] ) diff --git a/examples/evolve_dstream_surf.py b/examples/evolve_dstream_surf.py index 8c557cb2..e258ba77 100755 --- a/examples/evolve_dstream_surf.py +++ b/examples/evolve_dstream_surf.py @@ -4,7 +4,6 @@ import functools import gc import os -from pathlib import Path import random import types import typing @@ -25,15 +24,38 @@ raise e +class SaveRandomState: + + def __enter__(self): + self.st = random.getstate() + self.np_st = np.random.get_state() + + def __exit__(self, *args): + random.setstate(self.st) + np.random.set_state(self.np_st) + + def make_uuid4_fast() -> str: """Fast UUID4 generator, using lower-quality randomness.""" return str(uuid.UUID(int=random.getrandbits(128), version=4)) +def extract_fossils( + pop: typing.List, + fossil_sample_percentage: float = 0.1, +) -> typing.List: + return [ + parent.CreateOffspring(fossil=True) + for parent in random.sample( + pop, + k=int(len(pop) * fossil_sample_percentage), + ) + ] + + def evolve_drift( population: typing.List, fossil_interval: typing.Optional[int] = None, - fossil_sample_percentage: float = 0.1, ) -> typing.List: """ Simple asexual evolutionary algorithm under drift conditions. @@ -49,21 +71,14 @@ def evolve_drift( for generation in tqdm(range(500)): population = [ parent.CreateOffspring() - for parent in selector.choices(population, k=len(population)) + for parent in random.choices(population, k=len(population)) ] if fossil_interval and generation % fossil_interval == 0: - # note: we extract CreateOffspring() instead of the parent itself, - # beause parents with surviving children are not treated as leaf - # nodes by phylotrackpy; simplifies true/reconst phylo comparison - fossils.extend( - [ - parent.CreateOffspring(fossil=True) - for parent in selector.sample( - population, - k=int(len(population) * fossil_sample_percentage), - ) - ] - ) + with SaveRandomState(): + # note: we extract CreateOffspring() instead of the parent itself, + # beause parents with surviving children are not treated as leaf + # nodes by phylotrackpy; simplifies true/reconst phylo comparison + fossils.extend(extract_fossils(population)) # asyncrhonous generations nsplit = len(population) // 2 @@ -73,16 +88,9 @@ def evolve_drift( for parent in selector.choices(population[:nsplit], k=nsplit) ] if fossil_interval and generation % fossil_interval == 0: - # see above - fossils.extend( - [ - parent.CreateOffspring(fossil=True) - for parent in selector.sample( - population, - k=int(len(population) * fossil_sample_percentage), - ) - ] - ) + with SaveRandomState(): + # see above + fossils.extend(extract_fossils(population)) selector.shuffle(population) return [*fossils, *population]