-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorRT.py
39 lines (30 loc) · 1002 Bytes
/
tensorRT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @version: 0.1
# @license: Apache Licence
# @Filename: tensorRT.py
# @Author: chaidisheng
# @contact: [email protected]
# @site: https://github.com/chaidisheng
# @software: PyCharm
# @Time: 2020/9/24 18:09
# @torch: tensor.method(in-place) or torch.method(tensor)
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import torch
from torch2trt import torch2trt
from torchvision.models.alexnet import alexnet
# create some regular pytorch model...
model = alexnet(pretrained=True).eval().cuda()
# create example data
x = torch.ones((1, 3, 224, 224)).cuda()
# convert to TensorRT feeding sample data as input
model_trt = torch2trt(model, [x])
y = model(x)
y_trt = model_trt(x)
# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('alexnet_trt.pth'))