-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathbasic_utils.py
109 lines (80 loc) · 3.08 KB
/
basic_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import json
import zipfile
import numpy as np
import pickle
import random
def shuffle_list(*ls):
l =list(zip(*ls))
random.shuffle(l)
return zip(*l)
def load_pickle(filename):
with open(filename, "rb") as f:
return pickle.load(f)
def save_pickle(data, filename):
with open(filename, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
def save_json(data, filename, save_pretty=False, sort_keys=False):
with open(filename, "w") as f:
if save_pretty:
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
else:
json.dump(data, f)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
def save_jsonl(data, filename):
"""data is a list"""
with open(filename, "w") as f:
f.write("\n".join([json.dumps(e) for e in data]))
def save_lines(list_of_str, filepath):
with open(filepath, "w") as f:
f.write("\n".join(list_of_str))
def read_lines(filepath):
with open(filepath, "r") as f:
return [e.strip("\n") for e in f.readlines()]
def mkdirp(p):
if not os.path.exists(p):
os.makedirs(p)
def flat_list_of_lists(l):
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
return [item for sublist in l for item in sublist]
def l2_normalize_np_array(np_array, eps=1e-5):
"""np_array: np.ndarray, (*, D), where the last dim will be normalized"""
return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
def load_pretrained_weight(model, state_dict, start_prefix=''):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = ({} if metadata is None
else metadata.get(prefix[:-1], {}))
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys,
unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
print("Weights of {} not initialized from "
"pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in "
"{}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for '
'{}:\n\t{}'.format(
model.__class__.__name__,
"\n\t".join(error_msgs)))
return model