-
Notifications
You must be signed in to change notification settings - Fork 0
/
B2_cntkVisualizeInputs.py
52 lines (41 loc) · 2.16 KB
/
B2_cntkVisualizeInputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os, importlib, sys
from cntk_helpers import *
import PARAMETERS
locals().update(importlib.import_module("PARAMETERS").__dict__)
####################################
# Parameters
####################################
image_set = 'train' # 'train', 'test'
#no need to change these parameters
parseNrImages = 50 #None #for speed reasons only parse CNTK file for the first N images
boUseNonMaximaSurpression = False
####################################
# Main
####################################
print "Load ROI co-ordinates and labels"
cntkImgsPath, cntkRoiCoordsPath, cntkRoiLabelsPath, nrRoisPath = getCntkInputPaths(cntkFilesDir, image_set)
imgPaths = getColumn(readTable(cntkImgsPath),1)
nrRealRois = [int(s) for s in readFile(nrRoisPath)]
roiAllLabels = readCntkRoiLabels(cntkRoiLabelsPath, cntk_nrRois, len(classes), parseNrImages)
if parseNrImages:
imgPaths = imgPaths[:parseNrImages]
nrRealRois = nrRealRois[:parseNrImages]
roiAllLabels = roiAllLabels[:parseNrImages]
roiAllCoords = readCntkRoiCoordinates(imgPaths, cntkRoiCoordsPath, cntk_nrRois, cntk_padWidth, cntk_padHeight, parseNrImages)
assert(len(imgPaths) == len(roiAllCoords) == len(roiAllLabels) == len(nrRealRois))
#loop over all images and visualize
for imgIndex,imgPath in enumerate(imgPaths):
print "Visualizing image %d at %s..." %(imgIndex,imgPath)
roiCoords = roiAllCoords[imgIndex][:nrRealRois[imgIndex]]
roiLabels = roiAllLabels[imgIndex][:nrRealRois[imgIndex]]
#perform non-maxima surpression. note that the detected classes in the image is not affected by this.
nmsKeepIndices = []
if boUseNonMaximaSurpression:
imgWidth, imgHeight = imWidthHeight(imgPath)
nmsKeepIndices = applyNonMaximaSuppression(nmsThreshold, roiLabels, [0] * len(roiLabels), roiCoords)
print "Non-maxima surpression kept {} of {} rois (nmsThreshold={})".format(len(nmsKeepIndices), len(roiLabels), nmsThreshold)
#visualize results
imgDebug = visualizeResults(imgPath, roiLabels, None, roiCoords, cntk_padWidth, cntk_padHeight,
classes, nmsKeepIndices, boDrawNegativeRois=False)
imshow(imgDebug, waitDuration=1, maxDim = 800)
print "DONE."