-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathpredict.py
102 lines (79 loc) · 2.89 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
The entry point for your prediction algorithm.
"""
from __future__ import annotations
import argparse
import csv
import itertools
from pathlib import Path
import pprint
from typing import Any
import zipfile
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.ResidueDepth import get_surface
from Bio.PDB.vectors import calc_dihedral
from Bio.PDB.Structure import Structure
import temppathlib
def predict(pdb_file: Path) -> float:
"""
The function that puts it all together: parsing the PDB file, generating
features from it and performing inference with the ML model.
"""
# parse PDB
parser = PDBParser()
structure = parser.get_structure(pdb_file.stem, pdb_file)
# featurize + perform inference
features = featurize(structure)
predicted_solubility = ml_inference(features)
return predicted_solubility
def featurize(structure: Structure) -> list[Any]:
"""
Calculates 3D ML features from the `structure`.
"""
# get all the residues
residues = [res for res in structure.get_residues()]
# calculate some random 3D features (you should be smarter here!)
protein_length = residues[1]["CA"] - residues[-2]["CA"]
angle = calc_dihedral(
residues[1]["CA"].get_vector(),
residues[2]["CA"].get_vector(),
residues[-3]["CA"].get_vector(),
residues[-2]["CA"].get_vector(),
)
# create the feature vector
features = [protein_length, angle]
return features
def ml_inference(features: list[Any]) -> float:
"""
This would be a function where you normalize/standardize your features and
then feed them to your trained ML model (which you would load from a file).
"""
# this is my stupid manual ML model
if features[0] > 15.0 and features[1] > 0.5:
return 60
elif features[0] > 30.0 and features[1] > 1.5:
return 80
return 20
if __name__ == "__main__":
# set up argument parsing
parser = argparse.ArgumentParser()
parser.add_argument("--infile", type=str, default="data/test.zip")
args = parser.parse_args()
predictions = []
# use a temporary directory so we don't pollute our repo
with temppathlib.TemporaryDirectory() as tmpdir:
# unzip the file with all the test PDBs
with zipfile.ZipFile(args.infile, "r") as zip_:
zip_.extractall(tmpdir.path)
# iterate over all test PDBs and generate predictions
for test_pdb in tmpdir.path.glob("*.pdb"):
predictions.append({"protein": test_pdb.stem, "solubility": predict(test_pdb)})
# save to csv file, this will be used for benchmarking
outpath = "predictions.csv"
with open(outpath, "w") as fh:
writer = csv.DictWriter(fh, fieldnames=["protein", "solubility"])
writer.writeheader()
writer.writerows(predictions)
# print predictions to screen
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(predictions)