From 74a7f106e2bbc811b7e3f124bbd854aa46ca18ad Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Sat, 9 Oct 2021 11:44:28 +0800 Subject: [PATCH] fix len=1 bug --- paddlenlp/layers/crf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/paddlenlp/layers/crf.py b/paddlenlp/layers/crf.py index e90e2ed32a874e..4e3dd3e4f3fa93 100644 --- a/paddlenlp/layers/crf.py +++ b/paddlenlp/layers/crf.py @@ -424,6 +424,8 @@ def forward(self, inputs, lengths): # last_ids: batch_size scores, last_ids = alpha.max(1), alpha.argmax(1) + if max_seq_len == 1: + return scores, last_ids.unsqueeze(1) # Trace back the best path # historys: seq_len, batch_size, n_labels historys = paddle.stack(historys) @@ -438,10 +440,14 @@ def forward(self, inputs, lengths): # hist: batch_size, n_labels left_length = left_length + 1 gather_idx = batch_offset + last_ids - tag_mask = paddle.cast((left_length >= 0), 'int64') + tag_mask = paddle.cast((left_length > 0), 'int64') last_ids_update = paddle.gather(hist.flatten(), gather_idx) * tag_mask + zero_len_mask = paddle.cast((left_length == 0), 'int64') + last_ids_update = last_ids_update * (1 - zero_len_mask + ) + last_ids * zero_len_mask batch_path.append(last_ids_update) + tag_mask = paddle.cast((left_length >= 0), 'int64') last_ids = last_ids_update + last_ids * (1 - tag_mask) batch_path = paddle.reverse(paddle.stack(batch_path, 1), [1]) return scores, batch_path