Skip to content

Commit

Permalink
disable dynamic batch
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 8, 2022
1 parent 1188d38 commit fdb5efa
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
13 changes: 13 additions & 0 deletions onnx2tf/ops/NonMaxSuppression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import random
random.seed(0)
import numpy as np
Expand All @@ -9,6 +10,7 @@
print_node_info,
inverted_operation_enable_disable,
)
from onnx2tf.utils.colors import Color


@print_node_info
Expand Down Expand Up @@ -117,6 +119,17 @@ def make_node(
boxes = tf.transpose(boxes_t, perm=[0, 2, 1])

num_batches = boxes.shape[0]

if num_batches is None:
print(
f'{Color.RED}ERROR:{Color.RESET} '+
f'It is not possible to specify a dynamic shape '+
f'for the batch size of the input tensor in NonMaxSuppression. '+
f'Use the --batch_size option to change the batch size to a fixed size. \n'+
f'boxes.shape: {boxes.shape} scores.shape: {scores.shape}'
)
sys.exit(1)

for batch_i in tf.range(num_batches):
# get boxes in batch_i only
tf_boxes = tf.squeeze(tf.gather(boxes, [batch_i]), axis=0)
Expand Down
14 changes: 14 additions & 0 deletions onnx2tf/ops/Pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,20 @@ def make_node(
tensor_rank = len(input_tensor.shape)
paddings = tf_layers_dict[paddings.name]['tf_node'] \
if isinstance(paddings, gs.Variable) else paddings

# Transpose pads values
paddings = graph_node.inputs[1]
if hasattr(paddings, 'values'):
values = paddings.values
paddings = values.reshape([2, tensor_rank]).transpose()
paddings_rank = paddings.shape[0]
if paddings_rank > 2:
convertion_table = [0] + [i for i in range(2, paddings_rank)] + [1]
new_paddings = []
for idx in convertion_table:
new_paddings.append(paddings[idx, :])
paddings = np.asarray(new_paddings)

constant_value = tf_layers_dict[constant_value.name]['tf_node'] \
if isinstance(constant_value, gs.Variable) else constant_value

Expand Down

0 comments on commit fdb5efa

Please sign in to comment.