-
Notifications
You must be signed in to change notification settings - Fork 1
/
norm_plotting.py
38 lines (29 loc) · 1.01 KB
/
norm_plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#%%
import torch
import pdb
import matplotlib.pyplot as plt
# Load data
effective_CLS = torch.load("~/Desktop/researchData/effective_CLS_MRPC.pt")
standard_CLS = torch.load("~/Desktop/researchData/standard_CLS_MRPC.pt")
effective_SEP = torch.load("~/Desktop/researchData/effectiveSEP_MRPC.pt")
standard_SEP = torch.load("~/Desktop/researchData/standardSEP_MRPC.pt")
diff_CLS = torch.abs(effective_CLS - standard_CLS)
diff_SEP = torch.abs(effective_SEP - standard_SEP)
# Picturing
plt.figure()
plt.imshow(diff_CLS.numpy(), vmin=0, vmax=1.0, cmap='Greens')
plt.colorbar()
plt.ylabel("Layer")
plt.xlabel("Head")
plt.title("[CLS] difference(absolute value)")
plt.savefig('~/Desktop/researchData/CLS_Difference_MRPC.png')
plt.figure()
plt.imshow(diff_SEP.numpy(), vmin=0, vmax=1.0, cmap='Greens')
plt.colorbar()
plt.ylabel("Layer")
plt.xlabel("Head")
plt.title("[SEP] difference(absolute value)")
plt.savefig('~/Desktop/researchData/SEP_Difference_MRPC.png')
# Compute norm
print(torch.norm(diff_CLS))
print(torch.norm(diff_SEP))