-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmetrics.py
57 lines (49 loc) · 2.09 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import pandas as pd
def calc_confusion_matrix(test_loader, device, network, model_name):
confusion_matrix = np.zeros((2,2))
with torch.no_grad():
for i, (inputs, classes) in enumerate(test_loader):
inputs = inputs.to(device)
classes = classes.to(device)
outputs = network(inputs)
_, preds = torch.max(outputs, 1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
plt.figure(figsize=(15,10))
class_names = list(("Benign", "Malicious"))
df_cm = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names).astype(int)
heatmap = sns.heatmap(df_cm, annot=True, cmap="gray", fmt="d")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right',fontsize=15)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right',fontsize=15)
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.savefig('test_output/confusion_matrix_{}.png'.format(model_name))
def test_accuracy_plot(test_accuracy, model_name):
plt.plot(test_accuracy, "-bx")
plt.ylim([0,100])
plt.xlabel('epoch')
plt.ylabel('Accuracy')
plt.legend(['Testing Accuracy'])
plt.title('Accuracy vs. No. of epochs')
plt.savefig('test_output/test_accuracy_{}.png'.format(model_name))
def train_val_loss_plot(model_name, epoch):
validation=[]
training=[]
plt.figure(figsize=(10,6))
epochs = np.arange(0, epoch)
with open("train_output/train_losses_{}.txt".format(model_name), "r") as file:
for line in file:
temp = line.split(',')
training.append(float(temp[0]))
validation.append(float(temp[1]))
plt.plot(epochs, training)
plt.plot(epochs, validation)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Training Loss', 'Validation Loss'])
plt.title('Loss vs. No. of epochs')
plt.savefig('train_output/train_val_loss_{}.png'.format(model_name))