Skip to content

Commit

Permalink
DOC: add typing to tests and reformat fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcar17 committed Aug 29, 2024
1 parent 7cd5833 commit e1176ef
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 81 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
"tests/**/*.py" = [
"S101", # asserts allowed in tests...
"INP001", # __init__.py files are not required...
"ANN",
"N802",
"N803",
"SLF001",
"D"
"N802", # allow non snake_case function names for fixtures
"N803", # allow use of fixture constants
"SLF001", # private member access is useufl for testing
"FBT001", # allow bool pos args for parameterisation
"D", # don't require docstrings
]
"noxfile.py" = [
"S101", # asserts allowed in tests...
Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import pathlib

import pytest
from cogent3 import ArrayAlignment, load_aligned_seqs


@pytest.fixture(scope="session")
def DATA_DIR():
def DATA_DIR() -> pathlib.Path:
return pathlib.Path(__file__).parent / "data"


@pytest.fixture()
def three_otu(DATA_DIR: pathlib.Path) -> ArrayAlignment:
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)


@pytest.fixture()
def four_otu(DATA_DIR: pathlib.Path) -> ArrayAlignment:
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Chimpanzee", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)
24 changes: 4 additions & 20 deletions tests/test_app/test_app.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
from pathlib import Path

import piqtree2
import pytest
from cogent3 import ArrayAlignment, get_app, load_aligned_seqs, make_tree
from cogent3 import ArrayAlignment, get_app, make_tree
from piqtree2.model import DnaModel, Model


@pytest.fixture()
def three_otu(DATA_DIR: Path) -> ArrayAlignment:
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)


@pytest.fixture()
def four_otu(DATA_DIR: Path) -> ArrayAlignment:
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Chimpanzee", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)


def test_piqtree_phylo(four_otu: ArrayAlignment):
def test_piqtree_phylo(four_otu: ArrayAlignment) -> None:
expected = make_tree("(Human,Chimpanzee,(Rhesus,Mouse));")
app = get_app("piqtree_phylo", model=Model(DnaModel.JC))
got = app(four_otu)
assert expected.same_topology(got)


def test_piqtree_fit(three_otu: ArrayAlignment):
def test_piqtree_fit(three_otu: ArrayAlignment) -> None:
tree = make_tree(tip_names=three_otu.names)
app = get_app("model", "JC69", tree=tree)
expected = app(three_otu)
Expand All @@ -43,7 +27,7 @@ def test_piqtree_random_trees(
num_taxa: int,
tree_mode: piqtree2.TreeGenMode,
num_trees: int,
):
) -> None:
trees = piqtree2.random_trees(
num_taxa,
tree_mode,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_app/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
from piqtree2._app import _ALL_APP_NAMES


def test_pickle():
def test_pickle() -> None:
for app_name in _ALL_APP_NAMES:
assert len(pickle.dumps(get_app(app_name))) > 0
41 changes: 21 additions & 20 deletions tests/test_iqtree/test_build_tree.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from typing import Optional

import piqtree2
import pytest
from cogent3 import load_aligned_seqs, make_tree
from cogent3 import ArrayAlignment, make_tree
from piqtree2.model import (
DiscreteGammaModel,
DnaModel,
FreeRateModel,
FreqType,
Model,
RateModel,
RateType,
)


@pytest.fixture()
def four_otu(DATA_DIR):
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Chimpanzee", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)


def check_build_tree(
four_otu,
dna_model,
freq_type=None,
invariable_sites=None,
rate_model=None,
):
four_otu: ArrayAlignment,
dna_model: DnaModel,
freq_type: Optional[FreqType] = None,
invariable_sites: Optional[bool] = None,
rate_model: Optional[RateModel] = None,
) -> None:
expected = make_tree("(Human,Chimpanzee,(Rhesus,Mouse));")

