Skip to content

Commit

Permalink
add ner models
Browse files Browse the repository at this point in the history
  • Loading branch information
loujie0822 committed Apr 29, 2020
1 parent cf107e0 commit f9bd431
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 1 deletion.
56 changes: 56 additions & 0 deletions utils/file_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pickle
import sys

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')


def pickle_dump_large_file(obj, filepath):
max_bytes = 2 ** 31 - 1
Expand Down Expand Up @@ -44,3 +47,56 @@ def write_json(obj, path):
with open(path, 'wb') as f:
f.write(json.dumps(obj, indent=2, ensure_ascii=False).
encode('utf-8'))


def _read_conll(path, encoding='utf-8', indexes=2, dropna=True):
"""
Construct a generator to read conll items.
:param path: file path
:param encoding: file's encoding, default: utf-8
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item)
"""

def parse_conll(sample):
sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in range(indexes)]
for f in sample:
if len(f) <= 0:
raise ValueError('empty field')
return sample

with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f).strip()
if start != '':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
line = line[:-1]
if line == '':
if len(sample):
try:
res = parse_conll(sample)
sample = []
yield line_idx, res
except Exception as e:
if dropna:
logger.warning('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
# if len(sample) > 0:
# try:
# res = parse_conll(sample)
# yield line_idx, res
# except Exception as e:
# if dropna:
# return
# logger.error('invalid instance ends at line: {}'.format(line_idx))
# raise e
151 changes: 151 additions & 0 deletions utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from abc import abstractmethod
from collections import defaultdict


class MetricBase(object):

@abstractmethod
def evaluate(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def get_metric(self, reset=True):
raise NotImplemented

def __call__(self, p_ids, pred, eval_file):
return self.evaluate(p_ids, pred, eval_file)


def _bmeso_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 4))] (左闭右开区间)
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]


class SpanFPreRecMetric(MetricBase):
def __init__(self, tag_type, pred=None, target=None, encoding_type='bmeso',
only_gross=True, f_type='micro', beta=1):
self.tag_type = tag_type
self.only_gross = only_gross
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.encoding_type = encoding_type
if self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans

self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

def evaluate(self, p_ids, preds, eval_file):
answer_dict = {}
for p_id, pred in zip(p_ids.tolist(), preds.tolist()):
gold_ = eval_file[p_id].gold_answer
pred_ = [self.tag_type[tag] for tag in pred]
# gold_str_tags = [self.tag_type[tag] for tag in gold]
pred_spans = self.tag_to_span_func(pred_)
gold_spans = self.tag_to_span_func(gold_)
answer_dict[str(p_id)] = [pred_spans, gold_spans]
for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1
return answer_dict

def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag]
fn = self._false_negatives[tag]
fp = self._false_positives[tag]
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec

if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
f, pre, rec,em,pre_num,gold_num = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
evaluate_result['em'] = em
evaluate_result['pre'] = pre_num
evaluate_result['gold'] = gold_num


if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)
print(evaluate_result)
return evaluate_result

def _compute_f_pre_rec(self, tp, fn, fp):
"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)


return f, pre, rec,tp,fp + tp,fn + tp
Empty file added utils/ner_loader.py
Empty file.
2 changes: 1 addition & 1 deletion utils/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def set_optimizer(args, model, train_steps=None):
if args.use_bert:
if args.warm_up:
print('using BertAdam')
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
Expand Down

0 comments on commit f9bd431

Please sign in to comment.