Skip to content

Commit

Permalink
Merge pull request #66 from BojarLab/dev_dl
Browse files Browse the repository at this point in the history
Rework of train_model function
  • Loading branch information
Bribak authored Nov 14, 2024
2 parents add3192 + 2bc034d commit 25d4672
Showing 1 changed file with 173 additions and 156 deletions.
329 changes: 173 additions & 156 deletions glycowork/ml/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
device = "cuda:0"
except ImportError:
raise ImportError("<torch missing; did you do 'pip install glycowork[ml]'?>")
from sklearn.metrics import accuracy_score, matthews_corrcoef, mean_squared_error, label_ranking_average_precision_score, ndcg_score
from sklearn.metrics import accuracy_score, matthews_corrcoef, mean_squared_error, \
label_ranking_average_precision_score, ndcg_score, roc_auc_score, mean_absolute_error, r2_score
from glycowork.motif.annotate import annotate_dataset


Expand Down Expand Up @@ -87,169 +88,185 @@ def _enable(module):
def train_model(model: torch.nn.Module, # graph neural network for analyzing glycans
dataloaders: Dict[str, torch.utils.data.DataLoader], # dict with 'train' and 'val' loaders
criterion: torch.nn.Module, # PyTorch loss function
optimizer: torch.optim.Optimizer, # PyTorch optimizer
optimizer: torch.optim.Optimizer, # PyTorch optimizer, has to be SAM if mode != "regression"
scheduler: torch.optim.lr_scheduler._LRScheduler, # PyTorch learning rate decay
num_epochs: int = 25, # number of epochs for training
patience: int = 50, # epochs without improvement until early stop
mode: str = 'classification', # 'classification', 'multilabel', or 'regression'
mode2: str = 'multi' # 'multi' or 'binary' classification
) -> torch.nn.Module: # best model from training
"trains a deep learning model on predicting glycan properties"
since = time.time()
early_stopping = EarlyStopping(patience = patience, verbose = True)
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = 100.0
epoch_mcc = 0
if mode != 'regression':
best_acc = 0.0
else:
best_acc = 100.0
val_losses = []
val_acc = []

for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-'*10)

for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()

running_loss = []
running_acc = []
running_mcc = []
for data in dataloaders[phase]:
# Get all relevant node attributes; top LectinOracle-style models, bottom SweetNet-style models
try:
x, y, edge_index, prot, batch = data.labels, data.y, data.edge_index, data.train_idx, data.batch
prot = prot.view(max(batch)+1, -1).to(device)
except:
x, y, edge_index, batch = data.labels, data.y, data.edge_index, data.batch
x = x.to(device)
if mode == 'multilabel':
y = y.view(max(batch)+1, -1).to(device)
else:
y = y.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
optimizer.zero_grad()

with torch.set_grad_enabled(phase == 'train'):
# First forward pass
if mode+mode2 == 'classificationmulti' or mode+mode2 == 'multilabelmulti':
enable_running_stats(model)
try:
pred = model(prot, x, edge_index, batch)
loss = criterion(pred, y.view(-1, 1))
except:
pred = model(x, edge_index, batch)
loss = criterion(pred, y)

if phase == 'train':
loss.backward()
if mode+mode2 == 'classificationmulti' or mode+mode2 == 'multilabelmulti':
optimizer.first_step(zero_grad = True)
# Second forward pass
disable_running_stats(model)
mode2: str = 'multi', # 'multi' or 'binary' classification
return_metrics: bool = False, # whether to return metrics
) -> Union[torch.nn.Module, tuple[torch.nn.Module, dict[str, dict[str, list[float]]]]]: # best model from training and the training and validation metrics
"trains a deep learning model on predicting glycan properties"

since = time.time()
early_stopping = EarlyStopping(patience=patience, verbose=True)
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float("inf")
best_lead_metric = float("inf")

if mode == 'classification':
blank_metrics = {"loss": [], "acc": [], "mcc": [], "auroc": []}
elif mode == 'multilabel':
blank_metrics = {"loss": [], "acc": [], "mcc": [], "lrap": [], "ndcg": []}
else:
blank_metrics = {"loss": [], "mse": [], "mae": [], "r2": []}

metrics = {"train": copy.deepcopy(blank_metrics), "val": copy.deepcopy(blank_metrics)}

for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)

for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()

running_metrics = copy.deepcopy(blank_metrics)
running_metrics["weights"] = []

