Skip to content

Commit

Permalink
Got TensorRT working.
Browse files Browse the repository at this point in the history
make the utility util_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py.

Following was the progress:
a. Got VGG base CNN working correctly with UFF parser
b. Got mobilenet basee working correct with UFF parser. Issue was caused by Merge Op and Switch Op which was solved using K.set_learning_phase(0) even before graph construction. This was critical
c. ExpandDims is still not supported by Nvidia's tensorrt. Way around is using reshape in netvlad.

TODO:
a. expand_dims --> reshape. Verify UFFparser works on x86 as well as on TX2
b. numerical verification by giving same image to keras and same image to tensorrt.
  • Loading branch information
mpkuse committed May 30, 2019
1 parent c23f23b commit 0d8bbfc
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 23 deletions.
11 changes: 9 additions & 2 deletions CustomNets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,25 @@ def build( self, input_shape ):
trainable=True)

def call( self, x ):
print 'input x.shape=', x.shape
# soft-assignment.
s = K.conv2d( x, self.kernel, padding='same' ) + self.bias
a = K.softmax( s )
self.amap = K.argmax( a, -1 )
# print 'amap.shape', self.amap.shape

# import code
# code.interact( local=locals() )
# Dims used hereafter: batch, H, W, desc_coeff, cluster
a = K.expand_dims( a, -2 )
# print 'a.shape (before)=', a.shape
a = K.expand_dims( a, -2 ) #original code
# a = K.reshape( a, [ K.shape(a)[0], K.shape(a)[1], K.shape(a)[2], 1, K.shape(a)[3]] )
# print 'a.shape=',a.shape

# Core
v = K.expand_dims(x, -1) + self.C
# print 'x.shape', x.shape
v = K.expand_dims(x, -1) + self.C #original code
# v = K.reshape( x, [ K.shape(x)[0], K.shape(x)[1], K.shape(x)[2], K.shape(x)[3], 1 ] ) + self.C
# print 'v.shape', v.shape
v = a * v
# print 'v.shape', v.shape
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ for more details in this regard.

The following script in this repo, will help you convert hdf5 keras models
to .uff. Beware, that this is a rapidly changing/evolving thing.
This info is accurate for May 2019.
This info is accurate for May 2019.
```
python util_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py --keras_h5_model <path to hdf5 file>
python util_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py --kerasmodel_h5file <path to hdf5 file>
```

## References
Expand Down
3 changes: 2 additions & 1 deletion test_ghostvlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


input_img = keras.layers.Input( shape=(60, 80, 256 ) )
out = GhostVLADLayer(num_clusters = 16, num_ghost_clusters = 1)( input_img )
# out = GhostVLADLayer(num_clusters = 16, num_ghost_clusters = 1)( input_img )
out = NetVLADLayer(num_clusters = 16)( input_img )
model = keras.models.Model( inputs=input_img, outputs=out )

model.predict( np.random.rand( 1,60,80,256).astype('float32') )
14 changes: 14 additions & 0 deletions test_tensorrt_uffparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# To check if .uff file can be loaded back

import tensorrt as trt

TRT_LOGGER = trt.Logger( trt.Logger.WARNING)
LOG_DIR = 'models.keras/June2019/centeredinput-m1to1-240x320x3__mobilenet-conv_pw_6_relu__K16__allpairloss/'
uff_fname = 'output_nvinfer.uff'

