diff --git a/onnx_tf/backend_rep.py b/onnx_tf/backend_rep.py index bee1ec49e..71ea39064 100644 --- a/onnx_tf/backend_rep.py +++ b/onnx_tf/backend_rep.py @@ -49,40 +49,43 @@ def tensor_dict(self): def tensor_dict(self, tensor_dict): self._tensor_dict = tensor_dict - def run(self, inputs, **kwargs): + def run(self, inputs, sess=None, **kwargs): """ Run TensorflowRep. :param inputs: Given inputs. + :param sess: tf.Session. The environment in which Operation objects are executed, + and Tensor objects are evaluated. :param kwargs: Other args. :return: Outputs. """ super(TensorflowRep, self).run(inputs, **kwargs) + should_close_sess = sess is None # TODO: handle name scope if necessary with self.graph.as_default(): - with tf.Session() as sess: - if isinstance(inputs, dict): - feed_dict = inputs - elif isinstance(inputs, list) or isinstance(inputs, tuple): - if len(self.inputs) != len(inputs): - raise RuntimeError('Expected {} values for uninitialized ' - 'graph inputs ({}), but got {}.'.format( - len(self.inputs), ', '.join(self.inputs), - len(inputs))) - feed_dict = dict(zip(self.inputs, inputs)) - else: - # single input - feed_dict = dict([(self.inputs[0], inputs)]) - - feed_dict = { - self.tensor_dict[key]: feed_dict[key] for key in self.inputs - } - - sess.run(tf.global_variables_initializer()) - outputs = [self.tensor_dict[output] for output in self.outputs] - - output_values = sess.run(outputs, feed_dict=feed_dict) - return namedtupledict('Outputs', self.outputs)(*output_values) + sess = sess or tf.Session() + if isinstance(inputs, dict): + feed_dict = inputs + elif isinstance(inputs, list) or isinstance(inputs, tuple): + if len(self.inputs) != len(inputs): + raise RuntimeError('Expected {} values for uninitialized ' + 'graph inputs ({}), but got {}.'.format( + len(self.inputs), ', '.join(self.inputs), + len(inputs))) + feed_dict = dict(zip(self.inputs, inputs)) + else: + # single input + feed_dict = dict([(self.inputs[0], inputs)]) + + feed_dict = {self.tensor_dict[key]: feed_dict[key] for key in self.inputs} + + sess.run(tf.global_variables_initializer()) + outputs = [self.tensor_dict[output] for output in self.outputs] + + output_values = sess.run(outputs, feed_dict=feed_dict) + if should_close_sess: + sess.close() + return namedtupledict('Outputs', self.outputs)(*output_values) def export_graph(self, path): """Export backend representation to a Tensorflow proto file. @@ -99,3 +102,12 @@ def export_graph(self, path): file = open(path, "wb") file.write(graph_proto.SerializeToString()) file.close() + + def create_session(self): + """ Create tf.Session object by using current graph. + Pass it to `run` function could reduce the overhead of initialization + when doing inference consecutively. + + :returns: A Session object. + """ + return tf.Session(graph=self.graph)