From 21680cc5b7136359c033bb0c7fc5d0f7b002c931 Mon Sep 17 00:00:00 2001 From: Will Fondrie Date: Mon, 18 Jul 2022 10:03:08 -0700 Subject: [PATCH] Fixed need for intercept (#62) * Fixed need for intercept and cleaned brew * Create models for test on-the-fly * Fixed capitalization --- CHANGELOG.md | 8 ++++++- data/models/mokapot.model_fold-1.pkl | Bin 2332 -> 0 bytes data/models/mokapot.model_fold-2.pkl | Bin 2332 -> 0 bytes data/models/mokapot.model_fold-3.pkl | Bin 2332 -> 0 bytes mokapot/brew.py | 31 +++++++++++++++------------ mokapot/config.py | 7 +----- mokapot/model.py | 7 ++++++ tests/system_tests/test_cli.py | 15 ++++--------- tests/unit_tests/test_brew.py | 19 ++++++++++++++-- 9 files changed, 53 insertions(+), 34 deletions(-) delete mode 100644 data/models/mokapot.model_fold-1.pkl delete mode 100644 data/models/mokapot.model_fold-2.pkl delete mode 100644 data/models/mokapot.model_fold-3.pkl diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fab22d2..2030c9a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,13 @@ # Changelog for mokapot -## Unreleased +## [0.8.2] - 2022-07-18 +### Added +- `mokapot.Model()` objects now recored the CV fold that they were fit on. + This means that they can be provided to `mokapot.brew()` in any order + and still maintain proper cross-validation bins. + ### Fixed +- Resolved issue where models were required to have an intercept term. - The PepXML parser would sometimes try and log transform features with `0`'s, resulting in missing values. ## [0.8.1] - 2022-06-24 diff --git a/data/models/mokapot.model_fold-1.pkl b/data/models/mokapot.model_fold-1.pkl deleted file mode 100644 index a95959191b36081249fd49d57d78bf4a7e3dfe73..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2332 zcmZ`)4{#La6~819!zDmMh}r-u7{VF#Od=*R3g6O84CIpI+-O+-x$d9uvPbTA@9pj# zBqS7K2$#-Cm5xO_D5Y9qN*l^xU|LIWQb!S5(2z7@NuZS2ifJKY>eLQMzx{F-GgJHR z%+B|{_x8Q_``&xM-I^i=wMHSNna-dlT>PA-M)O)UV^s!(ppUq|M;9^{gU$OZN)G)gvPap~k03$+ZX31N1y$C9ykQC^hFKmijo1W*O}w8(gzVZ+Q|MPAtO^0M31)N2 zL?0-I9Evlb83x==(=;F~=vd|zk_?7BxEaJw<$YK%!#Xwy0v_&k=TP_Y>*L`&k{S+% znw*lRV<)`>LDz*QJFjC&Qw>uOOQs!i$F1zDJg$W~ZrgU4>q-(~-VV5ng}WUDUh;Oh zCvGU4O(ATD`7~Ew87%0G?tpt)$aA<;=>0zq3q2kWEi`?(Z6@T4ovuuh^Q1M3V37$9 zk()|{l7@Xe+^6I!G<<8a?8CG9FLPmK=74y`?Ot}*p9v~4i>`!~x{nVXr1w+O6_+1+wz z$IgO*L;7*_@;@9c-+%3$*k4{e*u1LnPqC60de0^o^r5#e z7d`vv`T^AX{>8|GTfH#@J+sGF)fHnF9o$$D9F8rT)!y{c@QoNaO?XJ;?8N^!y#$2` zctsYxViq2z@V3KZ3tZY~M@#QN^T|yFkC;#*(#B9}qwrAre@bb!r=Z)ROc^hSe=$62 zm7|=6$4EJp|1HGhqHBsZeXyhhma>32t15ha{&>>9hF?>Gp+zZj7CMVzSy-ffm1s8j z?s{-Js6jy0>q1aLQzaZ=*0@xtq~mU#$l(dvB*dnQ+N{vjCv-uA<=jjXmBl_mBJ+=B z&ZX%(te{=(!3{N9Ls6l(UI>ISRCC$HG$CZMQ5lJA++5BijY|uPgfHT{8k&R`hib42 zwKTR%5%dOJ1fEH!Vpuup^e}iQofgBBlTMESOGgr`%^s*rY+dja4Me0XR1==2_Y_b) z;b29b7wRc^xfWK@%MwI`>^VH%6$WqW0!dhg)lpglBP<%&rBEkf-culV>_%Pe7U|Ao^ep7-RL(JK8OVs_$W=03Jg<* zXQLE~q=lM5h$}^NFssiWi6zvY?>@ z6c+pxzm)(F8xk|nXw5)?g`hP9Iom;{1HN+b*CngJyfQK19a$S*so3>3a`g6J9eAUS zDcIk3uzyn@GTvxTckfPSyuX@0|Aa4xc{fxv^7gX1%$lvo`>}rxlRt7Tg=I?^D+_CV z_`~%*KfT_IO4Td1Z|WbTwi~ZrFosT};ji~k9XoUstoWz?7sfxn-ORxghYCI5kcnx60md&mcfDqyydlxEZQ#v17IfnNAEid%<75U78 zf1SN9&s)lPY%M2$QC-T+ztMQjpRSdIp4p$|0EhacIxKy2hO0>lVy+X`Q$yc z_x=2ZOMdwl+VGp9R-bn#s(fvozw-26^gzbY{<)sZXt=ui)N`x;jm{6`RS0Es(7_`k zm7iy3B8Zr_R@(suu-(GbobBH`e}@(YoC$2 zxQ#h+`;q^6YkQF6#?N>BF7-M}zudL-!drI6dt4v!eUZbYXZ&FISGCzpRgd@19~Cc* z0~+b_`UIUl&%FSMfw0!&OE)U1434|ELqanoI)!A1!=Z_r zblHTY@-N8}W)KZFtw6~3hiN{uE{G*VbD8xam*@*hrio3aBA^jvRkfRe=_FQiGu-49 zrje6ML<&MmyO2comL7x|r6{ZLA4Buq_X!C^$c>Ouc{z!m3S&(QS`9E) zNF?@PK4emyLESXr7Mi9B5y`-csFGwb+$!8m{4~*zB`acJlOPab9+yeO$DU7w+k9Fi z6mD?&bOXES8%TyBH87%qeY$2^M#N_^kTv1OXv%~YvN^|2nD0)IVZkm~$iwYU0xxkV z+!1qBtcEaVU=hvLUj%oyx9x&&@sQ)>lIZ(C2#Y-)4=ptPa1IM{<#u;E$$82eL2#D^ zPFa{vls+B%MYvndQt9>gsO0W%^Kh?ycLJB$3iqiw?e4>`5A+=1%8sIVn4FH{&2R1< zL`%otW#)8UL~nh+d1KA-zG!<@LS)ai!RYgUw3fc~*1u5CN5A-Y_P%8H^_lgTUp(Mo zKj|$!`~JrVqVYKA%vC?A;ZLEXthXn9-!ZhCuTdL6csts4aP)l5!oNntKN(NCAornT z8@b`+X9v)-N7eKX3eH659cmib+VpC)rRDR(-p)(W!PgeN65__A;IiO;Szt)~uX+iV z#Np-H@bYd`4dU|ngLw{mKG(?wZxSVt0OY)t9Y}) zLmR;Dqz*yNs0u?dO_gwfRpHj4gpRu{5#eFlB*dqb`jqL~!-nL8HNq@%E06t>kIX+- z1h;M&P)57jgKI1F+Po#+YAG1OP%dN;*SM0$W=Va_!U>sEEN(p{6TXPqDrgd39In6` zRMNZMs$|sSJn&4B@?q^1X(@Q8NCoi76lodobR@CX=z*&E)(wx+i-@*Nv*0oMPC?BR z2^Cj)p_+nMD&cYZSdwg#y?`gW!W1n-A_*(7u8r2fjL0UoM4KF_1=f3LY`esSE1sC- z$gqLKhVH|rNw~lzDv~8JRTkkH$xy1OGXph*FE=#9cb$}JHwP?L3}HzFv5lrk6{e-Y z#x@GYr-vIrYEx|ru`8pArW6VX39*z5BTdF%rCoBik~4t$b#Bd zPLl7ojn6V9zymEWbOow4*1%0Zx*h*c4czFTW4?jY4P40$l2R}b>M{- z_MVQGm;1N$A@haCRBnGF>wRtJq8@)H`%XA-=y$8&Jm?Hg{+^iRTMhI}3ZA z{?Uf+=SO?da_vgxOUC=C<@&+%=HMxG=|;!&u~&bCa-Hw}>C%p2l(GAcleu49MZLKx zD}MgzCurcHyyDrxJ6Y+m(bTlttE0XP8GZe0wj$6i2+Q>G>Lo)2h{G~$SSAlv3|54~ zI@#8F>zH(C64sn6W7XGY<*-@f3yySk&RyN;fBVC&zam{xeEGQS`upkB+0~ z-pyUS@>g%5O}`y%_IvlBlAqNCN=`k8?oAu)nD6-zT`Dg>`PAe8L1zbYilw4#^zzZ6 zl26mq5!73-S^lc;MA#As)?fo`+Sd~UHiBxF?|)INp22#K95Ry& zTG*ajj{etM*^Qjnf4S=q$)hOs!y~KCzrnEHW5$sGvrINM?dknrRA#WH-QIaW&R-k@ zHiDjbvY)#cJ%?5f{LLHYFQ6|wS5&snK8xV&I%ma_GibZ>^gE@Cqo}EReQ(E}<0v7E zNqq0`m(kV!g=x;;e}?8}|HZxI{0{WP?lgaW$~FYe7Cb4dtZK8F&3|;pM89Id_e#-@ z#*3v#=5AFpIEEUP0Qp5)HW6-{_+{&a7>$5}VS76z1yhaq{Xq;ZZPZ`XFjh>_R}W9M IM`YLk05SPD1^@s6 diff --git a/data/models/mokapot.model_fold-3.pkl b/data/models/mokapot.model_fold-3.pkl deleted file mode 100644 index b0b99847e9ee43f8b1f9340adfb976884a68c7da..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2332 zcmZ`)eN+_J6< zq687nX>4nc^(2SH+I-+?QqsmyPTJBf#}lod+D4(o)Zj;B6D_Ao46#j*n!b4~gmap{ zKjysq@$S37cklhp&ICE6*NI`>a)opS2N@qWWIZTax=}%8aNNHO;+i4e$tO7+4o#@l zB@+_Lz9NgCPBhrG0wLQU5@oY4h($wlne`!;AO}U$#3oZ7(1^0K%FVzu5-X`0?sD>p z3R2Vr^%%3S@okp=Dt_&B0+15<)V$%ZUGg?vJhpcNPo3n}!JUT;7u zKZDr(0ZRyI7D*n5Ett{H&nC)Ag<#5JFrd}IOivu2N-Q)tp#cYKR10P`Lz>g$@Gzb@ z4-?_~Y!j5RT6HNspz zp4fwVkU?<1-G))sCqJb4bCCOm8o1aPiG{KKWD`H@iAQ0dlPX-Midp-f?%UUE9 zZgk1IfnD?sM8gmp8NtA^u9=n*ku3%?$GsR$8Mi`~$FUn0xZ@;PxCa(-u-Hl9#qWlD zV~&c|7{&}Np}G1CVQEME9{3gq*-lRaeg8*bnb+&3g{B{m!-5>C!<|NQp0Y*}+-HGP z;-?X%tYg0b_p6yIy?(h$?*29hE9|@DJQ;29fSTRmKHMjL+UY4hf+oTws)Y^j?i)tE zavw9N_Y(R^fq!fDv4QA{tBWIhZw^OqAFW;W`tSdNnnt?*m9;;KZS6O&yw>Sp%eFQT zfA~pfbRtgW#g^|?bEnYE&gI9`_V=M%^PW|kE*_1}@4R!qdeL8^Pj|;nzaR~uvGm8U zC%rg?8um7(UCcid?VBNfuzkl{(YN3J;;^sha`fPmrEi8jx1->);6aIJNc^vR306+P z%eCR#p@`$d_|KUR3R zZWvHXyV{Fu%JrJumA)!57{O4+rxVu+C5O%8`k2MTXH2oU^^ipPB4#V6N%(NM9BWWP z?{=%AQG;{AJ4MQajZ>sm;F}`l!{bw=hk&CaiM1v#R8DN&@LhTl(Uxi!JVD>0?|4H_NrjUTy%Kvwi`)d1v1s4rH-8Bi9o+ZZ&IXVXyap zw59*2H%_Ah?P|sA#)qio)=THj;Zx}HovvwPZ~Y48IM4p^a_e=J{`|csbH2QWPUj@A zdF8WzqM?_hH7^b?WyL3Mq@*scipm$#2L?B6N1$5}mgwWvi-rI&0n4yqnH*R#SP=^A zWP8uuW8$GnSaYt9Ro$GG&1T+S*xlPRcYTxp=r8+w_Mtp3q2s}gW9ZN)+2K!1bJ)XQ zoVlSaT*rDHJ5Ib=R=_T~Rd+p*QO#0Jay#4&)KcS=(g%5ZB{TMp< zhn!_=fBp`t{q=CO-?tYPzg!(CK6MbSNFDB4;Qbg~E-O3v+>`%C=Z3P2#KJ7}#*vZY z&(hKm)LYOXecg8gY?}bqXaj5FV0#SM4hpPpG*Ri@e>?_k6jc`-culLC&U(8KnMwIA z?D2U={^P6YN6uS6+w;4m8z|-D?seziVOU?EG2;I`gH1{O{(&zm(%F)J-#tIhTNVR0 zik^CA&~qvJ5n4O+fiKKmKwtH&sc4&h4#Bxi&hnLK&@Siu@0Bczq8(M6Pj~G-hT<}r z__Ke#g02lNN_GDBb2K;W&+gXqt>}mSss8%pod}vOct%oL)n+rB|LBYfe#L-iOHgam zrIPNs+tqXrLyby+{30!z2zQSEvUNg?MnJ)^tAmn)sYd+%AcmH9>Mv**E2bdV!*d-G G$@O1HQ8eNJ diff --git a/mokapot/brew.py b/mokapot/brew.py index 6828774a..8e887c41 100644 --- a/mokapot/brew.py +++ b/mokapot/brew.py @@ -46,7 +46,7 @@ def brew(psms, model=None, test_fdr=0.01, folds=3, max_workers=1): machine models used by Percolator. If a list of :py:class:`mokapot.Model` objects is provided, they are assumed to be previously trained models and will and one will be - used to rescore each fold in the order they are provided. + used to rescore each fold. test_fdr : float, optional The false-discovery rate threshold at which to evaluate the learned models. @@ -98,8 +98,8 @@ def brew(psms, model=None, test_fdr=0.01, folds=3, max_workers=1): # If trained models are provided, use the them as-is. try: - models = [[m, False] for m in model if m.is_trained] - assert len(models) == len(model) # Test that all models are fitted. + fitted = [[m, False] for m in model if m.is_trained] + assert len(fitted) == len(model) # Test that all models are fitted. assert len(model) == folds except AssertionError as orig_err: if len(model) != folds: @@ -114,35 +114,37 @@ def brew(psms, model=None, test_fdr=0.01, folds=3, max_workers=1): raise err from orig_err except TypeError: - models = Parallel(n_jobs=max_workers, require="sharedmem")( + fitted = Parallel(n_jobs=max_workers, require="sharedmem")( delayed(_fit_model)(d, copy.deepcopy(model), f) for f, d in enumerate(train_sets) ) - # sort models to have deterministic results with multithreading. - # Only way I found to sort is using intercept values - models.sort(key=lambda x: x[0].estimator.intercept_) + # Sort models to have deterministic results with multithreading. + fitted.sort(key=lambda x: x[0].fold) + models, resets = list(zip(*fitted)) # Determine if the models need to be reset: - reset = any([m[1] for m in models]) + reset = any(resets) + + # If we reset, just use the original model on all the folds: if reset: - # If we reset, just use the original model on all the folds: scores = [ p._calibrate_scores(model.predict(p), test_fdr) for p in psms ] - elif all([m[0].is_trained for m in models]): - # If we don't reset, assign scores to each fold: - models = [m for m, _ in models] + + # If we don't reset, assign scores to each fold: + elif all([m.is_trained for m in models]): scores = [ _predict(p, i, models, test_fdr) for p, i in zip(psms, test_idx) ] + + # If model training has failed else: - # If model training has failed scores = [np.zeros(len(p.data)) for p in psms] # Find which is best: the learned model, the best feature, or # a pretrained model. - if not all([m.override for m in models]) or not model.override: + if not all([m.override for m in models]): best_feats = [p._find_best_feature(test_fdr) for p in psms] feat_total = sum([best_feat[1] for best_feat in best_feats]) else: @@ -280,6 +282,7 @@ def _fit_model(train_set, model, fold): """ LOGGER.info("") LOGGER.info("=== Analyzing Fold %i ===", fold + 1) + model.fold = fold + 1 reset = False try: model.fit(train_set) diff --git a/mokapot/config.py b/mokapot/config.py index 8ffbb40c..18e552bb 100644 --- a/mokapot/config.py +++ b/mokapot/config.py @@ -17,7 +17,7 @@ def _fill_text(self, text, width, indent): class Config: """ - The xenith configuration options. + The mokapot configuration options. Options can be specified as command-line arguments. """ @@ -259,11 +259,6 @@ def _parser(): help=( "Load previously saved models and skip model training." "Note that the number of models must match the value of --folds." - "If the models are being applied to the same dataset that they " - "were trained on originally, the models must be provided in the " - "same order as the folds to maintain cross-vaildation fold " - "relationships. Failure to do so will result in invalid FDR " - "estimates." ), ) diff --git a/mokapot/model.py b/mokapot/model.py index 9f2b963a..59f6f1ee 100644 --- a/mokapot/model.py +++ b/mokapot/model.py @@ -114,6 +114,8 @@ class Model: The number of PSMs for training. shuffle : bool Is the order of PSMs shuffled for training? + fold : int or None + The CV fold on which this model was fit, if any. """ def __init__( @@ -146,6 +148,11 @@ def __init__( self.override = override self.shuffle = shuffle + # To keep track of the fold that this was trained on. + # Needed to ensure reproducibility in brew() with + # multiprocessing. + self.fold = None + # Sort out whether we need to optimize hyperparameters: if hasattr(self.estimator, "estimator"): self._needs_cv = True diff --git a/tests/system_tests/test_cli.py b/tests/system_tests/test_cli.py index 78b3e562..5b6e859d 100644 --- a/tests/system_tests/test_cli.py +++ b/tests/system_tests/test_cli.py @@ -34,12 +34,6 @@ def pepxml_file(): return pepxml -@pytest.fixture -def models_path(): - """Get the saved models""" - return sorted(list(Path("data", "models").glob("*"))) - - def test_basic_cli(tmp_path, scope_files): """Test that basic cli works.""" cmd = ["mokapot", scope_files[0], "--dest_dir", tmp_path] @@ -183,7 +177,7 @@ def test_cli_pepxml(tmp_path, pepxml_file): assert len(binned) > len(unbinned) -def test_cli_saved_models(tmp_path, phospho_files, models_path): +def test_cli_saved_models(tmp_path, phospho_files): """Test that saved_models works""" cmd = [ "mokapot", @@ -192,12 +186,11 @@ def test_cli_saved_models(tmp_path, phospho_files, models_path): tmp_path, "--test_fdr", "0.01", - "--load_models", - models_path[0], - models_path[1], - models_path[2], ] + subprocess.run(cmd + ["--save_models"], check=True) + + cmd += ["--load_models", *list(Path(tmp_path).glob("*.pkl"))] subprocess.run(cmd, check=True) assert Path(tmp_path, "mokapot.psms.txt").exists() assert Path(tmp_path, "mokapot.peptides.txt").exists() diff --git a/tests/unit_tests/test_brew.py b/tests/unit_tests/test_brew.py index 583ec13b..75b301d5 100644 --- a/tests/unit_tests/test_brew.py +++ b/tests/unit_tests/test_brew.py @@ -2,7 +2,8 @@ import pytest import numpy as np import mokapot -from mokapot import PercolatorModel +from mokapot import PercolatorModel, Model +from sklearn.ensemble import RandomForestClassifier np.random.seed(42) @@ -21,6 +22,18 @@ def test_brew_simple(psms, svm): assert isinstance(models[0], PercolatorModel) +def test_brew_random_forest(psms): + """Verify there are no dependencies on the SVM.""" + rfm = Model( + RandomForestClassifier(), + train_fdr=0.1, + ) + results, models = mokapot.brew(psms, model=rfm, test_fdr=0.1) + assert isinstance(results, mokapot.confidence.LinearConfidence) + assert len(models) == 3 + assert isinstance(models[0], Model) + + def test_brew_joint(psms, svm): """Test that the multiple input PSM collections yield multiple out""" collections = [psms, psms, psms] @@ -58,8 +71,10 @@ def test_brew_trained_models(psms, svm): psms, svm, test_fdr=0.05 ) np.random.seed(3) + models = list(models_with_training) + models.reverse() # Change the model order results_without_training, models_without_training = mokapot.brew( - psms, models_with_training, test_fdr=0.05 + psms, models, test_fdr=0.05 ) assert models_with_training == models_without_training assert results_with_training.accepted == results_without_training.accepted