with trt.Builder( TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
parser.register_input("input_1", (3,240,320) )
# parser.register_output( "conv_pw_5_relu/Relu6" )
parser.register_output( "net_vlad_layer_1/l2_normalize_1" )
parser.parse( LOG_DIR+'/'+uff_fname, network )
pass
148 changes: 130 additions & 18 deletions util_keras-h5-model_to-tensorflow-pb_to-nvinfer-uff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import keras
import numpy as np
import os

import tensorflow as tf
from CustomNets import NetVLADLayer, GhostVLADLayer
from predict_utils import change_model_inputshape

from keras import backend as K

import TerminalColors
tcol = TerminalColors.bcolors()
Expand All @@ -24,6 +24,7 @@
def load_keras_hdf5_model( kerasmodel_h5file, verbose=True ):
""" Loads keras model from a HDF5 file """
assert os.path.isfile( kerasmodel_h5file ), 'The model weights file doesnot exists or there is a permission issue.'+"kerasmodel_file="+kerasmodel_h5file
K.set_learning_phase(0)

model = keras.models.load_model(kerasmodel_h5file, custom_objects={'NetVLADLayer': NetVLADLayer, 'GhostVLADLayer': GhostVLADLayer} )

Expand All @@ -33,41 +34,77 @@ def load_keras_hdf5_model( kerasmodel_h5file, verbose=True ):

return model

def write_kerasmodel_as_tensorflow_pb( model, LOG_DIR ):

def load_basic_model( ):
K.set_learning_phase(0)
from CustomNets import make_from_mobilenet, make_from_vgg16
from CustomNets import NetVLADLayer, GhostVLADLayer

input_img = keras.layers.Input( shape=(240, 320, 3 ) )
cnn = make_from_mobilenet( input_img, layer_name='conv_pw_5_relu', weights=None, kernel_regularizer=keras.regularizers.l2(0.01) )
# cnn = make_from_vgg16( input_img, weights=None, layer_name='block5_pool', kernel_regularizer=keras.regularizers.l2(0.01) )

# base_model = keras.applications.mobilenet_v2.MobileNetV2( weights=None, include_top=False, input_tensor=input_img )
# cnn = base_model.get_layer( 'block_11_add' ).output

model = keras.models.Model( inputs=input_img, outputs=cnn )


# out = NetVLADLayer(num_clusters = 16)( cnn )
# model = keras.models.Model( inputs=input_img, outputs=out )


model.summary()
return model

def write_kerasmodel_as_tensorflow_pb( model, LOG_DIR, output_model_name='output_model.pb' ):
""" Takes as input a keras.models.Model() and writes out
Tensorflow proto-binary.
"""
print tcol.HEADER,'[write_kerasmodel_as_tensorflow_pb] Start', tcol.ENDC
from keras import backend as K

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
K.set_learning_phase(0)
sess = K.get_session()

# Write .pbtxt (for viz only)
output_model_pbtxt_name = 'output_model.pbtxt' #output_model_name+'.pbtxt' #
print tcol.OKGREEN, 'Write ', output_model_pbtxt_name, tcol.ENDC
tf.train.write_graph(sess.graph.as_graph_def(), LOG_DIR,
output_model_pbtxt_name, as_text=True)


# Make const
print 'Make Computation Graph as Constant and Prune unnecessary stuff from it'
constant_graph = graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(),
[node.op.name for node in model.outputs])
constant_graph = tf.graph_util.remove_training_nodes(constant_graph)


#--- convert Switch --> Identity
# I am doing this because TensorRT cannot process Switch operations.
# # https://github.com/tensorflow/tensorflow/issues/8404#issuecomment-297469468
# for node in constant_graph.node:
# if node.op == "Switch":
# node.op = "Identity"
# del node.input[1]
# # END

# Write .pb
output_model_name = 'output_model.pb'
# output_model_name = 'output_model.pb'
print tcol.OKGREEN, 'Write ', output_model_name, tcol.ENDC
print 'model.outputs=', [node.op.name for node in model.outputs]
graph_io.write_graph(constant_graph, LOG_DIR, output_model_name,
as_text=False)
print tcol.HEADER, '[write_kerasmodel_as_tensorflow_pb] Done', tcol.ENDC


# Write .pbtxt (for viz only)
output_model_pbtxt_name = output_model_name+'.pbtxt' #'output_model.pbtxt'
print tcol.OKGREEN, 'Write ', output_model_pbtxt_name, tcol.ENDC
tf.train.write_graph(constant_graph, LOG_DIR,
output_model_pbtxt_name, as_text=True)


