-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestfps-2.py
41 lines (30 loc) · 992 Bytes
/
testfps-2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
import torch
from PIL import Image
from time import time
from FiT import FiT
import json
from types import SimpleNamespace
@torch.inference_mode()
def main():
with open("config.json", "r") as f:
config = SimpleNamespace(** json.load(f))
model = FiT(config).to('cuda').eval()
print(len([e for e in model.parameters()]))
dummy_image = torch.randn(1, 3, 1080, 1080, device='cuda')
dummy_image = torch.nn.functional.normalize(dummy_image, 0.5)
for param in model.parameters():
param.grad = None
total = 0.0
num = 200
with torch.no_grad():
for i in range(100): # warmup
_ = model(dummy_image)
for fname in range(0, num):
t1 = time()
_ = model(dummy_image)
total += time() - t1
print('num:{} total_time:{}s avg_time:{}s'.format(
num, total, total / num))
if __name__ == '__main__':
main()