From 35c3efa3a0791281165316073e98d09142028b3a Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Thu, 2 Jan 2025 09:51:41 -0800 Subject: [PATCH] Paired heavy-light modeling (#92) Enables training on paired heavy/light sequences. A separator token `^` is added after heavy and before light chain sequences. For now, only heavy chain sequences can be used for validation. --------- Co-authored-by: Will Dumm --- ...l_pcp_2024-11-21_no-naive_sample100.csv.gz | Bin 0 -> 20436 bytes netam/common.py | 34 +++++--- netam/dasm.py | 5 +- netam/dnsm.py | 4 +- netam/dxsm.py | 27 ++++-- netam/framework.py | 81 ++++++++++++++++-- netam/models.py | 4 +- netam/molevol.py | 4 +- netam/sequences.py | 75 +++++++++++----- tests/conftest.py | 12 +++ tests/test_ambiguous.py | 8 +- tests/test_common.py | 4 +- tests/test_dasm.py | 3 +- tests/test_dnsm.py | 5 +- tests/test_molevol.py | 6 +- tests/test_sequences.py | 31 +++++++ 16 files changed, 240 insertions(+), 63 deletions(-) create mode 100644 data/wyatt-10x-1p5m_paired-merged_fs-all_pcp_2024-11-21_no-naive_sample100.csv.gz diff --git a/data/wyatt-10x-1p5m_paired-merged_fs-all_pcp_2024-11-21_no-naive_sample100.csv.gz b/data/wyatt-10x-1p5m_paired-merged_fs-all_pcp_2024-11-21_no-naive_sample100.csv.gz new file mode 100644 index 0000000000000000000000000000000000000000..838457b60c7898aaa96044e500793cccfb3f6bca GIT binary patch literal 20436 zcmV)9K*hfwiwFqba9d{r|95#|baY>EVQF$@WM6J?PGMs*|H=@ zZsohaVxFZeT5KN6yulzFjBI3j!sxY)YEs=aRa0%@kpLVxIOq0%fByOVKmGW_|NI}t?XTbe>Caz({M}DK-2VOhKmGLezi)s3{x3iN z{I}ov54ipQk3W6=K`!;L|EPcW$J>AX?hily{G8E?>ph`K$PP z`})%#{`j}w{r->N|Mc_k{`TL0{_*y2fBF9B-~aJ596%{mXy;@MCjB z1SA3kK<4p(zm#9fsv8Pdj4$fvzy0)w5f=m^rtx=Ve?dTu<0p~7BL8XQpJDu04FBf; z7UpV}o|J}d-@h``36#n2}ZutKGoA|FVe%bJ?$1fUQ6@(WE{v*R5{0D>oaR1N# z_4=0tC_k3}VE+XIQW#!M_(%r%>B0y2Ph|aq0iW;z{+ItOeW@NU4nL6oZ~-d{FN_2( zgg4?Jj`%;(|NhtdOIiMX(cU_D8{aeIyX01O<*n-bTV2=wp2!au-Y4Pagl{ach94n9 z2;t|`ZvuKx^lpUL74Yzga(GGO4-sCqha$W<|7nEb4$2o4T>KvbZvMNDUoIdfpY9(A z0)0Jl(>;uO6T-KaZ%T3v{)hKpQ}H1G1^tJIn<)1(JaFMA$d{2{{^#^BF1&RAv)CQ; zKVom-e^Ky;@_A7H6aDkQmS5cX+*avx>)Kn@cHb4ZYVf~*{r*q? z_QUtL^2aaZMxYE5VCr!ApP8@ZW0iqBMAo*UWme!+nmyi>Bbk ztHCQo0C_%qoRZfskB_OABQ|cyRO;lE!r@ zJqYAu5CBZnu#|6nXRI;L*aSQox+h;<+jF@%=_Zv-Dk@So`a{A>A-U+N(|ru}V< z>t)l8k7_-pi^ivmz8`=8kMIBd=P_G+*A#Bd7~QS9Rp0exxJ<^;ZyU$ZyCXm55iM=oDRlALs2$m(Yt!g2jL7+S5?&(<1;N;A^iS#Z~Lh_R;oZb z2+C8Sks!n_|8>KIf=T})4P3NV$EJZRoOk(59gCrgod+%-H0?+mV!T8$?ezG1;KZ5X zW3c8o)|wTndgGK(b+^#f%332F^Qpx&g-HTScx~2u(C!H6>j*V2=+_GU3zV)MX=@}6 zYb_-F zhM}t~=`V#VNe+5usLwPqVo-w9jSQ4BUweJwU3;#P5gXvz_YQYME~I&gR~z7$8yOP3 z=8X(rHHS6}N`lumXpBO9XP4*=?!Dkjh;RD8FY2`rU;4ry!VyEt)+9RN6rJdNbP^m{ zPFLC}pse+6p@}v^Q`*MBBs9rEO9*(D{S#4B53-7q9OS{kE{?`?1!RP_dC2LhqqwK!-k%t^;1R0(_`QT&=QrBA`s=Z)q|scEAuJ-f*1lU$7%qb}O)Xi)Xd_*dhH zg`+`L6t3<2GCdizx4LspTL-36%7jDVM;1kydpL-EYzDGXz6tH8Yl>;Noe}QImdUQ;01n)yo zMg=ZBq$ef?t{dvw*EBLIsG7;WPt3i)(6OP_#v=6At^Nk z02x!t9M5hj3OE{LG{LiSvTTZpQp;I{OUqQ;;6rMoHz0$8q!r=;x4^CpiCkZUJW12z zWZxXDt_U?MN=b=2=s}rJNDxDfX-R{6-mXC5tY`y9TPbmbNk~d-f&Df$XQzTU;va;# z&jHPs##nqqcxR>TQ<72S(^}dfWvqk^n76d0ZVj1_V#@o>YFyaoO+-HismaPdU-#I%t1e&sj5kCktE`TN}!QmdD@soj` z<_Ohlx2U-ph4x7;aIpOeVVA6gGCppBQ`!ROi!E?`um!GN$54sG>3r}~1nKh-M}$IcWtA<0rL4?E-&jaD$8x!j&8giW}?fI!xd= zZXr~~$`*o_s|o-~dXxfINnC;*>5-k{j!F2L^>`#qTWO0)hg!~K8t6jBE9Y}--~rhH zm2LbEc9U_ci^IJ|S*&7R^U`yvMTn$vA-P@F!NT2ODO6dwY(OOL1>k7-B~#WL z!DVlRr`|}UQCe@rShS1Yi16^qRWRY;C%_Vs!fxZ37HlMpcy$fG_F1~OUU)^eU^Wd+eh0;BnM`Pnro6&?i@+yyp)_64ZVLTcm{jb-* zxN?GcE`0jvSrR3_O?5nEHw$l7C+BKbw z70RY+T&<&Ie>q=c_HQSk82eqH(U8zO9g8M7&YR${(gcUqCOBG8h$4Izg^v2*y5ml= znkwlSBv?v{l|Hyp8&><^q6zv^@2}K}t^x<1EYb=h#OH#C!+Dvk!LeK^H#Vtt_h0sx7<5H)_We(C<{Qyr^FH zXQPyLL4<;}UcD~Xs5FLw&PKDfTK#?pvr*30YI&u4oepL5wOXF4SCHPjR*gFX~P-G@9-4r8v6JPCIe2Z=yR z61nI({G;D0?sx`wz7zK-WUC)s6=f5eJ)G&o5Pgu z#uN`QrAJt*ee2xNG_|r+n`RhX-!<8egQS-6L)x}(hQM5E+Nx`VuAWaUcEZqIa+bZ+ z(+<7bgEw6uW9)*cy@hQmIT zUSzNT(#9E@((~Nl*-LYgJS+pTX?KY`1?EOwLG?HMjVBwT*}G%o-AgdstWt-w`M zbak(inxa7$pRL2^1Qm@UZQ{^B5(&l)EuX+J=WQ zN;Zl^+i99B%L&(XK$W7Z!Ai^T7_9idtZ*H4q1oL7x|l+U6nzFdEzj6rhWpfegeY7*Dj1mcuyUSD?T)~f0oW*DU}Dx zl;U7TDbBK(bRL=XDFeP^%)JFMy&=96;(cp|z<>%Xd8VNj1D9oDl4KaZI$LZ}r8P)c zoPky^d%>Y|pe+&KE!X%lt%*oK#H!!89t;X#5Qh*t4{IAQkaB_+iJ@=Ruf_Y6&Lcb` z-5z6Y>{4R zV_C&Xo8@GPpPPMkp>f|Bls5>7$Mxie_f%<7qS zXV5@+t%kBJJJoAy8|SLpi?*hgZ{w$RU0?ge&^4`yHTetMsg86di3T>1?%cft>HUs4 zrycRg?-LiB*bAq{kQC+v23N!GZJ}=sD1M7HAna{L4Sjj}HD=-n&Vw>8Xx1$I5d@tHQXVJoijT8aQyIQ7$pY zcz$Ba5SH!UP(BIDV=2FG)=ui%qVtuH*uQmM4gex5? zC}m7V(r9WXRUhF2VWwpK6kmeGxe}l2OR(l#No!qkX^2Ep4cJ2)^03ywBE%vrPxx2^ z(f4h5{5DL^5vxZK*St=~!-C>3O~ymSTB04!SuMafg7+hMMGLU~uJ(kd{?^~zcSXOj zmahD$+j*xr-^yw#k#7oDW+MtJCoLVjX)~fr8&I$*FBC>eFqapS^V+`uvGy~fwPg~% zgqE*_0@1F^Tjs24S@Q0e$}$`-3Fk&<%4s9UOju}YnS+=YkjQNDB7B)h@D#p~RKj99 zVruwO%_On7G-?7`bDforlSeAfBNa(M>W%ovBCa~I^*?SThdu3Z$wGJ8VForQ+3Yn9 zZIFbvIQC6Ln<=iUPx4y1jH-brq0P7pS-DIrp$*o`nEM0dC(Z_Aic>2Mggy>zUkl8h?8 zsNF@fRYr}9y0NX2s-io@Ds~ELnZ`=$UNlZ5TqlA=LANCRbBp>* z&o0}cZ(O79a_!)T@&EH>>#FgOyK|u{RTi#n^gt0VHBE$- z>x^c@R}2o0hSq2;j|)Rt1@C+10OO-dp!VUCW9iSH=4Bx>u@S2TIP&+|gK1cdGL%wry8j zM&AhDkDw?3R^3Sjz@n_$x~m2W`o(xMa6?(<=0eq2*Q?U#&d$44XUabG~C;v zUQj{0QvvWOQZ3AdOwAe1eDy5z64mp@1mZb~R5+Onfi*QYy5PKFE(EI-{Zq`=P+VB! zQg@s~2b8QSvJft%ZAY{>=$`=kWID8(D(dzU*9i^?I&n)bA?O4mvaWZWBj~7?<`&lp zAK^OL>F*pt$5tKbGzaz#<#V7^u2Xl4>m-e(CalX9ZC$o~k&~Rp&mNByK5aB(+Nj55 zq5Iag<&E+NO54>pJ?-*eH$oV~ofWQ`(!nLher`%`Npy-qqO3jHEnjR>lQ<}U$Tjo@!vY8 zS7bWMF|!nHID`cGzrJt#t|^AK>9&6t)LqqxKX%`j=8x?!)eg103CoaW3dyjl14v-{ z35!Y@Bk=$zZd?i^b7l@Fm9q2j;gc*fi?-eQYrJ^ zVT89#)^OQR!Q5y^UV}0{lsch4n@ZUrFkeMwAEjZ=-anI-v;T*03wl_Ao?it$-4DHL z&PpT-4T*c8pAIkeZiZQO?wM2Gr*`fs@7#0VxwpFp!uvG!mdOKp*13m7>Mrx{L?Kt} zm`8v<^6tT12RBe2j|`ZN`(vGZwrT|+>JyP1FeaPXrOyKF&puv_AZ;OtDS~VW>QMrf z5(BV6koK6f5L*S9AZUfsg9ys(LnhwGqmR0%YVT9g1^d&gsK!5ZeV#XlTRHx9uMQQi zt{QQusEW3UXdu~ttS}IM{?s-(dCS7%rCcF%4RmA6_`nA&J zkgdKHBPTWvj4A4=NS+Y%kfYp;dh-gOEfg5qUZIo>r_lF8srw}yaM{FO?LMs zlyKGl=tJvIws5L+fOwA#Mk&UouAb4B!D?VFEyx5|RJ`531V^mF^dkV$mc8(SbPa}; z)ttBNr9HForid+j;~m7J+LqbSc5pc7_#~Ye(ag}AU*9lBQ9ilm{W2)&J}8-0c<=R4 z@>wYH#OR~?QZ`2aKJ=!j`eJCx@t_hWAgvp^(wnL-4UJd@7F&;kp@v~Z4n09he<^O zG35~BKE${RF`R-J$(n1M5W@zwT%jY48>6PuB*$HjNoZheccvvtY~^CgsP!* zy$!|#m$rU;+nx2cld<$?Sl2CK#l%tXV_}6CSO#8z^~jm7XlIIbY>IU{npYV&BlahX zb)2lLnEwsh^u@f&O~sv!4nMQXu{0Yc-Qy5dHM(U*APU#%0L$jYrm3f0G8O-`GC0=Z z7At9@Mda0g1lbl9z)}p!n)^A{6FrEdwl*V+;Rq+{{_((CZ_EBQS>4}sRn=8~$)~bE4)suX8DqAtZi8#fuuRwY zPN&)7N|z*1r2q9NSm(>?J$*zqf{#@r_>5|VDi`Id5v0ctS-^$+RU@Xx=91Zu`>GN0 zH73ajvL-yJimGl%D}VZ;X2%=-&qDt|s-tl?ZANQ}3+bxBYnZyyddg=a91`9LK8E1o9(HCn;s16TVGdR^r^Xx zM~7Mnv*9Y{4Nxqk4ITNi#cPxSX)=!mn8c|kube&Ex=hlPbdz!>Tc@2m>DxxxI$ASP zn`AKt8{Q5?#h8;t(zy0 zq9yLrp@Q~t=MTaiPoFy4a>oz*MMiA0-Fw=I4fglSo?*lWn?`InH)4|rqOt%7szh`W zLB4<>9wBILqF)Jum#pl(E@0mls@X^R5!P>ANSCosiRn9?DpAwRqTQ@DgF zoW)c$oP>ZSUL5DvDlS;mb>p+$jQ=?qKzFXM-Iz~ijVNCZwcfbS3jh{XRW~Zw;V+}w zII^CR0`_|nfIHogQy~p!f^?v5yj4LuilRg`7o?*oNQbhdi%I$*-Q-Od#3&P_%hE?u zv9WF<%ofBfr14CUjxs?y@pNLP)e!SNDjvMFuTy+Q!v7Fo%JC&q$(uXyMZ*^od|B|t zYxr`+moMYX8NQU`OKbSj*7n_|ykF@groBM96iz6Y(ml$hu+(QqnDCN@`-ztl{@Pt+2BqyGZ*r@2J+4_tNwZcZnz z%3$NEWP;C=!RC*b!R9D~ee#s?2JvAKkJHFpD^^@7x~}i)p~>eHLEf@z#spH`dcSce zgKd9h0}3#^df|%NV7u?H2ti?N{#~@-lyC#f5GErjTNj$ETWc>$2_d#Eth{F;`7wN>z`71t~cOhZ0CtSve)Q;(~bz_Xt6~L4P{v zk}k@6!_jCg$gXUC?f2Curl$7M1zy|wLf)B7VA@v?`4W=7xvjqX>fRh3Dc*6?SBFFo z8&nHeXrw~yyjL*r4c|HV4mal9nvVE*)adHAZZidsy1$JF4%dfKL^EDVqsd>HemE&2 z!V-w<)kAfd8X!g~NLV*VQHt2sSS3nwqUel<< zs8I*!HR|w{8g+E2QKvM}NeAK}9sEK72MP{H!=P2KU*{%a+mr^crl8J)Cf;UsSE0*q#7nH z)Od+vK#^l4Ycfpd1Y&WiUJA$Drd9P)#2Cd)suGfp20W>a^v3#0SZnIoGMwn1R`)_~ zmq2Ii5=iY5Xi@j_XW1p_u-!9iDHK9BTBzRSR{EWpc;Yr7) z9Hj~xaB?flgdJ4HZ4z%yr5<=w>cRGi52YS(OX|T$J88`tRg*$$HC$Nj6wU0D~gqq4Q|fh zCRN%|bxuy<7Dfus8A=n}n7}RSPa)~0aML#2D!S2ssZ4WEH4LKWUU3JL<`lP82A%gY zX>ga@`aC9$wLmao(s0Uc6>8WyIuhj|PD>%DLCj{&YA1aOM3@wnjzBaa$gLHJ1&BBU z5fwxsTLWf_O6NeN2k1zC?X7ct;o7PUtfaE4TvfO(rUBVajt9Rbi1Aer>q1<)J#TnRFzrh)DJEwmn7$Pp@eoJ;c5%Ea!89g-q3B9i z$q0H(2z^m{GDUhcXnft%HcseBLswMFNAi~@-7Q@@6$v??6ds?DMj0SHIU%L>38^#< zNF$#SF4rfdyAf-nmBA)$6x$c9Q-`eo_Gs#K64Fs<$XOvZ0G`*{P z>m5~1-OhVQ?lv9|Tve+(scR~8Cmk>K)F7uT(9=I{>vGEhUmF`7RX@QN#ScZJ!+C#d z7TX6!RZlwsN6~1OEDiv9t3$=;l%tNUyHq3%a1dUyiAL+L*ix4}Hm=14CcZ9AMWYkz zOwM%T%;Gsx>$R0`-n1$PhmoLc^^VD9OjAn`5^4Hf zmOf+D=T5C{G}vvi^vzbv_V{RX%4$Y}Gx4Cv(hol?ML&melN2$w3gB!&p$>Ed7L z5oKvbj*0G-2$J2XyX2UVY(;j#z^5expH9j#!6R}^5e<34+5+Lb+D&hiUxo4{!BqYO z4JCeuWz*9tWvRTuXQ(lqQ4Zow(xoqMUwnf~~2LZnsiCo~;QfTQiPd)o;zh8oY$>dLM$Yi6H-C5WK>7 z_06X-hs0vS3#DCgukgiX;k(P{$4uz^jp;Kny?&Tb zkBG(%RWX}v5B||_WxKkrn^F{A!&OKZ?xv>S>J=F^n5v>#HlP641VYmn?DY2jBxWO z!p#rDt#YsZ-cjNujM&K`^KorC{CI6S8EL&ln(sFNADxEpkOiB}Bs}A(ZR)-px>}k{ z^-$JrQ8vwTpE?+nHl0`K^Vz{QgJy^EDr5%{dhL%Fwcuz`%aU^xoR)1ShBJ#=JPSzr z=8ub7N_x{d4^h6TMR8G!&Ma#2U5i@!Z6r!cC=D9Nx*`#G5hT1mXFKj*(1l#YPXL{M zNubjc_TpZm!AI)0p?7*`zPiEC5y1!%ptvX^NBLIQ)%j`+W~|Z{_Y$1m@STJ2!gxB6 zU;ExQWii}(k!dYkFRYZ?KDhGM4&8Xn@Z6@_^`fc~t|pytR5$vt6Sd|e_k7DBTC=gi z%a%dbp0i)@k(NR8JuY!8AXAr4`8498t4(&g%Ry)bdYA^+d;xk@n*o@hr#smj^c1>NWfrP)(#G-5A zCRr6lUiKsdP&zE;9D2;BoLh^@f3k>y^hn3avlxmxX zqVcMz9ffoxy?4Iz$nnPdHCP{PQ|t*gK{9?-liX#g%&WTqlAL#U0q9^CKqia2uy0v% zxeI_{T3TgLq=%=opcJjMH<1379S)u$W~Nj+Q>hdkdT!G z;!Q3bFWHC(Wv8eLE=B1kPPBwT&BHjVF>tiTRJBr=gnKZ`4GV*YBk?6tFwrW-l*T$A<$MCt1ZU5Q*1N6uos zGzi(WCoplpv1$lyAWI_sbfi7;hVLQxymUt$DSE?L0yRx2g{roUvx`KzqaViScPNIy zZ>s9Pt;Q5mW-xYL~-&S%Px*J{WtS z&n-*PHhUiu_CDt9ebtc%Rs9nQlEB_aWAD>!S%SCO`(*5WP*IkNy$^HtKF?JD1QCvE zhVodWIY~)Z2oW zWUQymCu@1U5qwzxb<7w=FC8c+EvVjgb=?*##HV&VYE;FT8-A#7by=5vBZrQ$SgP9C zh@#*XqaqB`p(Ee1i$wRw#bSvvu~i`(_x3?C#--y2) zaT$jWJ*>_bh9rJSw1E_i&ju; z&@=x9VYbPH%Qk^^A6cpm5wsNRg*@olX)kvALDxEMJ9Ms|ns~E`g)Xb20Ff%XN+e}H z#o9_TP1!n09&acghtemH>Py+!bGy`{O_E1ApCsZReOchPD;)4_6Gj zuqT&B#|?8WkhU97*4X_<+IVVFoE=Gkb}TGlb|g^=Q{6b8%3+wwVcb5F!1W`EQiC?7 zC{X(BHUxu7(OZ?p<(&Z8FX0@Sdjn;sS6n-ryl=zj-v&9JxM9Oq+cmPHmX*{TGqUw+ z`0f=oisRvXx6~-zQbU-WKXkzGRQU zb)msG)-T5TU_oQo7L^}W+(z9)*;L$~T-tC=CW@oT(I!Xqpi3_i3v$J6iE5jQ+j`@K z+oyY2n5z2P>HuWx7;gmcNAM1oZP5H^TQ`(d9nTQ5(AO5Xt|*4!l3f+9aI(r7W*?hu+9O8Cji2HyaI~V%X84 z5H&_py8>|^ofoaNszX7dLm`vLH603J>_i5dDQ2ZQ6gp{*JDu2@u)eVNqqE)w z(^=^`4U*?%JXY9Dp^LArbG6?F(>T4V*zv8QOi``vmyYs*u3uGtiWm_sIv39pBZB9N z5e0v9`bOoPt|5ruID&_%pwAVZ2Jckb4r*sw*KJ$RhRE&sxvuETkPND_9Qra)kDQ+o z$KDSZi$0W)dpi*i^oM3qiY-bBa+irPQs-qv7*&1Mea5iVb-7QyFX_a`7T1fhrTnE( z-$)Ku=h>`p&_5dVBLY(U4AzYGkBv3M@$4?pGYtyb7 zAN+8JC|k3u3zx<{Qh#vCyS%E-fq8X%ukjEkJ6#%0NI5DeoNyp_!qIfgXN-s7apNI4 z=Y%7qX8B1W&Nxxz7Nx4uW#`yJB|4u&9=`!}{|&$sHvm#!?Bg3CWoEnq{G1!W&%Xf& z+jY_8^rIC&=n)sTPgVSw+-Hx}p5B;lV#;SQ#k-yD!g|Db7U|kX_akcO%DU?6Y%^l~ z{I)OKu(RFP;|GT<6Rz}d0`Uzi!^C(vDORTvF~;F6#^HvMGp)PZsj8dkH3h76Yy>MY zE>@UUC+Bu+;qi);=2?uh_GA*{a3#rT&pnB8LV?+|TZ1&lg{PENstJBSjq)MnJKNp(FIT^>`qh#qg zl;onY%1Yc+b#|I0fpLS3Gi{B;VJyxh?@Y8yT8JF~JR{=>facjoR?1u!4SXYb1;IP% zOl4KKU0Drn-py)nUDFhOXl7Mq6=Y<;erda*HF+r2tYf9hPkVIbS;;`S)7^R6>RR>G zIn{Ef{1>M!ce>beH&30suR^C=?hwU=gh#<5#BnFP^$qkZfj$~!vPdOEAig z)}tXL4HisUIV__F5LuYSChD}}(;LBi5fss<%2|tQsruA!w5W!v?wcmh1>IQVjK6uQ zz`7jjQdK+pF)zgBIAmR_#&~wQwBof2+;g_Kp(9+G~b~7hFS%l)r`QW zXh!gXW`shd6g*vAGr>QxRQ|<5iQg`ill(??N`9m0$DrS|z0>|uUA9Bx`b-nB4F};~ zl{Z{DRHg1v`%6g^ZFSQnb}_&hUs-H24X-pcG9Qk(sio^%8?|gW8LDb&=$@!&vw0b* z9jBhjaFHnJg0l`j422~Xtt zYrPna6xP;yX-V;i#ab^twAM=xt@U#BXqRTq89RkHluv{5cu%^^wE>-I1D3(BsOy{3 z`5ciZj^S1{l@|c~p}ti`Uv{mWFRCW;rdVk0(Ok~5M*DK@4b?KAL?18{eSkgD2gpPp z@IKK8N<|+a6Mev|q7OlwMGbf)+zR(YL{eh_yUdt+!~Jr&541GuA~a1nHcj|YCOfuO zR%EiJX<}?i;jb~p?z#D!diU^3Y`uV=Gq5=Dromd_-UEi+lLD7q$-t>I10 zK|r-4C8kcQLdXdrG^z^G8@|WjyP`5)epgm)H8q{w_@{cPa)*j8Xl^v#$IjDf$NVtJ z1fnd8HclY4!4~Qmr4JL-j+?3SxXXGs)tCdwh`Bk^3NOePXJwy6y?^DqMceOD=`jZ zafxP%Ls5B6h?QcN5Jl+JVnpJCY{-i;JSgy-kBPu2lEqNYok31+K7o}d-8zX?*avX0E z?*j3X3g46j>b7mh5A3tWimtkqUN^i6qOM)(oLJmBH<%H2;1dAEeq5Avb;~3`IjbK| zu=;VaR>5YK!)BE;5z_Wp{hYD-C6Q`4%6sc2P2xB#sz1miK!~Ick3$wM{Vl&mnllI2030I%q8#58EOM7Em!qsQ?#=!Q* zq%WhtWL@Q~fy5ii$Dy1wz{@il;QM;5U43hcz8W9)@wm{swwK9uxYA64#>fpRgC|#_ za^DJ-qZKNhbK*3ZX>#H;OQJX)y=a9h%5qA2Q$-c!D`?dnZyTsxIZY{la^jTTmC1?I zI`z2~sunwx5ceQ4F@WgeCMQ%@Nomrj(OMs{-}j(AWxR1WclkSTrv!JtfO|APtyay4 zf|;Fud-DL-@)zwmB>Na(=~rFO(R56$CT6wYB-`#HIh)z3O%_Mg`(pCx6Aq9y?Dbpc zsugUvTq2(a!PXrWy zg@F14NA-_is@_)Zit+#3N{+|f5LWH#TnVJBo35^Up+Z$%?fR}Q+bo|nw|eldsWQ6o z(+xR!40T(CgfiS%dlIB;vuhjOC-jPztQ?;UZ=qPl`;`7I136SdZ4p8{kxed&M9Hh8PXHb70-Zk$bhs1ri4FB+ zB@jBSo#j-=YSO%@oq4TxrmS{W7Q>0!8K&A9pIAGS?r=&e@v39U8_K<7$O$W=kUjcx zOd!3Q5xKsp+qSLq;e{JdBd&6xZ&kJ9!9?~ghQ<#x^f;ocuB-y1v2oQOuHHcmWN^Vy zO{0mO60KpsxdpH?gSFYM&Tn-EAlIT;M6vUFMtougU?F*px*nthVN^O3t0K3E%*YPE zjb!BjxM_leCl~YHw_)ox;7e|UWEfX(dl79H%tZnj^5D(pC)=vtGU z(*q>W_Y1(nA}Bp5o#^so6J1=L=-P47Cb}3Wy6S$tv3*9XPN{DwAB9qe<8^!1Db+hr zboJ2HdOm6D@!1}Ryx%o`k+!VMpaoKm$C=WF4AYm5t2VtVj!B-hKAzn)d6rT^q9SIt zIr}c#sxvfAXqP-sI`&e7OD@}+@k_M0ZBB4@$N}&oSd}u!_rl5#!HORRD{pgT8{2Bq zV$2Q^Mf)u`;Q5a}E(@=bE;5rL~}nx@m^KaG4K9*WJpY z_jQe5$7|}LE!2m?UkV9IqI_3nH7evi4;4k6YE+nNR5%uo(6{bIVQgt*f*Oy>t=7h% zOmwAO4aF)d-A+qIvl!94C#z+wJx$R#~W>0>6RcYvyaT=-6i5PE4vk%ve(-Ng2q|7+Y*a5~RzEeCKG|Fb; z*{~*_%__Vo+si2HxLT;*CqMp1@HB!VKVEj{rJ9q*^eJG3-)6wsJgP1fkRXF zLnQ@$-&TXE`+@mWLax?SXA^8I3}B}&&)~Y`*%gI&8z84rzF9aQad!Xa#b-LQt>6XLmjl+U$5DKMoQQ-o8 zG+Sqt
UrS%GzcMc0CD5gDq6Ls0&4}&+FPeoIX4E+}SsV<7VDOB~1C&_lQZ&!9z zS&xY%A6gHdE?c`w{7d_)Rsot1W$^#f@k&z&8LvbQf}@j>TTLO?*MfO#EjXqSQKP#v zn(-`}(fMe`PmN|mLL$>8W%UkH`8Kv(T!*3q0PqA{rD+k@@NN7mM!bn)vEIpE-h04 zo1=Kh-;plcZ%T=uu-t*e!Kj*HihSg%EAxis4$GW}MM5I6+zH~3i7PXm)*MndHpmo9 zDOD`h3=cCW&WvaiBCXW7QX=WsROLb;`^zZQ{LQDhMJIH`-?)De?vp2Wzshl1HAiJW zBY`bOu*unW{bQ~qH&7U`@ zV;IXdiFb6`BJ~a8RS=J*#FIt5>bqtP{!`ZIx_&%I%&b!T^49mR3nS<8kkN@zYPeFL zR5?!x4*u4~)5Y+l(}mKUqfGmv(AC>}ve1RB(8VY{#2eB>Q=v<*?v?0Dqii^60dMMh zOe%bRdO`d7h~qO6kDvWa#OZ?&r_Fus?Q=!_!m3zy@x>ecfvTAAtcv;Zs#v&TcxH%k z9z^;m5b+lV;+2JM-+3zPvM9@1=&yCxjmHgF*aJYUa07kaz+GR zh@wwtxUVURAFu0~S}R2NV$PE;6+sb_m#gS)12Q`T-I+V#hLGtGErzVHLE_Is7Fdm| zdWsTTk#xU|MalSK+lj}3#$O-Md^gZK*4|bckN)8>3umb%^u3)Dy1z|=77Sx7V~u6F zAIoURbdBe+j1}D^mWd`elNrKo3Y@O!c>I!ZPR~XZS3&McjI8~hN!j&E+WR(`JIWZk2x=Sen77}&56S+|A)bz1Y z7qwl9x=EV|6w44dGkV67*$d851@RZL@JgZHiJVr}2x|*feDMH_IE6k_nLEg?QXlz8Cn@g$iiGR9wEjw;a1XuU%IYc zpg@(S{}?$qgv*c@+zm`#SESE5norcPHlJV{H?8mg?n{@dK(D%It04m1v2^`Bmm(6I zClc#lj6Ng~X?x5Mm==9YxYP8;c@FM;Rx&p3MDRN9Xt>)#`aQTG8Bq+gMhUD|#COkN z7Xw&0RS|>5JXjmR;vEB6)T7yIMVxd=P&?|V74f;5#7hv=7w{}pkPGh-gpDPNPW_Ad;VuuigOqCElk#B&e<=V4$2>Ts!mEAjOhoGkGXin(o z&2rulzZ_zjI<)21)S>Ke+Dc|!P(@WN3>5%l;>yWJMv))Z-r>^45X~)RQ$iFtQNb3K z=FR&fwkRAEwzQ^W&s%STQRvCTck|6m>hy z{7ZddG*tZ%sI9JRJ2^CjE1f&|-`h+~=^XL?l*zIk_tuE_v`V}ORXJ9J>B zE}s&E;|rCm+|C*CKIw2rA-?$Ua_{FQtUm60zNE(}<8x5v-x|ug%rL%z>Qz2cc}LHi z&Q1;B!Jaq0W@@&E``Xv z^P%u3$0!77go0%h!feS(_wAF)Q^m1}$X0`8EU!CZu-3gMVL)sI{WsjthMU*mo}z!V z;Tm*~C&RYmO~(pBBoQNMj|)#mGW^uWSSkc%JQzUC7v)Wc%%F{ z%E{E=A807?J1m=?HqHebqfa^-eZn;QRD-_x=u?d{*|cCNZ%ys@pM8#~o2DC7A;VwV78|RJnB>)Ij$EuTi|*gO zU39P~+6!c|J^+WbmudG_RHCyE_I&qtPJ6*U+6&LNooT*%JEOhuJboowlnFlq3-#ONFD4~6?+`>if+y1fcZR~jvptz5Qd zigvWsj+M1a*7b7Md&weFB`QN}vOtj#pcN`v4AV_@J+YDc<4E3@8UEal25m_&TD)>} zDhOQ{I|hWJJ5e6RWN|CxSdD~X`V*QhN1CJfk?_H<0Y1#(^Cyh)`LxG#d^DRTJfc-w zhGvwTNwyB<(0doSFnMm9x*c3UG@?va;_DtxUww!#+-a4JBQ7nfR!yn zx9$2#gW_g5nS`xm#2BmWp37#TW_Tbg^hA3NnI|T^oXV_!uC8BXM^l-aDmBfT<)Ptc z&GNiYB41N2W5xIhlwfTA&)CYe$`++625l%t>(#0IsKl9VcD>wPNSf!ytGuMtx(76V z3eda>G(H7rTmu>|0!`P|S<5d?yDgaU0tXwjCN@QC_k4bNmPX;(6r@KluG@Ed-Qe0Bme(fP(Ze}5j-^L2+4a!&LsyywXw$o*DZMySnjXeyy(!y}9GbRjD>aAnmzo>` zbhM(gSMHn}Ag26B&kZO>@=rZC;Oou}yf$#e$dashJkUISr;apic!kSALY}4}2S&(szn2$tRkXZ2qd`VHcg#J7rrpaOaD-^U4zS zCEVF?7by+YAN*H^`x&vNE!rnfJ(YuES%?W_v`>ckj-XVM`s{Io&I-)oY5Rg5YfVh; z3rGmqQmo9Sf-VzhpMfBLI)eOi1eG{jog~hVd7>`5p;AgzRdjXT7fqXwD!N-!RYf8C zd~H)WYey9>N^3CdVp70h%>aUO^BTOSHw`;_)7rK`W0dhM*T=an8P>IK@SM}scA*pXc zBz|bJ!b}w*Fdj$`GlDHHa_sabYhZ}h-~|y|xRepzm_8FzpCg(twpmw}KBzl=lB;lJ zXlF-;psvw&WnZ*ypEsj=uf$iDtt?e^&K0s$(Tw3=7wQO1#6aH1F*&SEu^-q1<2-m- z$3%J4SuCKj zZ0;HeIc8uEx;2^3U4=L%$grVm;S;Ri0;!#{0;&R2aYqajR*ubGg=STr4}GR9NwT;2EmP^xxo-_@>b`)u&={{a91 P|NjF3!%QvZjb{P?I4Q*g literal 0 HcmV?d00001 diff --git a/netam/common.py b/netam/common.py index 7a1970df..4c4fb40b 100644 --- a/netam/common.py +++ b/netam/common.py @@ -13,15 +13,16 @@ from torch import nn, Tensor import multiprocessing as mp -from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence +from netam.sequences import ( + iter_codons, + apply_aa_mask_to_nt_sequence, + RESERVED_TOKEN_TRANSLATIONS, + BASES, + AA_TOKEN_STR_SORTED, +) BIG = 1e9 SMALL_PROB = 1e-6 -BASES = ["A", "C", "G", "T"] -BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4} -AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY" -AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X" -MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1 # I needed some sequence to use to normalize the rate of mutation in the SHM model. # So, I chose perhaps the most famous antibody sequence, VRC01: @@ -65,7 +66,7 @@ def aa_idx_tensor_of_str_ambig(aa_str): character.""" try: return torch.tensor( - [AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int + [AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int ) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") @@ -88,17 +89,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask +def _consider_codon(codon): + """Return False if codon should be masked, True otherwise.""" + if "N" in codon: + return False + elif codon in RESERVED_TOKEN_TRANSLATIONS: + return False + else: + return True + + def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): """Return a mask tensor indicating codons which contain at least one N. Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, - the "and" mask will be computed for all sequences + the "and" mask will be computed for all sequences. Codons containing marker tokens + are also masked. """ if aa_length is None: aa_length = len(nt_parent) // 3 sequences = (nt_parent,) + other_nt_seqs mask = [ - all("N" not in codon for codon in codons) + all(_consider_codon(codon) for codon in codons) for codons in zip(*(iter_codons(sequence) for sequence in sequences)) ] if len(mask) < aa_length: @@ -114,7 +126,7 @@ def aa_strs_from_idx_tensor(idx_tensor): Args: idx_tensor (Tensor): A 2D tensor of shape (batch_size, seq_len) containing - indices into AA_STR_SORTED_AMBIG. + indices into AA_TOKEN_STR_SORTED. Returns: List[str]: A list of amino acid strings with trailing 'X's removed. @@ -123,7 +135,7 @@ def aa_strs_from_idx_tensor(idx_tensor): aa_str_list = [] for row in idx_tensor: - aa_str = "".join(AA_STR_SORTED_AMBIG[idx] for idx in row.tolist()) + aa_str = "".join(AA_TOKEN_STR_SORTED[idx] for idx in row.tolist()) aa_str_list.append(aa_str.rstrip("X")) return aa_str_list diff --git a/netam/dasm.py b/netam/dasm.py index 5975ab14..6dc60321 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -139,7 +139,10 @@ def prediction_pair_of_batch(self, batch): raise ValueError( f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}" ) - log_selection_factors = self.model(aa_parents_idxs, mask) + # We need the model to see special tokens here. For every other purpose + # they are masked out. + keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs) + log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) return log_neutral_aa_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): diff --git a/netam/dnsm.py b/netam/dnsm.py index 6efeda85..bd41e0af 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -163,7 +163,9 @@ def build_selection_matrix_from_parent(self, parent: str): """ parent = sequences.translate_sequence(parent) selection_factors = self.model.selection_factors_of_aa_str(parent) - selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float) + selection_matrix = torch.zeros( + (len(selection_factors), sequences.MAX_AA_TOKEN_IDX + 1), dtype=torch.float + ) # Every "off-diagonal" entry of the selection matrix is set to the selection # factor, where "diagonal" means keeping the same amino acid. selection_matrix[:, :] = selection_factors[:, None] diff --git a/netam/dxsm.py b/netam/dxsm.py index c118b28b..8539174e 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -15,7 +15,6 @@ from tqdm import tqdm from netam.common import ( - MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, stack_heterogeneous, codon_mask_tensor_of, @@ -28,6 +27,8 @@ translate_sequences, apply_aa_mask_to_nt_sequence, nt_mutation_frequency, + MAX_AA_TOKEN_IDX, + RESERVED_TOKEN_REGEX, ) @@ -43,8 +44,12 @@ def __init__( branch_lengths: torch.Tensor, multihit_model=None, ): - self.nt_parents = nt_parents - self.nt_children = nt_children + self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) + # We will replace reserved tokens with Ns but use the unmodified + # originals for translation and mask creation. + self.nt_children = nt_children.str.replace( + RESERVED_TOKEN_REGEX, "N", regex=True + ) self.nt_ratess = nt_ratess self.nt_cspss = nt_cspss self.multihit_model = copy.deepcopy(multihit_model) @@ -56,14 +61,16 @@ def __init__( assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) - aa_parents = translate_sequences(self.nt_parents) - aa_children = translate_sequences(self.nt_children) + # Important to use the unmodified versions of nt_parents and + # nt_children so they still contain special tokens. + aa_parents = translate_sequences(nt_parents) + aa_children = translate_sequences(nt_children) self.max_aa_seq_len = max(len(seq) for seq in aa_parents) # We have sequences of varying length, so we start with all tensors set # to the ambiguous amino acid, and then will fill in the actual values # below. self.aa_parents_idxss = torch.full( - (pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX + (pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX ) self.aa_children_idxss = self.aa_parents_idxss.clone() self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len)) @@ -90,7 +97,7 @@ def __init__( ) assert torch.all(self.masks.sum(dim=1) > 0) - assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX + assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX self._branch_lengths = branch_lengths self.update_neutral_probs() @@ -296,9 +303,11 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # # The following can be used when one wants a better traceback. + # The following can be used when one wants a better traceback. # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # return burrito.serial_find_optimal_branch_lengths( + # dataset, **optimization_kwargs + # ) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, diff --git a/netam/framework.py b/netam/framework.py index 87d50571..93962018 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -22,12 +22,12 @@ optimizer_of_name, tensor_to_np_if_needed, BASES, - BASES_AND_N_TO_INDEX, BIG, VRC01_NT_SEQ, encode_sequences, parallelize_function, ) +from netam.sequences import BASES_AND_N_TO_INDEX from netam import models import netam.molevol as molevol @@ -352,21 +352,78 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents): return trimmed_rates, trimmed_csps +def join_chains(pcp_df): + """Join the parent and child chains in the pcp_df. + + Make a parent column that is the parent_h + "^^^" + parent_l, and same for child. + + If parent_h and parent_l are not present, then we assume that the parent is the + heavy chain. If only one of parent_h or parent_l is present, then we place the ^^^ + padding to the right of heavy, or to the left of light. + """ + cols = pcp_df.columns + # Look for heavy chain + if "parent_h" in cols: + assert "child_h" in cols, "child_h column missing!" + assert "v_gene_h" in cols, "v_gene_h column missing!" + elif "parent" in cols: + assert "child" in cols, "child column missing!" + assert "v_gene" in cols, "v_gene column missing!" + pcp_df["parent_h"] = pcp_df["parent"] + pcp_df["child_h"] = pcp_df["child"] + pcp_df["v_gene_h"] = pcp_df["v_gene"] + else: + pcp_df["parent_h"] = "" + pcp_df["child_h"] = "" + pcp_df["v_gene_h"] = "N/A" + # Look for light chain + if "parent_l" in cols: + assert "child_l" in cols, "child_l column missing!" + assert "v_gene_l" in cols, "v_gene_l column missing!" + else: + pcp_df["parent_l"] = "" + pcp_df["child_l"] = "" + pcp_df["v_gene_l"] = "N/A" + + if (pcp_df["parent_h"].str.len() + pcp_df["parent_l"].str.len()).min() < 3: + raise ValueError("At least one PCP has fewer than three nucleotides.") + + pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"] + pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"] + + pcp_df.drop( + columns=["parent_h", "parent_l", "child_h", "child_l", "v_gene"], + inplace=True, + errors="ignore", + ) + return pcp_df + + def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): """Load a PCP dataframe from a gzipped CSV file. `orig_pcp_idx` is the index column from the original file, even if we subset by sampling or by choosing V families. + + If we will join the heavy and light chain sequences into a single + sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain + sequence is present, this separator will be added to the appropriate side of the available sequence. """ pcp_df = ( pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0) .reset_index() .rename(columns={"index": "orig_pcp_idx"}) ) - pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0] + pcp_df = join_chains(pcp_df) + + pcp_df["v_family_h"] = pcp_df["v_gene_h"].str.split("-").str[0] + pcp_df["v_family_l"] = pcp_df["v_gene_l"].str.split("-").str[0] if chosen_v_families is not None: chosen_v_families = set(chosen_v_families) - pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)] + pcp_df = pcp_df[ + pcp_df["v_family_h"].isin(chosen_v_families) + & pcp_df["v_family_l"].isin(chosen_v_families) + ] if sample_count is not None: pcp_df = pcp_df.sample(sample_count) pcp_df.reset_index(drop=True, inplace=True) @@ -374,9 +431,21 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): def add_shm_model_outputs_to_pcp_df(pcp_df, crepe): - rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"]) - pcp_df["nt_rates"] = rates - pcp_df["nt_csps"] = csps + # Split parent heavy and light chains to apply neutral model separately + split_parents = pcp_df["parent"].str.split(pat="^^^", expand=True, regex=False) + # To keep prediction aligned to joined h/l sequence, pad parent + h_parents = split_parents[0] + "NNN" + l_parents = split_parents[1] + + h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents) + l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents) + # Join predictions + pcp_df["nt_rates"] = [ + torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates) + ] + pcp_df["nt_csps"] = [ + torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps) + ] return pcp_df diff --git a/netam/models.py b/netam/models.py index 09ebd4d1..1edc8989 100644 --- a/netam/models.py +++ b/netam/models.py @@ -10,8 +10,8 @@ from torch import Tensor from netam.hit_class import apply_multihit_correction +from netam.sequences import MAX_AA_TOKEN_IDX from netam.common import ( - MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, PositionalEncoding, generate_kmers, @@ -622,7 +622,7 @@ def __init__( self.nhead = nhead self.dim_feedforward = dim_feedforward self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob) - self.amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model) + self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model) self.encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, diff --git a/netam/molevol.py b/netam/molevol.py index 2aef1c10..c089764d 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -9,7 +9,7 @@ import torch from torch import Tensor, optim -from netam.sequences import CODON_AA_INDICATOR_MATRIX +from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX import netam.sequences as sequences @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of( """ assert len(parent) % 3 == 0 - assert sel_matrix.shape == (len(parent) // 3, 20) + assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1) parent_idxs = sequences.nt_idx_tensor_of_str(parent) child_idxs = sequences.nt_idx_tensor_of_str(child) diff --git a/netam/sequences.py b/netam/sequences.py index feea5ad2..f9800ff7 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -1,6 +1,7 @@ """Code for handling sequences and sequence files.""" import itertools +import re import torch import numpy as np @@ -8,13 +9,31 @@ from Bio import SeqIO from Bio.Seq import Seq +BASES = ("A", "C", "G", "T") AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY" -NT_STR_SORTED = "ACGT" -CODONS = [ - "".join(codon_list) - for codon_list in itertools.product(["A", "C", "G", "T"], repeat=3) -] +# Add additional tokens to this string: +RESERVED_TOKENS = "^" + + +NT_STR_SORTED = "".join(BASES) +BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")} +# ambiguous must remain last. It is assumed elsewhere that the max index +# denotes the ambiguous base +AA_TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X" + +RESERVED_TOKEN_AA_BOUNDS = ( + min(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), + max(AA_TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), +) +MAX_AA_TOKEN_IDX = len(AA_TOKEN_STR_SORTED) - 1 +CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] +# Each token in RESERVED_TOKENS will appear once in aa strings, and three times +# in nt strings. +RESERVED_TOKEN_TRANSLATIONS = {token * 3: token for token in RESERVED_TOKENS} + +# Create a regex pattern +RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" def nt_idx_array_of_str(nt_str): @@ -29,7 +48,7 @@ def nt_idx_array_of_str(nt_str): def aa_idx_array_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return np.array([AA_STR_SORTED.index(aa) for aa in aa_str]) + return np.array([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -47,7 +66,7 @@ def nt_idx_tensor_of_str(nt_str): def aa_idx_tensor_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return torch.tensor([AA_STR_SORTED.index(aa) for aa in aa_str]) + return torch.tensor([AA_TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -90,26 +109,33 @@ def read_fasta_sequences(file_path): return sequences -def translate_sequences(nt_sequences): - aa_sequences = [] - for seq in nt_sequences: - if len(seq) % 3 != 0: - raise ValueError(f"The sequence '{seq}' is not a multiple of 3.") - aa_seq = str(Seq(seq).translate()) - if "*" in aa_seq: - raise ValueError(f"The sequence '{seq}' contains a stop codon.") - aa_sequences.append(aa_seq) - return aa_sequences +def translate_codon(codon): + """Translate a codon to an amino acid.""" + if codon in RESERVED_TOKEN_TRANSLATIONS: + return RESERVED_TOKEN_TRANSLATIONS[codon] + else: + return str(Seq(codon).translate()) def translate_sequence(nt_sequence): - return translate_sequences([nt_sequence])[0] + if len(nt_sequence) % 3 != 0: + raise ValueError(f"The sequence '{nt_sequence}' is not a multiple of 3.") + aa_seq = "".join( + translate_codon(nt_sequence[i : i + 3]) for i in range(0, len(nt_sequence), 3) + ) + if "*" in aa_seq: + raise ValueError(f"The sequence '{nt_sequence}' contains a stop codon.") + return aa_seq + + +def translate_sequences(nt_sequences): + return [translate_sequence(seq) for seq in nt_sequences] def aa_index_of_codon(codon): """Return the index of the amino acid encoded by a codon.""" aa = translate_sequence(codon) - return AA_STR_SORTED.index(aa) + return AA_TOKEN_STR_SORTED.index(aa) def generic_mutation_frequency(ambig_symb, parent, child): @@ -159,12 +185,12 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3): def generate_codon_aa_indicator_matrix(): """Generate a matrix that maps codons (rows) to amino acids (columns).""" - matrix = np.zeros((len(CODONS), len(AA_STR_SORTED))) + matrix = np.zeros((len(CODONS), len(AA_TOKEN_STR_SORTED))) for i, codon in enumerate(CODONS): try: aa = translate_sequences([codon])[0] - aa_idx = AA_STR_SORTED.index(aa) + aa_idx = AA_TOKEN_STR_SORTED.index(aa) matrix[i, aa_idx] = 1 except ValueError: # Handle STOP codon pass @@ -206,3 +232,10 @@ def set_wt_to_nan(predictions: torch.Tensor, aa_sequence: str) -> torch.Tensor: wt_idxs = aa_idx_tensor_of_str(aa_sequence) predictions[torch.arange(len(aa_sequence)), wt_idxs] = float("nan") return predictions + + +def token_mask_of_aa_idxs(aa_idxs: torch.Tensor) -> torch.Tensor: + """Return a mask indicating which positions in an amino acid sequence contain + special indicator tokens.""" + min_idx, max_idx = RESERVED_TOKEN_AA_BOUNDS + return (aa_idxs <= max_idx) & (aa_idxs >= min_idx) diff --git a/tests/conftest.py b/tests/conftest.py index c88350cb..3383f70f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,3 +16,15 @@ def pcp_df(): pretrained.load("ThriftyHumV0.2-45"), ) return df + + +@pytest.fixture(scope="module") +def pcp_df_paired(): + df = load_pcp_df( + "data/wyatt-10x-1p5m_paired-merged_fs-all_pcp_2024-11-21_no-naive_sample100.csv.gz", + ) + df = add_shm_model_outputs_to_pcp_df( + df, + pretrained.load("ThriftyHumV0.2-45"), + ) + return df diff --git a/tests/test_ambiguous.py b/tests/test_ambiguous.py index 86c55e23..c62fbd13 100644 --- a/tests/test_ambiguous.py +++ b/tests/test_ambiguous.py @@ -11,6 +11,7 @@ load_pcp_df, add_shm_model_outputs_to_pcp_df, ) +from netam.sequences import MAX_AA_TOKEN_IDX from netam import pretrained import random @@ -122,7 +123,10 @@ def ambig_pcp_df(): ) # Apply the random N adding function to each row df[["parent", "child"]] = df.apply( - lambda row: randomize_with_ns(row["parent"], row["child"]), + lambda row: tuple( + seq + "^^^" + for seq in randomize_with_ns(row["parent"][:-3], row["child"][:-3]) + ), axis=1, result_type="expand", ) @@ -166,7 +170,7 @@ def dasm_model(): d_model_per_head=4, dim_feedforward=256, layer_count=2, - output_dim=20, + output_dim=MAX_AA_TOKEN_IDX + 1, ) diff --git a/tests/test_common.py b/tests/test_common.py index eb22c07f..14787d5c 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -33,6 +33,6 @@ def test_codon_mask_tensor_of(): def test_aa_strs_from_idx_tensor(): - aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20], [4, 5, 19, 20, 20]]) + aa_idx_tensor = torch.tensor([[0, 1, 2, 3, 20, 21], [4, 5, 19, 21, 21, 21]]) aa_strings = aa_strs_from_idx_tensor(aa_idx_tensor) - assert aa_strings == ["ACDE", "FGY"] + assert aa_strings == ["ACDE^", "FGY"] diff --git a/tests/test_dasm.py b/tests/test_dasm.py index 6bae92ee..2c749ed8 100644 --- a/tests/test_dasm.py +++ b/tests/test_dasm.py @@ -8,6 +8,7 @@ crepe_exists, load_crepe, ) +from netam.sequences import MAX_AA_TOKEN_IDX from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dasm import ( DASMBurrito, @@ -29,7 +30,7 @@ def dasm_burrito(pcp_df): d_model_per_head=4, dim_feedforward=256, layer_count=2, - output_dim=20, + output_dim=MAX_AA_TOKEN_IDX + 1, ) burrito = DASMBurrito( diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index e0bca099..18b449e7 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -7,14 +7,15 @@ crepe_exists, load_crepe, ) -from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX, force_spawn +from netam.sequences import MAX_AA_TOKEN_IDX +from netam.common import aa_idx_tensor_of_str_ambig, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dnsm import DNSMBurrito, DNSMDataset def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" - expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int) + expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int) output = aa_idx_tensor_of_str_ambig(input_seq) assert torch.equal(output, expected_output) diff --git a/tests/test_molevol.py b/tests/test_molevol.py index 0b313fa1..3d4ab630 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -7,7 +7,7 @@ from netam.sequences import ( nt_idx_tensor_of_str, translate_sequence, - AA_STR_SORTED, + AA_TOKEN_STR_SORTED, CODONS, NT_STR_SORTED, ) @@ -114,7 +114,7 @@ def test_check_csps(): def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps): """Original version of codon_to_aa_probabilities, used for testing.""" aa_probs = {} - for aa in AA_STR_SORTED: + for aa in AA_TOKEN_STR_SORTED: aa_probs[aa] = 0.0 # iterate through all possible child codons @@ -139,7 +139,7 @@ def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps): # since probabilities to STOP codon are dropped psum = sum(aa_probs.values()) - return torch.tensor([aa_probs[aa] / psum for aa in AA_STR_SORTED]) + return torch.tensor([aa_probs[aa] / psum for aa in AA_TOKEN_STR_SORTED]) def test_aaprob_of_mut_and_sub(): diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 8866214e..761e92ef 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -1,19 +1,50 @@ import pytest +import pandas as pd import numpy as np import torch from Bio.Seq import Seq from Bio.Data import CodonTable from netam.sequences import ( + RESERVED_TOKENS, AA_STR_SORTED, + RESERVED_TOKEN_REGEX, + AA_TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, aa_onehot_tensor_of_str, nt_idx_array_of_str, nt_subs_indicator_tensor_of, translate_sequences, + token_mask_of_aa_idxs, + aa_idx_tensor_of_str, ) +def test_token_order(): + # If we always add additional tokens to the end, then converting to indices + # will not be affected when we have a proper aa string. + assert AA_TOKEN_STR_SORTED[: len(AA_STR_SORTED)] == AA_STR_SORTED + + +def test_token_replace(): + df = pd.DataFrame({"seq": ["AGCGTC" + token for token in AA_TOKEN_STR_SORTED]}) + newseqs = df["seq"].str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) + for seq, nseq in zip(df["seq"], newseqs): + for token in RESERVED_TOKENS: + seq = seq.replace(token, "N") + assert nseq == seq + + +def test_token_mask(): + sample_aa_seq = "QYX^QC" + mask = token_mask_of_aa_idxs(aa_idx_tensor_of_str(sample_aa_seq)) + for aa, mval in zip(sample_aa_seq, mask): + if aa in RESERVED_TOKENS: + assert mval + else: + assert not mval + + def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("AAA").tolist() == [0, 0, 0] assert nt_idx_array_of_str("TAC").tolist() == [3, 0, 1]