-
Notifications
You must be signed in to change notification settings - Fork 20
/
plotRoc2.py
194 lines (157 loc) · 6.53 KB
/
plotRoc2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# plot a ROC curve, CRISPOR versus MIT
import glob, sys, logging
from annotateOffs import *
import numpy as np
import matplotlib.pyplot as plt
from os.path import basename, splitext, isfile
# we only look at off-targets with a certain number of mismatches
maxMismatches = 4
# for the ROC curve, we only analyze off-targets with certain PAM sites
# assuming that no software can find the relatively rare PAM sites
# that are not GG/GA/AG
validPams = ["GG", "GA", "AG"]
# !!!
# only look at alternative PAMs, can be used to determine best cutoff for the alternative PAMs
# supplemental data??
onlyAlt = False
if len(sys.argv)>1:
altPamCutoff = float(sys.argv[1])
else:
altPamCutoff = None
def parseCropit(inDir, guideSeqs):
" parse the cropit minimal files, return a dict with guideSeq -> otSeq -> otScore "
data = defaultdict(dict)
for guideName in guideSeqs:
guideNameNoCell = guideName.replace("/K562", "").replace("/Hap1","")
fname = join(inDir, guideNameNoCell+".tsv")
print "parsing %s" % fname
if not isfile(fname):
logging.error("MISSING: %s" % fname)
continue
for line in open(fname):
fs = line.strip("\n").split('\t')
otSeq = fs[0]
score = fs[1]
guideSeq = guideSeqs[guideName]
data[guideSeq][otSeq]=float(score)
return data
def filterValidOfftargets(guideValidOts, minReadFrac):
" return a list of all validated offtargets with minReadFrac "
validOts = set()
for guideSeq, validOtSeqs in guideValidOts.iteritems():
for seq, readFrac in validOtSeqs.iteritems():
if readFrac > minReadFrac:
validOts.add(seq)
return validOts
def getRocValues(toolName, guideValidOts, guidePredOts, minReadFrac, ofh, isCropit=False):
" return a list of (sens, fdr) tuples for a ROC curve plot and output rows to ofh"
# keep only the validated off-targets with read fraction > minCutoff
validOts = filterValidOfftargets(guideValidOts, minReadFrac)
sensList = []
fdrList = []
cutoffs = mitScoreList
if isCropit:
cutoffs = cropitScoreList
for cutoff in cutoffs:
print "XX cutoff", cutoff
predOts = set()
allOts = set()
allOts.update(validOts) # make a copy of the elements
for guideSeq, predSeqScores in guidePredOts.iteritems():
# get the predicted sequences over the off-target score cutoff
for predSeq, seqScore in predSeqScores.iteritems():
allOts.add(predSeq)
# check if alternative PAM
if predSeq[-2:] in ["AG", "GA"] and altPamCutoff!=None and seqScore < altPamCutoff:
continue
elif onlyAlt:
continue
if seqScore > float(cutoff):
predOts.add(predSeq)
notPredOts = allOts - predOts
if cutoff==0.0 and toolName=="CRISPOR":
print "missed off-targets by crispor for mod freq > %f: %s" % (minReadFrac, notPredOts)
notValidOts = allOts - validOts
tp = validOts.intersection(predOts)
tn = notPredOts.intersection(notValidOts)
fp = predOts - validOts
fn = notPredOts.intersection(validOts)
# sensitivity - proportion of validated seqs that predicted to be off-targets
# relative to all off-targets
sens = float(len(tp)) / (len(tp)+len(fn))
# specificity - proportion of that are predicted to be not off-targets
if len(tn)+len(fp)!=0:
spec = float(len(tn)) / (len(tn)+len(fp))
else:
spec = 0.0
fdr = 1.0 - spec
sensList.append(sens)
fdrList.append(fdr)
row = [toolName, minReadFrac, cutoff, sens*100, fdr, len(tp), len(fp), len(fn), len(tn)]
row = [str(x) for x in row]
ofh.write("\t".join(row))
ofh.write("\n")
#sys.stdout.flush()
return sensList, fdrList, validOts
def plotRoc(prefix, guideValidOts, guidePredOts, colors, styles, plots, labels, ofh, isCropit=False):
" plot ROC curve and write annotation to ofh file "
i= 0
maxSens = 0
fracList = [0.0, 0.001, 0.01]
if isCropit:
fracList = [0.01]
for minFrac in fracList:
sensList, fdrList, validSeqs = getRocValues(prefix, guideValidOts, guidePredOts, minFrac, ofh, isCropit)
if minFrac == 0.0:
plotLabel = prefix+", no freq. limit (%d off-targets)" % (len(validSeqs))
else:
plotLabel = prefix+", mod. freq. > %0.1f%% (%d off-targets)" % ((minFrac*100), len(validSeqs))
p, = plt.plot(fdrList, sensList, ls=styles[i], color=colors[i]) # NB: comma!
plots.append(p)
labels.append(plotLabel)
maxSens = max(maxSens, max(sensList))
i+=1
return plots, labels, maxSens
def main():
guideValidOts, guideSeqs = parseOfftargets("out/annotFiltOfftargets.tsv", maxMismatches, onlyAlt, validPams)
guidePredOts = parseCrispor("crisporOfftargets", guideSeqs, maxMismatches)
mitPredOts = parseMit("mitOfftargets", guideSeqs)
cropitPredOts = parseCropit("cropitOfftargets", guideSeqs)
ofh = open("out/rocData.tsv", "w")
headers = ["guideSeq", "otSeq", "modFreq", "mitScore", "crisporScore", "cropitScore"]
ofh.write("\t".join(headers)+"\n")
plots = []
labels = []
colors = ["black", "blue", "green"]
styles = ["-", "-", "-"]
plt.figure(figsize=(7,7))
#dataName = "filtered BWA"
dataName = "CRISPOR"
plots, labels, maxSens1 = plotRoc(dataName, guideValidOts, guidePredOts, colors, styles, plots, labels, ofh)
colors = ["black", "blue", "green"]
styles = [":", ":", ":"]
plots, labels, maxSens2 = plotRoc("MIT", guideValidOts, mitPredOts, colors, styles, plots, labels, ofh)
colors = ["black", "blue", "green"]
styles = ["--", "--", "--"]
plots, labels, maxSens2 = plotRoc("CROP-IT", guideValidOts, cropitPredOts, colors, styles, plots, labels, ofh, isCropit=True)
plt.legend(plots,
labels,
loc='lower right',
ncol=1,
fontsize=12)
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
#ax = plt.gca()
#ax.axhline(y=maxSens1, ls=":", color="k")
#ax.axvline(x=0.6, ls="-", color="b")
#plt.text(0, maxSens1, "max = %0.2f" % maxSens1)
plt.ylim(0,1.0)
plt.xlim(0,1.0)
outfname = "out/roc.pdf"
plt.savefig(outfname)
print "wrote %s" % outfname
outfname = "out/roc.png"
plt.savefig(outfname)
print "wrote %s" % outfname
print "wrote data to %s" % ofh.name
main()