diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be2e658d..614307e0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,12 +22,14 @@ jobs: - uses: actions/checkout@v3 - name: Run black formatting check uses: psf/black@stable - - name: Run snakefmt formatting check - uses: super-linter/super-linter@v5 - env: - VALIDATE_ALL_CODEBASE: false - DEFAULT_BRANCH: main - VALIDATE_SNAKEMAKE_SNAKEFMT: true +# TODO (Gordon): Add snakefmt back in when/if fixed. See +# https://github.com/snakemake/snakefmt/issues/197 +# - name: Run snakefmt formatting check +# uses: super-linter/super-linter@v5 +# env: +# VALIDATE_ALL_CODEBASE: false +# DEFAULT_BRANCH: main +# VALIDATE_SNAKEMAKE_SNAKEFMT: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 171043f2..b9809f67 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,10 +15,12 @@ repos: rev: 23.1.0 hooks: - id: black - - repo: https://github.com/snakemake/snakefmt - rev: 'v0.8.4' - hooks: - - id: snakefmt +# TODO (Gordon): Add snakefmt back in when/if fixed. See +# https://github.com/snakemake/snakefmt/issues/197 +# - repo: https://github.com/snakemake/snakefmt +# rev: 'v0.8.4' +# hooks: +# - id: snakefmt - repo: https://github.com/econchick/interrogate rev: 1.5.0 hooks: diff --git a/README.md b/README.md index a3ebc03a..0a79ce79 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![build](https://github.com/cbg-ethz/PYggdrasil/actions/workflows/test.yml/badge.svg)](https://github.com/cbg-ethz/PYggdrasil/actions/workflows/test.yml) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/charliermarsh/ruff) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -[![Code style: snakefmt](https://img.shields.io/badge/code%20style-snakefmt-000000.svg)](https://github.com/snakemake/snakefmt) + # PYggdrasil @@ -50,10 +50,11 @@ The code quality checks run during on GitHub can be seen in ``.github/workflows/ We are using: - [Ruff](https://github.com/charliermarsh/ruff) to lint the code. - [Black](https://github.com/psf/black) to format the code. - - [Snakefmt](https://github.com/snakemake/snakefmt) to format Snakemake workflows. - [Pyright](https://github.com/microsoft/pyright) to check the types. - [Pytest](https://docs.pytest.org/) to run the unit tests. - [Interrogate](https://interrogate.readthedocs.io/) to check the documentation. + + ### Workflow diff --git a/pyproject.toml b/pyproject.toml index 81caa7c2..356117d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,9 @@ pytest-xdist = "^3.2.0" pre-commit = "^3.1.0" interrogate = "^1.5.0" pyright = "^1.1.309" -snakefmt = "^0.8.4" +# TODO (Gordon): Add snakefmt back in when/if fixed. See +# https://github.com/snakemake/snakefmt/issues/197 +# snakefmt = "^0.8.4" [tool.coverage.report] fail_under = 85.0 diff --git a/scripts/make_huntress.py b/scripts/make_huntress.py index 4b4d4589..bcf512cc 100644 --- a/scripts/make_huntress.py +++ b/scripts/make_huntress.py @@ -4,6 +4,8 @@ Make a TreeNode tree given a mutation matrix to generate a huntress tree. +Note: used 4 threads as default for the huntress tree inference. + Example Usage: poetry run python ../scripts/make_huntress.py @@ -134,7 +136,7 @@ def main() -> None: mut_mat = cell_simulation_data["noisy_mutation_mat"] # run huntress tree inference - tree_n = huntress_tree_inference(mut_mat, args.fpr, args.fnr, n_threads=2) + tree_n = huntress_tree_inference(mut_mat, args.fpr, args.fnr, n_threads=4) tree_tn = TreeNode(name=tree_n.name, parent=None, children=tree_n.children) # Save the tree - make path diff --git a/scripts/run_mcmc.py b/scripts/run_mcmc.py index cf8ca8c3..c55bc8f7 100644 --- a/scripts/run_mcmc.py +++ b/scripts/run_mcmc.py @@ -138,7 +138,7 @@ def run_chain( init_tree_node = serialize.read_tree_node(params.init_tree_fp) # convert TreeNode to Tree - init_tree = tree_inf.tree_from_tree_node(init_tree_node) + init_tree = tree_inf.Tree.tree_from_tree_node(init_tree_node) logging.info("Loaded tree (TreeNode) from file.") # Make Move Probabilities diff --git a/src/pyggdrasil/tree_inference/__init__.py b/src/pyggdrasil/tree_inference/__init__.py index 74700da6..58b6dfa3 100644 --- a/src/pyggdrasil/tree_inference/__init__.py +++ b/src/pyggdrasil/tree_inference/__init__.py @@ -14,6 +14,7 @@ TreeAdjacencyMatrix, AncestorMatrix, CellAttachmentVector, + MoveProbabilities, ) from pyggdrasil.tree_inference._tree_generator import ( @@ -22,7 +23,7 @@ generate_random_TreeNode, ) -from pyggdrasil.tree_inference._tree import Tree, tree_from_tree_node, get_descendants +from pyggdrasil.tree_inference._tree import Tree, get_descendants from pyggdrasil.tree_inference._simulate import ( CellAttachmentStrategy, @@ -52,7 +53,9 @@ from pyggdrasil.tree_inference._huntress import huntress_tree_inference -from pyggdrasil.tree_inference._mcmc_sampler import mcmc_sampler, MoveProbabilities +from pyggdrasil.tree_inference._mcmc_sampler import mcmc_sampler + +from pyggdrasil.tree_inference._tree_mcmc import evolve_tree_mcmc __all__ = [ @@ -67,7 +70,6 @@ "MutationMatrix", "Tree", "MoveProbabilities", - "tree_from_tree_node", "unpack_sample", "gen_sim_data", "huntress_tree_inference", @@ -93,4 +95,5 @@ "MoveProbConfigOptions", "McmcConfigOptions", "ErrorCombinations", + "evolve_tree_mcmc", ] diff --git a/src/pyggdrasil/tree_inference/_file_id.py b/src/pyggdrasil/tree_inference/_file_id.py index 44a8cf06..afe5ef48 100644 --- a/src/pyggdrasil/tree_inference/_file_id.py +++ b/src/pyggdrasil/tree_inference/_file_id.py @@ -1,6 +1,8 @@ """Provides classes for naming files Tree, Cell Simulation and MCMC run files uniquely """ +import re + from enum import Enum from typing import Union, Optional @@ -16,12 +18,14 @@ class TreeType(Enum): - STAR (star tree) - DEEP (deep tree) - HUNTRESS (Huntress tree) - inferred from real / cell simulation data + - MCMC - generated tree evolve by MCMC moves """ RANDOM = "r" STAR = "s" DEEP = "d" HUNTRESS = "h" + MCMC = "m" class MutationDataId: @@ -111,13 +115,14 @@ def from_str(cls, str_id: str): # split string by underscore and assign to attributes split_elements = str_id.split("_") seed = None - mutation_data = None + rest_id = None if len(split_elements) == 3: _, tree_type, n_nodes = split_elements elif len(split_elements) == 4: _, tree_type, n_nodes, seed = split_elements - elif len(split_elements) == 5: - _, tree_type, n_nodes, seed, mutation_data = split_elements + elif len(split_elements) >= 5: + _, tree_type, n_nodes, *rest = split_elements + rest_id = "_".join(rest) else: raise AssertionError("Tree id has invalid format") @@ -125,17 +130,109 @@ def from_str(cls, str_id: str): tree_id = TreeId(TreeType(tree_type), int(n_nodes), int(seed)) return tree_id else: - if mutation_data is not None: - try: - mutation_data = CellSimulationId.from_str(mutation_data) - except AssertionError: - mutation_data = MutationDataId(mutation_data) - - tree_id = TreeId(TreeType(tree_type), int(n_nodes), None, mutation_data) + if rest_id is not None: + # check if tree is MCMC tree + if tree_type == TreeType.MCMC.value: + try: + tree_id = McmcTreeId.from_str(str_id) + return tree_id + except AssertionError: + raise AssertionError( + "Tree id has invalid format for an MCMC tree" + ) + + # check if tree is Huntress tree + elif tree_type == TreeType.HUNTRESS.value: + try: + mutation_data = CellSimulationId.from_str(rest_id) + except AssertionError: + mutation_data = MutationDataId(rest_id) + + tree_id = TreeId( + TreeType(tree_type), int(n_nodes), None, mutation_data + ) + return tree_id else: tree_id = TreeId(TreeType(tree_type), int(n_nodes)) + return tree_id + + +class McmcTreeId(TreeId): + """Class for tree ids of trees evolved by MCMC moves under SCITE. + + MCMC move probabilities are not specified in the id! + ID is not unique, fully reproducible only with the MCMC config. + Assumed default values for MCMC config. + """ + + tree_type: TreeType + n_moves: int + n_nodes: int + mcmc_rng_seed: int + initial_tree_id: TreeId + + def __init__( + self, + n_moves: int, + n_nodes: int, + mcmc_rng_seed: int, + initial_tree_id: TreeId, + tree_type: TreeType = TreeType.MCMC, + ): + self.initial_tree_id = initial_tree_id + self.n_nodes = n_nodes + self.n_moves = n_moves + self.mcmc_rng_seed = mcmc_rng_seed + self.tree_type = tree_type + super().__init__(TreeType.MCMC, n_nodes) + + self.id = self._create_id() + + def _create_id(self) -> str: + """Creates a unique id for the tree, + by concatenating the values of the attributes""" + + str_rep = "T" + str_rep = str_rep + "_" + str(self.tree_type.value) + str_rep = str_rep + "_" + str(self.n_nodes) + str_rep = str_rep + "_" + str(self.n_moves) + str_rep = str_rep + "_" + str(self.mcmc_rng_seed) + str_rep = str_rep + "_o" + str(self.initial_tree_id) + + return str_rep + + def __str__(self) -> str: + return self.id + + @classmethod + def from_str(cls, str_id: str): + """Creates a tree id from a string representation of the id. + + Args: + str_id: str + """ + + # Define the regular expression pattern to match the variables + pattern = r"T_m_(\d+)_(\d+)_(\d+)_o(T_[a-zA-Z]_\d+_\d+)" + + # Use re.findall() to extract the matched variables + matches = re.findall(pattern, str_id) + + # The 'matches' variable now contains the extracted variables. + # Let's unpack the matches to get individual variable values. + if matches: + n_nodes, n_moves, mcmc_move_seed, initial_tree_id = matches[0] + + tree_id = McmcTreeId( + int(n_moves), + int(n_nodes), + int(mcmc_move_seed), + TreeId.from_str(initial_tree_id), + ) return tree_id + else: + raise AssertionError("MCMC tree id has invalid format") class CellSimulationId(MutationDataId): @@ -243,10 +340,9 @@ def from_str(cls, str_id: str): # create tree id tree_id = TreeId.from_str(tree_id) - # TODO: remove type ignore once PR #64 is merged return cls( seed, - tree_id, # type: ignore + tree_id, n_cells, fpr, fnr, diff --git a/src/pyggdrasil/tree_inference/_interface.py b/src/pyggdrasil/tree_inference/_interface.py index 73366873..566377fb 100644 --- a/src/pyggdrasil/tree_inference/_interface.py +++ b/src/pyggdrasil/tree_inference/_interface.py @@ -5,6 +5,7 @@ so we do not introduce circular imports. """ from typing import Union +import dataclasses import jax import numpy as np @@ -50,3 +51,15 @@ # Observational Error rates # tuple of (fpr, fnr) ErrorRates = tuple[float, float] + + +@dataclasses.dataclass +class MoveProbabilities: + """Move probabilities. The default values were taken from + the paragraph **Combining the three MCMC moves** of page 14 + of the SCITE paper supplement. + """ + + prune_and_reattach: float = 0.1 + swap_node_labels: float = 0.65 + swap_subtrees: float = 0.25 diff --git a/src/pyggdrasil/tree_inference/_mcmc.py b/src/pyggdrasil/tree_inference/_mcmc.py index 9da43135..e8a42855 100644 --- a/src/pyggdrasil/tree_inference/_mcmc.py +++ b/src/pyggdrasil/tree_inference/_mcmc.py @@ -11,16 +11,12 @@ import math from jax import random import jax.numpy as jnp -import dataclasses import logging from pyggdrasil.interface import JAXRandomKey -from pyggdrasil.tree_inference import ( - Tree, - get_descendants, -) +from pyggdrasil.tree_inference import Tree, get_descendants, MoveProbabilities logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -322,18 +318,6 @@ def _swap_subtrees_proposal(key: JAXRandomKey, tree: Tree) -> Tuple[Tree, float] return _swap_subtrees_move_diff_lineage(tree, node1, node2), 0.0 -@dataclasses.dataclass -class MoveProbabilities: - """Move probabilities. The default values were taken from - the paragraph **Combining the three MCMC moves** of page 14 - of the SCITE paper supplement. - """ - - prune_and_reattach: float = 0.1 - swap_node_labels: float = 0.65 - swap_subtrees: float = 0.25 - - def _validate_move_probabilities(move_probabilities: MoveProbabilities, /) -> None: """Validates if ``move_probabilities`` are valid. @@ -446,3 +430,38 @@ def _mcmc_kernel( else: logger.info("Move Rejected") return tree, logprobability + + +def _evolve_tree_mcmc( + tree: Tree, + n_moves: int, + rng: JAXRandomKey, + move_probs: MoveProbabilities, +) -> Tree: + """Evolves a tree using the SCITE MCMC moves, + + Args: + tree: Tree + tree to evolve + n_moves: int + number of moves to perform + rng: JAXRandomKey + random number generator + move_probs: MoveProbabilities + move probabilities to use + + Returns: + Tree: evolved tree + """ + + # make random log prob function + def log_prob_fn(t: Tree) -> float: + """Log prob function for testing. - dummy function""" + return jax.random.uniform(rng, shape=()).__float__() + + # Run the kernel + for i in range(n_moves): + rng, rng_now = jax.random.split(rng) + tree, _ = _mcmc_kernel(rng_now, tree, move_probs, log_prob_fn) # type: ignore + + return tree diff --git a/src/pyggdrasil/tree_inference/_mcmc_sampler.py b/src/pyggdrasil/tree_inference/_mcmc_sampler.py index fdb92958..e6996313 100644 --- a/src/pyggdrasil/tree_inference/_mcmc_sampler.py +++ b/src/pyggdrasil/tree_inference/_mcmc_sampler.py @@ -78,6 +78,14 @@ def mcmc_sampler( "Log-probability calculation does not yet support" ) + # assert that the number of mutations and the data matrix size match + # no of nodes must equal the number of rows in the data matrix plus root truncated + if not init_tree.labels.shape[0] == data.shape[0] + 1: + raise AssertionError( + "Number of mutations and data matrix size do not match.\n" + f"tree {init_tree.labels.shape[0]} != data {data.shape[0]}" + ) + # curry logprobability function logprobability_fn = logprob.create_logprob(data, error_rates) diff --git a/src/pyggdrasil/tree_inference/_mcmc_util.py b/src/pyggdrasil/tree_inference/_mcmc_util.py index 958bfd4a..71064084 100644 --- a/src/pyggdrasil/tree_inference/_mcmc_util.py +++ b/src/pyggdrasil/tree_inference/_mcmc_util.py @@ -5,6 +5,7 @@ import jax.scipy as jsp import xarray as xr + from pyggdrasil.tree_inference._tree import Tree import pyggdrasil.tree_inference._tree as tr diff --git a/src/pyggdrasil/tree_inference/_simulate.py b/src/pyggdrasil/tree_inference/_simulate.py index e3008916..4b7dab4d 100644 --- a/src/pyggdrasil/tree_inference/_simulate.py +++ b/src/pyggdrasil/tree_inference/_simulate.py @@ -20,7 +20,7 @@ TreeAdjacencyMatrix, AncestorMatrix, CellAttachmentVector, - tree_from_tree_node, + Tree, ) from pyggdrasil.tree import TreeNode @@ -644,7 +644,7 @@ def gen_sim_data( rng_tree, rng_cell_attachment, rng_noise = random.split(rng, 3) # Take Tree and convert to local format - tree_adj_mat = tree_from_tree_node(tree_tn).tree_topology + tree_adj_mat = Tree.tree_from_tree_node(tree_tn).tree_topology # Attach Cells To Tree # convert adjacency matrix to numpy array diff --git a/src/pyggdrasil/tree_inference/_tree.py b/src/pyggdrasil/tree_inference/_tree.py index ab7dccc1..4606e9dc 100644 --- a/src/pyggdrasil/tree_inference/_tree.py +++ b/src/pyggdrasil/tree_inference/_tree.py @@ -77,6 +77,39 @@ def __str__(self): ) return df.__str__() + @staticmethod + def tree_from_tree_node(tree_node: TreeNode) -> "Tree": + """Converts a tree node to a tree""" + + # Get all nodes in the tree - sort descendants by name and add root node to end + nodes = sorted(list(tree_node.descendants), key=lambda x: x.name) + [tree_node] + + # Create an empty adjacency matrix + n = len(nodes) + adj_matrix = np.zeros((n, n)) + + # Assign indices to nodes + node_indices = {node.name: i for i, node in enumerate(nodes)} + + node_indices = jnp.array(list(node_indices.values())) + + # Populate adjacency matrix + for node in nodes: + i = node_indices[node.name] + for child in node.children: + j = node_indices[child.name] + adj_matrix[i, j] = 1 + + # ensure is jax array + adj_matrix = jnp.array(adj_matrix) + + node_labels = jnp.array([node.name for node in nodes]) + + # make tree object + tree = Tree(tree_topology=adj_matrix, labels=node_labels) + + return tree + def _resort_root_to_end(tree: Tree, root: int) -> Tree: """Resorts tree so that root is at the end of the adjacency matrix. @@ -263,39 +296,6 @@ def _reorder_tree(tree: Tree, from_labels, to_labels): return reordered_tree -def tree_from_tree_node(tree_node: TreeNode) -> Tree: - """Converts a tree node to a tree""" - - # Get all nodes in the tree - sort descendants by name and add root node to end - nodes = sorted(list(tree_node.descendants), key=lambda x: x.name) + [tree_node] - - # Create an empty adjacency matrix - n = len(nodes) - adj_matrix = np.zeros((n, n)) - - # Assign indices to nodes - node_indices = {node.name: i for i, node in enumerate(nodes)} - - node_indices = jnp.array(list(node_indices.values())) - - # Populate adjacency matrix - for node in nodes: - i = node_indices[node.name] - for child in node.children: - j = node_indices[child.name] - adj_matrix[i, j] = 1 - - # ensure is jax array - adj_matrix = jnp.array(adj_matrix) - - node_labels = jnp.array([node.name for node in nodes]) - - # make tree object - tree = Tree(tree_topology=adj_matrix, labels=node_labels) - - return tree - - def is_same_tree(tree1: Tree, tree2: Tree) -> bool: """Check if two trees are the same. diff --git a/src/pyggdrasil/tree_inference/_tree_generator.py b/src/pyggdrasil/tree_inference/_tree_generator.py index af0c97a5..215747de 100644 --- a/src/pyggdrasil/tree_inference/_tree_generator.py +++ b/src/pyggdrasil/tree_inference/_tree_generator.py @@ -10,6 +10,7 @@ import numpy as np + from pyggdrasil import TreeNode from pyggdrasil.interface import JAXRandomKey diff --git a/src/pyggdrasil/tree_inference/_tree_mcmc.py b/src/pyggdrasil/tree_inference/_tree_mcmc.py new file mode 100644 index 00000000..64f4cc51 --- /dev/null +++ b/src/pyggdrasil/tree_inference/_tree_mcmc.py @@ -0,0 +1,40 @@ +"""Implements Tree operations that rely on _mcmc.py""" + +from pyggdrasil import TreeNode + +from pyggdrasil.interface import JAXRandomKey + +from pyggdrasil.tree_inference._interface import MoveProbabilities +from pyggdrasil.tree_inference._config import MoveProbConfigOptions +import pyggdrasil.tree_inference._mcmc as mcmc +from pyggdrasil.tree_inference._tree import Tree + + +def evolve_tree_mcmc( + init_tree: TreeNode, + n_moves: int, + rng: JAXRandomKey, + move_probs: MoveProbabilities = MoveProbConfigOptions.DEFAULT.value, # type: ignore +) -> TreeNode: + """Evolves a tree using the SCITE MCMC moves, assumes default move weights. + + Args: + init_tree: TreeNode + tree to evolve + n_moves: int + number of moves to perform + rng: JAXRandomKey + random number generator + move_probs: MoveProbabilities + move probabilities to use + + Returns: + tree_ev: TreeNode + evolved tree + """ + + tree = Tree.tree_from_tree_node(init_tree) + + tree_ev = mcmc._evolve_tree_mcmc(tree, n_moves, rng, move_probs) + + return tree_ev.to_TreeNode() diff --git a/tests/tree_inference/test_file_id.py b/tests/tree_inference/test_file_id.py index 8865209e..a29d8b69 100644 --- a/tests/tree_inference/test_file_id.py +++ b/tests/tree_inference/test_file_id.py @@ -102,7 +102,7 @@ def test_mcmc_run_id(mcmc_run_id) -> None: def test_tree_id_from_str(tree_id) -> None: """Tests for tree id.""" - test_id = TreeId.from_str(str(tree_id)) + test_id: TreeId = TreeId.from_str(str(tree_id)) # type: ignore assert test_id.tree_type == tree_id.tree_type assert test_id.n_nodes == tree_id.n_nodes @@ -145,3 +145,25 @@ def test_cell_simulation_id_from_str_scientific_notation( == cell_simulation_id_scientific_notation.observe_homozygous ) assert csi.strategy == cell_simulation_id_scientific_notation.strategy + + +def test_huntrees_tree_id_from_str() -> None: + """Tests for tree id.""" + + str = "T_h_6_CS_42-T_r_6_42-200_0.1_0.1_0.0_f_UXR" + + test_id: TreeId = TreeId.from_str(str) # type: ignore + + assert test_id.tree_type == TreeType("h") + assert test_id.n_nodes == 6 + + +def test_mcmc_tree_id_from_str() -> None: + """Tests for tree id.""" + + str = "iT_m_6_5_99_oT_r_6_42" + + test_id: TreeId = TreeId.from_str(str) # type: ignore + + assert test_id.tree_type == TreeType.MCMC + assert test_id.n_nodes == 6 diff --git a/tests/tree_inference/test_mcmc.py b/tests/tree_inference/test_mcmc.py index 1557d6b4..0984528b 100644 --- a/tests/tree_inference/test_mcmc.py +++ b/tests/tree_inference/test_mcmc.py @@ -5,11 +5,11 @@ import jax.numpy as jnp import pyggdrasil.tree_inference._mcmc as mcmc +import pyggdrasil.tree_inference._tree_generator import pyggdrasil.tree_inference._tree_generator as tree_gen import pyggdrasil.tree_inference._tree as tr import pyggdrasil.tree_inference._mcmc_util as mcmc_util -from pyggdrasil.tree_inference import MoveProbConfigOptions - +from pyggdrasil.tree_inference import MoveProbConfigOptions, MoveProbabilities from pyggdrasil.tree_inference._tree import Tree @@ -478,3 +478,23 @@ def log_prob_fn(t: Tree) -> float: rng_now, tree, move_probs, log_prob_fn # type: ignore ) assert tr.is_valid_tree(tree) + + +def test_evolve_tree_mcmc(): + """Static Test evolve_tree_mcmc. - assures that imports are correct + and that the function runs""" + + seed = 42 + rng = random.PRNGKey(seed) + + # generate random tree + tree = pyggdrasil.tree_inference._tree_generator.generate_random_Tree(rng, 10) + + # define move probabilities + move_probs = MoveProbabilities() + + # evolve tree + tree_ev = mcmc._evolve_tree_mcmc(tree, 2, rng, move_probs) + + # check if the tree is still a tree + assert not tr.is_same_tree(tree, tree_ev) diff --git a/tests/tree_inference/test_tree.py b/tests/tree_inference/test_tree.py index be6fdd5c..4c78f67c 100644 --- a/tests/tree_inference/test_tree.py +++ b/tests/tree_inference/test_tree.py @@ -129,7 +129,7 @@ def test_tree_node_to_tree(): json_obj, deserialize_data=lambda x: x ) # convert to tree - tree = tr.tree_from_tree_node(tree_node) + tree = tr.Tree.tree_from_tree_node(tree_node) # check that tree is correct assert jnp.all( diff --git a/workflows/README.md b/workflows/README.md index 2765edad..0063a565 100644 --- a/workflows/README.md +++ b/workflows/README.md @@ -41,6 +41,7 @@ rm -rf ~/miniconda3/miniconda.sh ~/miniconda3/bin/conda init bash ~/miniconda3/bin/conda init zsh ``` +Note that the above installs miniconda in your home directory. Further, enhance by adding mamba to the conda environment, for faster resolving of dependencies: ```commandline @@ -49,6 +50,7 @@ conda install mamba -n base -c conda-forge Then, create a new environment for the project: ```commandline +cd PYggdrasil/ mamba env create -f environment.yml ``` @@ -57,6 +59,7 @@ mamba env create -f environment.yml Then add in all project specific dependencies via: ```commandline cd PYggdrasil/ +conda activate PYggdrasil pip install -e . ``` This should install all the dependencies, and make the package available in the environment `PYggdrasil` that is currently active by running the prior command. diff --git a/workflows/Snakefile b/workflows/Snakefile index dac825ab..37df5312 100644 --- a/workflows/Snakefile +++ b/workflows/Snakefile @@ -10,6 +10,4 @@ include: "visualize.smk" include: "mark00.smk" include: "mark01.smk" include: "mark02.smk" - - -################################################################################ +include: "mark03.smk" diff --git a/workflows/analyze.smk b/workflows/analyze.smk index c46fdd3f..22ec7ca6 100644 --- a/workflows/analyze.smk +++ b/workflows/analyze.smk @@ -12,8 +12,10 @@ rule analyze_metric: an MCMC sample as input i.e. all distances /similarity metrics. Note: this includes ``is_true_tree``.""" input: - mcmc_samples="{DATADIR}/{experiment}/mcmc/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}.json", - base_tree="{DATADIR}/{experiment}/trees/{base_tree_id}.json", + mcmc_samples = '{DATADIR}/{experiment}/mcmc/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}.json', + base_tree = '{DATADIR}/{experiment}/trees/{base_tree_id}.json' + wildcard_constraints: + mcmc_config_id = "MC.*", output: result="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}/{base_tree_id}/{metric}.json", log="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}/{base_tree_id}/{metric}.log", diff --git a/workflows/mark01.smk b/workflows/mark01.smk index 8d36b10e..b592f48a 100644 --- a/workflows/mark01.smk +++ b/workflows/mark01.smk @@ -129,38 +129,3 @@ rule calculate_huntress_distances: # save the distances and the huntress tree ids yg.serialize.save_metric_result(axis=huntress_tree_ids, result=distances, out_fp=Path(output.distances), axis_name="huntress_tree_id") - -# below rule input will trigger gen_cell_simulation rule, which will trigger tree generation rule -rule run_huntress: - """Run HUNTRESS on the true tree. - - - Cell Simulation data requires - - no missing entries - - no homozygous mutations - """ - # TODO (Gordon): this rule should be general enough to be used for all experiments, i.e. it should be moved to the common workflow - input: - mutation_data = "{DATADIR}/{experiment}/mutations/{mutation_data_id}.json", - output: - huntrees_tree = "{DATADIR}/{experiment}/huntress/HUN-{mutation_data_id}.json" - run: - import pyggdrasil as yg - # load data of mutation matrix - with open(input.mutation_data,"r") as f: - cell_simulation_data = json.load(f) - # TODO (Gordon): modify to allow non-simulated data - cell_simulation_data = yg.tree_inference.get_simulation_data(cell_simulation_data) - # get the mutation matrix - mut_mat = cell_simulation_data["noisy_mutation_mat"] - # get error rates from the cell simulation id - # get name of file without extension - data_fn = Path(input.mutation_data).stem - # try to match the cell simulation id - cell_sim_id = yg.tree_inference.CellSimulationId.from_str(data_fn) - # run huntress - huntress_tree = yg.tree_inference.huntress_tree_inference(mut_mat, cell_sim_id.fpr, cell_sim_id.fnr) - # make TreeNode from Node - huntress_treeNode = yg.TreeNode.convert_anytree_to_treenode(huntress_tree) - # save the huntress tree - yg.serialize.save_tree_node(huntress_treeNode, Path(output.huntrees_tree)) - diff --git a/workflows/mark02.smk b/workflows/mark02.smk index d168fe02..eac79c2d 100644 --- a/workflows/mark02.smk +++ b/workflows/mark02.smk @@ -150,13 +150,13 @@ rule combined_chain_histogram: """ input: # calls analyze_metric rule - all_chain_metrics = ['{DATADIR}/{experiment}/analysis/MCMC_' + str(mcmc_seed) + '-{mutation_data_id}-iT_' + all_chain_metrics = ['{DATADIR}/mark02/analysis/MCMC_' + str(mcmc_seed) + '-{mutation_data_id}-iT_' + str(init_tree_type)+ '_{n_nodes,\d+}_' + str(init_tree_seed) + '-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json' for mcmc_seed, init_tree_type, init_tree_seed in initial_points] output: - combined_chain_histogram = '{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/' + combined_chain_histogram = '{DATADIR}/mark02/plots/{mcmc_config_id}/{mutation_data_id}/' 'T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.svg', run: diff --git a/workflows/mark03.smk b/workflows/mark03.smk new file mode 100644 index 00000000..584ba9a6 --- /dev/null +++ b/workflows/mark03.smk @@ -0,0 +1,454 @@ +"""Experiment mark03 + + Investigate convergence of SCITE MCMC chains, + given different initial points with tree + distances. + """ + +# imports +import matplotlib.pyplot as plt +import jax +from pathlib import Path + +import pyggdrasil as yg + +from pyggdrasil.tree_inference import CellSimulationId, TreeType, TreeId, McmcConfig + +##################### +# Environment variables +DATADIR = "../data" +# DATADIR = "/cluster/work/bewi/members/gkoehn/data" + +##################### +experiment = "mark03" + +# Metrics: Distances / Similarity Measure to use +metrics = ["MP3", "AD", "log_prob"] # <-- configure distances here + +##################### +# Error Parameters +# used for both cell simulation and MCMC inference + +# Errors <--- set the error rates here +errors = { # get the pre-defined error rate combinations + member.name: member.value.dict() for member in yg.tree_inference.ErrorCombinations +} + +rate_na = 0.0 # <-- configure NA rate here + +##################### +##################### +# Cell Simulation Parameters + +n_mutations = [5, 10, 30, 50] # <-- configure number of mutations here +n_cells = [200, 1000, 5000] # <-- configure number of cells here + +# Homozygous mutations +observe_homozygous = False # <-- configure whether to observe homozygous mutations here + +# cell attachment strategy +cell_attachment_strategy = ( + yg.tree_inference.CellAttachmentStrategy.UNIFORM_EXCLUDE_ROOT +) # <-- configure cell attachment strategy here + +# cell simulation seed +CS_seed = 42 # <-- configure cell simulation seed here + +##################### +# True Tree Parameters +tree_types = ["r"] # <-- configure tree type here ["r","s","d"] +tree_seeds = [42] # <-- configure tree seed here + +##################### +##################### +# MCMC Parameters + +# define 4 initial points, different chains +# given each error rate, true tree, no of cells and mutations +# make random trees and mcmc seeds +desired_counts = { + "d": 10, # Deep Trees + "r": 10, # Random Trees + "s": 1, # Star Tree + "h": 1, # Huntress Tree, derived from the cell simulation + "m": 10, # MCMC Move Trees +} + +# number of mcmc moves applied on random initial trees +n_mcmc_tree_moves = 5 + +# MCMC config +n_samples = 2000 # <-- configure number of samples here + +##################### +##################### + + +def make_initial_points_mark03(desired_counts: dict): + """Make initial mcmc points for mark03 experiment. + + Args: + desired_counts: dict + A dictionary of the form + { + 'd': 10, # Deep Trees + 'r': 10, # Random Trees + 's': 1, # Star Tree + 'h': 10, # Huntress Trees + 'mcmc': 5 # MCMC Move Trees + } + indicating the number of initial points to generate for each type of tree. + + Returns: + list of tuples (mcmc_seed, init_tree_type, init_tree_seed) + """ + + key = jax.random.PRNGKey(0) # Set the initial PRNG key + new_trees = [] + seed_pool = set() + for init_tree_type, count in desired_counts.items(): + for _ in range(count): + key, subkey = jax.random.split(key) # Split the PRNG key + mcmc_seed = jax.random.randint( + subkey, (), 1, 100 + ) # Generate a random MCMC seed + key, subkey = jax.random.split(key) # Split the PRNG key + init_tree_seed = jax.random.randint( + subkey, (), 1, 100 + ) # Generate a random seed for init_tree + while init_tree_seed.item() in seed_pool: # Ensure the seed is unique + key, subkey = jax.random.split(key) # Split the PRNG key + init_tree_seed = jax.random.randint(subkey, (), 1, 100) + new_trees.append((mcmc_seed.item(), init_tree_type, init_tree_seed.item())) + seed_pool.add(init_tree_seed.item()) + return new_trees + + +# Generate the initial points of mcmc chains +initial_points = make_initial_points_mark03(desired_counts) + + +def make_all_mark03(): + """Make all final output file names.""" + + # f"{DATADIR}/{experiment}/plots/{McmcConfig}/{CellSimulationId}/" + + # "AD.svg" and "MP3.svg" or log_prob.svg + + filepaths = [] + filepath = f"{DATADIR}/{experiment}/plots/" + # add +1 to n_mutation to account for the root mutation + n_nodes = [n_mutation + 1 for n_mutation in n_mutations] + + # make true tree ids for cell simulation - true trees + tree_id_ls = [] + for tree_type in tree_types: + for tree_seed in tree_seeds: + for n_node in n_nodes: + # if star tree, ignore tree_seed + if tree_type == "s": + tree_id_ls.append( + TreeId(tree_type=TreeType(tree_type), n_nodes=n_node) + ) + else: + tree_id_ls.append( + TreeId( + tree_type=TreeType(tree_type), + n_nodes=n_node, + seed=tree_seed, + ) + ) + + # make cell simulation ids + for true_tree_id in tree_id_ls: + for n_cell in n_cells: + for error_name, error in errors.items(): + # make cell simulation id + cs = CellSimulationId( + seed=CS_seed, + tree_id=true_tree_id, + n_cells=n_cell, + fpr=error["fpr"], + fnr=error["fnr"], + na_rate=rate_na, + observe_homozygous=observe_homozygous, + strategy=cell_attachment_strategy, + ) + # make mcmc config id + mc = McmcConfig( + n_samples=n_samples, fpr=error["fpr"], fnr=error["fnr"] + ).id() + # make filepaths for each metric + for each_metric in metrics: + filepaths.append( + filepath + + mc + + "/" + + str(cs) + + "/" + + str(true_tree_id) + + "/" + + each_metric + + "_iter.svg" + ) + return filepaths + + +rule mark03: + """Main mark03 rule.""" + input: + make_all_mark03(), + + +def make_combined_metric_iteration_in(): + """Make input for combined_metric_iteration rule.""" + input = [] + tree_type = [] + + for mcmc_seed, init_tree_type, init_tree_seed in initial_points: + # make variables strings dependent on tree type + # catch the case where init_tree_type is star tree + if init_tree_type == "s": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}" + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + # catch the case where init_tree_type is huntress tree + elif init_tree_type == "h": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_h_" + + "{n_nodes,\d+}" + + "_{mutation_data_id}" + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + # if mcmc tree + elif init_tree_type == "m": + # split the mcmc seed int into 2 parts: tree_seed, mcmc_seed + tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100 + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_m_{n_nodes}_" + + str(n_mcmc_tree_moves) + + "_" + + str(mcmc_move_seed) + + "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}" + + "-{mcmc_config_id}" + + "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + # all other cases + else: + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}_" + + str(init_tree_seed) + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + tree_type.append(init_tree_type) + + return input, tree_type + + +rule combined_metric_iteration_plot: + """Make combined metric iteration plot. + + For each metric, make a plot with all the chains, where + each initial tree type is a different color. + """ + input: + # calls analyze_metric rule + all_chain_metrics=make_combined_metric_iteration_in()[0], + wildcard_constraints: + # metric wildcard cannot be log_prob + metric=r"(?!(log_prob))\w+", + output: + combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/" + "T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter.svg", + run: + # load the data + distances_chains = [] + # get the initial tree type, same order as the input + initial_tree_type = make_combined_metric_iteration_in()[1] + # for each chain + for each_chain_metric in input.all_chain_metrics: + # load the distances + _, distances = yg.serialize.read_metric_result(Path(each_chain_metric)) + # append to the list + distances_chains.append(distances) + + # Create a figure and axis + fig, ax = plt.subplots() + + # Define the list of colors to repeat + colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"} + labels = { + "h": "Huntress", + "s": "Star", + "d": "Deep", + "r": "Random", + "m": "MCMC5", + } + + # Define opacity and line style + alpha = 0.6 + line_style = "solid" + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + for i, distances in enumerate(distances_chains): + color = colors[initial_tree_type[i]] + ax.plot( + distances, + color=color, + label=f"{labels[initial_tree_type[i]]}", + alpha=alpha, + linestyle=line_style, + ) + + # Set labels and title + ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}") + ax.set_xlabel("Iteration") + + # Add a legend of fixed legend position and size + ax.legend(loc="upper right") + + # save the histogram + fig.savefig(Path(output.combined_metric_iter)) + + +def make_combined_log_prob_iteration_in(): + """Make input for combined_metric_iteration rule.""" + input = [] + + for mcmc_seed, init_tree_type, init_tree_seed in initial_points: + # make variables strings dependent on tree type + # catch the case where init_tree_type is star tree + if init_tree_type == "s": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}" + + "-{mcmc_config_id}/log_prob.json" + ) + # catch the case where init_tree_type is huntress tree + elif init_tree_type == "h": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_h_" + + "{n_nodes,\d+}" + + "_{mutation_data_id}" + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json" + ) + # if mcmc tree + elif init_tree_type == "m": + # split the mcmc seed int into 2 parts: tree_seed, mcmc_seed + tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100 + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_m_{n_nodes}_" + + str(n_mcmc_tree_moves) + + "_" + + str(mcmc_move_seed) + + "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}" + + "-{mcmc_config_id}" + + "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json" + ) + + # all other cases + else: + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}_" + + str(init_tree_seed) + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json" + ) + return input + + +rule combined_logProb_iteration_plot: + """Make combined logProb iteration plot.""" + input: + # calls analyze_metric rule + all_chain_logProb=make_combined_log_prob_iteration_in(), + output: + combined_logP_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob_iter.svg", + run: + # load the data + logP_chains = [] + # get the initial tree type, same order as the input + initial_tree_type = make_combined_metric_iteration_in()[1] + # for each chain + for each_chain_metric in input.all_chain_logProb: + # load the distances + _, logP = yg.serialize.read_metric_result(Path(each_chain_metric)) + # append to the list + logP_chains.append(logP) + + # Create a figure and axis + fig, ax = plt.subplots() + + # Define the list of colors to repeat + colors = { + "h": "red", + "s": "green", + "d": "blue", + "r": "orange", + "mcmc": "purple", + } + + labels = { + "h": "Huntress", + "s": "Star", + "d": "Deep", + "r": "Random", + "mcmc": "MCMC5", + } + + # Define opacity and line style + alpha = 0.6 + line_style = "solid" + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + for i, logP in enumerate(logP_chains): + color = colors[initial_tree_type[i]] + ax.plot( + logP, + color=color, + label=f"{labels[initial_tree_type[i]]}", + alpha=alpha, + linestyle=line_style, + ) + + # Set labels and title + ax.set_ylabel(f"Log Probability:" + r"$\log(P(D|T,\theta))$") + ax.set_xlabel("Iteration") + + # Add a legend of fixed legend position + ax.legend(loc="upper right") + + # save the histogram + fig.savefig(Path(output.combined_logP_iter)) diff --git a/workflows/tree_inference.smk b/workflows/tree_inference.smk index a18ff323..7ad3be24 100644 --- a/workflows/tree_inference.smk +++ b/workflows/tree_inference.smk @@ -1,6 +1,13 @@ """Snakemake rules for the tree inference pipeline.""" import json +import shutil +import jax + +from pathlib import Path + +import pyggdrasil as yg + from pyggdrasil.tree_inference import ( McmcConfig, @@ -9,10 +16,12 @@ from pyggdrasil.tree_inference import ( McmcConfigOptions, ) + ############################################### ## Relative path from DATADIR to the repo root -REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil" +#REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil" +REPODIR = ".." ############################################### @@ -141,8 +150,8 @@ rule mcmc: init_tree="{DATADIR}/{experiment}/trees/{init_tree_id}.json", mcmc_config="{DATADIR}/{experiment}/mcmc/config/{mcmc_config_id}.json", wildcard_constraints: - mcmc_config_id="MC.*", - init_tree_id="T.*", + mcmc_config_id = "MC.*", + init_tree_id = "(HUN|T).*" # allowing both generated and huntress trees output: mcmc_log="{DATADIR}/{experiment}/mcmc/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}.log", mcmc_samples="{DATADIR}/{experiment}/mcmc/MCMC_{mcmc_seed,\d+}-{mutation_data_id}-i{init_tree_id}-{mcmc_config_id}.json", @@ -155,3 +164,110 @@ rule mcmc: --data_fp {input.mutation_data} \ --init_tree_fp {input.init_tree} """ + + +# below rule input will trigger gen_cell_simulation rule, which will trigger tree generation rule +rule run_huntress: + """Run HUNTRESS on the true tree. + + Output is saved in huntress directory, intentionally not in the tree directory. + HUNTRESS output may vary in the number of mutations from the mutation matrix. + + - Cell Simulation data requires + - no missing entries + - no homozygous mutations + """ + input: + mutation_data="{DATADIR}/{experiment}/mutations/{mutation_data_id}.json", + output: + huntrees_tree="{DATADIR}/{experiment}/huntress/HUN-{mutation_data_id}.json" + threads: 4 # as many threads as defined in make_huntress.py + run: + # load data of mutation matrix + with open(input.mutation_data,"r") as f: + cell_simulation_data = json.load(f) + # TODO (Gordon): modify to allow non-simulated data + cell_simulation_data = yg.tree_inference.get_simulation_data(cell_simulation_data) + # get the mutation matrix + mut_mat = cell_simulation_data["noisy_mutation_mat"] + # get error rates from the cell simulation id + # get name of file without extension + data_fn = Path(input.mutation_data).stem + # try to match the cell simulation id + cell_sim_id = yg.tree_inference.CellSimulationId.from_str(data_fn) + # run huntress + huntress_tree = yg.tree_inference.huntress_tree_inference(mut_mat,cell_sim_id.fpr,cell_sim_id.fnr) + # make TreeNode from Node + huntress_treeNode = yg.TreeNode.convert_anytree_to_treenode(huntress_tree) + # save the huntress tree + yg.serialize.save_tree_node(huntress_treeNode,Path(output.huntrees_tree)) + + +rule copy_simulated_huntress_r_d_tree: + """Copy the simulated huntress tree to the tree directory, + with information about the number of nodes from the true tree. + + Validates that the number of nodes in the tree matches the number of nodes in the true tree. + """ + input: + huntrees_tree="{DATADIR}/{experiment}/huntress/HUN-CS_{CS_seed}-T_{tree_type}_{n_nodes}_{tree_seed}-{n_cells}_{CS_fpr}_{CS_fnr}_{CS_na}_{observe_homozygous}_{cell_attachment_strategy}.json" + output: + huntrees_tree="{DATADIR}/{experiment}/trees/T_h_{n_nodes}_CS_{CS_seed}-T_{tree_type}_{n_nodes}_{tree_seed}-{n_cells}_{CS_fpr}_{CS_fnr}_{CS_na}_{observe_homozygous}_{cell_attachment_strategy}.json" + run: + # validate the number of nodes in the tree + init_tree_node = yg.serialize.read_tree_node(Path(input.huntrees_tree)) + # convert TreeNode to Tree + init_tree = yg.tree_inference.Tree.tree_from_tree_node(init_tree_node) + + # assert that the number of mutations and the data matrix size match + # no of nodes must equal the number of rows in the data matrix plus root truncated + if not int(init_tree.labels.shape[0]) == int(wildcards.n_nodes): + raise ValueError(f"Number of nodes in the tree {init_tree.labels.shape[0]} does not match the number of nodes in the filename {wildcards.n_nodes}, Huntress may have not included all mutations.") + + # copy and rename the file from the huntress tree directory to the tree directory + shutil.copy(input.huntrees_tree,output.huntrees_tree) + + +rule copy_simulated_huntress_s_tree: + """Copy the simulated huntress tree to the tree directory, + with information about the number of nodes from the true tree. + + Validates that the number of nodes in the tree matches the number of nodes in the true tree. + """ + input: + huntrees_tree="{DATADIR}/{experiment}/huntress/HUN-CS_{CS_seed}-T_s_{n_nodes}-{n_cells}_{CS_fpr}_{CS_fnr}_{CS_na}_{observe_homozygous}_{cell_attachment_strategy}.json" + output: + huntrees_tree="{DATADIR}/{experiment}/trees/T_h_{n_nodes}_CS_{CS_seed}-T_s_{n_nodes}-{n_cells}_{CS_fpr}_{CS_fnr}_{CS_na}_{observe_homozygous}_{cell_attachment_strategy}.json" + run: + # validate the number of nodes in the tree + init_tree_node = yg.serialize.read_tree_node(Path(input.huntrees_tree)) + # convert TreeNode to Tree + init_tree = yg.tree_inference.Tree.tree_from_tree_node(init_tree_node) + + # assert that the number of mutations and the data matrix size match + # no of nodes must equal the number of rows in the data matrix plus root truncated + if not int(init_tree.labels.shape[0]) == int(wildcards.n_nodes): + raise ValueError(f"Number of nodes in the tree {init_tree.labels.shape[0]} does not match the number of nodes in the filename {wildcards.n_nodes}, Huntress may have not included all mutations.") + + # copy and rename the file from the huntress tree directory to the tree directory + shutil.copy(input.huntrees_tree,output.huntrees_tree) + + +rule mcmc_evolve_tree: + """Evolves a tree using mcmc moves from SCITE.""" + + input: + init_tree="{DATADIR}/{experiment}/trees/{init_tree_id}.json", + output: + evolved_tree="{DATADIR}/{experiment}/trees/T_m_{n_nodes}_{n_moves}_{mcmc_move_seed}_o{init_tree_id}.json" + run: + # load the initial tree + init_tree_node = yg.serialize.read_tree_node(Path(input.init_tree)) + # get mcmc parameters from the filename + n_moves = int(wildcards.n_moves) + mcmc_seed = int(wildcards.mcmc_move_seed) + rng = jax.random.PRNGKey(mcmc_seed) + # evolve the tree + evolved_tree_node = yg.tree_inference.evolve_tree_mcmc(init_tree_node,n_moves,rng) + # save the evolved tree + yg.serialize.save_tree_node(evolved_tree_node,Path(output.evolved_tree))