diff --git a/prymer/offtarget/offtarget_detector.py b/prymer/offtarget/offtarget_detector.py index d1c7868..71d3201 100644 --- a/prymer/offtarget/offtarget_detector.py +++ b/prymer/offtarget/offtarget_detector.py @@ -75,6 +75,7 @@ """ # noqa: E501 import itertools +from collections import defaultdict from contextlib import AbstractContextManager from dataclasses import dataclass from dataclasses import field @@ -83,6 +84,7 @@ from types import TracebackType from typing import Optional from typing import Self +from typing import TypeAlias from typing import TypeVar from ordered_set import OrderedSet @@ -90,6 +92,7 @@ from prymer.api.oligo import Oligo from prymer.api.primer_pair import PrimerPair from prymer.api.span import Span +from prymer.api.span import Strand from prymer.offtarget.bwa import BWA_EXECUTABLE_NAME from prymer.offtarget.bwa import BwaAlnInteractive from prymer.offtarget.bwa import BwaHit @@ -98,6 +101,9 @@ PrimerType = TypeVar("PrimerType", bound=Oligo) +ReferenceName: TypeAlias = str +"""Alias for a reference sequence name.""" + @dataclass(init=True, frozen=True) class OffTargetResult: @@ -334,27 +340,78 @@ def _build_off_target_result( result: OffTargetResult # Get the mappings for the left primer and right primer respectively - p1: BwaResult = hits_by_primer[primer_pair.left_primer.bases] - p2: BwaResult = hits_by_primer[primer_pair.right_primer.bases] - - # Get all possible amplicons from the left_primer_mappings and right_primer_mappings - # primer hits, filtering if there are too many for either - if p1.hit_count > self._max_primer_hits or p2.hit_count > self._max_primer_hits: + left_bwa_result: BwaResult = hits_by_primer[primer_pair.left_primer.bases] + right_bwa_result: BwaResult = hits_by_primer[primer_pair.right_primer.bases] + + # If there are too many hits, this primer pair will not pass. Exit early. + if ( + left_bwa_result.hit_count > self._max_primer_hits + or right_bwa_result.hit_count > self._max_primer_hits + ): result = OffTargetResult(primer_pair=primer_pair, passes=False) - else: - amplicons = self._to_amplicons(p1.hits, p2.hits, self._max_amplicon_size) - result = OffTargetResult( - primer_pair=primer_pair, - passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits, - spans=amplicons if self._keep_spans else [], - left_primer_spans=( - [self._hit_to_span(h) for h in p1.hits] if self._keep_primer_spans else [] - ), - right_primer_spans=( - [self._hit_to_span(h) for h in p2.hits] if self._keep_primer_spans else [] - ), + if self._cache_results: + self._primer_pair_cache[primer_pair] = replace(result, cached=True) + return result + + # Map the hits by reference name + left_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + left_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + right_positive_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + right_negative_hits: defaultdict[ReferenceName, list[BwaHit]] = defaultdict(list) + + # Split the hits for left and right by reference name and strand + for hit in left_bwa_result.hits: + if hit.negative: + left_negative_hits[hit.refname].append(hit) + else: + left_positive_hits[hit.refname].append(hit) + + for hit in right_bwa_result.hits: + if hit.negative: + right_negative_hits[hit.refname].append(hit) + else: + right_positive_hits[hit.refname].append(hit) + + refnames: set[ReferenceName] = { + h.refname for h in itertools.chain(left_bwa_result.hits, right_bwa_result.hits) + } + + # Build amplicons from hits on the same reference with valid relative orientation + amplicons: list[Span] = [] + for refname in refnames: + amplicons.extend( + self._to_amplicons( + positive_hits=left_positive_hits[refname], + negative_hits=right_negative_hits[refname], + max_len=self._max_amplicon_size, + strand=Strand.POSITIVE, + ) + ) + amplicons.extend( + self._to_amplicons( + positive_hits=right_positive_hits[refname], + negative_hits=left_negative_hits[refname], + max_len=self._max_amplicon_size, + strand=Strand.NEGATIVE, + ) ) + result = OffTargetResult( + primer_pair=primer_pair, + passes=self._min_primer_pair_hits <= len(amplicons) <= self._max_primer_pair_hits, + spans=amplicons if self._keep_spans else [], + left_primer_spans=( + [self._hit_to_span(h) for h in left_bwa_result.hits] + if self._keep_primer_spans + else [] + ), + right_primer_spans=( + [self._hit_to_span(h) for h in right_bwa_result.hits] + if self._keep_primer_spans + else [] + ), + ) + if self._cache_results: self._primer_pair_cache[primer_pair] = replace(result, cached=True) @@ -410,19 +467,89 @@ def mappings_of(self, primers: list[PrimerType]) -> dict[str, BwaResult]: return hits_by_primer @staticmethod - def _to_amplicons(lefts: list[BwaHit], rights: list[BwaHit], max_len: int) -> list[Span]: - """Takes a set of hits for one or more left primers and right primers and constructs - amplicon mappings anywhere a left primer hit and a right primer hit align in F/R - orientation up to `maxLen` apart on the same reference. Primers may not overlap. + def _to_amplicons( + positive_hits: list[BwaHit], negative_hits: list[BwaHit], max_len: int, strand: Strand + ) -> list[Span]: + """Takes lists of positive strand hits and negative strand hits and constructs amplicon + mappings anywhere a positive strand hit and a negative strand hit occur where the end of + the negative strand hit is no more than `max_len` from the start of the positive strand + hit. + + Primers may not overlap. + + Args: + positive_hits: List of hits on the positive strand for one of the primers in the pair. + negative_hits: List of hits on the negative strand for the other primer in the pair. + max_len: Maximum length of amplicons to consider. + strand: The strand of the amplicon to generate. Set to Strand.POSITIVE if + `positive_hits` are for the left primer and `negative_hits` are for the right + primer. Set to Strand.NEGATIVE if `positive_hits` are for the right primer and + `negative_hits` are for the left primer. + + Raises: + ValueError: If any of the positive hits are not on the positive strand, or any of the + negative hits are not on the negative strand. If hits are present on more than one + reference. """ + if any(h.negative for h in positive_hits): + raise ValueError("Positive hits must be on the positive strand.") + if any(not h.negative for h in negative_hits): + raise ValueError("Negative hits must be on the negative strand.") + + refnames: set[ReferenceName] = { + h.refname for h in itertools.chain(positive_hits, negative_hits) + } + if len(refnames) > 1: + raise ValueError(f"Hits are present on more than one reference: {refnames}") + + # Exit early if one of the hit lists is empty - this will save unnecessary sorting of the + # other list + if len(positive_hits) == 0 or len(negative_hits) == 0: + return [] + + # Sort the positive strand hits by start position and the negative strand hits by *end* + # position. The `max_len` cutoff is based on negative_hit.end - positive_hit.start + 1. + positive_hits_sorted = sorted(positive_hits, key=lambda h: h.start) + negative_hits_sorted = sorted(negative_hits, key=lambda h: h.end) + amplicons: list[Span] = [] - for h1, h2 in itertools.product(lefts, rights): - if h1.negative == h2.negative or h1.refname != h2.refname: # not F/R orientation - continue - plus, minus = (h2, h1) if h1.negative else (h1, h2) - if minus.start > plus.end and (minus.end - plus.start + 1) <= max_len: - amplicons.append(Span(refname=plus.refname, start=plus.start, end=minus.end)) + # Track the position of the previously examined negative hit. + prev_negative_hit_index = 0 + for positive_hit in positive_hits_sorted: + # Check only negative hits starting with the previously examined one. + for negative_hit_index, negative_hit in enumerate( + negative_hits_sorted[prev_negative_hit_index:], + start=prev_negative_hit_index, + ): + # TODO: Consider allowing overlapping positive and negative hits. + if ( + negative_hit.start > positive_hit.end + and negative_hit.end - positive_hit.start + 1 <= max_len + ): + # If the negative hit starts to the right of the positive hit, and the amplicon + # length is <= max_len, add it to the list of amplicon hits to be returned. + amplicons.append( + Span( + refname=positive_hit.refname, + start=positive_hit.start, + end=negative_hit.end, + strand=strand, + ) + ) + + if negative_hit.end - positive_hit.start + 1 > max_len: + # Stop searching for negative hits to pair with this positive hit. + # All subsequence negative hits will have amplicon length > max_len + break + + if negative_hit.end < positive_hit.start: + # This positive hit is genomically right of the current negative hit. + # All subsequent positive hits will also be genomically right of this negative + # hit, so we should start at the one after this. If this index is past the end + # of the list, the slice `negative_hits_sorted[prev_negative_hit_index:]` will + # be empty. + prev_negative_hit_index = negative_hit_index + 1 return amplicons diff --git a/tests/offtarget/test_offtarget.py b/tests/offtarget/test_offtarget.py index f8f5f8c..a2f3a46 100644 --- a/tests/offtarget/test_offtarget.py +++ b/tests/offtarget/test_offtarget.py @@ -11,6 +11,7 @@ from prymer.offtarget.bwa import BWA_EXECUTABLE_NAME from prymer.offtarget.bwa import BwaHit from prymer.offtarget.bwa import BwaResult +from prymer.offtarget.bwa import Query from prymer.offtarget.offtarget_detector import OffTargetDetector from prymer.offtarget.offtarget_detector import OffTargetResult @@ -171,68 +172,159 @@ def test_mappings_of(ref_fasta: Path, cache_results: bool) -> None: assert results_dict[p2.bases].hits[0] == expected_hit2 +# Test building an OffTargetResult for a primer pair with left/right hits on different references +# and in different orientations +def test_build_off_target_result(ref_fasta: Path) -> None: + hits_by_primer: dict[str, BwaResult] = { + "A" * 100: BwaResult( + query=Query( + id="left", + bases="A" * 100, + ), + hit_count=3, + hits=[ + BwaHit.build("chr1", 100, False, "100M", 0), + BwaHit.build("chr1", 400, True, "100M", 0), + BwaHit.build("chr2", 100, False, "100M", 0), + BwaHit.build("chr3", 700, True, "100M", 0), + ], + ), + "C" * 100: BwaResult( + query=Query( + id="right", + bases="C" * 100, + ), + hit_count=2, + hits=[ + BwaHit.build("chr1", 800, False, "100M", 0), + BwaHit.build("chr1", 200, True, "100M", 0), + BwaHit.build("chr3", 600, False, "100M", 0), + ], + ), + } + + primer_pair = PrimerPair( + left_primer=Oligo( + tm=50, + penalty=0, + span=Span(refname="chr10", start=100, end=199, strand=Strand.POSITIVE), + bases="A" * 100, + ), + right_primer=Oligo( + tm=50, + penalty=0, + span=Span(refname="chr10", start=300, end=399, strand=Strand.NEGATIVE), + bases="C" * 100, + ), + amplicon_tm=100, + penalty=0, + ) + + with _build_detector( + ref_fasta=ref_fasta, max_primer_hits=10, max_primer_pair_hits=10 + ) as detector: + off_target_result: OffTargetResult = detector._build_off_target_result( + primer_pair=primer_pair, + hits_by_primer=hits_by_primer, + ) + + assert set(off_target_result.spans) == { + Span(refname="chr1", start=100, end=299, strand=Strand.POSITIVE), + Span(refname="chr3", start=600, end=799, strand=Strand.NEGATIVE), + } + + # Test that using the cache (or not) does not affect the results @pytest.mark.parametrize("cache_results", [True, False]) @pytest.mark.parametrize( - "test_id, left, right, expected", + "test_id, positive, negative, strand, expected", [ - ( - "No mappings - different refnames", - BwaHit.build("chr1", 100, False, "100M", 0), - BwaHit.build("chr2", 100, True, "100M", 0), - [], - ), - ( - "No mappings - FF pair", - BwaHit.build("chr1", 100, True, "100M", 0), - BwaHit.build("chr1", 100, True, "100M", 0), - [], - ), - ( - "No mappings - RR pair", - BwaHit.build("chr1", 100, False, "100M", 0), - BwaHit.build("chr1", 100, False, "100M", 0), - [], - ), ( "No mappings - overlapping primers (1bp overlap)", BwaHit.build("chr1", 100, False, "100M", 0), BwaHit.build("chr1", 199, True, "100M", 0), + Strand.POSITIVE, [], ), ( "No mappings - amplicon size too big (1bp too big)", BwaHit.build("chr1", 100, False, "100M", 0), BwaHit.build("chr1", 151, True, "100M", 0), + Strand.POSITIVE, [], ), ( "Mappings - FR pair (R1 F)", BwaHit.build("chr1", 100, False, "100M", 0), BwaHit.build("chr1", 200, True, "100M", 0), - [Span(refname="chr1", start=100, end=299)], + Strand.POSITIVE, + [Span(refname="chr1", start=100, end=299, strand=Strand.POSITIVE)], ), ( "Mappings - FR pair (R1 R)", - BwaHit.build("chr1", 200, True, "100M", 0), BwaHit.build("chr1", 100, False, "100M", 0), - [Span(refname="chr1", start=100, end=299)], + BwaHit.build("chr1", 200, True, "100M", 0), + Strand.NEGATIVE, + [Span(refname="chr1", start=100, end=299, strand=Strand.NEGATIVE)], ), ], ) def test_to_amplicons( ref_fasta: Path, test_id: str, - left: BwaHit, - right: BwaHit, + positive: BwaHit, + negative: BwaHit, + strand: Strand, expected: list[Span], cache_results: bool, ) -> None: with _build_detector(ref_fasta=ref_fasta, cache_results=cache_results) as detector: - actual = detector._to_amplicons(lefts=[left], rights=[right], max_len=250) + actual = detector._to_amplicons( + positive_hits=[positive], negative_hits=[negative], max_len=250, strand=strand + ) assert actual == expected, test_id +@pytest.mark.parametrize("cache_results", [True, False]) +@pytest.mark.parametrize( + "positive, negative, expected_error", + [ + ( + # No mappings - different refnames + BwaHit.build("chr1", 100, False, "100M", 0), + BwaHit.build("chr2", 100, True, "100M", 0), + "Hits are present on more than one reference", + ), + ( + # No mappings - FF pair + BwaHit.build("chr1", 100, True, "100M", 0), + BwaHit.build("chr1", 100, True, "100M", 0), + "Positive hits must be on the positive strand", + ), + ( + # No mappings - RR pair + BwaHit.build("chr1", 100, False, "100M", 0), + BwaHit.build("chr1", 100, False, "100M", 0), + "Negative hits must be on the negative strand", + ), + ], +) +def test_to_amplicons_value_error( + ref_fasta: Path, + positive: BwaHit, + negative: BwaHit, + expected_error: str, + cache_results: bool, +) -> None: + with ( + _build_detector(ref_fasta=ref_fasta, cache_results=cache_results) as detector, + pytest.raises(ValueError, match=expected_error), + ): + detector._to_amplicons( + positive_hits=[positive], negative_hits=[negative], max_len=250, strand=Strand.POSITIVE + ) + + def test_generic_filter(ref_fasta: Path) -> None: """ This test isn't intended to validate any runtime assertions, but is a minimal example for the