-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport-onnx.py
52 lines (40 loc) · 1.13 KB
/
export-onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from glob import glob
import onnx
import torch
from narabas.model import Narabas
# TODO fix sorting
state_dict_path = sorted(glob("./lightning_logs/version_*/checkpoints/epoch=*.ckpt"))[-1]
print(f"Loading state dict from {state_dict_path}")
model = Narabas.load_from_checkpoint(state_dict_path)
dest_path = "narabas.onnx"
model.eval()
dummy_input = torch.randn(1, 16000)
torch.onnx.export(
model=model,
args=dummy_input,
f=dest_path,
export_params=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "sequence_length"},
"output": {0: "batch_size", 1: "sequence_length"},
},
do_constant_folding=True,
)
print(f"Exported to {dest_path}")
onnx_model = onnx.load(dest_path)
# add metadata
print("Adding metadata ...")
onnx.helper.set_model_props(
onnx_model,
{
"sample_rate": str(model.sample_rate),
"hop_length": str(model.hop_length),
},
)
# check model
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, dest_path)
size = onnx_model.ByteSize()
print("Done! model size: {:.1f} MB".format(size / (1024 * 1024)))