-
Notifications
You must be signed in to change notification settings - Fork 0
/
build_vocab.py
111 lines (82 loc) · 2.66 KB
/
build_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Builds vocabulary file from data."""
"""For Yelp and Superuser data"""
import argparse
import collections
import json
import os
import pickle
import re
def build_counter(train_data, initial_counter=None):
train_tokens = []
for u in train_data:
for c in train_data[u]['x']:
train_tokens.extend([s for s in c])
all_tokens = []
for i in train_tokens:
all_tokens.extend(i)
train_tokens = []
if initial_counter is None:
counter = collections.Counter()
else:
counter = initial_counter
counter.update(all_tokens)
all_tokens = []
return counter
def build_vocab(counter, vocab_size=10000):
pad_symbol, unk_symbol = 0, 1
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
count_pairs = count_pairs[:(vocab_size - 2)] # -2 to account for the unknown and pad symbols
words, _ = list(zip(*count_pairs))
vocab = {}
vocab['<PAD>'] = pad_symbol
vocab['<UNK>'] = unk_symbol
for i, w in enumerate(words):
if w != '<PAD>':
vocab[w] = i + 1
return {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol}
def load_leaf_data(file_path):
with open(file_path) as json_file:
data = json.load(json_file)
to_ret = data['user_data']
data = None
return to_ret
def save_vocab(vocab, target_dir):
os.makedirs(target_dir, exist_ok=True)
pickle.dump(vocab, open(os.path.join(target_dir, 'yelp_vocab.pck'), 'wb'))
def main():
args = parse_args()
json_files = [f for f in os.listdir(args.data_dir) if f.endswith('.json')]
json_files.sort()
counter = None
train_data = {}
for f in json_files:
print('loading {}'.format(f))
train_data = load_leaf_data(os.path.join(args.data_dir, f))
print('counting {}'.format(f))
counter = build_counter(train_data, initial_counter=counter)
print()
train_data = {}
if counter is not None:
vocab = build_vocab(counter, vocab_size=args.vocab_size)
save_vocab(vocab, args.target_dir)
else:
print('No files to process.')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--data-dir',
help='dir with training file;',
type=str,
required=True)
parser.add_argument('--vocab-size',
help='size of the vocabulary;',
type=int,
default=10000,
required=False)
parser.add_argument('--target-dir',
help='dir with training file;',
type=str,
default='./',
required=False)
return parser.parse_args()
if __name__ == '__main__':
main()