-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-spacy-textcat.py
215 lines (188 loc) · 7.87 KB
/
train-spacy-textcat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#!/usr/bin/env python
# coding: utf8
"""
Trains a baseline script for convolutional neural network text classifier using out of the box spacy the TextCategorizer component.
be sure to run prepare-for-spacy-and-pytorch.py using a directory with the original train.csv and validation.csv file to an output dir (e./g., "data-spacy-pytorch-jsonl" is used in this case
Usage: python train_spacy_textcategorizer.py <input_dir> <labels>
e.g., python train_spacy_textcategorizer.py -i "data-spacy-pytorch-jsonl" -l "ABUSE,UNRELATED,
* Training: https://spacy.io/usage/training
Compatible with: spaCy v2.0.0+
"""
from __future__ import unicode_literals, print_function
import plac
import random
from pathlib import Path
from utilz import listify, matches, get_top_cat
import spacy
from spacy.util import minibatch, compounding
import jsonlines
@plac.annotations(
input_dir=("Input directory with data folder for train and validation", "option", "i", Path),
labels=("String with labels that we are predicting for", "option", "l", str),
model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
output_dir=("Optional output directory", "option", "o", Path),
n_texts=("Number of texts to train from", "option", "t", int),
n_iter=("Number of training iterations", "option", "n", int),
init_tok2vec=("Pretrained tok2vec weights", "option", "t2v", Path),
)
def main(input_dir="data-jsonl", labels="ABUSE,CONSUMPTION,MENTION,UNRELATED", model=None, output_dir=None, n_iter=20, n_texts=10537, init_tok2vec=None):
if output_dir is not None:
output_dir = Path(output_dir)
if not output_dir.exists():
output_dir.mkdir()
if model is not None:
nlp = spacy.load(model) # load existing spaCy model
print("Loaded model '%s'" % model)
else:
nlp = spacy.load("en_core_sci_lg")
nlp.add_pipe(nlp.create_pipe("sentencizer"))
print("Created en core sci lg base model")
# add the text classifier to the pipeline if it doesn't exist
# nlp.create_pipe works for built-ins that are registered with spaCy
if "textcat" not in nlp.pipe_names:
textcat = nlp.create_pipe(
"textcat", config={"exclusive_classes": True, "architecture": "simple_cnn"}
)
nlp.add_pipe(textcat, last=True)
# otherwise, get it, so we can add labels to it
else:
textcat = nlp.get_pipe("textcat")
labels = labels.split(',')
labels = listify(labels)
# add label to text classifier
for label in labels:
print(f'adding label {label}')
textcat.add_label(label)
print("Loading data...")
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(input_dir, labels=labels)
train_texts = train_texts[:n_texts]
train_cats = train_cats[:n_texts]
print(
"Using {} examples ({} training, {} evaluation)".format(
n_texts, len(train_texts), len(dev_texts)
)
)
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
print(f'sample of training data:\n {train_data[:10]}')
dev_data = list(zip(dev_texts,
[{'cats': cats} for cats in dev_cats]))
print(f'sample dev:\n {dev_data[:10]}')
# get names of other pipes to disable them during training
pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"]
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
with nlp.disable_pipes(*other_pipes): # only train textcat
optimizer = nlp.begin_training()
if init_tok2vec is not None:
with init_tok2vec.open("rb") as file_:
textcat.model.tok2vec.from_bytes(file_.read())
print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
batch_sizes = compounding(4.0, 64.0, 1.001)
for i in range(n_iter):
losses = {}
# batch up the examples using spaCy's minibatch
random.shuffle(train_data)
batches = minibatch(train_data, size=batch_sizes)
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
with textcat.model.use_params(optimizer.averages):
scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats)
print(
"{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}".format( # print a simple table
losses["textcat"],
scores["textcat_p"],
scores["textcat_r"],
scores["textcat_f"],
)
)
# test the trained model
# example_dev_texts = dev_data[:10]
# for objs in example_dev_texts:
# test_text = objs[0]
# test_cats = objs[1]
# doc = nlp(test_text)
# predicted_cat = get_top_cat(doc)
# true_class = [k for k, v in test_cats.items() if v == True]
# true_class = true_class[0]
# print(f'TEXT:{test_text}\tPRED CLASS:{predicted_cat}\tTRUE CLASS:{true_class}')
#
# test_text = "I lie about my anxiety so I can get prescriptionz for klonopin hahah"
# doc = nlp(test_text)
# print(test_text, doc.cats)
if output_dir is not None:
with nlp.use_params(optimizer.averages):
nlp.to_disk(output_dir)
print("Saved model to", output_dir)
def load_data(input_dir, labels):
labels = listify(labels)
training_path = Path(input_dir) / 'train.jsonl'
# Partition off part of the train data for evaluation
train_texts = []
train_cats = []
with jsonlines.open(training_path) as reader:
for obj in reader:
cats = {}
text = obj['text']
train_texts.append(text)
label = obj['label']
for k in labels:
cats[k] = matches(k, label)
train_texts.append(text)
train_cats.append(cats)
validation_path = Path(input_dir) / 'validation.jsonl'
dev_texts = []
dev_cats = []
with jsonlines.open(validation_path) as reader:
for obj in reader:
labs = {}
text = obj['text']
lab = obj['label']
for k in labels:
labs[k] = matches(k, lab)
dev_texts.append(text)
dev_cats.append(labs)
# val_data.append(obj)
return (train_texts, train_cats), (dev_texts, dev_cats)
# train_data = list(zip(train_texts, [{'cats': cats} for cats in train_cats]))
# print(train_data[0])
# random.shuffle(train_data)
# print(train_data[0])
# dev_data = list(zip(dev_texts,
# [{'cats': cats} for cats in dev_cats]))
#
# return train_data, dev_data
def evaluate(tokenizer, textcat, texts, cats):
docs = (tokenizer(text) for text in texts)
tp = 0.0 # True positives
fp = 1e-8 # False positives
fn = 1e-8 # False negatives
tn = 0.0 # True negatives
for i, doc in enumerate(textcat.pipe(docs)):
gold = cats[i]
for label, score in doc.cats.items():
if label not in gold:
continue
if label == 'CONSUMPTION':
continue
if label == 'MENTION':
continue
if label == 'UNRELATED':
continue
if score >= 0.5 and gold[label] >= 0.5:
tp += 1.0
elif score >= 0.5 and gold[label] < 0.5:
fp += 1.0
elif score < 0.5 and gold[label] < 0.5:
tn += 1
elif score < 0.5 and gold[label] >= 0.5:
fn += 1
precision = tp / (tp + fp)
recall = tp / (tp + fn)
if (precision + recall) == 0:
f_score = 0.0
else:
f_score = 2 * (precision * recall) / (precision + recall)
return {"textcat_p": precision, "textcat_r": recall, "textcat_f": f_score}
if __name__ == "__main__":
plac.call(main)