diff --git a/README.md b/README.md index 45a75c6..ef1e0ba 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@ [![Documentation Status](https://readthedocs.org/projects/lm-lstm-crf/badge/?version=latest)](http://lm-lstm-crf.readthedocs.io/en/latest/?badge=latest) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -This project provides high-performance character-aware sequence labeling tools and tutorials. Model details can be accessed [here](http://arxiv.org/abs/1709.04109), and the implementation is based on the PyTorch library. +This project provides high-performance character-aware sequence labeling tools, including [Training](#usage), [Evaluation](#evaluation) and [Prediction](#prediction). -LM-LSTM-CRF achieves F1 score of 91.71+/-0.10 on the CoNLL 2003 NER dataset, without using any additional corpus or resource. +Details about LM-LSTM-CRF can be accessed [here](http://arxiv.org/abs/1709.04109), and the implementation is based on the PyTorch library. Our model achieves F1 score of 91.71+/-0.10 on the CoNLL 2003 NER dataset, without using any additional corpus or resource. The documents would be available [here](http://lm-lstm-crf.readthedocs.io/en/latest/). @@ -202,6 +202,14 @@ to newcomers Uzbekistan . +``` +and the corresponding output is: + +``` +-DOCSTART- -DOCSTART- -DOCSTART- + +But China saw their luck desert them in the second match of the group , crashing to a surprise 2-0 defeat to newcomers Uzbekistan . + ``` ## Reference diff --git a/model/predictor.py b/model/predictor.py index 6deaae2..c89bc9b 100644 --- a/model/predictor.py +++ b/model/predictor.py @@ -114,7 +114,7 @@ def decode_s(self, feature, label): return chunks - def output_batch(self, ner_model, features, fout): + def output_batch(self, ner_model, documents, fout): """ decode the whole corpus in the specific format by calling apply_model to fit specific models @@ -123,18 +123,22 @@ def output_batch(self, ner_model, features, fout): feature (list): list of words list fout: output file """ - f_len = len(features) + d_len = len(documents) - for ind in tqdm( range(0, f_len, self.batch_size), mininterval=1, + for d_ind in tqdm( range(0, d_len), mininterval=1, desc=' - Process', leave=False, file=sys.stdout): - eind = min(f_len, ind + self.batch_size) - labels = self.apply_model(ner_model, features[ind: eind]) - labels = torch.unbind(labels, 1) - - for ind2 in range(ind, eind): - f = features[ind2] - l = labels[ind2 - ind][0: len(f) ] - fout.write(self.decode_str(features[ind2], l) + '\n\n') + fout.write('-DOCSTART- -DOCSTART- -DOCSTART-\n\n') + features = documents[d_ind] + f_len = len(features) + for ind in range(0, f_len, self.batch_size): + eind = min(f_len, ind + self.batch_size) + labels = self.apply_model(ner_model, features[ind: eind]) + labels = torch.unbind(labels, 1) + + for ind2 in range(ind, eind): + f = features[ind2] + l = labels[ind2 - ind][0: len(f) ] + fout.write(self.decode_str(features[ind2], l) + '\n\n') def apply_model(self, ner_model, features): """ diff --git a/model/utils.py b/model/utils.py index 89b2b79..bcf6563 100644 --- a/model/utils.py +++ b/model/utils.py @@ -239,23 +239,45 @@ def read_corpus(lines): return features, labels -def read_features(lines): +def read_features(lines, multi_docs = True): """ convert un-annotated corpus into features """ - features = list() - tmp_fl = list() - for line in lines: - if not (line.isspace() or (len(line) > 10 and line[0:10] == '-DOCSTART-')): - line = line.rstrip() - tmp_fl.append(line) - elif len(tmp_fl) > 0: + if multi_docs: + documents = list() + features = list() + tmp_fl = list() + for line in lines: + if_doc_end = (len(line) > 10 and line[0:10] == '-DOCSTART-') + if not (line.isspace() or if_doc_end): + line = line.rstrip() + tmp_fl.append(line) + else: + if len(tmp_fl) > 0: + features.append(tmp_fl) + tmp_fl = list() + if if_doc_end and len(features) > 0: + documents.append(features) + features = list() + if len(tmp_fl) > 0: features.append(tmp_fl) - tmp_fl = list() - if len(tmp_fl) > 0: - features.append(tmp_fl) - - return features + if len(features) >0: + documents.append(features) + return documents + else: + features = list() + tmp_fl = list() + for line in lines: + if not (line.isspace() or (len(line) > 10 and line[0:10] == '-DOCSTART-')): + line = line.rstrip() + tmp_fl.append(line) + elif len(tmp_fl) > 0: + features.append(tmp_fl) + tmp_fl = list() + if len(tmp_fl) > 0: + features.append(tmp_fl) + + return features def shrink_embedding(feature_map, word_dict, word_embedding, caseless): """