model = Model(
Expand All @@ -46,15 +42,15 @@ def check_build_tree(
@pytest.mark.parametrize("dna_model", list(DnaModel)[:22])
@pytest.mark.parametrize("freq_type", list(FreqType))
def test_non_lie_build_tree(
four_otu,
dna_model,
freq_type,
):
four_otu: ArrayAlignment,
dna_model: DnaModel,
freq_type: FreqType,
) -> None:
check_build_tree(four_otu, dna_model, freq_type)


@pytest.mark.parametrize("dna_model", list(DnaModel)[22:])
def test_lie_build_tree(four_otu, dna_model):
def test_lie_build_tree(four_otu: ArrayAlignment, dna_model: DnaModel) -> None:
check_build_tree(four_otu, dna_model)


Expand All @@ -70,7 +66,12 @@ def test_lie_build_tree(four_otu, dna_model):
FreeRateModel(6),
],
)
def test_rate_model_build_tree(four_otu, dna_model, invariable_sites, rate_model):
def test_rate_model_build_tree(
four_otu: ArrayAlignment,
dna_model: DnaModel,
invariable_sites: Optional[bool],
rate_model: RateModel,
) -> None:
check_build_tree(
four_otu,
dna_model,
Expand Down
11 changes: 2 additions & 9 deletions tests/test_iqtree/test_fit_tree.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import piqtree2
import pytest
from cogent3 import get_app, load_aligned_seqs, make_tree
from cogent3 import ArrayAlignment, get_app, make_tree
from piqtree2.model import DnaModel, Model


@pytest.fixture()
def three_otu(DATA_DIR):
aln = load_aligned_seqs(DATA_DIR / "example.fasta", moltype="dna")
aln = aln.take_seqs(["Human", "Rhesus", "Mouse"])
return aln.omit_gap_pos(allowed_gap_frac=0)


@pytest.mark.parametrize(
("iq_model", "c3_model"),
[
Expand All @@ -22,7 +15,7 @@ def three_otu(DATA_DIR):
(DnaModel.F81, "F81"),
],
)
def test_fit_tree(three_otu, iq_model, c3_model):
def test_fit_tree(three_otu: ArrayAlignment, iq_model: DnaModel, c3_model: str) -> None:
tree_topology = make_tree(tip_names=three_otu.names)
app = get_app("model", c3_model, tree=tree_topology)
expected = app(three_otu)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_iqtree/test_robinson_foulds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from numpy.testing import assert_array_equal


def test_robinson_foulds():
def test_robinson_foulds() -> None:
tree1 = "(A,B,(C,D));"
tree2 = "(A,C,(B,D));"
pairwise_distances = piqtree2.robinson_foulds(tree1, tree2)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_iqtree/test_segmentation_fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from cogent3 import make_aligned_seqs, make_tree
from piqtree2 import TreeGenMode, build_tree, fit_tree, random_trees
from piqtree2.exceptions import IqTreeError
from piqtree2.model import DiscreteGammaModel, DnaModel, FreeRateModel, Model
from piqtree2.model import DiscreteGammaModel, DnaModel, FreeRateModel, Model, RateModel


def test_two_build_random_trees():
def test_two_build_random_trees() -> None:
"""
Calling build tree twice followed by random trees with a bad input
used to result in a Segmentation Fault in a previous version.
Expand All @@ -21,7 +21,7 @@ def test_two_build_random_trees():
random_trees(2, TreeGenMode.BALANCED, 3, 1)


def test_two_fit_random_trees():
def test_two_fit_random_trees() -> None:
"""
Calling fit tree twice followed by random trees with a bad input
used to result in a Segmentation Fault in a previous version.
Expand All @@ -36,17 +36,17 @@ def test_two_fit_random_trees():
random_trees(2, TreeGenMode.BALANCED, 3, 1)


@pytest.mark.parametrize("rate_type_class", [DiscreteGammaModel, FreeRateModel])
@pytest.mark.parametrize("rate_model_class", [DiscreteGammaModel, FreeRateModel])
@pytest.mark.parametrize("categories", [0, -4])
def test_two_invalid_models(rate_type_class, categories):
def test_two_invalid_models(rate_model_class: type[RateModel], categories: int) -> None:
"""
Calling build_tree multiple times with an invalid
model has resulted in a Segmentation Fault.
"""
aln = make_aligned_seqs({"a": "GGG", "b": "GGC", "c": "AAC", "d": "AAA"})

with pytest.raises(IqTreeError):
_ = build_tree(aln, Model(DnaModel.JC, rate_type=rate_type_class(categories)))
_ = build_tree(aln, Model(DnaModel.JC, rate_type=rate_model_class(categories)))

with pytest.raises(IqTreeError):
_ = build_tree(aln, Model(DnaModel.JC, rate_type=rate_type_class(categories)))
_ = build_tree(aln, Model(DnaModel.JC, rate_type=rate_model_class(categories)))
6 changes: 4 additions & 2 deletions tests/test_iqtree/test_tree_yaml.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any

import pytest
from piqtree2.exceptions import ParseIqTreeError
from piqtree2.iqtree._tree import _process_tree_yaml


@pytest.fixture()
def newick_not_in_candidates():
def newick_not_in_candidates() -> dict[str, Any]:
# The newick string does not appear in the CandidateSet
return {
"CandidateSet": {
Expand Down Expand Up @@ -36,6 +38,6 @@ def newick_not_in_candidates():
}


def test_newick_not_in_candidates(newick_not_in_candidates):
def test_newick_not_in_candidates(newick_not_in_candidates: dict[str, Any]) -> None:
with pytest.raises(ParseIqTreeError):
_ = _process_tree_yaml(newick_not_in_candidates, ["a", "b", "c"])
4 changes: 2 additions & 2 deletions tests/test_model/test_freq_type.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from piqtree2.model import FreqType


def test_number_of_descriptions():
def test_number_of_descriptions() -> None:
assert len(FreqType) == len(FreqType._descriptions())


def test_descriptions_exist():
def test_descriptions_exist() -> None:
for freq_type in FreqType:
# Raises an error if description not present
_ = freq_type.description
18 changes: 13 additions & 5 deletions tests/test_model/test_options.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
# testing the display of functions
from typing import Optional

import pytest
from piqtree2 import available_freq_type, available_models, available_rate_type
from piqtree2.model import AaModel, DnaModel, FreqType
from piqtree2.model import AaModel, DnaModel, FreqType, SubstitutionModel
from piqtree2.model._rate_type import ALL_BASE_RATE_TYPES


@pytest.mark.parametrize(
("model_class", "model_type"),
[(None, None), (DnaModel, "dna"), (AaModel, "protein")],
)
def test_num_available_models(model_class, model_type):
def test_num_available_models(
model_class: Optional[SubstitutionModel],
model_type: Optional[str],
) -> None:
table = available_models(model_type)
total_models = (
len(DnaModel) + len(AaModel) if model_class is None else len(model_class)
Expand All @@ -22,7 +27,10 @@ def test_num_available_models(model_class, model_type):
("model_fetch", "model_type"),
[(None, None), ("dna", "nucleotide"), ("protein", "protein")],
)
def test_available_models_types(model_fetch, model_type):
def test_available_models_types(
model_fetch: Optional[str],
model_type: Optional[str],
) -> None:
table = available_models(model_fetch)

if model_type is None:
Expand All @@ -33,15 +41,15 @@ def test_available_models_types(model_fetch, model_type):
assert check_model_type[0] == model_type


def test_num_freq_type():
def test_num_freq_type() -> None:
table = available_freq_type()
total_freq_types = len(FreqType)

assert total_freq_types > 0
assert table.shape[0] == total_freq_types


def test_num_rate_type():
def test_num_rate_type() -> None:
table = available_rate_type()
total_rate_types = len(ALL_BASE_RATE_TYPES)

Expand Down
12 changes: 7 additions & 5 deletions tests/test_model/test_rate_type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional

import pytest
from piqtree2.model import DiscreteGammaModel, FreeRateModel, RateModel, RateType


def test_rate_model_uninstantiable():
def test_rate_model_uninstantiable() -> None:
with pytest.raises(TypeError):
_ = RateModel()

Expand All @@ -27,10 +29,10 @@ def test_rate_model_uninstantiable():
],
)
def test_invariable_sites(
invariable_sites,
rate_model,
iqtree_str,
):
invariable_sites: bool,
rate_model: Optional[RateModel],
iqtree_str: str,
) -> None:
model = RateType(invariable_sites=invariable_sites, model=rate_model)
assert model.iqtree_str() == iqtree_str

Expand Down
Loading

0 comments on commit e1176ef

Please sign in to comment.