for data in dataloaders[phase]:
# Get all relevant node attributes; top LectinOracle-style models, bottom SweetNet-style models
try:
criterion(model(prot, x, edge_index, batch), y.view(-1, 1)).backward()
x, y, edge_index, prot, batch = data.labels, data.y, data.edge_index, data.train_idx, data.batch
prot = prot.view(max(batch) + 1, -1).to(device)
except:
criterion(model(x, edge_index, batch), y).backward()
optimizer.second_step(zero_grad = True)
else:
optimizer.step()
x, y, edge_index, batch = data.labels, data.y, data.edge_index, data.batch
x = x.to(device)
if mode == 'multilabel':
y = y.view(max(batch) + 1, -1).to(device)
else:
y = y.to(device)
edge_index = edge_index.to(device)
batch = batch.to(device)
optimizer.zero_grad()

with torch.set_grad_enabled(phase == 'train'):
# First forward pass
if mode + mode2 == 'classificationmulti' or mode + mode2 == 'multilabelmulti':
enable_running_stats(model)
try:
pred = model(prot, x, edge_index, batch)
loss = criterion(pred, y.view(-1, 1))
except:
pred = model(x, edge_index, batch)
loss = criterion(pred, y)

if phase == 'train':
loss.backward()
if mode + mode2 == 'classificationmulti' or mode + mode2 == 'multilabelmulti':
optimizer.first_step(zero_grad=True)
# Second forward pass
disable_running_stats(model)
try:
criterion(model(prot, x, edge_index, batch), y.view(-1, 1)).backward()
except:
criterion(model(x, edge_index, batch), y).backward()
optimizer.second_step(zero_grad=True)
else:
optimizer.step()

# Collecting relevant metrics
running_metrics["loss"].append(loss.item())
running_metrics["weights"].append(batch.max().cpu() + 1)

y_det = y.detach().cpu().numpy()
pred_det = pred.cpu().detach().numpy()
if mode == 'classification':
if mode2 == 'multi':
pred2 = np.argmax(pred_det, axis=1)
else:
pred2 = [np.round(sigmoid(x)) for x in pred_det]
running_metrics["acc"].append(accuracy_score(y_det.astype(int), pred2))
running_metrics["mcc"].append(matthews_corrcoef(y_det, pred2))
running_metrics["auroc"].append(roc_auc_score(y_det.astype(int), pred2))
elif mode == 'multilabel':
running_metrics["acc"].append(accuracy_score(y_det.astype(int), pred_det))
running_metrics["mcc"].append(matthews_corrcoef(y_det, pred_det))
running_metrics["lrap"].append(label_ranking_average_precision_score(y_det.astype(int), pred_det))
running_metrics["ndcg"].append(ndcg_score(y_det.astype(int), pred_det))
else:
running_metrics["mse"].append(mean_squared_error(y_det, pred_det))
running_metrics["mae"].append(mean_absolute_error(y_det, pred_det))
running_metrics["r2"].append(r2_score(y_det, pred_det))

# Averaging metrics at end of epoch
for key in running_metrics:
if key == "weights":
continue
metrics[phase][key].append(np.average(running_metrics[key], weights=running_metrics["weights"]))

# Collecting relevant metrics
running_loss.append(loss.item())
if mode == 'classification':
if mode2 == 'multi':
pred2 = np.argmax(pred.cpu().detach().numpy(), axis = 1)
if mode == 'classification':
print('{} Loss: {:.4f} Accuracy: {:.4f} MCC: {:.4f}'.format(phase, metrics[phase]["loss"][-1], metrics[phase]["acc"][-1], metrics[phase]["mcc"][-1]))
elif mode == 'multilabel':
print('{} Loss: {:.4f} LRAP: {:.4f} NDCG: {:.4f}'.format(phase, metrics[phase]["loss"][-1], metrics[phase]["acc"][-1], metrics[phase]["mcc"][-1]))
else:
pred2 = [sigmoid(x) for x in pred.cpu().detach().numpy()]
pred2 = [np.round(x) for x in pred2]
running_acc.append(accuracy_score(y.cpu().detach().numpy().astype(int), pred2))
running_mcc.append(matthews_corrcoef(y.detach().cpu().numpy(), pred2))
elif mode == 'multilabel':
running_acc.append(label_ranking_average_precision_score(y.cpu().detach().numpy().astype(int),
pred.cpu().detach().numpy()))
running_mcc.append(ndcg_score(y.cpu().detach().numpy().astype(int),
pred.cpu().detach().numpy()))
else:
running_acc.append(mean_squared_error(y.cpu().detach().numpy(), pred.cpu().detach().numpy()))

