-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
49 lines (39 loc) · 1.67 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import utils
import config.base
from models import self_attention as sa
import torch
if __name__ == "__main__":
# obtain all config parameters as an object. Only using the cfg for file paths
cfg = config.base.Config()
# obtain file paths
pickle_path = utils.get_file_path(cfg.dataset_dir, cfg.pkl_file)
pt_path = utils.get_file_path(cfg.param_dir, cfg.pt_file)
sample_path = utils.get_file_path(cfg.sample_dir, cfg.sample_file)
# obtain metadata from pkl
meta_vocab_size, meta_encode, meta_decode = utils.abstract_pickle(pickle_path)
# obtain torch model saved
torch_model = torch.load(pt_path)
# obtain cfg in torch model
# TODO: there probably is some bug by doing it this way. Need to figure out
# a better way to load config that reduces a bug happening from
# differentiating config parameters
cfg = torch_model["config"]
# obtain model and optimizer
model = sa.Model(meta_vocab_size, cfg)
model.to(cfg.device_type)
model.load_state_dict(torch_model["model"])
model.eval()
# TODO: read from a context file if given one and use that as the start
start = "\n"
start_ids = meta_encode(start)
x = torch.tensor(start_ids, dtype=torch.long, device=cfg.device_type)[None, ...]
with torch.no_grad():
print("Generating text and writing to " + sample_path)
print(meta_decode(start_ids), end="")
open(sample_path, "w").write(meta_decode(start_ids))
for _ in range(cfg.max_new_tokens):
x, x_next = model.generate2(x)
token = meta_decode(x_next[0].tolist())
print(token, end="")
open(sample_path, "a").write(token)
print()