-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathTrainer.py
120 lines (99 loc) · 7.36 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# utils
import torch
import os
import pandas as pd
import gc
# data
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
# models
import torch.nn as nn
from transformers import AutoConfig, AutoTokenizer, AutoModel
from models.ContextAwareDAC import ContextAwareDAC
# training and evaluation
import wandb
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from dataset.dataset import DADataset
class LightningModel(pl.LightningModule):
def __init__(self, config):
super(LightningModel, self).__init__()
self.config = config
self.model = ContextAwareDAC(
model_name=self.config['model_name'],
hidden_size=self.config['hidden_size'],
num_classes=self.config['num_classes'],
device=self.config['device']
)
self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
def forward(self, batch):
logits = self.model(batch)
return logits
def configure_optimizers(self):
return optim.Adam(params=self.parameters(), lr=self.config['lr'])
def train_dataloader(self):
# train_data = load_dataset("csv", data_files=os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_train.csv"))
train_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_train.csv"))
train_dataset = DADataset(tokenizer=self.tokenizer, data=train_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
drop_last = True if len(train_dataset.text) % self.config['batch_size'] == 1 else False # Drop last batch if it cointains a single sample (causes error)
train_loader = DataLoader(dataset=train_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
return train_loader
def training_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
wandb.log({"loss":loss, "accuracy":acc, "f1_score":f1})
return {"loss":loss, "accuracy":acc, "f1_score":f1}
def val_dataloader(self):
#valid_data = load_dataset("csv", data_files=os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_valid.csv"))
valid_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_valid.csv")) # valid has ~40k samples this is valid is same as test to run it quickely, test has ~16k samples
valid_dataset = DADataset(tokenizer=self.tokenizer, data=valid_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
drop_last = True if len(valid_dataset.text) % self.config['batch_size'] == 1 else False # Drop last batch if it cointains a single sample (causes error)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
return valid_loader
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
return {"val_loss":loss, "val_accuracy":torch.tensor([acc]), "val_f1":torch.tensor([f1]), "val_precision":torch.tensor([precision]), "val_recall":torch.tensor([recall])}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
wandb.log({"val_loss":avg_loss, "val_accuracy":avg_acc, "val_f1":avg_f1, "val_precision":avg_precision, "val_recall":avg_recall})
return {"val_loss":avg_loss, "val_accuracy":avg_acc, "val_f1":avg_f1, "val_precision":avg_precision, "val_recall":avg_recall}
def test_dataloader(self):
#test_data = load_dataset("csv", data_files=os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_test.csv"))
test_data = pd.read_csv(os.path.join(self.config['data_dir'], self.config['dataset'], self.config['dataset']+"_test.csv"))
test_dataset = DADataset(tokenizer=self.tokenizer, data=test_data, max_len=self.config['max_len'], text_field=self.config['text_field'], label_field=self.config['label_field'])
drop_last = True if len(test_dataset.text) % self.config['batch_size'] == 1 else False # Drop last batch if it cointains a single sample (causes error)
test_loader = DataLoader(dataset=test_dataset, batch_size=self.config['batch_size'], shuffle=False, num_workers=self.config['num_workers'], drop_last=drop_last)
return test_loader
def test_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
return {"test_loss":loss, "test_precision":torch.tensor([precision]), "test_recall":torch.tensor([recall]), "test_accuracy":torch.tensor([acc]), "test_f1":torch.tensor([f1])}
def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['test_accuracy'] for x in outputs]).mean()
avg_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()
avg_precision = torch.stack([x['test_precision'] for x in outputs]).mean()
avg_recall = torch.stack([x['test_recall'] for x in outputs]).mean()
return {"test_loss":avg_loss, "test_precision":avg_precision, "test_recall":avg_recall, "test_acc":avg_acc, "test_f1":avg_f1}