# Averaging metrics at end of epoch
epoch_loss = np.mean(running_loss)
epoch_acc = np.mean(running_acc)
if mode != 'regression':
epoch_mcc = np.mean(running_mcc)
else:
epoch_mcc = 0
if mode == 'classification':
print('{} Loss: {:.4f} Accuracy: {:.4f} MCC: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_mcc))
elif mode == 'multilabel':
print('{} Loss: {:.4f} LRAP: {:.4f} NDCG: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_mcc))
else:
print('{} Loss: {:.4f} MSE: {:.4f}'.format(phase, epoch_loss, epoch_acc))

# Keep best model state_dict
if phase == 'val' and epoch_loss <= best_loss:
best_loss = epoch_loss
best_model_wts = copy.deepcopy(model.state_dict())
if mode != 'regression':
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
else:
if phase == 'val' and epoch_acc < best_acc:
best_acc = epoch_acc
if phase == 'val':
val_losses.append(epoch_loss)
val_acc.append(epoch_acc)
# Check Early Stopping & adjust learning rate if needed
early_stopping(epoch_loss, model)
try:
scheduler.step(epoch_loss)
except:
scheduler.step()

if early_stopping.early_stop:
print("Early stopping")
break
print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
if mode == 'classification':
print('Best val loss: {:4f}, best Accuracy score: {:.4f}'.format(best_loss, best_acc))
elif mode == 'multilabel':
print('Best val loss: {:4f}, best LRAP score: {:.4f}'.format(best_loss, best_acc))
else:
print('Best val loss: {:4f}, best MSE score: {:.4f}'.format(best_loss, best_acc))
model.load_state_dict(best_model_wts)

# Plot loss & score over the course of training
_, _ = plt.subplots(nrows = 2, ncols = 1)
plt.subplot(2, 1, 1)
plt.plot(range(epoch+1), val_losses)
plt.title('Model Training')
plt.ylabel('Validation Loss')
plt.legend(['Validation Loss'], loc = 'best')

plt.subplot(2, 1, 2)
plt.plot(range(epoch+1), val_acc)
plt.xlabel('Number of Epochs')
if mode == 'classification':
plt.ylabel('Validation Accuracy')
plt.legend(['Validation Accuracy'], loc = 'best')
elif mode == 'multilabel':
plt.ylabel('Validation LRAP')
plt.legend(['Validation LRAP'], loc = 'best')
else:
plt.ylabel('Validation MSE')
plt.legend(['Validation MSE'], loc = 'best')
return model
print('{} Loss: {:.4f} MSE: {:.4f} MAE: {:.4f}'.format(phase, metrics[phase]["loss"][-1], metrics[phase]["mse"][-1], metrics[phase]["mae"][-1]))

# Keep best model state_dict
if phase == "val":
if metrics[phase]["loss"][-1] <= best_loss:
best_loss = metrics[phase]["loss"][-1]
best_model_wts = copy.deepcopy(model.state_dict())

# Extract the lead metric (ACC, LRAP, or MSE) of the new best model
if mode == 'classification':
best_lead_metric = metrics[phase]["acc"][-1]
elif mode == 'multilabel':
best_lead_metric = metrics[phase]["lrap"][-1]
else:
best_lead_metric = metrics[phase]["mse"][-1]

# Check Early Stopping & adjust learning rate if needed
early_stopping(metrics[phase]["loss"][-1], model)
try:
scheduler.step(metrics[phase]["loss"][-1])
except:
scheduler.step()

if early_stopping.early_stop:
print("Early stopping")
break
print()

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
if mode == 'classification':
print('Best val loss: {:4f}, best Accuracy score: {:.4f}'.format(best_loss, best_lead_metric))
elif mode == 'multilabel':
print('Best val loss: {:4f}, best LRAP score: {:.4f}'.format(best_loss, best_lead_metric))
else:
print('Best val loss: {:4f}, best MSE score: {:.4f}'.format(best_loss, best_lead_metric))
model.load_state_dict(best_model_wts)

if return_metrics:
return model, metrics

# Plot loss & score over the course of training
_, _ = plt.subplots(nrows=2, ncols=1)
plt.subplot(2, 1, 1)
plt.plot(range(epoch + 1), metrics["val"]["loss"])
plt.title('Model Training')
plt.ylabel('Validation Loss')
plt.legend(['Validation Loss'], loc='best')

plt.subplot(2, 1, 2)
plt.xlabel('Number of Epochs')
if mode == 'classification':
plt.plot(range(epoch + 1), metrics["val"]["acc"])
plt.ylabel('Validation Accuracy')
plt.legend(['Validation Accuracy'], loc='best')
elif mode == 'multilabel':
plt.plot(range(epoch + 1), metrics["val"]["lrap"])
plt.ylabel('Validation LRAP')
plt.legend(['Validation LRAP'], loc='best')
else:
plt.plot(range(epoch + 1), metrics["val"]["mse"])
plt.ylabel('Validation MSE')
plt.legend(['Validation MSE'], loc='best')
return model


class SAM(torch.optim.Optimizer):
Expand Down

0 comments on commit 25d4672

Please sign in to comment.