-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathexport_onnx.py
31 lines (25 loc) · 989 Bytes
/
export_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# download the upscale models & place inside models/upscaler_models
# edit model paths accordingly
import torch
import folder_paths
from spandrel import ModelLoader, ImageModelDescriptor
model_name = "4xNomos2_otf_esrgan.pth"
onnx_save_path = "./4xNomos2_otf_esrgan.onnx"
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
model = ModelLoader().load_from_file(model_path).model.eval().cuda()
x = torch.rand(1, 3, 512, 512).cuda()
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
torch.onnx.export(model,
x,
onnx_save_path,
verbose=True,
input_names=['input'],
output_names=['output'],
opset_version=17,
export_params=True,
dynamic_axes=dynamic_axes,
)
print("Saved onnx to:", onnx_save_path)