Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gRNA1 error tolerance #18

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/pgmap/alignment/grna_cached_aligner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import itertools

from typing import Iterable


def construct_grna_error_alignment_cache(gRNAs: Iterable[str], gRNA_error_tolerance: int) -> dict[str, tuple[str, int]]:
"""
Construct an alignment cache object containing all gRNAs with the error tolerance amount of substitutions. The
number of gRNAs within the error tolerance grows exponentially. As such this function should only be used for
error tolerances of 0, 1, or 2.

Args:
gRNAs (Iterable[str]): An iterable producing all the gRNAs to construct the alignment cache from.
gRNA_error_tolerance (int): The error tolerance used to create gRNA alignment candidates.

Returns:
alignment_cache (dict[str, tuple[str, int]]): A mapping from each valid alignment string to a tuple of the
reference alignment sequence and the hamming distance from the reference. Guarantees a minimum alignment,
though there could be multiple best alignments depending on the gRNAs.

Raises:
ValueError: if gRNA_error_tolerance is not 0, 1, or 2.
"""

if gRNA_error_tolerance > 2 or gRNA_error_tolerance < 0:
raise ValueError(
"gRNA error tolerance must be 0, 1, or 2 but was " + str(gRNA_error_tolerance))

alignment_cache = {}

if gRNA_error_tolerance:
for gRNA in gRNAs:
# go from high subs to low subs to prefer better alignments
for num_substitutions in reversed(range(1, gRNA_error_tolerance + 1)):
alignment_cache.update(
{mutation: (gRNA, num_substitutions) for mutation in _get_mutations(gRNA, num_substitutions)})

alignment_cache.update({gRNA: (gRNA, 0)
for gRNA in gRNAs}) # prefer perfect alignments

return alignment_cache


def _get_mutations(gRNA: str, num_substitutions: int) -> Iterable[str]:
for substitution_indices in _get_all_substitution_indices(gRNA, num_substitutions):
yield from _generate_substitutions(gRNA, substitution_indices)


def _get_all_substitution_indices(gRNA: str, num_substitutions: int) -> Iterable[tuple[int]]:
yield from itertools.combinations(range(len(gRNA)), num_substitutions)


def _generate_substitutions(gRNA: str, substitution_indices: Iterable[int]) -> Iterable[str]:
for substitutions in itertools.product("ATCG", repeat=len(substitution_indices)):
if any(gRNA[i] == substitution for i, substitution in zip(substitution_indices, substitutions)):
continue

bases = list(gRNA)

for i, substitution in zip(substitution_indices, substitutions):
bases[i] = substitution

yield "".join(bases)
38 changes: 32 additions & 6 deletions src/pgmap/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def get_counts(args: argparse.Namespace):
candidate_reads = read_trimmer.two_read_trim(*args.fastq)

paired_guide_counts = counter.get_counts(
candidate_reads, gRNA_mappings, barcodes, gRNA2_error_tolerance=args.gRNA2_error, barcode_error_tolerance=args.barcode_error)
candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=args.gRNA1_error, gRNA2_error_tolerance=args.gRNA2_error, barcode_error_tolerance=args.barcode_error)

