-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7ecc917
Showing
11 changed files
with
1,864 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Medical-NER | ||
|
||
## 简介 | ||
|
||
一个使用`Pytorch` 构建的基于 `BERT+BiLSTM+CRF` 的中文医疗信息命名实体识别程序。 | ||
|
||
## 项目结构 | ||
|
||
`data`:存放训练数据<br> | ||
`config.py`: 模型参数,训练超参数,文件路径等配置信息<br> | ||
`dataset.py`: 定义数据集以及与数据处理相关的函数<br> | ||
`main.py`:主函数<br> | ||
`model.py`:模型文件(BERT+BiLSTM+CRF)<br> | ||
`preprocess.py`:处理原始数据,使用BIO标签 <br> | ||
`utils.py`:一些工具函数(模型训练,验证,测试,推理等)<br> | ||
|
||
|
||
## 如何使用 | ||
|
||
1. **安装依赖库** | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
2. **处理原始数据集** | ||
|
||
``` | ||
python preprocess.py | ||
``` | ||
|
||
3. **训练模型(含测试结果)** | ||
|
||
``` | ||
python main.py --mode='train' | ||
``` | ||
|
||
4. **模型推理** | ||
|
||
``` | ||
python main.py --mode='infer' --ckpt_name="best" --txt="xxxxxxxxxx(中文输入)" | ||
``` | ||
|
||
|
||
## 参考文献 | ||
|
||
1. [Pytorch-crf doc](https://pytorch-crf.readthedocs.io/en/stable/) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
# device | ||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
# dataset | ||
TRAINSET_RATIO = 0.2 | ||
# data path | ||
RAW_DATA_PATH = 'data/ccks2019/' | ||
PROCESSED_DATA_PATH = 'data/processed_data/' | ||
SAVED_MODEL_PATH = 'saved_model/' | ||
# model parameter | ||
BASE_MODEL = 'bert-base-chinese' | ||
EMBEDDING_DIM = 768 | ||
HIDDEN_DIM = 256 | ||
# train parameter | ||
BATCH_SIZE = 32 | ||
LR = 0.001 | ||
EPOCHS = 50 | ||
|
||
# tag&label | ||
label_dict = {'药物': 'DRUG', | ||
'解剖部位': 'BODY', | ||
'疾病和诊断': 'DISEASES', | ||
'影像检查': 'EXAMINATIONS', | ||
'实验室检验': 'TEST', | ||
'手术': 'TREATMENT'} | ||
label_dict2 = {'DRUG': '药物', | ||
'BODY': '解剖部位', | ||
'DISEASES': '疾病和诊断', | ||
'EXAMINATIONS': '影像检查', | ||
'TEST': '实验室检验', | ||
'TREATMENT': '手术'} | ||
model_tag = ('<PAD>', '[CLS]', '[SEP]', 'O', 'B-BODY', 'I-TEST', 'I-EXAMINATIONS', | ||
'I-TREATMENT', 'B-DRUG', 'B-TREATMENT', 'I-DISEASES', 'B-EXAMINATIONS', | ||
'I-BODY', 'B-TEST', 'B-DISEASES', 'I-DRUG') | ||
tag2idx = {tag: idx for idx, tag in enumerate(model_tag)} | ||
idx2tag = {idx: tag for idx, tag in enumerate(model_tag)} | ||
LABELS = ['B-BODY', 'B-DISEASES', 'B-DRUG', 'B-EXAMINATIONS', 'B-TEST', 'B-TREATMENT', | ||
'I-BODY', 'I-DISEASES', 'I-DRUG', 'I-EXAMINATIONS', 'I-TEST', 'I-TREATMENT'] |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from transformers import BertTokenizer | ||
|
||
from config import BASE_MODEL, tag2idx | ||
|
||
|
||
class NerDataset(Dataset): | ||
# 以句号为分割符,依次从预处理的文本中读取句子 | ||
def __init__(self, file): | ||
self.sentences = [] | ||
self.labels = [] | ||
self.tokenizer = BertTokenizer.from_pretrained(BASE_MODEL) | ||
self.MAX_LEN = 256 - 2 | ||
|
||
with open(file, 'r', encoding='utf-8') as f: | ||
lines = [line.split('\n')[0] for line in f.readlines() if len(line.strip()) != 0] | ||
word_from_file = [line.split('\t')[0] for line in lines] | ||
tag_from_file = [line.split('\t')[1] for line in lines] | ||
|
||
word, tag = [], [] | ||
for char, t in zip(word_from_file, tag_from_file): | ||
if char != '。' and len(word) <= self.MAX_LEN: | ||
word.append(char) | ||
tag.append(t) | ||
else: | ||
if len(word) > self.MAX_LEN: | ||
self.sentences.append(['[CLS]'] + word[:self.MAX_LEN] + ['[SEP]']) | ||
self.labels.append(['[CLS]'] + tag[:self.MAX_LEN] + ['[SEP]']) | ||
else: | ||
self.sentences.append(['[CLS]'] + word + ['[SEP]']) | ||
self.labels.append(['[CLS]'] + tag + ['[SEP]']) | ||
word, tag = [], [] | ||
|
||
def __getitem__(self, idx): | ||
sentence, label = self.sentences[idx], self.labels[idx] | ||
sentence_ids = self.tokenizer.convert_tokens_to_ids(sentence) | ||
label_ids = [tag2idx[l] for l in label] | ||
seqlen = len(label_ids) | ||
return sentence_ids, label_ids, seqlen | ||
|
||
def __len__(self): | ||
return len(self.sentences) | ||
|
||
|
||
def PadBatch(batch): | ||
maxlen = max([i[2] for i in batch]) | ||
token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch]) | ||
# 可以参考config.py <PAD> 对应的是 0 | ||
label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch]) | ||
mask = (token_tensors > 0) | ||
return token_tensors, label_tensors, mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import argparse | ||
import os | ||
|
||
from torch.utils.data import DataLoader | ||
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup | ||
|
||
from dataset import NerDataset, PadBatch | ||
from model import Bert_BiLSTM_CRF | ||
from utils import * | ||
|
||
if __name__ == "__main__": | ||
best_model = None | ||
_best_val_loss = float("inf") | ||
_best_val_acc = -float("inf") | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--mode', default='train', type=str, required=False, help="The running mode: train or infer?") | ||
parser.add_argument('--ckpt_name', type=str, required=False, | ||
help="The name of the trained checkpoint. (without extension)") | ||
parser.add_argument('--txt', type=str, required=False) | ||
args = parser.parse_args() | ||
|
||
if args.mode == 'train': | ||
train_dataset = NerDataset(PROCESSED_DATA_PATH + 'train_data.txt') | ||
val_dataset = NerDataset(PROCESSED_DATA_PATH + 'val_data.txt') | ||
test_dataset = NerDataset(PROCESSED_DATA_PATH + 'test_data.txt') | ||
print('Load Data Done.') | ||
|
||
train_iter = DataLoader(dataset=train_dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=True, | ||
collate_fn=PadBatch, | ||
pin_memory=True | ||
) | ||
|
||
eval_iter = DataLoader(dataset=val_dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=False, | ||
collate_fn=PadBatch, | ||
pin_memory=True) | ||
|
||
test_iter = DataLoader(dataset=test_dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=False, | ||
collate_fn=PadBatch, | ||
pin_memory=True) | ||
|
||
model = Bert_BiLSTM_CRF(tag2idx).to(DEVICE) | ||
optimizer = AdamW(model.parameters(), lr=LR, eps=1e-6) | ||
# Warmup | ||
len_dataset = len(train_dataset) | ||
total_steps = (len_dataset // BATCH_SIZE) * EPOCHS if len_dataset % BATCH_SIZE == 0 else ( | ||
len_dataset // BATCH_SIZE + 1) * EPOCHS | ||
warm_up_ratio = 0.1 | ||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_up_ratio * total_steps, | ||
num_training_steps=total_steps) | ||
|
||
print('Train Start ...') | ||
for epoch in range(1, EPOCHS + 1): | ||
train(epoch, model, train_iter, optimizer, scheduler, DEVICE) | ||
if epoch % 5 == 0: | ||
print('valid-->', end='') | ||
candidate_model, loss, acc = validate(epoch, model, eval_iter, DEVICE) | ||
if loss < _best_val_loss and acc > _best_val_acc: | ||
best_model = candidate_model | ||
_best_val_loss = loss | ||
_best_val_acc = acc | ||
y_test, y_pred = test(best_model, test_iter, DEVICE) | ||
if not os.path.exists(SAVED_MODEL_PATH): | ||
os.makedirs(SAVED_MODEL_PATH) | ||
torch.save({'model': best_model.state_dict()}, SAVED_MODEL_PATH + 'best.ckpt') | ||
print('Train End ... Model saved') | ||
|
||
elif args.mode == 'infer': | ||
print('Start infer') | ||
model = Bert_BiLSTM_CRF(tag2idx).to(DEVICE) | ||
tokenizer = BertTokenizer.from_pretrained(BASE_MODEL) | ||
if args.ckpt_name is not None: | ||
if os.path.exists(f"{SAVED_MODEL_PATH}{args.ckpt_name}.ckpt"): | ||
print("Loading the pre-trained checkpoint...") | ||
ckpt = torch.load(f"{SAVED_MODEL_PATH}/{args.ckpt_name}.ckpt", map_location=DEVICE) | ||
model.load_state_dict(ckpt['model']) | ||
sentence = ['[CLS]'] + list(args.txt) + ['[SEP]'] | ||
infer(model, tokenizer, sentence) | ||
else: | ||
print("No such file!") | ||
exit() | ||
else: | ||
print("mode type error!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torchcrf import CRF | ||
from transformers import BertModel | ||
|
||
from config import EMBEDDING_DIM, HIDDEN_DIM | ||
|
||
|
||
class Bert_BiLSTM_CRF(nn.Module): | ||
|
||
def __init__(self, tag2idx, embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM): | ||
super(Bert_BiLSTM_CRF, self).__init__() | ||
self.tag_to_ix = tag2idx | ||
self.tagset_size = len(tag2idx) | ||
self.hidden_dim = hidden_dim | ||
self.embedding_dim = embedding_dim | ||
|
||
self.bert = BertModel.from_pretrained('bert-base-chinese') | ||
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim // 2, | ||
num_layers=2, bidirectional=True, batch_first=True) | ||
self.dropout = nn.Dropout(p=0.1) | ||
self.linear = nn.Linear(hidden_dim, self.tagset_size) | ||
self.crf = CRF(self.tagset_size, batch_first=True) | ||
|
||
def getfeature(self, sentence): | ||
with torch.no_grad(): | ||
# BERT默认返回两个 last_hidden_state, pooler_output | ||
# last_hidden_state:输出序列每个位置的语义向量,形状为:(batch_size, sequence_length, hidden_size) | ||
# pooler_output:[CLS]符号对应的语义向量,经过了全连接层和tanh激活;该向量可用于下游分类任务 | ||
embeds, _ = self.bert(sentence, return_dict=False) | ||
# LSTM默认返回两个 output, (h,c) | ||
# output:[batch_size,seq_len,hidden_dim * 2] if birectional | ||
# h,c :[num_layers * 2,batch_size,hidden_dim] if birectional | ||
# h 为LSTM最后一个时间步的隐层结果,c 为LSTM最后一个时间步的Cell状态 | ||
out, _ = self.lstm(embeds) | ||
out = self.dropout(out) | ||
feats = self.linear(out) | ||
return feats | ||
|
||
def forward(self, sentence, tags, mask, is_test=False): | ||
feature = self.getfeature(sentence) | ||
# training | ||
if not is_test: | ||
# return log-likelihood | ||
# make this value negative as our loss | ||
loss = -self.crf.forward(feature, tags, mask, reduction='mean') | ||
return loss | ||
# testing | ||
else: | ||
decode = self.crf.decode(feature, mask) | ||
return decode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import json | ||
import os | ||
|
||
from config import RAW_DATA_PATH, PROCESSED_DATA_PATH, TRAINSET_RATIO, label_dict | ||
|
||
|
||
# 将txt/json数据用dataframe读出 | ||
def read_data(file, enc): | ||
data_list = [] | ||
with open(file, 'r', encoding=enc) as f: | ||
while True: | ||
data = f.readline() | ||
if not data: | ||
break | ||
data = json.loads(data) | ||
sentence = data['originalText'] | ||
entities = data['entities'] | ||
data_list.append([sentence, entities]) | ||
return data_list | ||
|
||
|
||
def BIOtag(sentence, entities): | ||
label = ['O'] * len(sentence) | ||
for entity in entities: | ||
start_idx = entity['start_pos'] | ||
end_idx = entity['end_pos'] | ||
type_cn = entity['label_type'] | ||
type = label_dict[type_cn] | ||
# 为实体设置BIO格式标签 | ||
label[start_idx] = 'B-' + type | ||
for i in range(start_idx + 1, end_idx): | ||
label[i] = 'I-' + type | ||
return label | ||
|
||
|
||
def process(raw_data): | ||
processed_data = [] | ||
for data in raw_data: | ||
sentence = data[0] | ||
entities = data[1] | ||
label = BIOtag(sentence, entities) | ||
assert len(sentence) == len(label) | ||
processed_data.append([list(sentence), label]) | ||
return processed_data | ||
|
||
|
||
def savefile(file, datas): | ||
with open(file, 'w') as f: | ||
for data in datas: | ||
size = len(data[0]) | ||
for i in range(size): | ||
f.write(data[0][i]) | ||
f.write('\t') | ||
f.write(data[1][i]) | ||
f.write('\n') | ||
|
||
|
||
if __name__ == '__main__': | ||
# raw_data (from txt/json) | ||
raw_data_train_part1 = read_data(RAW_DATA_PATH + 'subtask1_training_part1.txt', 'utf-8-sig') | ||
raw_data_train_part2 = read_data(RAW_DATA_PATH + 'subtask1_training_part2.txt', 'utf-8-sig') | ||
raw_data_train = raw_data_train_part1 + raw_data_train_part2 | ||
raw_data_test = read_data(RAW_DATA_PATH + 'subtask1_test_set_with_answer.json', 'utf-8') | ||
# processed_data (convert to BIO tag) | ||
data_train = process(raw_data_train) | ||
data_test = process(raw_data_test) | ||
# split | ||
num = len(data_train) | ||
train_data = data_train[:int(num * TRAINSET_RATIO)] | ||
val_data = data_train[int(num * TRAINSET_RATIO):] | ||
test_data = data_test | ||
# save | ||
if not os.path.exists(PROCESSED_DATA_PATH): | ||
os.makedirs(PROCESSED_DATA_PATH) | ||
savefile(PROCESSED_DATA_PATH + "train_data.txt", train_data) | ||
savefile(PROCESSED_DATA_PATH + "val_data.txt", val_data) | ||
savefile(PROCESSED_DATA_PATH + "test_data.txt", test_data) | ||
|
||
print("preprocess done!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
numpy==1.19.5 | ||
torch==1.11.0 | ||
transformers==4.12.5 | ||
pytorch-crf==0.7.2 | ||
scikit-learn == 0.23.2 |
Oops, something went wrong.