Skip to content

Commit

Permalink
added support for subset > data size (#40)
Browse files Browse the repository at this point in the history
Co-authored-by: Jspaezp <[email protected]>
  • Loading branch information
wfondrie and jspaezp authored Sep 3, 2021
1 parent 4f76a56 commit 5148da1
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog for mokapot

## [0.7.4] - 2021-09-03
### Changed
- Improved documentation and added warnings for `--subset_max_train`. Thanks
@jspaezp!

## [0.7.3] - 2021-07-20
### Fixed
- Fixed bug where the `--keep_decoys` did not work with `--aggregate`. Also,
Expand Down
6 changes: 4 additions & 2 deletions mokapot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,10 @@ def _parser():
type=int,
default=None,
help=(
"Use only a random subset of PSMs for training. "
"This is useful for very large datasets."
"Maximum number of PSMs to use during the training "
"of each of the cross validation folds in the model. "
"This is useful for very large datasets and will be "
"ignored if less PSMS are available."
),
)

Expand Down
26 changes: 20 additions & 6 deletions mokapot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,26 @@ def fit(self, psms):
)

if self.subset_max_train is not None:
subset_idx = np.random.choice(
len(psms), self.subset_max_train, replace=False
)

psms = copy.copy(psms)
psms._data = psms._data.iloc[subset_idx, :]
if self.subset_max_train > len(psms):
LOGGER.warning(
"The provided subset value (%i) is larger than the number "
"of psms in the training split (%i), so it will be "
"ignored.",
self.subset_max_train,
len(psms),
)
else:
LOGGER.info(
"Subsetting PSMs (%i) to (%i).",
len(psms),
self.subset_max_train,
)
subset_idx = np.random.choice(
len(psms), self.subset_max_train, replace=False
)

psms = copy.copy(psms)
psms._data = psms._data.iloc[subset_idx, :]

# Choose the initial direction
start_labels, feat_pass = _get_starting_labels(psms, self)
Expand Down
2 changes: 2 additions & 0 deletions tests/system_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def test_cli_options(tmp_path, scope_files):
"--max_iter",
"1",
"--keep_decoys",
"--subset_max_train",
"50000",
]

subprocess.run(cmd, check=True)
Expand Down
12 changes: 12 additions & 0 deletions tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ def test_model_fit(psms):
assert model.is_trained


def test_model_fit_large_subset(psms):
model = mokapot.Model(
LogisticRegression(),
train_fdr=0.05,
max_iter=1,
subset_max_train=2_000_000_000,
)
model.fit(psms)

assert model.is_trained


def test_model_predict(psms):
"""Test predictions"""
model = mokapot.Model(LogisticRegression(), train_fdr=0.05, max_iter=1)
Expand Down

0 comments on commit 5148da1

Please sign in to comment.