forked from maum-ai/pnlp-mixer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsentencepiece_extractor.py
150 lines (124 loc) · 4.78 KB
/
sentencepiece_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from argparse import ArgumentParser
from json import dump
from logging import basicConfig, getLogger
from os import linesep, remove
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import Dict, List, Tuple
from requests import get
from sentencepiece import SentencePieceProcessor
from tqdm import trange, tqdm
basicConfig()
logger = getLogger()
class SentencePieceExtractor:
"""
Extractor implementation for SentencePiece trained models.
https://github.com/google/sentencepiece
"""
def __init__(self, model: str):
# Get SentencePiece
self.sp = SentencePieceProcessor()
self.sp.Load(model)
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}
# Merges
merges = []
for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
for piece_r in vocab.keys():
merge = f"{piece_l}{piece_r}"
piece_id = vocab.get(merge, None)
if piece_id:
merges += [(piece_l, piece_r, piece_id)]
merges = sorted(merges, key=lambda val: val[2])
merges = [(val[0], val[1]) for val in merges]
return vocab, merges
class YouTokenToMeExtractor:
"""
Extractor implementation for YouTokenToMe trained models format.
Model are as follow:
vocab_size nb_merges
piece piece_id
...(repeated vocab_size)
piece_id_left piece_id_right piece_id
...(repeated nb merges)
"""
def __init__(self, model: str):
self._model = model
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
with open(self._model, "r") as model_f:
# Retrieve information
nb_pieces, nb_merges = map(int, model_f.readline().split())
vocab, merges = {}, []
# Vocab
for _ in trange(nb_pieces):
piece, piece_id = map(int, model_f.readline().split())
vocab[piece_id] = chr(piece)
# Merges
for _ in trange(nb_merges):
piece_id_l, piece_id_r, piece = map(int, model_f.readline().split())
piece_l, piece_r = vocab[piece_id_l], vocab[piece_id_r]
vocab[piece] = f"{piece_l}{piece_r}"
merges += [(piece_l, piece_r)]
# Special tokens
unk, pad, bos, eos = map(int, model_f.readline().split())
vocab[unk] = "<unk>"
vocab[pad] = "<pad>"
vocab[bos] = "<bos>"
vocab[eos] = "<eos>"
# Invert key and value for vocab
vocab = dict(zip(vocab.values(), vocab.keys()))
return vocab, merges
if __name__ == "__main__":
parser = ArgumentParser("SentencePiece vocab extractor")
parser.add_argument(
"--provider",
type=str,
required=True,
choices=["sentencepiece", "youtokentome"],
help="Indicate the format of the file.",
)
parser.add_argument(
"--model", type=str, required=True, help="SentencePiece model to extract vocab from."
)
parser.add_argument(
"--vocab-output-path",
type=str,
required=True,
help="Path where the vocab.json file will be extracted",
)
parser.add_argument(
"--merges-output-path",
type=str,
required=True,
help="Path where the merges file will be extracted",
)
# Parse cli arguments
args = parser.parse_args()
try:
if args.model.startswith("http"):
# Saving model
with NamedTemporaryFile("wb", delete=False) as f:
logger.info("Writing content from {} to {}".format(args.model, f.name))
response = get(args.model, allow_redirects=True)
f.write(response.content)
args.remote_model = args.model
args.model = f.name
# Allocate extractor
extractor = (
SentencePieceExtractor if args.provider == "sentencepiece" else YouTokenToMeExtractor
)
extractor = extractor(args.model)
logger.info(f"Using {type(extractor).__name__}")
# Open output files and let's extract model information
with open(args.vocab_output_path, "w") as vocab_f:
with open(args.merges_output_path, "w") as merges_f:
# Do the extraction
vocab, merges = extractor.extract()
# Save content
dump(vocab, vocab_f)
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{linesep}", merges))
finally:
# If model was downloaded from internet we need to cleanup the tmp folder.
if hasattr(args, "remote_model") and exists(args.model):
remove(args.model)