-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute_NMI.m
49 lines (44 loc) · 979 Bytes
/
compute_NMI.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
function accuracy=compute_NMI(r_labels,labels)
N=length(labels);
temp=zeros(1,N);
ids=unique(labels);
for i=1:length(ids)
index=find(labels==ids(i));
temp(index)=i;
end
labels=temp;
% labels
% r_labels
rows=max(r_labels);
cols=max(labels);
matrix=zeros(rows,cols);
%compute the number of the common elements between r_label and label
for i=1:rows
for j=1:cols
set1=find(r_labels == i);
set2=find(labels == j);
matrix(i,j)=length(intersect(set1,set2));
end
end
% disp(matrix)
ro=sum(matrix,1);%res is a horizon
co=sum(matrix,2);%res is a vector
co=co';
t1=matrix*N;
t2=ro'*co;
t2=t2';
tt=logme(t1./t2);
% disp(tt)
tt=-2*tt.*matrix;
pp=sum(ro.*(logme(ro/N)))+sum(co.*(logme(co/N)));
accuracy=sum(sum(tt))/pp;
function a=logme(b)
count=size(b);
a=zeros(count(1),count(2));
for i=1:count(1)
for j=1:count(2)
if b(i,j)~=0
a(i,j)=log(b(i,j));
end
end
end