counts_writer.write_counts(
args.output, paired_guide_counts, barcodes, id_mapping)
Expand All @@ -46,18 +46,44 @@ def _parse_args(args: list[str]) -> argparse.Namespace:
# TODO support arbitrary trim strategies
parser.add_argument("--trim-strategy", required=True, choices=(TWO_READ_STRATEGY, THREE_READ_STRATEGY),
help="The trim strategy used to extract guides and barcodes. The two read strategy should have fastqs R1 and I1. The three read strategy should have fastqs R1, I1, and I2") # TODO extract consts
parser.add_argument("--gRNA2-error", required=False, default=2, type=_check_nonnegative,
help="The number of substituted base pairs to allow in gRNA2.")
parser.add_argument("--barcode-error", required=False, default=2, type=_check_nonnegative,
parser.add_argument("--gRNA1-error", required=False, default=1, type=_check_gRNA1_error,
help="The number of substituted base pairs to allow in gRNA1. Must be less than 3.")
parser.add_argument("--gRNA2-error", required=False, default=1, type=_check_gRNA2_error,
help="The number of substituted base pairs to allow in gRNA2. Must be less than 3.")
parser.add_argument("--barcode-error", required=False, default=1, type=_check_barcode_error,
help="The number of insertions, deletions, and subsititions of base pairs to allow in the barcodes.")
return parser.parse_args(args)


def _check_nonnegative(value: str) -> int:
def _check_gRNA1_error(value: str) -> int:
int_value = int(value)

if int_value < 0:
raise ValueError(f"Count must be nonnegative but was {value}")
raise ValueError(f"gRNA1-error must be nonnegative but was {value}")

if int_value > 2:
raise ValueError(f"gRNA1-error must be less than 3 but was {value}")

return int_value


def _check_gRNA2_error(value: str) -> int:
int_value = int(value)

if int_value < 0:
raise ValueError(f"gRNA2-error must be nonnegative but was {value}")

if int_value > 2:
raise ValueError(f"gRNA2-error must be less than 3 but was {value}")

return int_value


def _check_barcode_error(value: str) -> int:
int_value = int(value)

if int_value < 0:
raise ValueError(f"barcode-error must be nonnegative but was {value}")

return int_value

Expand Down
28 changes: 20 additions & 8 deletions src/pgmap/counter/counter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from collections import Counter
from typing import Counter, Iterable
import itertools

from pgmap.model.paired_read import PairedRead
from pgmap.alignment import pairwise_aligner
from pgmap.alignment import pairwise_aligner, grna_cached_aligner


def get_counts(paired_reads: Iterable[PairedRead],
gRNA_mappings: dict[str, set[str]],
barcodes: set[str],
gRNA2_error_tolerance: int = 2,
barcode_error_tolerance: int = 2) -> Counter[tuple[str, str, str]]:
gRNA1_error_tolerance: int = 1,
gRNA2_error_tolerance: int = 1,
barcode_error_tolerance: int = 1) -> Counter[tuple[str, str, str]]:
"""
Count paired guides for each sample barcode with tolerance for errors. gRNA1 matchs only through
perfect alignment. gRNA2 aligns if there is a match of a known (gRNA1, gRNA2) pairing having hamming distance
Expand Down Expand Up @@ -38,16 +40,26 @@ def get_counts(paired_reads: Iterable[PairedRead],

paired_guide_counts = Counter()

gRNA1_cached_aligner = grna_cached_aligner.construct_grna_error_alignment_cache(
gRNA_mappings.keys(), gRNA1_error_tolerance)

gRNA2_cached_aligner = grna_cached_aligner.construct_grna_error_alignment_cache(
set(itertools.chain.from_iterable(gRNA_mappings.values())), gRNA2_error_tolerance)

for paired_read in paired_reads:
gRNA1 = paired_read.gRNA1_candidate
paired_read.gRNA1_candidate

if paired_read.gRNA1_candidate not in gRNA1_cached_aligner:
continue

gRNA1, _ = gRNA1_cached_aligner[paired_read.gRNA1_candidate]

if gRNA1 not in gRNA_mappings:
if paired_read.gRNA2_candidate not in gRNA2_cached_aligner:
continue

gRNA2_score, gRNA2 = max((pairwise_aligner.hamming_score(paired_read.gRNA2_candidate, reference), reference)
for reference in gRNA_mappings[paired_read.gRNA1_candidate])
gRNA2, _ = gRNA2_cached_aligner[paired_read.gRNA2_candidate]

if (len(gRNA2) - gRNA2_score) > gRNA2_error_tolerance:
if gRNA2 not in gRNA_mappings[gRNA1]:
continue

barcode_score, barcode = max((pairwise_aligner.edit_distance_score(
Expand Down
164 changes: 153 additions & 11 deletions tests/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pgmap.io import barcode_reader, fastx_reader, library_reader
from pgmap.trimming import read_trimmer
from pgmap.alignment import pairwise_aligner
from pgmap.alignment import pairwise_aligner, grna_cached_aligner
from pgmap.counter import counter
from pgmap.model.paired_read import PairedRead
from pgmap import cli
Expand Down Expand Up @@ -142,6 +142,89 @@ def test_blast_aligner_score(self):
self.assertEqual(
pairwise_aligner.blast_aligner_score("ABC", "ABC"), 3)

def test_grna_cached_aligner_error_tolerance_0(self):
gRNAs = ["AAA"]

alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache(
gRNAs, 0)

self.assertEqual(alignment_cache, {"AAA": ("AAA", 0)})

def test_grna_cached_aligner_error_tolerance_1(self):
gRNAs = ["AAA"]

alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache(
gRNAs, 1)

self.assertEqual(alignment_cache, {"AAA": ("AAA", 0),
"CAA": ("AAA", 1),
"ACA": ("AAA", 1),
"AAC": ("AAA", 1),
"TAA": ("AAA", 1),
"ATA": ("AAA", 1),
"AAT": ("AAA", 1),
"GAA": ("AAA", 1),
"AGA": ("AAA", 1),
"AAG": ("AAA", 1)})

def test_grna_cached_aligner_error_tolerance_2(self):
gRNAs = ["AAA"]

alignment_cache = grna_cached_aligner.construct_grna_error_alignment_cache(
gRNAs, 2)

self.assertEqual(alignment_cache, {"AAA": ("AAA", 0),
"AAC": ("AAA", 1),
"AAG": ("AAA", 1),
"AAT": ("AAA", 1),
"ACA": ("AAA", 1),
"ACC": ("AAA", 2),
"ACG": ("AAA", 2),
"ACT": ("AAA", 2),
"AGA": ("AAA", 1),
"AGC": ("AAA", 2),
"AGG": ("AAA", 2),
"AGT": ("AAA", 2),
"ATA": ("AAA", 1),
"ATC": ("AAA", 2),
"ATG": ("AAA", 2),
"ATT": ("AAA", 2),
"CAA": ("AAA", 1),
"CAC": ("AAA", 2),
"CAG": ("AAA", 2),
"CAT": ("AAA", 2),
"CCA": ("AAA", 2),
"CGA": ("AAA", 2),
"CTA": ("AAA", 2),
"GAA": ("AAA", 1),
"GAC": ("AAA", 2),
"GAG": ("AAA", 2),
"GAT": ("AAA", 2),
"GCA": ("AAA", 2),
"GGA": ("AAA", 2),
"GTA": ("AAA", 2),
"TAA": ("AAA", 1),
"TAC": ("AAA", 2),
"TAG": ("AAA", 2),
"TAT": ("AAA", 2),
"TCA": ("AAA", 2),
"TGA": ("AAA", 2),
"TTA": ("AAA", 2)})

def test_grna_cached_aligner_negative_error_tolerance(self):
gRNAs = ["AAA"]

with self.assertRaises(ValueError):
grna_cached_aligner.construct_grna_error_alignment_cache(
gRNAs, -1)

def test_grna_cached_aligner_error_tolerance_greater_than_1(self):
gRNAs = ["AAA"]

with self.assertRaises(ValueError):
grna_cached_aligner.construct_grna_error_alignment_cache(
gRNAs, 3)

def test_counter_no_error_tolerance(self):
barcodes = barcode_reader.read_barcodes(TWO_READ_BARCODES_PATH)
gRNA1s, gRNA2s, gRNA_mappings = library_reader.read_paired_guide_library_fastas(
Expand All @@ -151,7 +234,7 @@ def test_counter_no_error_tolerance(self):
TWO_READ_R1_PATH, TWO_READ_I1_PATH)

paired_guide_counts = counter.get_counts(
candidate_reads, gRNA_mappings, barcodes, gRNA2_error_tolerance=0, barcode_error_tolerance=0)
candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=0, gRNA2_error_tolerance=0, barcode_error_tolerance=0)

perfect_alignments = 0

Expand Down Expand Up @@ -188,18 +271,44 @@ def test_counter_default_error_tolerance(self):
self.assertLess(
sum(paired_guide_counts.values()), count)

def test_counter_max_error_tolerance(self):
barcodes = barcode_reader.read_barcodes(TWO_READ_BARCODES_PATH)
gRNA1s, gRNA2s, gRNA_mappings = library_reader.read_paired_guide_library_fastas(
"example-data/pgPEN-library/pgPEN_R1.fa", "example-data/pgPEN-library/pgPEN_R2.fa")

candidate_reads = read_trimmer.two_read_trim(
TWO_READ_R1_PATH, TWO_READ_I1_PATH)

paired_guide_counts = counter.get_counts(
candidate_reads, gRNA_mappings, barcodes, gRNA1_error_tolerance=2, gRNA2_error_tolerance=2)

count = 0
perfect_alignments = 0

for paired_read in read_trimmer.two_read_trim(TWO_READ_R1_PATH, TWO_READ_I1_PATH):
count += 1

if paired_read.gRNA1_candidate in gRNA1s and paired_read.gRNA2_candidate in gRNA_mappings[paired_read.gRNA1_candidate] and paired_read.barcode_candidate in barcodes:
perfect_alignments += 1

# Max error tolerance counts should be greater than perfect alignment but less than all counts
self.assertGreater(
sum(paired_guide_counts.values()), perfect_alignments)
self.assertLess(
sum(paired_guide_counts.values()), count)

def test_counter_hardcoded_test_data(self):
barcodes = {"COOL", "WOOD", "FOOD"}
gRNA_mappings = {"LET": {"WOW", "LEG", "EAT"}}
candidate_reads = [PairedRead("LET", "ROT", "FOOD"),
PairedRead("LET", "EXT", "FXOD"),
PairedRead("RUN", "LEG", "WOOD")]
barcodes = {"AAAA", "CCCC", "TTTT"}
gRNA_mappings = {"ACT": {"CAG", "GGC", "TTG"}}
candidate_reads = [PairedRead("ACT", "GAA", "AAAA"),
PairedRead("AGT", "CGG", "CCCA"),
PairedRead("ATT", "GAC", "TTTG")]

paired_guide_counts = counter.get_counts(
candidate_reads, gRNA_mappings, barcodes)

self.assertEqual(paired_guide_counts[("LET", "EAT", "FOOD")], 1)
self.assertEqual(paired_guide_counts[("LET", "WOW", "FOOD")], 1)
self.assertEqual(paired_guide_counts[("ACT", "CAG", "CCCC")], 1)
self.assertEqual(paired_guide_counts[("ACT", "GGC", "TTTT")], 1)
self.assertEqual(sum(paired_guide_counts.values()), 2)

# TODO separate these into own test module?
Expand All @@ -213,8 +322,9 @@ def test_arg_parse_happy_case(self):
self.assertEqual(args.library, PGPEN_ANNOTATION_PATH)
self.assertEqual(args.barcodes, TWO_READ_BARCODES_PATH)
self.assertEqual(args.trim_strategy, "two-read")
self.assertEqual(args.gRNA2_error, 2)
self.assertEqual(args.barcode_error, 2)
self.assertEqual(args.gRNA1_error, 1)
self.assertEqual(args.gRNA1_error, 1)
self.assertEqual(args.barcode_error, 1)

def test_arg_parse_invalid_fastq(self):
with self.assertRaises(argparse.ArgumentError):
Expand Down Expand Up @@ -244,6 +354,14 @@ def test_arg_parse_invalid_trim_strategy(self):
"--barcodes", TWO_READ_BARCODES_PATH,
"--trim-strategy", "burger"])

def test_arg_parse_negative_gRNA1_error(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
"--library", PGPEN_ANNOTATION_PATH,
"--barcodes", TWO_READ_BARCODES_PATH,
"--trim-strategy", "two-read",
"--gRNA1-error", "-1"])

def test_arg_parse_negative_gRNA2_error(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
Expand All @@ -260,6 +378,30 @@ def test_arg_parse_negative_barcode_error(self):
"--trim-strategy", "two-read",
"--barcode-error", "-1"])

def test_arg_parse_gRNA1_error_greater_than_2(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
"--library", PGPEN_ANNOTATION_PATH,
"--barcodes", TWO_READ_BARCODES_PATH,
"--trim-strategy", "two-read",
"--gRNA1-error", "3"])

def test_arg_parse_gRNA2_error_greater_than_2(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
"--library", PGPEN_ANNOTATION_PATH,
"--barcodes", TWO_READ_BARCODES_PATH,
"--trim-strategy", "two-read",
"--gRNA2-error", "3"])

def test_arg_parse_invalid_type_gRNA1_error(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
"--library", PGPEN_ANNOTATION_PATH,
"--barcodes", TWO_READ_BARCODES_PATH,
"--trim-strategy", "two-read",
"--gRNA1-error", "one"])

def test_arg_parse_invalid_type_gRNA2_error(self):
with self.assertRaises(argparse.ArgumentError):
args = cli._parse_args(["--fastq", TWO_READ_R1_PATH, TWO_READ_I1_PATH,
Expand Down
Loading