diff --git a/src/pgmap/alignment/grna_cached_aligner.py b/src/pgmap/alignment/grna_cached_aligner.py new file mode 100644 index 0000000..246fc56 --- /dev/null +++ b/src/pgmap/alignment/grna_cached_aligner.py @@ -0,0 +1,63 @@ +import itertools + +from typing import Iterable + + +def construct_grna_error_alignment_cache(gRNAs: Iterable[str], gRNA_error_tolerance: int) -> dict[str, tuple[str, int]]: + """ + Construct an alignment cache object containing all gRNAs with the error tolerance amount of substitutions. The + number of gRNAs within the error tolerance grows exponentially. As such this function should only be used for + error tolerances of 0, 1, or 2. + + Args: + gRNAs (Iterable[str]): An iterable producing all the gRNAs to construct the alignment cache from. + gRNA_error_tolerance (int): The error tolerance used to create gRNA alignment candidates. + + Returns: + alignment_cache (dict[str, tuple[str, int]]): A mapping from each valid alignment string to a tuple of the + reference alignment sequence and the hamming distance from the reference. Guarantees a minimum alignment, + though there could be multiple best alignments depending on the gRNAs. + + Raises: + ValueError: if gRNA_error_tolerance is not 0, 1, or 2. + """ + + if gRNA_error_tolerance > 2 or gRNA_error_tolerance < 0: + raise ValueError( + "gRNA error tolerance must be 0, 1, or 2 but was " + str(gRNA_error_tolerance)) + + alignment_cache = {} + + if gRNA_error_tolerance: + for gRNA in gRNAs: + # go from high subs to low subs to prefer better alignments + for num_substitutions in reversed(range(1, gRNA_error_tolerance + 1)): + alignment_cache.update( + {mutation: (gRNA, num_substitutions) for mutation in _get_mutations(gRNA, num_substitutions)}) + + alignment_cache.update({gRNA: (gRNA, 0) + for gRNA in gRNAs}) # prefer perfect alignments + + return alignment_cache + + +def _get_mutations(gRNA: str, num_substitutions: int) -> Iterable[str]: + for substitution_indices in _get_all_substitution_indices(gRNA, num_substitutions): + yield from _generate_substitutions(gRNA, substitution_indices) + + +def _get_all_substitution_indices(gRNA: str, num_substitutions: int) -> Iterable[tuple[int]]: + yield from itertools.combinations(range(len(gRNA)), num_substitutions) + + +def _generate_substitutions(gRNA: str, substitution_indices: Iterable[int]) -> Iterable[str]: + for substitutions in itertools.product("ATCG", repeat=len(substitution_indices)): + if any(gRNA[i] == substitution for i, substitution in zip(substitution_indices, substitutions)): + continue + + bases = list(gRNA) + + for i, substitution in zip(substitution_indices, substitutions): + bases[i] = substitution + + yield "".join(bases) diff --git a/src/pgmap/cli.py b/src/pgmap/cli.py index f6c4b56..be7d09d 100644 --- a/src/pgmap/cli.py +++ b/src/pgmap/cli.py @@ -23,7 +23,7 @@ def get_counts(args: argparse.Namespace): candidate_reads = read_trimmer.two_read_trim(*args.fastq) paired_guide_counts = counter.get_counts( - candidate_reads, gRNA_mappings, barcodes, gRNA2_error_tolerance=args.gRNA2_error, barcode_error_tolerance=args.barcode_error) + candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=args.gRNA1_error, gRNA2_error_tolerance=args.gRNA2_error, barcode_error_tolerance=args.barcode_error) counts_writer.write_counts( args.output, paired_guide_counts, barcodes, id_mapping) @@ -46,18 +46,44 @@ def _parse_args(args: list[str]) -> argparse.Namespace: # TODO support arbitrary trim strategies parser.add_argument("--trim-strategy", required=True, choices=(TWO_READ_STRATEGY, THREE_READ_STRATEGY), help="The trim strategy used to extract guides and barcodes. The two read strategy should have fastqs R1 and I1. The three read strategy should have fastqs R1, I1, and I2") # TODO extract consts - parser.add_argument("--gRNA2-error", required=False, default=2, type=_check_nonnegative, - help="The number of substituted base pairs to allow in gRNA2.") - parser.add_argument("--barcode-error", required=False, default=2, type=_check_nonnegative, + parser.add_argument("--gRNA1-error", required=False, default=1, type=_check_gRNA1_error, + help="The number of substituted base pairs to allow in gRNA1. Must be less than 3.") + parser.add_argument("--gRNA2-error", required=False, default=1, type=_check_gRNA2_error, + help="The number of substituted base pairs to allow in gRNA2. Must be less than 3.") + parser.add_argument("--barcode-error", required=False, default=1, type=_check_barcode_error, help="The number of insertions, deletions, and subsititions of base pairs to allow in the barcodes.") return parser.parse_args(args) -def _check_nonnegative(value: str) -> int: +def _check_gRNA1_error(value: str) -> int: int_value = int(value) if int_value < 0: - raise ValueError(f"Count must be nonnegative but was {value}") + raise ValueError(f"gRNA1-error must be nonnegative but was {value}") + + if int_value > 2: + raise ValueError(f"gRNA1-error must be less than 3 but was {value}") + + return int_value + + +def _check_gRNA2_error(value: str) -> int: + int_value = int(value) + + if int_value < 0: + raise ValueError(f"gRNA2-error must be nonnegative but was {value}") + + if int_value > 2: + raise ValueError(f"gRNA2-error must be less than 3 but was {value}") + + return int_value + + +def _check_barcode_error(value: str) -> int: + int_value = int(value) + + if int_value < 0: + raise ValueError(f"barcode-error must be nonnegative but was {value}") return int_value diff --git a/src/pgmap/counter/counter.py b/src/pgmap/counter/counter.py index 17de1e0..f6a9844 100644 --- a/src/pgmap/counter/counter.py +++ b/src/pgmap/counter/counter.py @@ -1,15 +1,17 @@ from collections import Counter from typing import Counter, Iterable +import itertools from pgmap.model.paired_read import PairedRead -from pgmap.alignment import pairwise_aligner +from pgmap.alignment import pairwise_aligner, grna_cached_aligner def get_counts(paired_reads: Iterable[PairedRead], gRNA_mappings: dict[str, set[str]], barcodes: set[str], - gRNA2_error_tolerance: int = 2, - barcode_error_tolerance: int = 2) -> Counter[tuple[str, str, str]]: + gRNA1_error_tolerance: int = 1, + gRNA2_error_tolerance: int = 1, + barcode_error_tolerance: int = 1) -> Counter[tuple[str, str, str]]: """ Count paired guides for each sample barcode with tolerance for errors. gRNA1 matchs only through perfect alignment. gRNA2 aligns if there is a match of a known (gRNA1, gRNA2) pairing having hamming distance @@ -38,16 +40,26 @@ def get_counts(paired_reads: Iterable[PairedRead], paired_guide_counts = Counter() + gRNA1_cached_aligner = grna_cached_aligner.construct_grna_error_alignment_cache( + gRNA_mappings.keys(), gRNA1_error_tolerance) + + gRNA2_cached_aligner = grna_cached_aligner.construct_grna_error_alignment_cache( + set(itertools.chain.from_iterable(gRNA_mappings.values())), gRNA2_error_tolerance) + for paired_read in paired_reads: - gRNA1 = paired_read.gRNA1_candidate + paired_read.gRNA1_candidate + + if paired_read.gRNA1_candidate not in gRNA1_cached_aligner: + continue + + gRNA1, _ = gRNA1_cached_aligner[paired_read.gRNA1_candidate] - if gRNA1 not in gRNA_mappings: + if paired_read.gRNA2_candidate not in gRNA2_cached_aligner: continue - gRNA2_score, gRNA2 = max((pairwise_aligner.hamming_score(paired_read.gRNA2_candidate, reference), reference) - for reference in gRNA_mappings[paired_read.gRNA1_candidate]) + gRNA2, _ = gRNA2_cached_aligner[paired_read.gRNA2_candidate] - if (len(gRNA2) - gRNA2_score) > gRNA2_error_tolerance: + if gRNA2 not in gRNA_mappings[gRNA1]: continue barcode_score, barcode = max((pairwise_aligner.edit_distance_score( diff --git a/tests/__main__.py b/tests/__main__.py index 4111804..57cbebf 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -7,7 +7,7 @@ from pgmap.io import barcode_reader, fastx_reader, library_reader from pgmap.trimming import read_trimmer -from pgmap.alignment import pairwise_aligner +from pgmap.alignment import pairwise_aligner, grna_cached_aligner from pgmap.counter import counter from pgmap.model.paired_read import PairedRead from pgmap import cli @@ -142,6 +142,89 @@ def test_blast_aligner_score(self): self.assertEqual( pairwise_aligner.blast_aligner_score("ABC", "ABC"), 3) + def test_grna_cached_aligner_error_tolerance_0(self): + gRNAs = ["AAA"] + + alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache( + gRNAs, 0) + + self.assertEqual(alignment_cache, {"AAA": ("AAA", 0)}) + + def test_grna_cached_aligner_error_tolerance_1(self): + gRNAs = ["AAA"] + + alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache( + gRNAs, 1) + + self.assertEqual(alignment_cache, {"AAA": ("AAA", 0), + "CAA": ("AAA", 1), + "ACA": ("AAA", 1), + "AAC": ("AAA", 1), + "TAA": ("AAA", 1), + "ATA": ("AAA", 1), + "AAT": ("AAA", 1), + "GAA": ("AAA", 1), + "AGA": ("AAA", 1), + "AAG": ("AAA", 1)}) + + def test_grna_cached_aligner_error_tolerance_2(self): + gRNAs = ["AAA"] + + alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache( + gRNAs, 2) + + self.assertEqual(alignment_cache, {"AAA": ("AAA", 0), + "AAC": ("AAA", 1), + "AAG": ("AAA", 1), + "AAT": ("AAA", 1), + "ACA": ("AAA", 1), + "ACC": ("AAA", 2), + "ACG": ("AAA", 2), + "ACT": ("AAA", 2), + "AGA": ("AAA", 1), + "AGC": ("AAA", 2), + "AGG": ("AAA", 2), + "AGT": ("AAA", 2), + "ATA": ("AAA", 1), + "ATC": ("AAA", 2), + "ATG": ("AAA", 2), + "ATT": ("AAA", 2), + "CAA": ("AAA", 1), + "CAC": ("AAA", 2), + "CAG": ("AAA", 2), + "CAT": ("AAA", 2), + "CCA": ("AAA", 2), + "CGA": ("AAA", 2), + "CTA": ("AAA", 2), + "GAA": ("AAA", 1), + "GAC": ("AAA", 2), + "GAG": ("AAA", 2), + "GAT": ("AAA", 2), + "GCA": ("AAA", 2), + "GGA": ("AAA", 2), + "GTA": ("AAA", 2), + "TAA": ("AAA", 1), + "TAC": ("AAA", 2), + "TAG": ("AAA", 2), + "TAT": ("AAA", 2), + "TCA": ("AAA", 2), + "TGA": ("AAA", 2), + "TTA": ("AAA", 2)}) + + def test_grna_cached_aligner_negative_error_tolerance(self): + gRNAs = ["AAA"] + + with self.assertRaises(ValueError): + grna_cached_aligner.construct_grna_error_alignment_cache( + gRNAs, -1) + + def test_grna_cached_aligner_error_tolerance_greater_than_1(self): + gRNAs = ["AAA"] + + with self.assertRaises(ValueError): + grna_cached_aligner.construct_grna_error_alignment_cache( + gRNAs, 3) + def test_counter_no_error_tolerance(self): barcodes = barcode_reader.read_barcodes(TWO_READ_BARCODES_PATH) gRNA1s, gRNA2s, gRNA_mappings = library_reader.read_paired_guide_library_fastas( @@ -151,7 +234,7 @@ def test_counter_no_error_tolerance(self): TWO_READ_R1_PATH, TWO_READ_I1_PATH) paired_guide_counts = counter.get_counts( - candidate_reads, gRNA_mappings, barcodes, gRNA2_error_tolerance=0, barcode_error_tolerance=0) + candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=0, gRNA2_error_tolerance=0, barcode_error_tolerance=0) perfect_alignments = 0 @@ -188,18 +271,44 @@ def test_counter_default_error_tolerance(self): self.assertLess( sum(paired_guide_counts.values()), count) + def test_counter_max_error_tolerance(self): + barcodes = barcode_reader.read_barcodes(TWO_READ_BARCODES_PATH) + gRNA1s, gRNA2s, gRNA_mappings = library_reader.read_paired_guide_library_fastas( + "example-data/pgPEN-library/pgPEN_R1.fa", "example-data/pgPEN-library/pgPEN_R2.fa") + + candidate_reads = read_trimmer.two_read_trim( + TWO_READ_R1_PATH, TWO_READ_I1_PATH) + + paired_guide_counts = counter.get_counts( + candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=2, gRNA2_error_tolerance=2) + + count = 0 + perfect_alignments = 0 + + for paired_read in read_trimmer.two_read_trim(TWO_READ_R1_PATH, TWO_READ_I1_PATH): + count += 1 + + if paired_read.gRNA1_candidate in gRNA1s and paired_read.gRNA2_candidate in gRNA_mappings[paired_read.gRNA1_candidate] and paired_read.barcode_candidate in barcodes: + perfect_alignments += 1 + + # Max error tolerance counts should be greater than perfect alignment but less than all counts + self.assertGreater( + sum(paired_guide_counts.values()), perfect_alignments) + self.assertLess( + sum(paired_guide_counts.values()), count) + def test_counter_hardcoded_test_data(self): - barcodes = {"COOL", "WOOD", "FOOD"} - gRNA_mappings = {"LET": {"WOW", "LEG", "EAT"}} - candidate_reads = [PairedRead("LET", "ROT", "FOOD"), - PairedRead("LET", "EXT", "FXOD"), - PairedRead("RUN", "LEG", "WOOD")] + barcodes = {"AAAA", "CCCC", "TTTT"} + gRNA_mappings = {"ACT": {"CAG", "GGC", "TTG"}} + candidate_reads = [PairedRead("ACT", "GAA", "AAAA"), + PairedRead("AGT", "CGG", "CCCA"), + PairedRead("ATT", "GAC", "TTTG")] paired_guide_counts = counter.get_counts( candidate_reads, gRNA_mappings, barcodes) - self.assertEqual(paired_guide_counts[("LET", "EAT", "FOOD")], 1) - self.assertEqual(paired_guide_counts[("LET", "WOW", "FOOD")], 1) + self.assertEqual(paired_guide_counts[("ACT", "CAG", "CCCC")], 1) + self.assertEqual(paired_guide_counts[("ACT", "GGC", "TTTT")], 1) self.assertEqual(sum(paired_guide_counts.values()), 2) # TODO separate these into own test module? @@ -213,8 +322,9 @@ def test_arg_parse_happy_case(self): self.assertEqual(args.library, PGPEN_ANNOTATION_PATH) self.assertEqual(args.barcodes, TWO_READ_BARCODES_PATH) self.assertEqual(args.trim_strategy, "two-read") - self.assertEqual(args.gRNA2_error, 2) - self.assertEqual(args.barcode_error, 2) + self.assertEqual(args.gRNA1_error, 1) + self.assertEqual(args.gRNA1_error, 1) + self.assertEqual(args.barcode_error, 1) def test_arg_parse_invalid_fastq(self): with self.assertRaises(argparse.ArgumentError): @@ -244,6 +354,14 @@ def test_arg_parse_invalid_trim_strategy(self): "--barcodes", TWO_READ_BARCODES_PATH, "--trim-strategy", "burger"]) + def test_arg_parse_negative_gRNA1_error(self): + with self.assertRaises(argparse.ArgumentError): + args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH, + "--library", PGPEN_ANNOTATION_PATH, + "--barcodes", TWO_READ_BARCODES_PATH, + "--trim-strategy", "two-read", + "--gRNA1-error", "-1"]) + def test_arg_parse_negative_gRNA2_error(self): with self.assertRaises(argparse.ArgumentError): args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH, @@ -260,6 +378,30 @@ def test_arg_parse_negative_barcode_error(self): "--trim-strategy", "two-read", "--barcode-error", "-1"]) + def test_arg_parse_gRNA1_error_greater_than_2(self): + with self.assertRaises(argparse.ArgumentError): + args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH, + "--library", PGPEN_ANNOTATION_PATH, + "--barcodes", TWO_READ_BARCODES_PATH, + "--trim-strategy", "two-read", + "--gRNA1-error", "3"]) + + def test_arg_parse_gRNA2_error_greater_than_2(self): + with self.assertRaises(argparse.ArgumentError): + args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH, + "--library", PGPEN_ANNOTATION_PATH, + "--barcodes", TWO_READ_BARCODES_PATH, + "--trim-strategy", "two-read", + "--gRNA2-error", "3"]) + + def test_arg_parse_invalid_type_gRNA1_error(self): + with self.assertRaises(argparse.ArgumentError): + args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH, + "--library", PGPEN_ANNOTATION_PATH, + "--barcodes", TWO_READ_BARCODES_PATH, + "--trim-strategy", "two-read", + "--gRNA1-error", "one"]) + def test_arg_parse_invalid_type_gRNA2_error(self): with self.assertRaises(argparse.ArgumentError): args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,