def convert_to_uff( pb_input_fname, uff_output_fname ):
""" Uses Nvidia's `convert-to-uff` through os.system.
This will convert the .pb file (generated from call to `write_kerasmodel_as_tensorflow_pb` )
Expand Down Expand Up @@ -112,11 +149,77 @@ def convert_to_uff( pb_input_fname, uff_output_fname ):

assert os.path.isfile( pb_input_fname ), "The .pb file="+str(pb_input_fname)+" does not exist"

cmd = 'convert-to-uff -t -o %s %s' %(uff_output_fname, pb_input_fname)
cmd = 'convert-to-uff -t -o %s %s | tee %s' %(uff_output_fname, pb_input_fname, uff_output_fname+'.log')
print tcol.HEADER, '[bash run] ', cmd, tcol.ENDC

os.system( cmd )

print tcol.WARNING, 'If there are warning above like `No conversion function...`, this means that Nvidias UFF doesnt yet have certain function. Most like in this case your model cannot be run with tensorrt.', tcol.ENDC


def graphsurgeon_cleanup( LOG_DIR, input_model_name='output_model.pb', cleaned_model_name='output_model_aftersurgery.pb' ):
""" Loads the tensorflow frozen_graph and cleans up with nvidia's graphsurgeon
"""
assert os.path.isfile( LOG_DIR+'/'+input_model_name ), "[graphsurgeon_cleanup]The .pb file="+str(input_model_name)+" does not exist"

import graphsurgeon as gs
print tcol.HEADER, '[graphsurgeon_cleanup] graphsurgeon.__version__', gs.__version__, tcol.ENDC

DG = gs.DynamicGraph()
print tcol.OKGREEN, '[graphsurgeon_cleanup] READ tensorflow Graph using graphsurgeon.DynamicGraph: ', LOG_DIR+'/'+input_model_name, tcol.ENDC
DG.read( LOG_DIR+'/'+input_model_name )


# Remove control variable first


all_switch = DG.find_nodes_by_op( 'Switch' )
DG.forward_inputs( all_switch )
print 'Write (after graphsurgery) : ', LOG_DIR+'/'+cleaned_model_name
DG.write( LOG_DIR+'/'+cleaned_model_name )


if os.path.isdir( LOG_DIR+'/graphsurgeon_cleanup' ):
pass
else:
os.mkdir( LOG_DIR+'/graphsurgeon_cleanup')
DG.write_tensorboard( LOG_DIR+'/graphsurgeon_cleanup' )


# import code
# code.interact( local=locals() )


print tcol.HEADER, '[graphsurgeon_cleanup] END', tcol.ENDC


# def verify_generated_uff_with_tensorrt_uffparser( ufffilename, uffinput, uffinput_dims, uff_output ):
def verify_generated_uff_with_tensorrt_uffparser( ufffilename ):
""" Loads the UFF file with TensorRT (py). """
assert os.path.isfile( ufffilename ), "ufffilename="+ ufffilename+ ' doesnt exist'
import tensorrt as trt

print tcol.HEADER, '[verify_generated_uff_with_tensorrt_uffparser] TensorRT version=', trt.__version__, tcol.ENDC

try:
uffinput = "input_1"
uffinput_dims = (3,240,320)
uffoutput = "conv_pw_5_relu/Relu6"
# uffoutput = "net_vlad_layer_1/l2_normalize_1"

TRT_LOGGER = trt.Logger( trt.Logger.WARNING)
with trt.Builder( TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
print 'ufffilename=', str( ufffilename)
print 'uffinput=', str( uffinput), '\t', 'uffinput_dims=', str( uffinput_dims)
print 'uffoutput=', str( uffoutput)
parser.register_input( uffinput, uffinput_dims )
parser.register_output( uffoutput )
parser.parse( ufffilename, network )
pass

print tcol.OKGREEN, '[verify_generated_uff_with_tensorrt_uffparser] Verified.....!', tcol.ENDC
except:
print tcol.FAIL, '[verify_generated_uff_with_tensorrt_uffparser] UFF file=', ufffilename, ' with uffinput=', uffinput , ' uffoutput=', uffoutput , ' cannot be parsed.'



Expand All @@ -127,14 +230,11 @@ def convert_to_uff( pb_input_fname, uff_output_fname ):
parser.add_argument('--kerasmodel_h5file', '-h5', type=str, help='The input keras modelarch_and_weights full filename')
args = parser.parse_args()

# import code
# code.interact( local=locals() )
# quit()


#---
# Paths, File Init and other initialize
#kerasmodel_h5file = 'models.keras/June2019/centeredinput-m1to1-240x320x3__mobilenet-conv_pw_6_relu__K16__allpairloss/modelarch_and_weights.700.h5'
# kerasmodel_h5file = 'models.keras/June2019/centeredinput-m1to1-240x320x3__mobilenet-conv_pw_6_relu__K16__allpairloss/modelarch_and_weights.700.h5'
kerasmodel_h5file = args.kerasmodel_h5file

LOG_DIR = '/'.join( kerasmodel_h5file.split('/')[0:-1] )
Expand All @@ -147,12 +247,13 @@ def convert_to_uff( pb_input_fname, uff_output_fname ):

#---
# Load HDF5 Keras model
model = load_keras_hdf5_model( kerasmodel_h5file, verbose=True )
# model = load_keras_hdf5_model( kerasmodel_h5file, verbose=True ) #this
model = load_basic_model()


#-----
# Replace Input Layer's Dimensions
im_rows = 480
im_rows = None#480
im_cols = 752
im_chnls = 3
if im_rows == None or im_cols == None or im_chnls == None:
Expand All @@ -167,9 +268,20 @@ def convert_to_uff( pb_input_fname, uff_output_fname ):

#-----
# Write Tensorflow (atleast 1.12) proto-binary (.pb)
write_kerasmodel_as_tensorflow_pb( new_model, LOG_DIR=LOG_DIR )
write_kerasmodel_as_tensorflow_pb( new_model, LOG_DIR=LOG_DIR, output_model_name='output_model.pb' )


#-----
# Clean up graph with Nvidia's graphsurgeon
# currently not in use but might come in handly later...maybe
# graphsurgeon_cleanup( LOG_DIR=LOG_DIR, input_model_name='output_model.pb', cleaned_model_name='output_model_aftersurgery.pb')

#-----
# Write UFF
convert_to_uff( pb_input_fname=LOG_DIR+'/output_model.pb', uff_output_fname=LOG_DIR+'/output_nvinfer.uff' )
# convert_to_uff( pb_input_fname=LOG_DIR+'/output_model_aftersurgery.pb', uff_output_fname=LOG_DIR+'/output_nvinfer.uff' )


#-----
# Try to load UFF with tensorrt
verify_generated_uff_with_tensorrt_uffparser( ufffilename=LOG_DIR+'/output_nvinfer.uff' )

0 comments on commit 0d8bbfc

Please sign in to comment.