-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrandomsurvivalforest.py
38 lines (32 loc) · 1.34 KB
/
randomsurvivalforest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pandas as pd
import numpy as np
from sksurv.ensemble import RandomSurvivalForest
from sksurv.util import Surv
import pickle
################################
# Random Survival Forest
################################
def train_rsf(rsf_x, rsf_y, rsf_filename, tune='False'):
print("Tuning Random Survival Forest")
# Training RSF
random_state = 20
if tune: # tuned
rsf = RandomSurvivalForest(n_estimators=775,
min_samples_split=5,
min_samples_leaf=5,
max_features="sqrt",
max_depth=671,
n_jobs=-1,
random_state=random_state)
else: # untuned
rsf = RandomSurvivalForest(n_estimators=1000,
min_samples_split=10,
min_samples_leaf=15,
max_features="sqrt",
max_depth=10,
n_jobs=-1,
random_state=random_state)
print("Fitting rsf")
rsf.fit(rsf_x, Surv.from_dataframe('breakdown', 'cycle', rsf_y)) # y only takes structured data
# save trained model
pickle.dump(rsf, open(rsf_filename, 'wb'))