-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn4.py
107 lines (75 loc) · 1.99 KB
/
nn4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import matplotlib.pylab as plt
import torch
from torch.nn import functional as F
with open("data.txt", "r", encoding='utf-8') as f:
text = f.read()
text = text.lower()
chars = sorted(list(set(text)))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
data = [stoi[c] for c in text]
vocab_size = len(chars)
print(chars)
ins = 64
outs = vocab_size
nodes = 200
lr = 0.003
n_emb = 64
embed = torch.randn(vocab_size, n_emb)
pos = torch.randn(ins, n_emb)
data = torch.tensor(data).long()
params = []
def weights(ins, outs):
ws = torch.randn(ins, outs) * 0.1
ws.requires_grad_(True)
params.append(ws)
return ws
class Model():
def __init__(self):
self.wv = weights(n_emb, n_emb)
self.w0 = weights(n_emb, nodes)
self.w1 = weights(nodes, nodes)
self.w2 = weights(nodes, outs)
def forward(self, x):
x = embed[x] * pos
x = x @ self.wv
x = torch.sum(x, dim=-2)
x = torch.relu(x @ self.w0)
x = torch.relu(x @ self.w1)
yh = (x @ self.w2)
return yh
model = Model()
optimizer = torch.optim.Adam(params, lr)
ers = []
for i in range(5000):
b = torch.randint(len(data) - ins, (100,))
xs = torch.stack([data[i:i+ins] for i in b])
ys = torch.stack([data[i+ins:i+ins+1] for i in b])
yh = model.forward(xs)
loss = F.cross_entropy(yh.view(-1, vocab_size), ys.long().view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
e = loss.item()
if (i % 50 == 0):
print(i, "Loss", e)
ers.append(e)
plt.figure(1)
plt.plot(ers)
plt.figure(2)
plt.plot(ys)
yh = torch.argmax(yh, dim=-1)
plt.plot(yh.detach())
s = xs[0]
gen_text = ""
for i in range(3000):
yh = model.forward(s)
prob = F.softmax(yh, dim=0)
# pred = torch.argmax(yh).item()
pred = torch.multinomial(prob, num_samples=1).item()
s = torch.roll(s, -1)
s[-1] = pred
gen_text += itos[pred]
print(gen_text)
plt.show()