From fc0bd57b1b3c608c7ae8ae2e4dc0698b7cd9e455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8D=AB=E8=8B=8F?= Date: Tue, 10 Sep 2024 18:05:24 +0800 Subject: [PATCH] fix bug of SeqAugmentOps --- easy_rec/python/layers/keras/custom_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/easy_rec/python/layers/keras/custom_ops.py b/easy_rec/python/layers/keras/custom_ops.py index 37367d1b8..f0b04b2ab 100644 --- a/easy_rec/python/layers/keras/custom_ops.py +++ b/easy_rec/python/layers/keras/custom_ops.py @@ -48,7 +48,10 @@ def __init__(self, params, name='sequence_aug', reuse=None, **kwargs): self.seq_augment = custom_ops.my_seq_augment def call(self, inputs, training=None, **kwargs): - assert isinstance(inputs, (list, tuple)) + assert isinstance( + inputs, + (list, tuple)), 'the inputs of SeqAugmentOps must be type of list/tuple' + assert len(inputs) >= 2, 'SeqAugmentOps must have at least 2 inputs' seq_input, seq_len = inputs[:2] embedding_dim = int(seq_input.shape[-1]) with tf.variable_scope(self.name, reuse=self.reuse):