diff --git a/README.md b/README.md
index 451c41b65..09b110f0b 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ DLI supports inference using the following frameworks:
- [ONNX Runtime][onnx-runtime].
- [MXNet][mxnet].
- [OpenCV DNN][opencv-dnn] (C++ and Python API).
+- [PyTorch][pytorch].
More information about DLI is available on the web-site
([here][dli-ru-web-page] (in Russian)
@@ -99,6 +100,9 @@ Novgorod State University Publishing House, 2021. – 423 p.
- [`validation_results_mxnet_gluon_modelzoo.md`](results/validation/validation_results_mxnet_gluon_modelzoo.md)
is a table that confirms correctness of inference implementation
based on MXNet for [GluonCV-models][gluoncv-omz].
+ - [`validation_results_pytorch.md`](results/validation/validation_results_pytorch.md)
+ is a table that confirms correctness of inference implementation
+ based on PyTorch for [TorchVision][torchvision].
- [`mxnet_models_checklist.md`](results/mxnet_models_checklist.md) contains a list
of deep models inferred by MXNet checked in the DLI benchmark.
@@ -108,6 +112,8 @@ Novgorod State University Publishing House, 2021. – 423 p.
of deep models inferred by TensorFlow checked in the DLI benchmark.
- [`tflite_models_checklist.md`](results/tflite_models_checklist.md) contains a list
of deep models inferred by TensorFlow Lite checked in the DLI benchmark.
+ - [`pytorch_models_checklist.md`](results/pytorch_models_checklist.md) contains a list
+ of deep models inferred by PyTorch checked in the DLI benchmark.
- `src` directory contains benchmark sources.
@@ -184,11 +190,13 @@ Report questions, issues and suggestions, using:
[onnx-runtime-github]: https://github.com/microsoft/onnxruntime
[mxnet]: https://mxnet.apache.org
[opencv-dnn]: https://docs.opencv.org/4.7.0/d2/d58/tutorial_table_of_content_dnn.html
+[pytorch]: https://pytorch.org
[benchmark-app]: https://github.com/openvinotoolkit/openvino/tree/master/samples/cpp/benchmark_app
[dli-ru-web-page]: http://hpc-education.unn.ru/dli-ru
[dli-web-page]: http://hpc-education.unn.ru/dli
[open-model-zoo]: https://github.com/opencv/open_model_zoo
[gluoncv-omz]: https://cv.gluon.ai/model_zoo/index.html
+[torchvision]: https://pytorch.org/vision/stable/models.html
[mmst-2021]: https://hpc-education.unn.ru/files/conference_hpc/2021/MMST2021_Proceedings.pdf
[dli-wiki]: https://github.com/itlab-vision/dl-benchmark/wiki
[dli-wiki-build]: https://github.com/itlab-vision/dl-benchmark/wiki#how-to-build
diff --git a/requirements_frameworks.txt b/requirements_frameworks.txt
index dda1ff2e8..8c8c4e2a0 100644
--- a/requirements_frameworks.txt
+++ b/requirements_frameworks.txt
@@ -1,2 +1,3 @@
-openvino-dev[caffe,mxnet,tensorflow2]==2022.3.0
-gluoncv
\ No newline at end of file
+openvino-dev[caffe,mxnet,tensorflow2,pytorch]==2022.3.0
+gluoncv
+torchvision
\ No newline at end of file
diff --git a/results/pytorch_models_checklist.md b/results/pytorch_models_checklist.md
new file mode 100644
index 000000000..6a8042588
--- /dev/null
+++ b/results/pytorch_models_checklist.md
@@ -0,0 +1,88 @@
+# Model validation and performance analysis status for PyTorch
+
+## Public models
+
+The list of models is from [TorchVision][torchvision].
+
+### Image classification on ImageNet
+
+Model | Availability in [TorchVision][torchvision] (0.15.1)| Availability in the validation table |
+-|-|-|
+alexnet|+|+|
+densenet121|+|+|
+densenet161|+|+|
+densenet169|+|+|
+densenet201|+|+|
+googlenet|+|+|
+inception_v3|+|+|
+mnasnet0_5|+|+|
+mnasnet0_75|+|+|
+mnasnet1_0|+|+|
+mnasnet1_3|+|+|
+mobilenet_v2|+|+|
+resnext50_32x4d|+|+|
+resnext101_32x8d|+|+|
+resnet18|+|+|
+resnet34|+|+|
+resnet50|+|+|
+resnet101|+|+|
+resnet152|+|+|
+shufflenet_v2_x0_5|+|+|
+shufflenet_v2_x1_0|+|+|
+shufflenet_v2_x1_5|+|+|
+shufflenet_v2_x2_0|+|+|
+squeezenet1_0|+|+|
+squeezenet1_1|+|+|
+vgg11|+|+|
+vgg11_bn|+|+|
+vgg13|+|+|
+vgg13_bn|+|+|
+vgg16|+|+|
+vgg16_bn|+|+|
+vgg19|+|+|
+vgg19_bn|+|+|
+wide_resnet50_2|+|+|
+wide_resnet101_2|+|+|
+
+### Object detection
+
+Model | Availability in [TorchVision][torchvision] (0.15.1)| Availability in the validation table |
+-|-|-|
+fasterrcnn_resnet50_fpn|+|-|
+fasterrcnn_resnet50_fpn_v2|+|-|
+fasterrcnn_mobilenet_v3_large_fpn|+|-|
+fasterrcnn_mobilenet_v3_large_320_fpn|+|-|
+fcos_resnet50_fpn|+|-|
+retinanet_resnet50_fpn|+|-|
+retinanet_resnet50_fpn_v2|+|-|
+ssd300_vgg16|+|-|
+ssdlite320_mobilenet_v3_large|+|-|
+
+### Semantic segmentation
+
+Model | Availability in [TorchVision][torchvision] (0.15.1)| Availability in the validation table |
+-|-|-|
+deeplabv3_mobilenet_v3_large|+|-|
+deeplabv3_resnet50|+|-|
+deeplabv3_resnet101|+|-|
+fcn_resnet50|+|-|
+fcn_resnet101|+|-|
+lraspp_mobilenet_v3_large|+|-|
+
+### Instance segmentation
+
+Model | Availability in [TorchVision][torchvision] (0.15.1)| Availability in the validation table |
+-|-|-|
+maskrcnn_resnet50_fpn|+|-|
+maskrcnn_resnet50_fpn_v2|+|-|
+
+
+### Keypoint Detection
+
+Model | Availability in [TorchVision][torchvision] (0.15.1)| Availability in the validation table |
+-|-|-|
+keypointrcnn_resnet50_fpn|+|-|
+
+
+
+[torchvision]: https://pytorch.org/vision/stable/models.html
diff --git a/results/validation/validation_results_pytorch.md b/results/validation/validation_results_pytorch.md
new file mode 100644
index 000000000..22ca9f98e
--- /dev/null
+++ b/results/validation/validation_results_pytorch.md
@@ -0,0 +1,175 @@
+# Validation results for the models inferring using PyTorch
+
+## Image classification
+
+Complete information about the supported classification
+models is available [here][torchvision_classification].
+
+Notes:
+
+- For all classification models input shape BxCxHxW, where
+ B is a batch size, C is an image number of channels,
+ H is an image height, W is an image width.
+ W=H=224 except inception_v3, for this model W=H=299.
+- Values of mean and standard deviation parameters used
+ for model validation are represented for each image.
+
+### Test image #1
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 709 x 510
+
+Mean: [0.485, 0.456, 0.406]
+
+Standard deviation: [0.229, 0.224, 0.225]
+
+
+
+
+
+ Model | Python (implementation) |
+---------------------|---------------------------|
+alexnet |0.4499779 Granny Smith
0.0933098 dumbbell
0.0876729 ocarina, sweet potato
0.0628701 hair slide
0.0484683 bottlecap
|
+densenet121 |0.9523344 Granny Smith
0.0132273 orange
0.0125171 lemon
0.0027910 banana
0.0020333 piggy bank, penny bank
|
+densenet161 |0.9372966 Granny Smith
0.0082274 dumbbell
0.0056475 piggy bank, penny bank
0.0055374 ping-pong ball
0.0041915 pitcher, ewer
|
+densenet169 |0.9523344 Granny Smith
0.0132273 orange
0.0125171 lemon
0.0027910 banana
0.0020333 piggy bank, penny bank
|
+densenet201 |0.9119796 Granny Smith
0.0533455 piggy bank, penny bank
0.0056832 lemon
0.0017810 pool table, billiard table, snooker table
0.0015689 tennis ball
|
+googlenet |0.5432544 Granny Smith
0.1103975 piggy bank, penny bank
0.0232569 vase
0.0213901 pitcher, ewer
0.0196196 bell pepper
|
+inception_v3 |0.9999496 Granny Smith
0.0000175 piggy bank, penny bank
0.0000171 pomegranate
0.0000016 whiskey jug
0.0000011 water jug
|
+mnasnet0_5 |0.0728453 Granny Smith
0.0632434 piggy bank, penny bank
0.0314232 pitcher, ewer
0.0235659 safety pin
0.0232115 saltshaker, salt shaker
|
+mnasnet0_75 |0.1419373 Granny Smith
0.0228578 lemon
0.0172555 piggy bank, penny bank
0.0164039 orange
0.0098360 dumbbell
|
+mnasnet1_0 |0.1931600 Granny Smith
0.1841204 lemon
0.1414814 piggy bank, penny bank
0.0808636 teapot
0.0785343 orange
|
+mnasnet1_3 |0.2819200 Granny Smith
0.0192138 piggy bank, penny bank
0.0092013 lemon
0.0071209 tennis ball
0.0066861 orange
|
+mobilenet_v2 |0.5066760 Granny Smith
0.0543401 pitcher, ewer
0.0461567 saltshaker, salt shaker
0.0433900 lemon
0.0314979 vase
|
+resnext50_32x4d |0.9059284 Granny Smith
0.0208117 lemon
0.0138094 orange
0.0067536 banana
0.0038122 piggy bank, penny bank
|
+resnext101_32x8d |0.4214238 Granny Smith
0.1213467 piggy bank, penny bank
0.0461432 lemon
0.0443953 orange
0.0392590 vase
|
+resnet18 |0.1507515 safety pin
0.1102253 piggy bank, penny bank
0.0657376 purse
0.0558254 teapot
0.0341885 hair slide
|
+resnet34 |0.9595408 Granny Smith
0.0054884 banana
0.0043731 orange
0.0035087 piggy bank, penny bank
0.0025556 lemon
|
+resnet50 |0.9278085 Granny Smith
0.0129410 orange
0.0059573 lemon
0.0042141 necklace
0.0025712 banana
|
+resnet101 |0.9483170 Granny Smith
0.0055002 hay
0.0050311 orange
0.0020548 syringe
0.0018223 pitcher, ewer
|
+resnet152 |0.8913258 Granny Smith
0.0149238 piggy bank, penny bank
0.0058584 hook, claw
0.0057759 saltshaker, salt shaker
0.0044805 analog clock
|
+shufflenet_v2_x0_5 |0.1090845 vase
0.1067461 piggy bank, penny bank
0.1048482 saltshaker, salt shaker
0.0685889 lemon
0.0643356 pitcher, ewer
|
+shufflenet_v2_x1_0 |0.2771632 Granny Smith
0.1798794 safety pin
0.0308319 warplane, military plane
0.0300632 hair slide
0.0290498 piggy bank, penny bank
|
+shufflenet_v2_x1_5 |0.3420659 Granny Smith
0.0544274 lemon
0.0335732 orange
0.0223579 piggy bank, penny bank
0.0180992 pomegranate
|
+shufflenet_v2_x2_0 |0.7161155 Granny Smith
0.0284324 orange
0.0277892 lemon
0.0249504 piggy bank, penny bank
0.0099330 saltshaker, salt shaker
|
+squeezenet1_0 |0.3275049 piggy bank, penny bank
0.1791330 dumbbell
0.1542633 Granny Smith
0.0912989 water bottle
0.0385818 rubber eraser, rubber, pencil eraser
|
+squeezenet1_1 |0.5895362 piggy bank, penny bank
0.0677936 Granny Smith
0.0610653 necklace
0.0610450 lemon
0.0490913 bucket, pail
|
+vgg11 |0.3721458 piggy bank, penny bank
0.2952032 Granny Smith
0.1076759 tennis ball
0.0314685 soap dispenser
0.0285692 dumbbell
|
+vgg11_bn |0.5464042 Granny Smith
0.2313125 dumbbell
0.0658235 piggy bank, penny bank
0.0269569 tennis ball
0.0218533 teapot
|
+vgg13 |0.4068233 Granny Smith
0.2272189 dumbbell
0.0475026 necklace
0.0303710 maraca
0.0250665 teapot
|
+vgg13_bn |0.9389400 Granny Smith
0.0383619 tennis ball
0.0069443 lemon
0.0039320 orange
0.0013574 banana
|
+vgg16 |0.2294290 Granny Smith
0.2084264 tennis ball
0.0561063 necklace
0.0523627 piggy bank, penny bank
0.0300660 pencil box, pencil case
|
+vgg16_bn |0.3850222 Granny Smith
0.0595219 dumbbell
0.0565265 pencil box, pencil case
0.0528648 tennis ball
0.0333556 piggy bank, penny bank
|
+vgg19 |0.4694367 Granny Smith
0.1882819 tennis ball
0.0588473 acorn
0.0566673 lemon
0.0416055 piggy bank, penny bank
|
+vgg19_bn |0.8778616 Granny Smith
0.0404896 lemon
0.0315287 orange
0.0050720 soap dispenser
0.0047691 piggy bank, penny bank
|
+wide_resnet50_2 |0.8607227 Granny Smith
0.0346713 piggy bank, penny bank
0.0130316 abacus
0.0083917 necklace
0.0075835 spindle
|
+wide_resnet101_2 |0.5728682 Granny Smith
0.0822066 piggy bank, penny bank
0.0709140 tennis ball
0.0597228 golf ball
0.0115059 goblet
|
+
+### Test image #2
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 500 x 500
+
+Mean: [0.485, 0.456, 0.406]
+
+Standard deviation: [0.229, 0.224, 0.225]
+
+
+
+
+
+ Model | Python (implementation) |
+---------------------|---------------------------|
+alexnet |0.9947648 junco, snowbird
0.0043087 chickadee
0.0002780 water ouzel, dipper
0.0002770 bulbul
0.0001244 brambling, Fringilla montifringilla
|
+densenet121 |0.9841599 junco, snowbird
0.0072199 chickadee
0.0034963 brambling, Fringilla montifringilla
0.0016226 water ouzel, dipper
0.0012858 indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
+densenet161 |0.9932058 junco, snowbird
0.0015922 chickadee
0.0012295 brambling, Fringilla montifringilla
0.0011838 indigo bunting, indigo finch, indigo bird, Passerina cyanea
0.0008891 goldfinch, Carduelis carduelis
|
+densenet169 |0.9640695 junco, snowbird
0.0201315 brambling, Fringilla montifringilla
0.0044098 chickadee
0.0032345 goldfinch, Carduelis carduelis
0.0026739 water ouzel, dipper
|
+densenet201 |0.9515251 junco, snowbird
0.0178251 water ouzel, dipper
0.0109119 brambling, Fringilla montifringilla
0.0077980 house finch, linnet, Carpodacus mexicanus
0.0044695 chickadee
|
+googlenet |0.6461046 junco, snowbird
0.0772564 chickadee
0.0468783 brambling, Fringilla montifringilla
0.0295898 goldfinch, Carduelis carduelis
0.0123323 house finch, linnet, Carpodacus mexicanus
|
+inception_v3 |0.9999989 junco, snowbird
0.0000001 iron, smoothing iron
0.0000001 cleaver, meat cleaver, chopper
0.0000000 water ouzel, dipper
0.0000000 chickadee
|
+mnasnet0_5 |0.9237853 junco, snowbird
0.0206866 chickadee
0.0049339 brambling, Fringilla montifringilla
0.0039299 water ouzel, dipper
0.0029348 jay
|
+mnasnet0_75 |0.1342174 junco, snowbird
0.0332264 goldfinch, Carduelis carduelis
0.0287836 brambling, Fringilla montifringilla
0.0138830 chickadee
0.0116728 indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
+mnasnet1_0 |0.9980335 junco, snowbird
0.0013290 brambling, Fringilla montifringilla
0.0004499 water ouzel, dipper
0.0001339 chickadee
0.0000133 goldfinch, Carduelis carduelis
|
+mnasnet1_3 |0.3347574 junco, snowbird
0.0074588 chickadee
0.0058638 brambling, Fringilla montifringilla
0.0055867 water ouzel, dipper
0.0047144 indigo bunting, indigo finch, indigo bird, Passerina cyanea
|
+mobilenet_v2 |0.9989253 junco, snowbird
0.0004260 water ouzel, dipper
0.0004213 chickadee
0.0001264 brambling, Fringilla montifringilla
0.0000537 goldfinch, Carduelis carduelis
|
+resnext50_32x4d |0.9919545 junco, snowbird
0.0036273 brambling, Fringilla montifringilla
0.0016091 goldfinch, Carduelis carduelis
0.0015197 chickadee
0.0004831 water ouzel, dipper
|
+resnext101_32x8d |0.9755010 junco, snowbird
0.0071145 water ouzel, dipper
0.0047595 brambling, Fringilla montifringilla
0.0021230 chickadee
0.0009639 red-backed sandpiper, dunlin, Erolia alpina
|
+resnet18 |0.9991090 junco, snowbird
0.0005329 chickadee
0.0002098 water ouzel, dipper
0.0000690 bulbul
0.0000579 brambling, Fringilla montifringilla
|
+resnet34 |0.9923642 junco, snowbird
0.0043307 chickadee
0.0011341 water ouzel, dipper
0.0005041 brambling, Fringilla montifringilla
0.0004572 goldfinch, Carduelis carduelis
|
+resnet50 |0.9805019 junco, snowbird
0.0049154 goldfinch, Carduelis carduelis
0.0039196 chickadee
0.0038097 water ouzel, dipper
0.0028983 brambling, Fringilla montifringilla
|
+resnet101 |0.9986678 junco, snowbird
0.0004156 chickadee
0.0002674 goldfinch, Carduelis carduelis
0.0001532 brambling, Fringilla montifringilla
0.0001518 water ouzel, dipper
|
+resnet152 |0.9983380 junco, snowbird
0.0009362 water ouzel, dipper
0.0003330 brambling, Fringilla montifringilla
0.0001030 goldfinch, Carduelis carduelis
0.0000701 house finch, linnet, Carpodacus mexicanus
|
+shufflenet_v2_x0_5 |0.9972883 junco, snowbird
0.0010430 goldfinch, Carduelis carduelis
0.0004120 brambling, Fringilla montifringilla
0.0003422 jay
0.0001990 chickadee
|
+shufflenet_v2_x1_0 |0.9997568 junco, snowbird
0.0001209 damselfly
0.0000400 chickadee
0.0000233 water ouzel, dipper
0.0000091 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
|
+shufflenet_v2_x1_5 |0.3117434 junco, snowbird
0.0404440 brambling, Fringilla montifringilla
0.0254127 chickadee
0.0110461 water ouzel, dipper
0.0090534 goldfinch, Carduelis carduelis
|
+shufflenet_v2_x2_0 |0.3471888 junco, snowbird
0.0091527 chickadee
0.0086562 bulbul
0.0055233 brambling, Fringilla montifringilla
0.0050128 water ouzel, dipper
|
+squeezenet1_0 |0.9904412 junco, snowbird
0.0045286 chickadee
0.0040343 brambling, Fringilla montifringilla
0.0003414 water ouzel, dipper
0.0002521 house finch, linnet, Carpodacus mexicanus
|
+squeezenet1_1 |0.9614578 junco, snowbird
0.0250983 chickadee
0.0040701 brambling, Fringilla montifringilla
0.0035156 goldfinch, Carduelis carduelis
0.0030858 ruffed grouse, partridge, Bonasa umbellus
|
+vgg11 |0.9998955 junco, snowbird
0.0000967 chickadee
0.0000043 brambling, Fringilla montifringilla
0.0000023 water ouzel, dipper
0.0000006 bulbul
|
+vgg11_bn |0.9994940 junco, snowbird
0.0002460 brambling, Fringilla montifringilla
0.0002328 chickadee
0.0000130 water ouzel, dipper
0.0000100 goldfinch, Carduelis carduelis
|
+vgg13 |0.9359031 junco, snowbird
0.0610291 chickadee
0.0012531 brambling, Fringilla montifringilla
0.0012155 water ouzel, dipper
0.0002740 bulbul
|
+vgg13_bn |0.9927478 junco, snowbird
0.0041162 chickadee
0.0028725 brambling, Fringilla montifringilla
0.0000676 goldfinch, Carduelis carduelis
0.0000641 house finch, linnet, Carpodacus mexicanus
|
+vgg16 |0.9991580 junco, snowbird
0.0007120 chickadee
0.0000800 water ouzel, dipper
0.0000323 brambling, Fringilla montifringilla
0.0000049 house finch, linnet, Carpodacus mexicanus
|
+vgg16_bn |0.9920998 junco, snowbird
0.0066640 chickadee
0.0004240 jay
0.0003181 water ouzel, dipper
0.0001396 brambling, Fringilla montifringilla
|
+vgg19 |0.9994042 junco, snowbird
0.0003172 brambling, Fringilla montifringilla
0.0001609 chickadee
0.0000671 water ouzel, dipper
0.0000236 goldfinch, Carduelis carduelis
|
+vgg19_bn |0.9999533 junco, snowbird
0.0000318 chickadee
0.0000074 brambling, Fringilla montifringilla
0.0000030 water ouzel, dipper
0.0000021 house finch, linnet, Carpodacus mexicanus
|
+wide_resnet50_2 |0.9617861 junco, snowbird
0.0119062 water ouzel, dipper
0.0064385 chickadee
0.0044642 brambling, Fringilla montifringilla
0.0019096 bulbul
|
+wide_resnet101_2 |0.9748272 junco, snowbird
0.0074479 water ouzel, dipper
0.0047218 chickadee
0.0023339 brambling, Fringilla montifringilla
0.0022661 goldfinch, Carduelis carduelis
|
+
+### Test image #3
+
+Data source: [ImageNet][imagenet]
+
+Image resolution: 333 x 500
+
+Mean: [0.485, 0.456, 0.406]
+
+Standard deviation: [0.229, 0.224, 0.225]
+
+
+
+
+
+ Model | Python (implementation) |
+---------------------|---------------------------|
+alexnet |0.3216888 container ship, containership, container vessel
0.1360615 drilling platform, offshore rig
0.1140690 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1057476 beacon, lighthouse, beacon light, pharos
0.0471225 liner, ocean liner
|
+densenet121 |0.3022412 liner, ocean liner
0.1322481 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1194608 container ship, containership, container vessel
0.0795042 drilling platform, offshore rig
0.0723068 dock, dockage, docking facility
|
+densenet161 |0.4418393 lifeboat
0.1824290 liner, ocean liner
0.0596464 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0325273 submarine, pigboat, sub, U-boat
0.0298845 dock, dockage, docking facility
|
+densenet169 |0.2955876 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.2342377 drilling platform, offshore rig
0.0940930 liner, ocean liner
0.0876009 container ship, containership, container vessel
0.0717737 dock, dockage, docking facility
|
+densenet201 |0.5008168 fireboat
0.0950198 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0701648 lifeboat
0.0622605 liner, ocean liner
0.0582345 container ship, containership, container vessel
|
+googlenet |0.1323652 liner, ocean liner
0.0796395 drilling platform, offshore rig
0.0678082 container ship, containership, container vessel
0.0585721 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0366881 fireboat
|
+inception_v3 |0.3510072 drilling platform, offshore rig
0.3065925 beacon, lighthouse, beacon light, pharos
0.1853052 submarine, pigboat, sub, U-boat
0.0660644 wreck
0.0121473 space shuttle
|
+mnasnet0_5 |0.0854459 drilling platform, offshore rig
0.0850178 liner, ocean liner
0.0445188 container ship, containership, container vessel
0.0357058 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0247475 aircraft carrier, carrier, flattop, attack aircraft carrier
|
+mnasnet0_75 |0.0212122 aircraft carrier, carrier, flattop, attack aircraft carrier
0.0189475 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0135526 beacon, lighthouse, beacon light, pharos
0.0114984 submarine, pigboat, sub, U-boat
0.0114306 liner, ocean liner
|
+mnasnet1_0 |0.2078572 container ship, containership, container vessel
0.1769478 dock, dockage, docking facility
0.1064179 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0966766 liner, ocean liner
0.0637138 lifeboat
|
+mnasnet1_3 |0.0924414 lifeboat
0.0341632 container ship, containership, container vessel
0.0281921 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0210488 liner, ocean liner
0.0208530 beacon, lighthouse, beacon light, pharos
|
+mobilenet_v2 |0.3933903 container ship, containership, container vessel
0.2136005 liner, ocean liner
0.0991812 beacon, lighthouse, beacon light, pharos
0.0715421 drilling platform, offshore rig
0.0498366 breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
+resnext50_32x4d |0.3138136 liner, ocean liner
0.1791683 catamaran
0.0695947 drilling platform, offshore rig
0.0535790 dock, dockage, docking facility
0.0486278 breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
+resnext101_32x8d |0.2383151 beacon, lighthouse, beacon light, pharos
0.2232965 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1179476 water bottle
0.0526662 drilling platform, offshore rig
0.0363510 liner, ocean liner
|
+resnet18 |0.1980100 liner, ocean liner
0.1092247 submarine, pigboat, sub, U-boat
0.1024882 container ship, containership, container vessel
0.1021967 drilling platform, offshore rig
0.0800809 breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
+resnet34 |0.2605784 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1120401 fireboat
0.1080514 liner, ocean liner
0.0992261 pirate, pirate ship
0.0759654 container ship, containership, container vessel
|
+resnet50 |0.4759621 liner, ocean liner
0.1025401 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0690000 container ship, containership, container vessel
0.0524496 dock, dockage, docking facility
0.0473781 pirate, pirate ship
|
+resnet101 |0.8149654 drilling platform, offshore rig
0.0403631 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0207643 beacon, lighthouse, beacon light, pharos
0.0188019 container ship, containership, container vessel
0.0160020 liner, ocean liner
|
+resnet152 |0.3274736 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.2284682 liner, ocean liner
0.0779443 lifeboat
0.0710691 beacon, lighthouse, beacon light, pharos
0.0688560 container ship, containership, container vessel
|
+shufflenet_v2_x0_5 |0.2142884 agama
0.0945462 water bottle
0.0885760 jay
0.0579272 liner, ocean liner
0.0566052 breakwater, groin, groyne, mole, bulwark, seawall, jetty
|
+shufflenet_v2_x1_0 |0.3391730 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1130056 beacon, lighthouse, beacon light, pharos
0.0324217 liner, ocean liner
0.0203185 terrapin
0.0181812 drilling platform, offshore rig
|
+shufflenet_v2_x1_5 |0.0450035 pirate, pirate ship
0.0366058 lifeboat
0.0158971 liner, ocean liner
0.0139720 aircraft carrier, carrier, flattop, attack aircraft carrier
0.0139460 dock, dockage, docking facility
|
+shufflenet_v2_x2_0 |0.0682886 beacon, lighthouse, beacon light, pharos
0.0495765 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0448658 container ship, containership, container vessel
0.0410116 liner, ocean liner
0.0403494 lifeboat
|
+squeezenet1_0 |0.8105499 liner, ocean liner
0.0785143 drilling platform, offshore rig
0.0295160 container ship, containership, container vessel
0.0153662 dock, dockage, docking facility
0.0115069 submarine, pigboat, sub, U-boat
|
+squeezenet1_1 |0.4413064 liner, ocean liner
0.1931020 container ship, containership, container vessel
0.1459110 pirate, pirate ship
0.0937753 fireboat
0.0198683 drilling platform, offshore rig
|
+vgg11 |0.3343855 container ship, containership, container vessel
0.3068857 liner, ocean liner
0.0492899 submarine, pigboat, sub, U-boat
0.0455569 fireboat
0.0391509 lifeboat
|
+vgg11_bn |0.7272952 container ship, containership, container vessel
0.1716904 liner, ocean liner
0.0226532 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0206520 dock, dockage, docking facility
0.0114507 lifeboat
|
+vgg13 |0.3224932 container ship, containership, container vessel
0.2891453 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.1808192 liner, ocean liner
0.0591593 beacon, lighthouse, beacon light, pharos
0.0270378 dock, dockage, docking facility
|
+vgg13_bn |0.3478982 container ship, containership, container vessel
0.2664560 fireboat
0.0766569 lifeboat
0.0664668 liner, ocean liner
0.0515882 submarine, pigboat, sub, U-boat
|
+vgg16 |0.4804810 container ship, containership, container vessel
0.1304805 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0867475 liner, ocean liner
0.0751447 drilling platform, offshore rig
0.0444228 lifeboat
|
+vgg16_bn |0.5045572 container ship, containership, container vessel
0.1368753 liner, ocean liner
0.1096228 lifeboat
0.0501405 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0392852 dock, dockage, docking facility
|
+vgg19 |0.4432594 container ship, containership, container vessel
0.1617560 liner, ocean liner
0.1536936 fireboat
0.0549521 drilling platform, offshore rig
0.0304159 lifeboat
|
+vgg19_bn |0.2604308 fireboat
0.1715146 container ship, containership, container vessel
0.0810636 submarine, pigboat, sub, U-boat
0.0738690 dock, dockage, docking facility
0.0685641 lifeboat
|
+wide_resnet50_2 |0.1823847 liner, ocean liner
0.1433828 dock, dockage, docking facility
0.1098627 container ship, containership, container vessel
0.0992484 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0717444 catamaran
|
+wide_resnet101_2 |0.4919022 drilling platform, offshore rig
0.0996324 liner, ocean liner
0.0898810 beacon, lighthouse, beacon light, pharos
0.0402922 breakwater, groin, groyne, mole, bulwark, seawall, jetty
0.0381101 catamaran
|
+
+
+[imagenet]: http://www.image-net.org
+[torchvision_classification]: https://pytorch.org/vision/0.8/models.html
diff --git a/src/benchmark/README.md b/src/benchmark/README.md
index 9fa53bfec..55371a3c9 100644
--- a/src/benchmark/README.md
+++ b/src/benchmark/README.md
@@ -16,6 +16,7 @@
- [ONNX Runtime][onnx-runtime].
- [OpenCV][opencv].
- [MXNet][mxnet].
+- [PyTorch][pytorch].
### Алгоритм работы скрипта
@@ -106,10 +107,10 @@ Inference Engine предоставляет 2 программных интер
**Примечание:** в публикуемой html-таблице содержатся только показатели FPS.
-### Показатели производительности вывода для Intel® Optimization for Caffe, Intel® Optimization for TensorFlow, TensorFlow Lite, OpenCV и MXNet
+### Показатели производительности вывода для Intel® Optimization for Caffe, Intel® Optimization for TensorFlow, TensorFlow Lite, OpenCV, MXNet и PyTorch
При оценке производительности вывода для Intel® Optimization for Caffe,
-Intel® Optimization for TensorFlow, TensorFlow Lite, OpenCV и MXNet
+Intel® Optimization for TensorFlow, TensorFlow Lite, OpenCV, MXNet и PyTorch
осуществляется последовательный и независимый запуск запросов.
Запуск очередного запроса выполняется после завершения предыдущего.
Для каждого запроса осуществляется замер времени его выполнения.
@@ -252,3 +253,4 @@ pip install openvino_dev[mxnet,caffe,caffe2,onnx,pytorch,tensorflow2]==
```
+#### Пример заполнения конфигурации для измерения производительности вывода средствами PyTorch
+
+```xml
+
+
+
+
+ classification
+ alexnet
+ FP32
+ PyTorch
+
+
+
+
+ ImageNet
+ /mnt/datasets/ILSVRC2012_img_val
+
+
+ PyTorch
+ 1
+ CPU
+ 10
+ 60
+
+
+ data
+ 1 3 224 224
+ True
+ 0.485 0.456 0.406
+ 0.229 0.224 0.225
+
+ scripted
+ True
+
+
+
+```
+
## Заполнение файла конфигурации для скрипта оценки точности
### Правила заполнения
diff --git a/src/inference/README.md b/src/inference/README.md
index b7aa6408e..e09edcbc8 100644
--- a/src/inference/README.md
+++ b/src/inference/README.md
@@ -9,6 +9,7 @@
1. Intel® Optimization for TensorFlow.
1. TensorFlow Lite.
1. MXNet.
+1. PyTorch.
## Вывод глубоких моделей с использованием Inference Engine
@@ -550,6 +551,8 @@ inference_mxnet.py
- `-t / --task` - название задачи. Текущая реализация поддерживает
решение задачи классификации. По умолчанию принимает значение
`classification`.
+- `-b / --batch_size` - количество изображений, которые будут обработаны
+ за один проход сети. По умолчанию равно `1`.
- `-in / --input_name` - название входа модели. По умолчанию модель
имеет один вход с названием `data`. Текущая реализация вывода
предусматривает наличие только одного входа.
@@ -606,7 +609,6 @@ python inference_mxnet.py --model_name \
--input_name \
--input_shape \
--norm --mean --std \
- --batch_size \
--save_model --path_save_model
```
@@ -618,11 +620,105 @@ python inference_mxnet.py --model .json \
--input_name \
--input_shape \
--input \
- --labels .json \
- --batch_size
+ --labels .json
+```
+
+## Вывод глубоких моделей с использованием PyTorch (TorchVision)
+
+#### Аргументы командной строки
+
+Название скрипта:
+
+```bash
+inference_pytorch.py
+```
+
+Обязательные аргументы:
+
+- `-m / --model` - путь до описания архитектуры модели
+ в формате `.pt`.
+- `-mn / --model_name` - название модели, если модель
+ загружается из [TorchVision][torchvision].
+ При таком варианте запуска модель загружается из сети Интернет.
+- `-i / --input` - путь до изображения или директории
+ с изображениями (расширения файлов `.jpg`, `.png`,
+ `.bmp` и т.д.).
+- `-is / --input_shape` - размеры входного тензора сети в формате
+ BxCxHxW, B - размер пачки, C - количество каналов изображений,
+ W - ширина изображений, H - высота изображений.
+
+Опциональные аргументы:
+
+- `-t / --task` - название задачи. Текущая реализация поддерживает
+ решение задачи классификации. По умолчанию принимает значение
+ `feedforward`.
+- `-b / --batch_size` - количество изображений, которые будут обработаны
+ за один проход сети. По умолчанию равно `1`.
+- `-in / --input_name` - название входа модели. По умолчанию модель
+ имеет один вход с названием `data`. Текущая реализация вывода
+ предусматривает наличие только одного входа.
+- `--norm` - флаг необходимости нормировки изображений.
+ Выполняется с использованием модуля `torchvision.transforms`.
+ Среднее и среднеквадратическое отклонение, которые принимаются
+ на вход указываются в следующих двух аргументах.
+- `--mean` - среднее значение интенсивности, которое вычитается
+ из изображений в процессе нормировки. Для классификационных моделей
+ из [TorchVision][torchvision], которые обучены на наборе
+ данных ImageNet, значение равно `0.485 0.456 0.406`. По умолчанию
+ данный параметр принимает значение `0 0 0`.
+- `--std` - среднеквадратическое отклонение интенсивности, на которое
+ делится значение интенсивности каждого пикселя входного изображения
+ в процессе нормировки. Для классификационных моделей
+ из [TorchVision][torchvision], которые обучены на наборе
+ данных ImageNet, значение равно `0.229 0.224 0.225`. По умолчанию
+ данный параметр принимает значение `1 1 1`.
+- `--output_names` - название выхода модели. По умолчанию модель
+ имеет один вход с названием `output`. Текущая реализация вывода
+ предусматривает наличие только одного выхода.
+- `-d / --device` - оборудование, на котором выполняется вывод сети.
+ Поддерживается вывод на CPU (значение параметра `CPU`) и NVIDIA GPU
+ (значение параметра `NVIDIA GPU`). По умолчанию принимает значение `CPU`.
+- `-l / --labels`- путь до файла в формате JSON с перечнем меток
+ при решении задачи. По умолчанию принимает значение
+ `image_net_labels.json`, что соответствует меткам набора данных
+ ImageNet.
+- `-ni / --number_iter` - количество прямых проходов по сети.
+ По умолчанию выполняется один проход по сети.
+- `--raw_output` - работа скрипта без логов. По умолчанию не установлен.
+- `--model_type` - тип модели для запуска. Доступно два режима:
+ `baseline` - базовый режим и `scripted` - оптимизированный режим с использованием
+ `JIT` (`just in time`) компилятора.
+ По умолчанию используется версия `scripted` модели.
+- `--inference_mode` - флаг оптимизированного режима прогонки модели
+ без хранения информации для тренировки модели.
+ По умолчанию данный режим включен.
+
+#### Примеры запуска
+
+**Запуск вывода для модели, которая загружается из TorchVision**
+
+```bash
+python inference_pytorch.py --model_name \
+ --input \
+ --input_name \
+ --input_shape \
+ --norm --mean --std \
+ --batch_size
+```
+
+**Запуск вывода для модели, которая загружается из файлов**
+
+```bash
+python inference_pytorch.py --model .pt \
+ --input_name \
+ --input_shape \
+ --input \
+ --labels .json \
+ --batch_size
```
[gluon_modelzoo]: https://cv.gluon.ai/model_zoo/index.html
[tflite_delegates]: https://www.tensorflow.org/lite/performance/delegates
+[torchvision]: https://pytorch.org/vision/stable/models.html
diff --git a/src/inference/inference_pytorch.py b/src/inference/inference_pytorch.py
new file mode 100644
index 000000000..50c5b036f
--- /dev/null
+++ b/src/inference/inference_pytorch.py
@@ -0,0 +1,297 @@
+import sys
+import argparse
+import traceback
+import importlib
+import logging as log
+from time import time
+
+import torch
+
+import postprocessing_data as pp
+from io_adapter import IOAdapter
+from io_model_wrapper import PyTorchIOModelWrapper
+from transformer import PyTorchTransformer
+
+
+def cli_argument_parser():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-m', '--model',
+ help='Path to PyTorch model with format .pt.',
+ type=str,
+ dest='model')
+ parser.add_argument('-mn', '--model_name',
+ help='Model name from TorchVision.',
+ type=str,
+ dest='model_name')
+ parser.add_argument('-i', '--input',
+ help='Path to data.',
+ required=True,
+ type=str,
+ nargs='+',
+ dest='input')
+ parser.add_argument('-in', '--input_name',
+ help='Input name.',
+ default='data',
+ type=str,
+ dest='input_name')
+ parser.add_argument('-is', '--input_shape',
+ help='Input shape BxCxHxW, B is a batch size,'
+ 'C is an input tensor number of channels'
+ 'H is an input tensor height,'
+ 'W is an input tensor width.',
+ required=True,
+ type=int,
+ nargs=4,
+ dest='input_shape')
+ parser.add_argument('--norm',
+ help='Flag to normalize input images'
+ '(use --mean and --std arguments to set'
+ 'required normalization parameters).',
+ action='store_true',
+ dest='norm')
+ parser.add_argument('--mean',
+ help='Mean values.',
+ default=[0, 0, 0],
+ type=float,
+ nargs=3,
+ dest='mean')
+ parser.add_argument('--std',
+ help='Standard deviation values.',
+ default=[1., 1., 1.],
+ type=float,
+ nargs=3,
+ dest='std')
+ parser.add_argument('--output_names',
+ help='Name of the output tensors.',
+ default='output',
+ type=str,
+ nargs='+',
+ dest='output_names')
+ parser.add_argument('-b', '--batch_size',
+ help='Batch size.',
+ default=1,
+ type=int,
+ dest='batch_size')
+ parser.add_argument('-l', '--labels',
+ help='Labels mapping file.',
+ default=None,
+ type=str,
+ dest='labels')
+ parser.add_argument('-nt', '--number_top',
+ help='Number of top results.',
+ default=5,
+ type=int,
+ dest='number_top')
+ parser.add_argument('-t', '--task',
+ help='Task type determines the type of output processing '
+ 'method. Available values: feedforward - without'
+ 'postprocessing (by default), classification - output'
+ 'is a vector of probabilities.',
+ choices=['feedforward', 'classification'],
+ default='feedforward',
+ type=str,
+ dest='task')
+ parser.add_argument('-ni', '--number_iter',
+ help='Number of inference iterations.',
+ default=1,
+ type=int,
+ dest='number_iter')
+ parser.add_argument('--raw_output',
+ help='Raw output without logs.',
+ default=False,
+ type=bool,
+ dest='raw_output')
+ parser.add_argument('-d', '--device',
+ help='Specify the target device to infer on CPU or '
+ 'NVIDIA GPU (CPU by default)',
+ default='CPU',
+ type=str,
+ dest='device')
+ parser.add_argument('--model_type',
+ help='Model type for inference',
+ choices=['scripted', 'baseline'],
+ default='scripted',
+ type=str,
+ dest='model_type')
+ parser.add_argument('--inference_mode',
+ help='Inference mode',
+ default=True,
+ type=bool,
+ dest='inference_mode')
+
+ args = parser.parse_args()
+
+ return args
+
+
+def get_device_to_infer(device):
+ log.info('Get device for inference')
+ if device == 'CPU':
+ log.info(f'Inference will be executed on {device}')
+ return torch.device('cpu')
+ elif device == 'NVIDIA GPU':
+ log.info(f'Inference will be executed on {device}')
+ return torch.device('cuda')
+ else:
+ log.info(f'The device {device} is not supported')
+ raise ValueError('The device is not supported')
+
+
+def load_model_from_module(model_name):
+ model_cls = model_name
+ model_path = 'torchvision.models'
+ model_cls = importlib.import_module(model_path).__getattribute__(model_cls)
+ module = model_cls(weights=True)
+ return module
+
+
+def load_model_from_file(model_path):
+ log.info(f'Loading model from path {model_path}')
+ file_type = model_path.split('.')[-1]
+ supported_extensions = ['pt']
+ if file_type not in supported_extensions:
+ raise ValueError(f'The file type {file_type} is not supported')
+ model = torch.load(model_path)
+ return model
+
+
+def compile_model(module, device, model_type):
+ if model_type == 'baseline':
+ log.info('Inference will be executed on baseline model')
+ elif model_type == 'scripted':
+ log.info('Inference will be executed on scripted model')
+ module = torch.jit.script(module)
+ else:
+ raise ValueError(f'Model type {model_type} is not supported for inference')
+ module.to(device)
+ module.eval()
+ return module
+
+
+def create_dict_for_transformer(args):
+ dictionary = {
+ 'mean': args.mean,
+ 'std': args.std,
+ 'norm': args.norm,
+ 'input_shape': args.input_shape,
+ 'batch_size': args.batch_size,
+ }
+ return dictionary
+
+
+def create_dict_for_modelwrapper(args):
+ dictionary = {
+ 'input_name': args.input_name,
+ 'input_shape': [args.batch_size] + args.input_shape[1:],
+ }
+ return dictionary
+
+
+def inference_pytorch(model, num_iterations, get_slice, input_name, inference_mode):
+ with torch.inference_mode(inference_mode):
+ predictions = None
+ time_infer = []
+ slice_input = None
+ if num_iterations == 1:
+ slice_input = get_slice(0)
+ t0 = time()
+ predictions = torch.nn.functional.softmax(model(slice_input[input_name]), dim=1)
+ t1 = time()
+ time_infer.append(t1 - t0)
+ else:
+ for i in range(num_iterations):
+ slice_input = get_slice(i)
+ t0 = time()
+ torch.nn.functional.softmax(model(slice_input[input_name]), dim=1)
+ t1 = time()
+ time_infer.append(t1 - t0)
+
+ return predictions, time_infer
+
+
+def process_result(batch_size, inference_time):
+ inference_time = pp.three_sigma_rule(inference_time)
+ average_time = pp.calculate_average_time(inference_time)
+ latency = pp.calculate_latency(inference_time)
+ fps = pp.calculate_fps(batch_size, latency)
+ return average_time, latency, fps
+
+
+def result_output(average_time, fps, latency):
+ log.info('Average time of single pass : {0:.3f}'.format(average_time))
+ log.info('FPS : {0:.3f}'.format(fps))
+ log.info('Latency : {0:.3f}'.format(latency))
+
+
+def raw_result_output(average_time, fps, latency):
+ print('{0:.3f},{1:.3f},{2:.3f}'.format(average_time, fps, latency))
+
+
+def prepare_output(result, output_names, task):
+ if task == 'feedforward':
+ return {}
+ if (output_names is None) or len(output_names) == 0:
+ raise ValueError('The number of output tensors does not match the number of corresponding output names')
+ if task == 'classification':
+ return {output_names[0]: result.detach().numpy()}
+ else:
+ raise ValueError(f'Unsupported task {task} to print inference results')
+
+
+def main():
+ log.basicConfig(
+ format='[ %(levelname)s ] %(message)s',
+ level=log.INFO,
+ stream=sys.stdout,
+ )
+ args = cli_argument_parser()
+ try:
+ model_wrapper = PyTorchIOModelWrapper(create_dict_for_modelwrapper(args))
+ data_transformer = PyTorchTransformer(create_dict_for_transformer(args))
+ io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer)
+
+ if args.model_name is not None and args.model is None:
+ model = load_model_from_module(args.model_name)
+ elif args.model_name is None and args.model is not None:
+ model = load_model_from_file(args.model)
+ else:
+ raise ValueError('Incorrect arguments.')
+
+ device = get_device_to_infer(args.device)
+ compiled_model = compile_model(model, device, args.model_type)
+
+ log.info(f'Shape for input layer {args.input_name}: {args.input_shape}')
+
+ log.info(f'Preparing input data {args.input}')
+ io.prepare_input(compiled_model, args.input)
+
+ log.info(f'Starting inference ({args.number_iter} iterations) on {args.device}')
+ result, inference_time = inference_pytorch(compiled_model, args.number_iter,
+ io.get_slice_input, args.input_name, args.inference_mode)
+
+ log.info('Computing performance metrics')
+ average_time, latency, fps = process_result(args.batch_size, inference_time)
+
+ if not args.raw_output:
+ if args.number_iter == 1:
+ try:
+ log.info('Converting output tensor to print results')
+ result = prepare_output(result, args.output_names, args.task)
+
+ log.info('Inference results')
+ io.process_output(result, log)
+ except Exception as ex:
+ log.warning('Error when printing inference results. {0}'.format(str(ex)))
+
+ log.info('Performance results')
+ result_output(average_time, fps, latency)
+ else:
+ raw_result_output(average_time, fps, latency)
+ except Exception:
+ log.error(traceback.format_exc())
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ sys.exit(main() or 0)
diff --git a/src/inference/io_model_wrapper.py b/src/inference/io_model_wrapper.py
index b63f8a231..3770acb8d 100644
--- a/src/inference/io_model_wrapper.py
+++ b/src/inference/io_model_wrapper.py
@@ -184,3 +184,19 @@ def get_input_layer_shape(self, model, layer_name):
def get_input_layer_dtype(self, model, layer_name):
from numpy import float32
return float32
+
+
+class PyTorchIOModelWrapper(IOModelWrapper):
+ def __init__(self, args):
+ self._input_names = [args['input_name']]
+ self._input_shapes = [args['input_shape']]
+
+ def get_input_layer_names(self, model):
+ return self._input_names
+
+ def get_input_layer_shape(self, model, layer_name):
+ return self._input_shapes[0]
+
+ def get_input_layer_dtype(self, model, layer_name):
+ import torch
+ return torch.float32
diff --git a/src/inference/transformer.py b/src/inference/transformer.py
index 15c0ef856..fea99518c 100644
--- a/src/inference/transformer.py
+++ b/src/inference/transformer.py
@@ -254,3 +254,37 @@ def transform_images(self, images, shape, element_type, *args):
blob = cv2.dnn.blobFromImages(images, **self._converting)
transformed_blob = self.__set_layout_order(blob / self._std)
return transformed_blob
+
+
+class PyTorchTransformer(Transformer):
+ def __init__(self, converting):
+ self._converting = converting
+
+ def __set_norm(self, image):
+ import torchvision
+
+ if not self._converting['norm']:
+ preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor()])
+ return preprocess(image.astype(np.float32))
+
+ preprocess = torchvision.transforms.Compose([
+ torchvision.transforms.ToTensor(),
+ torchvision.transforms.Normalize(mean=self._converting['mean'],
+ std=self._converting['std']),
+ ])
+
+ return preprocess(image.astype(np.float32) / 255)
+
+ def _transform(self, image):
+ normalized_image = self.__set_norm(image)
+ return normalized_image
+
+ def transform_images(self, images, shape, element_type, *args):
+ import torch
+ dataset_size = images.shape[0]
+ new_shape = [dataset_size] + shape[1:]
+ transformed_images = torch.zeros(new_shape, dtype=element_type)
+ for i in range(dataset_size):
+ transformed_images[i] = self._transform(images[i])
+ return transformed_images
diff --git a/test/smoke_test/smoke_config.xml b/test/smoke_test/smoke_config.xml
index a9323935e..45b78514d 100644
--- a/test/smoke_test/smoke_config.xml
+++ b/test/smoke_test/smoke_config.xml
@@ -184,4 +184,35 @@
+
+
+ classification
+ alexnet
+ FP32
+ PyTorch
+
+
+
+
+ ImageNet
+ ./black_square.jpg
+
+
+ PyTorch
+ 1
+ CPU
+ 10
+ 60
+
+
+ data
+ 1 3 224 224
+ True
+ 0.485 0.456 0.406
+ 0.229 0.224 0.225
+
+
+
+
+