Skip to content

Commit

Permalink
update cc code to support longer separator
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 20, 2023
1 parent 578a6e4 commit 4ca9ab8
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 83 deletions.
2 changes: 1 addition & 1 deletion operators/text/string_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ OrtStatusPtr string_split(const ortc::Tensor<std::string>& input_X,
indices.push_back(col);
++col;
}
previous = current + 1;
previous = current + sep.size();
current = str.find_first_of(sep, previous);
}
current = str.size();
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# include requirements.txt so pip has context to avoid installing incompatible dependencies
-r requirements.txt
pytest
pytest-subtests
# multiple versions of onnxruntime are supported, but only one can be installed at a time
protobuf < 4.0.0
onnxruntime >=1.12.0
Expand Down
162 changes: 80 additions & 82 deletions test/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def _create_test_model_string_equal(prefix, domain='ai.onnx.contrib'):
'%sStringEqual' % prefix, ['id1', 'id2'], ['z'], domain=domain))

input0 = helper.make_tensor_value_info(
'x', onnx_proto.TensorProto.STRING, [])
'x', onnx_proto.TensorProto.STRING, None)
input1 = helper.make_tensor_value_info(
'y', onnx_proto.TensorProto.STRING, [])
'y', onnx_proto.TensorProto.STRING, None)
output0 = helper.make_tensor_value_info(
'z', onnx_proto.TensorProto.BOOL, [])
'z', onnx_proto.TensorProto.BOOL, None)

graph = helper.make_graph(nodes, 'test0', [input0, input1], [output0])
model = make_onnx_model(graph)
Expand Down Expand Up @@ -200,11 +200,11 @@ def _create_test_model_string_split(prefix, domain='ai.onnx.contrib'):
['indices', 'values', 'shape'], domain=domain))

input0 = helper.make_tensor_value_info(
'input', onnx_proto.TensorProto.STRING, [])
'input', onnx_proto.TensorProto.STRING, [None])
input1 = helper.make_tensor_value_info(
'delimiter', onnx_proto.TensorProto.STRING, [])
'delimiter', onnx_proto.TensorProto.STRING, [1])
input2 = helper.make_tensor_value_info(
'skip_empty', onnx_proto.TensorProto.BOOL, [])
'skip_empty', onnx_proto.TensorProto.BOOL, [1])
output0 = helper.make_tensor_value_info(
'indices', onnx_proto.TensorProto.INT64, [])
output1 = helper.make_tensor_value_info(
Expand All @@ -229,19 +229,19 @@ def _create_test_model_string_regex_split(prefix, domain='ai.onnx.contrib'):
['tokens', 'begins', 'ends', 'row_indices'], domain=domain))

input0 = helper.make_tensor_value_info(
'input', onnx_proto.TensorProto.STRING, [])
'input', onnx_proto.TensorProto.STRING, None)
input1 = helper.make_tensor_value_info(
'pattern', onnx_proto.TensorProto.STRING, [])
'pattern', onnx_proto.TensorProto.STRING, None)
input2 = helper.make_tensor_value_info(
'keep_pattern', onnx_proto.TensorProto.STRING, [])
'keep_pattern', onnx_proto.TensorProto.STRING, None)
output0 = helper.make_tensor_value_info(
'tokens', onnx_proto.TensorProto.STRING, [])
'tokens', onnx_proto.TensorProto.STRING, None)
output1 = helper.make_tensor_value_info(
'begins', onnx_proto.TensorProto.INT64, [])
'begins', onnx_proto.TensorProto.INT64, None)
output2 = helper.make_tensor_value_info(
'ends', onnx_proto.TensorProto.INT64, [])
'ends', onnx_proto.TensorProto.INT64, None)
output3 = helper.make_tensor_value_info(
'row_indices', onnx_proto.TensorProto.INT64, [])
'row_indices', onnx_proto.TensorProto.INT64, None)

graph = helper.make_graph(nodes, 'test0', [input0, input1, input2],
[output0, output1, output2, output3])
Expand Down Expand Up @@ -404,7 +404,7 @@ def string_equal(x, y):
PyCustomOpDef.dt_int64])
def string_split(input, delimiter, skip_empty):
if delimiter.shape != (1, ):
raise RuntimeError("demiliter must a single element tensor.")
raise RuntimeError("delimiter must a single element tensor.")
if skip_empty.shape != (1, ):
raise RuntimeError("skip_empty must a single element tensor.")
if len(input.shape) != 1:
Expand Down Expand Up @@ -914,32 +914,31 @@ def test_string_split_python(self):
self.assertIn('op_type: "PyStringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
for sep in [",", ":/", ",,"]:
with self.subTest(sep=sep):
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip, sep=sep):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())

def test_string_split_cc(self):
so = _ort.SessionOptions()
Expand All @@ -948,48 +947,47 @@ def test_string_split_cc(self):
self.assertIn('op_type: "StringSplit"', str(onnx_model))
sess = _ort.InferenceSession(onnx_model.SerializeToString(), so, providers=['CPUExecutionProvider'])
for sep in [",", ":/", ",,"]:
with self.subTest(sep=sep):
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

try:
from tensorflow.raw_ops import StringSplit
dotf = True
except ImportError:
dotf = False
if dotf:
tfres = StringSplit(
input=input, delimiter=",,", skip_empty=skip)
self.assertEqual(
[_.decode() for _ in tfres[1].numpy().tolist()],
txout[1].tolist())
self.assertEqual(
tfres[0].numpy().tolist(), txout[0].tolist())
self.assertEqual(
tfres[2].numpy().tolist(), txout[2].tolist())

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())
input = np.array([f"a{sep}{sep}b", "", f"aa{sep}b{sep}c", "dddddd"])
delimiter = np.array([sep])

for skip in [True, False]:
with self.subTest(skip=skip, sep=sep):
skip_empty = np.array([skip])

txout = sess.run(
None, {'input': input, 'delimiter': delimiter,
'skip_empty': skip_empty})

try:
from tensorflow.raw_ops import StringSplit
dotf = True
except ImportError:
dotf = False
if dotf:
tfres = StringSplit(
input=input, delimiter=",,", skip_empty=skip)
self.assertEqual(
[_.decode() for _ in tfres[1].numpy().tolist()],
txout[1].tolist())
self.assertEqual(
tfres[0].numpy().tolist(), txout[0].tolist())
self.assertEqual(
tfres[2].numpy().tolist(), txout[2].tolist())

if skip_empty:
exp_indices = np.array(
[[0, 0], [0, 1], [2, 0], [2, 1], [2, 2], [3, 0]])
exp_text = np.array(['a', 'b', 'aa', 'b', 'c', 'dddddd'])
else:
exp_indices = np.array(
[[0, 0], [0, 1], [0, 2], [2, 0], [2, 1],
[2, 2], [3, 0]])
exp_text = np.array(
['a', '', 'b', 'aa', 'b', 'c', 'dddddd'])
exp_shape = np.array([4, 3])
self.assertEqual(exp_indices.tolist(), txout[0].tolist())
self.assertEqual(exp_text.tolist(), txout[1].tolist())
self.assertEqual(exp_shape.tolist(), txout[2].tolist())

def test_string_split_cc_sep2(self):
so = _ort.SessionOptions()
Expand Down

0 comments on commit 4ca9ab8

Please sign in to comment.