diff --git a/tests/test_formats.py b/tests/test_formats.py index 270d9438..b4cdb4f8 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -1279,16 +1279,53 @@ def test_bad_format_version(self): with sd.copy() as copy: copy.data.attrs[tsinfer.FORMAT_VERSION_KEY] = 100, 0 - def test_ancestral_allele(self): + @pytest.mark.parametrize( + "position, genotypes, alleles, ancestral_allele, expected_ancestral_state, " + "expected_recode_alleles, expected_genotypes", + [ + (0, [0, 0, 1, 2], ["A", "B", "C"], 2, "C", ("C", "A", "B"), [1, 1, 2, 0]), + ( + 0, + [0, 1, 2, 3], + ["A", "C", "G", "T"], + 0, + "A", + ("A", "C", "G", "T"), + [0, 1, 2, 3], + ), + (0, [0, -1, 1, 2], ["A", "B", "C"], 0, "A", ("A", "B", "C"), [0, -1, 1, 2]), + (0, [0, -1, 1, 2], ["A", "B", "C"], 2, "C", ("C", "A", "B"), [1, -1, 2, 0]), + (0, [0, 0, 0, 0], ["A"], 0, "A", ("A",), [0, 0, 0, 0]), + (0, [-1, -1, -1, -1], ["A", "B"], 0, "A", ("A", "B"), [-1, -1, -1, -1]), + ], + ) + def test_ancestral_allele( + self, + position, + genotypes, + alleles, + ancestral_allele, + expected_ancestral_state, + expected_recode_alleles, + expected_genotypes, + ): with tsinfer.SampleData() as sd: - sd.add_site(0, [0, 0, 1, 2], alleles=["A", "B", "C"], ancestral_allele=2) + sd.add_site( + position, genotypes, alleles=alleles, ancestral_allele=ancestral_allele + ) v = next(sd.variants(recode_ancestral=True)) - assert v.site.alleles == ("A", "B", "C") - assert v.site.ancestral_allele == 2 - assert v.site.ancestral_state == "C" - assert v.alleles == ("C", "A", "B") - assert list(v.genotypes) == [1, 1, 2, 0] - assert [h[0] for _, h in sd.haplotypes(recode_ancestral=True)] == [1, 1, 2, 0] + assert v.site.alleles == tuple(alleles + [None] if -1 in genotypes else alleles) + assert v.site.ancestral_allele == ancestral_allele + assert v.site.ancestral_state == expected_ancestral_state + assert v.alleles == tuple( + list(expected_recode_alleles) + [None] + if -1 in genotypes + else expected_recode_alleles + ) + assert list(v.genotypes) == expected_genotypes + assert [ + h[0] for _, h in sd.haplotypes(recode_ancestral=True) + ] == expected_genotypes def test_missing_ancestral_allele(self): with tsinfer.SampleData() as sd: