Skip to content

Commit

Permalink
kfold
Browse files Browse the repository at this point in the history
  • Loading branch information
Emre Akkaya committed Jul 9, 2018
1 parent 8286bcf commit b88f948
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions cross-validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import codecs
import numpy as np
from sklearn.model_selection import ShuffleSplit
from sklearn.model_selection import KFold
from shutil import copyfile
import subprocess

Expand Down Expand Up @@ -39,16 +39,15 @@ def write(path, sts):
# Need numpy array so that we can 'extract' using indices
sentences = np.array(sentences)

rs = ShuffleSplit(n_splits=5, train_size=0.8, test_size=0.1)
rs = KFold(n_splits=10)
count = rs.get_n_splits()

# Generate n-fold CV files & build, train, eval
for train_index, test_index in rs.split(sentences):
# Find dev index as well...
temp = list()
temp.extend(train_index)
temp.extend(test_index)
dev_index = list(set(range(0, len(sentences))) - set(temp))
numb_dev = len(train_index) // 10
dev_index = train_index[-1*numb_dev:]
train_index = train_index[:-1*numb_dev]

# Extract sentences from indices
train_sentences = sentences[train_index]
Expand Down

0 comments on commit b88f948

Please sign in to comment.