-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrain.py
29 lines (25 loc) · 1.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from trainers.sick_trainer import SICKTrainer
from trainers.msrp_trainer import MSRPTrainer
from trainers.msrvid_trainer import MSRVIDTrainer
from trainers.trecqa_trainer import TRECQATrainer
from trainers.wikiqa_trainer import WikiQATrainer
from trainers.sts_trainer import STSTrainer
class MPCNNTrainerFactory(object):
"""
Get the corresponding Trainer class for a particular dataset.
"""
trainer_map = {
'sick': SICKTrainer,
'msrp': MSRPTrainer,
'msrvid': MSRVIDTrainer,
'trecqa': TRECQATrainer,
'wikiqa': WikiQATrainer,
'sts': STSTrainer
}
@staticmethod
def get_trainer(dataset_name, model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator=None, nonstatic_embedding=None):
if dataset_name not in MPCNNTrainerFactory.trainer_map:
raise ValueError('{} is not implemented.'.format(dataset_name))
return MPCNNTrainerFactory.trainer_map[dataset_name](
model, embedding, train_loader, trainer_config, train_evaluator, test_evaluator, dev_evaluator, nonstatic_embedding
)