-
Notifications
You must be signed in to change notification settings - Fork 32
/
analysis.lua
70 lines (58 loc) · 1.19 KB
/
analysis.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
--- file to perform some analysis
manifold = require 'manifold'
stats = torch.load(arg[1])
vec = stats.embeddings
--normalize
for i, val in pairs(vec) do
local norm = vec[i]:norm()
if norm > 0 then
vec[i]:div(norm)
end
end
function dot(a, b)
return torch.dot(vec[a], vec[b])
end
function nearest_neighbors()
for i, v in pairs(vec) do
local maxDot = -10
local NN = i
for j, w in pairs(vec) do
if j ~= i then
if torch.dot(v,w) > maxDot then
maxDot = torch.dot(v,w)
NN = j
end
end
end
print(i, NN ,maxDot)
end
end
function find_len(table)
local cnt = 0
for k, v in pairs(table) do
cnt = cnt+1
end
return cnt
end
function plot_tsne(vec)
local n = find_len(vec)
local m = torch.zeros(n-1, vec['you']:size(1))
local i = 1
local symbols = {}
for k, val in pairs(vec) do
if k~='NULL' then
m[i] = vec[k]
symbols[i] = k
i = i+1
end
end
opts = {ndims = 2, perplexity = 50, pca = 50, use_bh = false}
mapped_x1 = manifold.embedding.tsne(m)
return mapped_x1, symbols
end
tsne, symbols = plot_tsne(vec)
--write
local file = io.open('tsne.txt', "w");
for i=1, #symbols do
file:write(symbols[i] .. ' ' .. tsne[i][1] .. ' ' .. tsne[i][2] .. '\n')
end