-
Notifications
You must be signed in to change notification settings - Fork 3
/
zerof2ours.py
30 lines (26 loc) · 1016 Bytes
/
zerof2ours.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from pathlib import Path
import numpy as np
import fire, time
def main(
# download from: https://github.com/kylevedder/zeroflow_weights/tree/master/argo
model_path: str = "/home/kin/nsfp_distilatation_3x_49_epochs.ckpt",
# download from: https://zenodo.org/records/12632962
reference_path: str = "/home/kin/fastflow3d.ckpt",
# new output weight file
output_path: str = "/home/kin/zeroflow3x.ckpt",
):
model = torch.load(model_path)
reference = torch.load(reference_path)
ref_model_weight = reference['state_dict']
real_model_weight = model['state_dict']
for k in real_model_weight.keys():
if k not in ref_model_weight.keys():
print(f"Warning: {k} not in reference model, not same model.")
exit(0)
reference['state_dict'] = real_model_weight
torch.save(reference, output_path)
if __name__ == '__main__':
start_time = time.time()
fire.Fire(main)
print(f"Time used: {time.time() - start_time:.2f} s")