From 8e8680e2fe5887ab98690b4ca112a4d515d28337 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Mon, 12 May 2014 23:21:40 +1000 Subject: [PATCH 1/5] Disconnected implementation of CEAF metric --- conll03_nel_eval/coref_metrics.py | 72 ++++++++ conll03_nel_eval/munkres.py | 280 ++++++++++++++++++++++++++++++ 2 files changed, 352 insertions(+) create mode 100644 conll03_nel_eval/coref_metrics.py create mode 100644 conll03_nel_eval/munkres.py diff --git a/conll03_nel_eval/coref_metrics.py b/conll03_nel_eval/coref_metrics.py new file mode 100644 index 0000000..47611f8 --- /dev/null +++ b/conll03_nel_eval/coref_metrics.py @@ -0,0 +1,72 @@ +from __future__ import division + +from functools import partial + +import numpy as np + +from .munkres import linear_assignment + +# Clusters are most readily evaluated when represented as a +# mapping from cluster ID to set of mention IDs. Could equally +# use a set of frozensets, but lose debugging info. + + +def cluster_sim_f1(a, b): + """ + + "Entity-based" measure in CoNLL; #4 in CEAF paper + """ + if a and b: + return len(a & b) / (len(a) + len(b)) + return 0. + + +def cluster_sim_overlap(a, b): + """Intersection of sets + + "Mention-based" measure in CoNLL; #3 in CEAF paper + """ + return len(a & b) + + +def ceaf(true, pred, similarity=cluster_sim_f1): + """ + + >>> true = {'A': {1,2,3,4,5}, 'B': {6,7}, 'C': {8, 9, 10, 11, 12}} + >>> pred_a = {'A': {1,2,3,4,5}, 'B': {6,7, 8, 9, 10, 11, 12}} + >>> pred_b = {'A': {1,2,3,4,5,8, 9, 10, 11, 12}, 'B': {6,7}} + >>> pred_c = {'A': {1,2,3,4,5, 6,7, 8, 9, 10, 11, 12}} + >>> pred_d = {i: {i,} for i in range(1, 13)} + >>> mention_ceaf(true, pred_a)[-1] # doctest: +ELLIPSIS + 0.83... + >>> entity_ceaf(true, pred_a)[-1] # doctest: +ELLIPSIS + 0.73... + >>> mention_ceaf(true, pred_b)[-1] # doctest: +ELLIPSIS + 0.58... + >>> entity_ceaf(true, pred_b)[-1] # doctest: +ELLIPSIS + 0.66... + >>> mention_ceaf(true, pred_c) # doctest: +ELLIPSIS + (0.416..., 0.416..., 0.416...) + >>> entity_ceaf(true, pred_c) # doctest: +ELLIPSIS + (0.588..., 0.196..., 0.294...) + >>> mention_ceaf(true, pred_d) # doctest: +ELLIPSIS + (0.25, 0.25, 0.25) + >>> entity_ceaf(true, pred_d) # doctest: +ELLIPSIS + (0.111..., 0.444..., 0.177...) + """ + X = np.empty((len(true), len(pred))) + pred = list(pred.values()) + for R, Xrow in zip(true.values(), X): + Xrow[:] = [similarity(R, S) for S in pred] + indices = linear_assignment(-X) + + numerator = sum(X[indices[:, 0], indices[:, 1]]) + true_denom = sum(similarity(R, R) for R in true.values()) + pred_denom = sum(similarity(S, S) for S in pred) + p = numerator / pred_denom + r = numerator / true_denom + return p, r, 2 * p * r / (p + r) + + +entity_ceaf = partial(ceaf, similarity=cluster_sim_f1) +mention_ceaf = partial(ceaf, similarity=cluster_sim_overlap) diff --git a/conll03_nel_eval/munkres.py b/conll03_nel_eval/munkres.py new file mode 100644 index 0000000..20c2aeb --- /dev/null +++ b/conll03_nel_eval/munkres.py @@ -0,0 +1,280 @@ +""" +Solve the unique lowest-cost assignment problem using the +Hungarian algorithm (also known as Munkres algorithm). + +""" +# From scikit-learn:sklearn/utils/linear_assignment_.py +# Based on original code by Brain Clapper, adapted to NumPy by Gael Varoquaux. +# Heavily refactored by Lars Buitinck. + +# Copyright (c) 2008 Brian M. Clapper , Gael Varoquaux +# Author: Brian M. Clapper, Gael Varoquaux +# LICENSE: BSD + +import numpy as np + + +def linear_assignment(X): + """Solve the linear assignment problem using the Hungarian algorithm. + + The problem is also known as maximum weight matching in bipartite graphs. + The method is also known as the Munkres or Kuhn-Munkres algorithm. + + Parameters + ---------- + X : array + The cost matrix of the bipartite graph + + Returns + ------- + indices : array, + The pairs of (row, col) indices in the original array giving + the original ordering. + + References + ---------- + + 1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html + + 2. Harold W. Kuhn. The Hungarian Method for the assignment problem. + *Naval Research Logistics Quarterly*, 2:83-97, 1955. + + 3. Harold W. Kuhn. Variants of the Hungarian method for assignment + problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956. + + 4. Munkres, J. Algorithms for the Assignment and Transportation Problems. + *Journal of the Society of Industrial and Applied Mathematics*, + 5(1):32-38, March, 1957. + + 5. http://en.wikipedia.org/wiki/Hungarian_algorithm + """ + indices = _hungarian(X).tolist() + indices.sort() + # Re-force dtype to ints in case of empty list + indices = np.array(indices, dtype=int) + # Make sure the array is 2D with 2 columns. + # This is needed when dealing with an empty list + indices.shape = (-1, 2) + return indices + + +class _HungarianState(object): + """State of one execution of the Hungarian algorithm. + + Parameters + ---------- + cost_matrix : 2D matrix + The cost matrix. Does not have to be square. + """ + + def __init__(self, cost_matrix): + cost_matrix = np.atleast_2d(cost_matrix) + + # If there are more rows (n) than columns (m), then the algorithm + # will not be able to work correctly. Therefore, we + # transpose the cost function when needed. Just have to + # remember to swap the result columns back later. + transposed = (cost_matrix.shape[1] < cost_matrix.shape[0]) + if transposed: + self.C = (cost_matrix.T).copy() + else: + self.C = cost_matrix.copy() + self.transposed = transposed + + # At this point, m >= n. + n, m = self.C.shape + self.row_uncovered = np.ones(n, dtype=np.bool) + self.col_uncovered = np.ones(m, dtype=np.bool) + self.Z0_r = 0 + self.Z0_c = 0 + self.path = np.zeros((n + m, 2), dtype=int) + self.marked = np.zeros((n, m), dtype=int) + + def _find_prime_in_row(self, row): + """ + Find the first prime element in the specified row. Returns + the column index, or -1 if no starred element was found. + """ + col = np.argmax(self.marked[row] == 2) + if self.marked[row, col] != 2: + col = -1 + return col + + def _clear_covers(self): + """Clear all covered matrix cells""" + self.row_uncovered[:] = True + self.col_uncovered[:] = True + + +def _hungarian(cost_matrix): + """The Hungarian algorithm. + + Calculate the Munkres solution to the classical assignment problem and + return the indices for the lowest-cost pairings. + + Parameters + ---------- + cost_matrix : 2D matrix + The cost matrix. Does not have to be square. + + Returns + ------- + indices : 2D array of indices + The pairs of (row, col) indices in the original array giving + the original ordering. + """ + state = _HungarianState(cost_matrix) + + # No need to bother with assignments if one of the dimensions + # of the cost matrix is zero-length. + step = None if 0 in cost_matrix.shape else _step1 + + while step is not None: + step = step(state) + + # Look for the starred columns + results = np.array(np.where(state.marked == 1)).T + + # We need to swap the columns because we originally + # did a transpose on the input cost matrix. + if state.transposed: + results = results[:, ::-1] + + return results + + +# Individual steps of the algorithm follow, as a state machine: they return +# the next step to be taken (function to be called), if any. + +def _step1(state): + """Steps 1 and 2 in the Wikipedia page.""" + + # Step1: For each row of the matrix, find the smallest element and + # subtract it from every element in its row. + state.C -= state.C.min(axis=1)[:, np.newaxis] + # Step2: Find a zero (Z) in the resulting matrix. If there is no + # starred zero in its row or column, star Z. Repeat for each element + # in the matrix. + for i, j in zip(*np.where(state.C == 0)): + if state.col_uncovered[j] and state.row_uncovered[i]: + state.marked[i, j] = 1 + state.col_uncovered[j] = False + state.row_uncovered[i] = False + + state._clear_covers() + return _step3 + + +def _step3(state): + """ + Cover each column containing a starred zero. If n columns are covered, + the starred zeros describe a complete set of unique assignments. + In this case, Go to DONE, otherwise, Go to Step 4. + """ + marked = (state.marked == 1) + state.col_uncovered[np.any(marked, axis=0)] = False + + if marked.sum() < state.C.shape[0]: + return _step4 + + +def _step4(state): + """ + Find a noncovered zero and prime it. If there is no starred zero + in the row containing this primed zero, Go to Step 5. Otherwise, + cover this row and uncover the column containing the starred + zero. Continue in this manner until there are no uncovered zeros + left. Save the smallest uncovered value and Go to Step 6. + """ + # We convert to int as numpy operations are faster on int + C = (state.C == 0).astype(np.int) + covered_C = C * state.row_uncovered[:, np.newaxis] + covered_C *= state.col_uncovered.astype(np.int) + n = state.C.shape[0] + m = state.C.shape[1] + while True: + # Find an uncovered zero + row, col = np.unravel_index(np.argmax(covered_C), (n, m)) + if covered_C[row, col] == 0: + return _step6 + else: + state.marked[row, col] = 2 + # Find the first starred element in the row + star_col = np.argmax(state.marked[row] == 1) + if not state.marked[row, star_col] == 1: + # Could not find one + state.Z0_r = row + state.Z0_c = col + return _step5 + else: + col = star_col + state.row_uncovered[row] = False + state.col_uncovered[col] = True + covered_C[:, col] = C[:, col] * ( + state.row_uncovered.astype(np.int)) + covered_C[row] = 0 + + +def _step5(state): + """ + Construct a series of alternating primed and starred zeros as follows. + Let Z0 represent the uncovered primed zero found in Step 4. + Let Z1 denote the starred zero in the column of Z0 (if any). + Let Z2 denote the primed zero in the row of Z1 (there will always be one). + Continue until the series terminates at a primed zero that has no starred + zero in its column. Unstar each starred zero of the series, star each + primed zero of the series, erase all primes and uncover every line in the + matrix. Return to Step 3 + """ + count = 0 + path = state.path + path[count, 0] = state.Z0_r + path[count, 1] = state.Z0_c + + while True: + # Find the first starred element in the col defined by + # the path. + row = np.argmax(state.marked[:, path[count, 1]] == 1) + if not state.marked[row, path[count, 1]] == 1: + # Could not find one + break + else: + count += 1 + path[count, 0] = row + path[count, 1] = path[count - 1, 1] + + # Find the first prime element in the row defined by the + # first path step + col = np.argmax(state.marked[path[count, 0]] == 2) + if state.marked[row, col] != 2: + col = -1 + count += 1 + path[count, 0] = path[count - 1, 0] + path[count, 1] = col + + # Convert paths + for i in range(count + 1): + if state.marked[path[i, 0], path[i, 1]] == 1: + state.marked[path[i, 0], path[i, 1]] = 0 + else: + state.marked[path[i, 0], path[i, 1]] = 1 + + state._clear_covers() + # Erase all prime markings + state.marked[state.marked == 2] = 0 + return _step3 + + +def _step6(state): + """ + Add the value found in Step 4 to every element of each covered row, + and subtract it from every element of each uncovered column. + Return to Step 4 without altering any stars, primes, or covered lines. + """ + # the smallest uncovered value in the matrix + if np.any(state.row_uncovered) and np.any(state.col_uncovered): + minval = np.min(state.C[state.row_uncovered], axis=0) + minval = np.min(minval[state.col_uncovered]) + state.C[np.logical_not(state.row_uncovered)] += minval + state.C[:, state.col_uncovered] -= minval + return _step4 From b9fb8aa984fc088333491330e6043e9ca8ed45ce Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Tue, 13 May 2014 13:56:59 +1000 Subject: [PATCH 2/5] Other metrics --- conll03_nel_eval/coref_metrics.py | 131 ++++++++++++++++++++++++++++-- 1 file changed, 122 insertions(+), 9 deletions(-) diff --git a/conll03_nel_eval/coref_metrics.py b/conll03_nel_eval/coref_metrics.py index 47611f8..c74a13b 100644 --- a/conll03_nel_eval/coref_metrics.py +++ b/conll03_nel_eval/coref_metrics.py @@ -1,17 +1,44 @@ from __future__ import division from functools import partial +from collections import defaultdict +import itertools import numpy as np from .munkres import linear_assignment -# Clusters are most readily evaluated when represented as a -# mapping from cluster ID to set of mention IDs. Could equally -# use a set of frozensets, but lose debugging info. +# TODO: Blanc and standard clustering metrics (e.g. http://scikit-learn.org/stable/modules/clustering.html) +# TODO: cite originating papers +# XXX: perhaps use set (or list) of sets rather than dict of sets -def cluster_sim_f1(a, b): + +def mapping_to_sets(mapping): + """ + >>> sets = mapping_to_sets({'a': 1, 'b': 2, 'c': 1}).items() + >>> sorted((k, sorted(v)) for k, v in sets) + [(1, ['a', 'c']), (2, ['b'])] + """ + s = defaultdict(set) + for m, k in mapping.items(): + s[k].add(m) + return dict(s) + + +def sets_to_mapping(s): + """ + >>> sorted(sets_to_mapping({1: {'a', 'c'}, 2: {'b'}}).items()) + [('a', 1), ('b', 2), ('c', 1)] + """ + return {m: k for k, ms in s.items() for m in ms} + + +def _f1(a, b): + return 2 * a * b / (a + b) + + +def dice(a, b): """ "Entity-based" measure in CoNLL; #4 in CEAF paper @@ -21,7 +48,7 @@ def cluster_sim_f1(a, b): return 0. -def cluster_sim_overlap(a, b): +def overlap(a, b): """Intersection of sets "Mention-based" measure in CoNLL; #3 in CEAF paper @@ -29,7 +56,7 @@ def cluster_sim_overlap(a, b): return len(a & b) -def ceaf(true, pred, similarity=cluster_sim_f1): +def ceaf(true, pred, similarity=dice): """ >>> true = {'A': {1,2,3,4,5}, 'B': {6,7}, 'C': {8, 9, 10, 11, 12}} @@ -65,8 +92,94 @@ def ceaf(true, pred, similarity=cluster_sim_f1): pred_denom = sum(similarity(S, S) for S in pred) p = numerator / pred_denom r = numerator / true_denom - return p, r, 2 * p * r / (p + r) + return p, r, _f1(p, r) + +entity_ceaf = partial(ceaf, similarity=dice) +mention_ceaf = partial(ceaf, similarity=overlap) -entity_ceaf = partial(ceaf, similarity=cluster_sim_f1) -mention_ceaf = partial(ceaf, similarity=cluster_sim_overlap) + +def _b_cubed(A, B, A_mapping, B_mapping, EMPTY=frozenset([])): + res = 0. + for m, k in A_mapping.items(): + A_cluster = A.get(k, EMPTY) + res += len(A_cluster & B.get(B_mapping.get(m), EMPTY)) / len(A_cluster) + res /= len(A_mapping) + return res + + +def b_cubed(true, pred): + """ + + TODO: tests + """ + true_mapping = sets_to_mapping(true) + pred_mapping = sets_to_mapping(pred) + p = _b_cubed(pred, true, pred_mapping, true_mapping) + r = _b_cubed(true, pred, true_mapping, pred_mapping) + return p, r, _f1(p, r) + + +def pairwise_f1(true, pred): + """Measure the proportion of correctly identified pairwise coindexations + + This is called MUC score, and erroneously cited to Vilain et al. (2102) in + the CoNLL 2011-2012 Shared Task descriptions. + + TODO: tests + """ + pred_mapping = sets_to_mapping(pred) + correct = 0 + for cluster in true.values(): + for m1, m2 in itertools.combinations(cluster, 2): + if pred_mapping.get(m1) == pred_mapping.get(m2): + correct += 1 + p = correct / sum(len(cluster) - 1 for cluster in pred.values()) + r = correct / sum(len(cluster) - 1 for cluster in true.values()) + return p, r, _f1(p, r) + + +def _vilain(A, B_mapping): + numerator = 0 + denominator = 0 + for cluster in A.values(): + corresponding = set() + n_unaligned = 0 + for m in cluster: + if m not in B_mapping: + n_unaligned += 1 + else: + corresponding.add(B_mapping[m]) + numerator += len(cluster) - n_unaligned - len(corresponding) + denominator += len(cluster) - 1 + return numerator / denominator + + +def vilain(true, pred): + """The MUC evaluation metric defined in Vilain et al. (1995) + + This calculates recall error for each true cluster C as the number of + response clusters that would need to be merged in order to produce a + superset of C. + + Examples from Vilain et al. (1995): + >>> vilain({1: {'A', 'B', 'C', 'D'}}, + ... {1: {'A', 'B'}, 2: {'C', 'D'}}) # doctest: +ELLIPSIS + (1.0, 0.66..., 0.8) + >>> vilain({1: {'A', 'B'}, 2: {'C', 'D'}}, + ... {1: {'A', 'B', 'C', 'D'}}) # doctest: +ELLIPSIS + (0.66..., 1.0, 0.8) + >>> vilain({1: {'A', 'B', 'C'}}, {1: {'A', 'C'}}) # doctest: +ELLIPSIS + (1.0, 0.5, 0.66...) + >>> vilain({1: {'B', 'C', 'D', 'E', 'G', 'H', 'J'}}, + ... {1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F'}, 3: {'G', 'H', 'I'}}) + ... # doctest: +ELLIPSIS + (0.5, 0.5, 0.5) + >>> vilain({1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F', 'G'}}, + ... {1: {'A', 'B'}, 2: {'C', 'D'}, 3: {'F', 'G', 'H'}}) + ... # doctest: +ELLIPSIS + (0.5, 0.4, 0.44...) + """ + p = _vilain(pred, sets_to_mapping(true)) + r = _vilain(true, sets_to_mapping(pred)) + return p, r, _f1(p, r) From a9a519d4d5e6d2684acbca73075dd01860316935 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Thu, 15 May 2014 13:35:58 +1000 Subject: [PATCH 3/5] Script to evaluate CoNLL coref --- conll03_nel_eval/coref_metrics.py | 70 ++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/conll03_nel_eval/coref_metrics.py b/conll03_nel_eval/coref_metrics.py index c74a13b..07fd339 100644 --- a/conll03_nel_eval/coref_metrics.py +++ b/conll03_nel_eval/coref_metrics.py @@ -1,4 +1,4 @@ -from __future__ import division +from __future__ import division, print_function from functools import partial from collections import defaultdict @@ -123,9 +123,6 @@ def b_cubed(true, pred): def pairwise_f1(true, pred): """Measure the proportion of correctly identified pairwise coindexations - This is called MUC score, and erroneously cited to Vilain et al. (2102) in - the CoNLL 2011-2012 Shared Task descriptions. - TODO: tests """ pred_mapping = sets_to_mapping(pred) @@ -155,7 +152,7 @@ def _vilain(A, B_mapping): return numerator / denominator -def vilain(true, pred): +def muc(true, pred): """The MUC evaluation metric defined in Vilain et al. (1995) This calculates recall error for each true cluster C as the number of @@ -163,23 +160,68 @@ def vilain(true, pred): superset of C. Examples from Vilain et al. (1995): - >>> vilain({1: {'A', 'B', 'C', 'D'}}, - ... {1: {'A', 'B'}, 2: {'C', 'D'}}) # doctest: +ELLIPSIS + >>> muc({1: {'A', 'B', 'C', 'D'}}, + ... {1: {'A', 'B'}, 2: {'C', 'D'}}) # doctest: +ELLIPSIS (1.0, 0.66..., 0.8) - >>> vilain({1: {'A', 'B'}, 2: {'C', 'D'}}, - ... {1: {'A', 'B', 'C', 'D'}}) # doctest: +ELLIPSIS + >>> muc({1: {'A', 'B'}, 2: {'C', 'D'}}, + ... {1: {'A', 'B', 'C', 'D'}}) # doctest: +ELLIPSIS (0.66..., 1.0, 0.8) - >>> vilain({1: {'A', 'B', 'C'}}, {1: {'A', 'C'}}) # doctest: +ELLIPSIS + >>> muc({1: {'A', 'B', 'C'}}, {1: {'A', 'C'}}) # doctest: +ELLIPSIS (1.0, 0.5, 0.66...) - >>> vilain({1: {'B', 'C', 'D', 'E', 'G', 'H', 'J'}}, - ... {1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F'}, 3: {'G', 'H', 'I'}}) + >>> muc({1: {'B', 'C', 'D', 'E', 'G', 'H', 'J'}}, + ... {1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F'}, 3: {'G', 'H', 'I'}}) ... # doctest: +ELLIPSIS (0.5, 0.5, 0.5) - >>> vilain({1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F', 'G'}}, - ... {1: {'A', 'B'}, 2: {'C', 'D'}, 3: {'F', 'G', 'H'}}) + >>> muc({1: {'A', 'B', 'C'}, 2: {'D', 'E', 'F', 'G'}}, + ... {1: {'A', 'B'}, 2: {'C', 'D'}, 3: {'F', 'G', 'H'}}) ... # doctest: +ELLIPSIS (0.5, 0.4, 0.44...) """ p = _vilain(pred, sets_to_mapping(true)) r = _vilain(true, sets_to_mapping(pred)) return p, r, _f1(p, r) + + +def read_conll_coref(f): + res = defaultdict(set) + # TODO: handle annotations over document boundary + start = None + i = 0 + for l in f: + if l.startswith('#') or not l.strip(): + continue + i += 1 + doc_id, tag = l.strip().split(' ', 1) + if tag == '-': + continue + if tag.endswith(')'): + if start is None: + assert tag.startswith('(') + else: + assert not tag.startswith('(') + cid = tag.lstrip('(').rstrip(')') + res[cid].add((doc_id, start, i)) + start = None + elif tag.startswith('('): + start = i + return dict(res) + + +if __name__ == '__main__': + import argparse + ap = argparse.ArgumentParser(description='CoNLL2011-2 coreference evaluator') + ap.add_argument('key_file', type=argparse.FileType('r')) + ap.add_argument('response_file', type=argparse.FileType('r')) + args = ap.parse_args() + METRICS = { + 'bcubed': b_cubed, + 'ceafe': entity_ceaf, + 'ceafm': mention_ceaf, + 'muc': muc, + 'pairs': pairwise_f1, + } + key = read_conll_coref(args.key_file) + response = read_conll_coref(args.response_file) + print('Metric', 'P', 'R', 'F1', sep='\t') + for name, fn in sorted(METRICS.items()): + print(name, *('{:0.2f}'.format(100 * x) for x in fn(key, response)), sep='\t') From 6c94357834e29f2c3a830aca39978e8815d94e3e Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Sun, 18 May 2014 19:57:13 +1000 Subject: [PATCH 4/5] Fix conll reading, pairwise denominator --- conll03_nel_eval/coref_metrics.py | 37 +++++++++++++++++-------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/conll03_nel_eval/coref_metrics.py b/conll03_nel_eval/coref_metrics.py index 07fd339..34bb231 100644 --- a/conll03_nel_eval/coref_metrics.py +++ b/conll03_nel_eval/coref_metrics.py @@ -3,6 +3,7 @@ from functools import partial from collections import defaultdict import itertools +import re import numpy as np @@ -131,8 +132,8 @@ def pairwise_f1(true, pred): for m1, m2 in itertools.combinations(cluster, 2): if pred_mapping.get(m1) == pred_mapping.get(m2): correct += 1 - p = correct / sum(len(cluster) - 1 for cluster in pred.values()) - r = correct / sum(len(cluster) - 1 for cluster in true.values()) + p = correct / sum(len(cluster) * (len(cluster) - 1) for cluster in pred.values()) * 2 + r = correct / sum(len(cluster) * (len(cluster) - 1) for cluster in true.values()) * 2 return p, r, _f1(p, r) @@ -185,25 +186,27 @@ def muc(true, pred): def read_conll_coref(f): res = defaultdict(set) # TODO: handle annotations over document boundary - start = None i = 0 + opened = {} for l in f: - if l.startswith('#') or not l.strip(): + if l.startswith('#'): continue - i += 1 - doc_id, tag = l.strip().split(' ', 1) - if tag == '-': + l = l.split() + if not l: + assert not opened continue - if tag.endswith(')'): - if start is None: - assert tag.startswith('(') - else: - assert not tag.startswith('(') - cid = tag.lstrip('(').rstrip(')') - res[cid].add((doc_id, start, i)) - start = None - elif tag.startswith('('): - start = i + + i += 1 + tag = l[-1] + + for match in re.finditer(r'\(?[0-9]+\)?', tag): + match = match.group() + cid = match.strip('()') + if match.startswith('('): + assert cid not in opened + opened[cid] = i + if match.endswith(')'): + res[cid].add((opened.pop(cid), i)) return dict(res) From eed09cb890709d9fbadf65701c0bf2d6d4facdaa Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 21 May 2014 11:40:08 +1000 Subject: [PATCH 5/5] Debug mode to cross-check our implementation against official CoNLL11-2 scorer --- conll03_nel_eval/coref_metrics.py | 216 ++++++++++++++++++++++++------ 1 file changed, 176 insertions(+), 40 deletions(-) diff --git a/conll03_nel_eval/coref_metrics.py b/conll03_nel_eval/coref_metrics.py index 34bb231..7087926 100644 --- a/conll03_nel_eval/coref_metrics.py +++ b/conll03_nel_eval/coref_metrics.py @@ -4,6 +4,10 @@ from collections import defaultdict import itertools import re +import os +import subprocess +import tempfile +import warnings import numpy as np @@ -15,6 +19,95 @@ # XXX: perhaps use set (or list) of sets rather than dict of sets +######## Debug mode comparison to reference implementation ######## + + +def _get_reference_coref_scorer_path(): + path = os.environ.get('COREFSCORER', None) + if path is None: + return None + if os.path.isdir(path): + path = os.path.join(path, 'scorer.pl') + if not os.path.isfile(path): + warnings.warn('Not using coreference metric debug mode:' + '{} is not a file'.format(path)) + return path + + +REFERENCE_COREF_SCORER_PATH = _get_reference_coref_scorer_path() + + +def _parse_reference_coref_scorer(output): + sections = output.split('\nMETRIC ') + if len(sections) > 1: + sections = sections[1:] # strip preamble + one_metric = False + else: + one_metric = True + + res = {} + for section in sections: + match = re.match(r''' + .* + Coreference:\s + Recall:\s + \(([^/]+)/([^)]+)\) + .* + Precision:\s + \(([^/]+)/([^)]+)\) + ''', + section, + re.DOTALL | re.VERBOSE) + r_num, r_den, p_num, p_den = map(float, match.groups()) + stats = _prf(p_num, p_den, r_num, r_den) + if one_metric: + return stats + else: + metric = section[:section.index(':')] + res[metric] = stats + return res + + +def _run_reference_coref_scorer(true, pred, metric='all', + script=REFERENCE_COREF_SCORER_PATH): + true_file = tempfile.NamedTemporaryFile(prefix='coreftrue', delete=False) + pred_file = tempfile.NamedTemporaryFile(prefix='corefpred', delete=False) + write_conll_coref(true, pred, true_file, pred_file) + true_file.close() + pred_file.close() + output = subprocess.check_output([script, metric, true_file.name, + pred_file.name]) + os.unlink(true_file.name) + os.unlink(pred_file.name) + return _parse_reference_coref_scorer(output) + + +def _cross_check(metric): + """A wrapper that will assert our output matches reference implementation + + Applies only if the environment variable COREFSCORER points to the + reference implementation. + """ + def decorator(fn): + if REFERENCE_COREF_SCORER_PATH is None: + return fn + + def wrapper(true, pred): + our_results = fn(true, pred) + ref_results = _run_reference_coref_scorer(true, pred, metric) + assert len(our_results) == len(ref_results) == 3 + for our_val, ref_val, name in zip(our_results, ref_results, 'PRF'): + if abs(our_val - ref_val) > 1e-3: + msg = 'Our {}={}; reference {}={}'.format(name, our_val, + name, ref_val) + raise AssertionError(msg) + return our_results + return wrapper + return decorator + + +######## Data formats ######## + def mapping_to_sets(mapping): """ >>> sets = mapping_to_sets({'a': 1, 'b': 2, 'c': 1}).items() @@ -35,9 +128,67 @@ def sets_to_mapping(s): return {m: k for k, ms in s.items() for m in ms} +def read_conll_coref(f): + res = defaultdict(set) + # TODO: handle annotations over document boundary + i = 0 + opened = {} + for l in f: + if l.startswith('#'): + continue + l = l.split() + if not l: + assert not opened + continue + + i += 1 + tag = l[-1] + + for match in re.finditer(r'\(?[0-9]+\)?', tag): + match = match.group() + cid = match.strip('()') + if match.startswith('('): + assert cid not in opened + opened[cid] = i + if match.endswith(')'): + res[cid].add((opened.pop(cid), i)) + return dict(res) + + +def write_conll_coref(true, pred, true_file, pred_file): + """Artificially aligns mentions as CoNLL coreference data + """ + # relabel clusters + true = {'({})'.format(i + 1): s for i, s in enumerate(true.values())} + pred = {'({})'.format(i + 1): s for i, s in enumerate(pred.values())} + # make lookups + true_mapping = sets_to_mapping(true) + pred_mapping = sets_to_mapping(pred) + # headers + print('#begin document (XX); part 000', file=true_file) + print('#begin document (XX); part 000', file=pred_file) + # print all mentions + for mention in set(true_mapping).union(pred_mapping): + print('XX', true_mapping.get(mention, '-'), file=true_file) + print('XX', pred_mapping.get(mention, '-'), file=pred_file) + # footers + print('#end document', file=true_file) + print('#end document', file=pred_file) + + def _f1(a, b): - return 2 * a * b / (a + b) + if a + b: + return 2 * a * b / (a + b) + return 0. + + +def _prf(p_num, p_den, r_num, r_den): + p = p_num / p_den if p_den > 0 else 0. + r = r_num / r_den if r_den > 0 else 0. + return p, r, _f1(p, r) + +######## Cluster comparison ######## def dice(a, b): """ @@ -57,6 +208,9 @@ def overlap(a, b): return len(a & b) +######## Coreference metrics ######## + + def ceaf(true, pred, similarity=dice): """ @@ -93,11 +247,11 @@ def ceaf(true, pred, similarity=dice): pred_denom = sum(similarity(S, S) for S in pred) p = numerator / pred_denom r = numerator / true_denom - return p, r, _f1(p, r) + return _prf(numerator, pred_denom, numerator, true_denom) -entity_ceaf = partial(ceaf, similarity=dice) -mention_ceaf = partial(ceaf, similarity=overlap) +entity_ceaf = _cross_check('ceafe')(partial(ceaf, similarity=dice)) +mention_ceaf = _cross_check('ceafm')(partial(ceaf, similarity=overlap)) def _b_cubed(A, B, A_mapping, B_mapping, EMPTY=frozenset([])): @@ -105,10 +259,10 @@ def _b_cubed(A, B, A_mapping, B_mapping, EMPTY=frozenset([])): for m, k in A_mapping.items(): A_cluster = A.get(k, EMPTY) res += len(A_cluster & B.get(B_mapping.get(m), EMPTY)) / len(A_cluster) - res /= len(A_mapping) - return res + return res, len(A_mapping) +@_cross_check('bcub') def b_cubed(true, pred): """ @@ -116,9 +270,9 @@ def b_cubed(true, pred): """ true_mapping = sets_to_mapping(true) pred_mapping = sets_to_mapping(pred) - p = _b_cubed(pred, true, pred_mapping, true_mapping) - r = _b_cubed(true, pred, true_mapping, pred_mapping) - return p, r, _f1(p, r) + p_num, p_den = _b_cubed(pred, true, pred_mapping, true_mapping) + r_num, r_den = _b_cubed(true, pred, true_mapping, pred_mapping) + return _prf(p_num, p_den, r_num, r_den) def pairwise_f1(true, pred): @@ -132,9 +286,9 @@ def pairwise_f1(true, pred): for m1, m2 in itertools.combinations(cluster, 2): if pred_mapping.get(m1) == pred_mapping.get(m2): correct += 1 - p = correct / sum(len(cluster) * (len(cluster) - 1) for cluster in pred.values()) * 2 - r = correct / sum(len(cluster) * (len(cluster) - 1) for cluster in true.values()) * 2 - return p, r, _f1(p, r) + p_den = sum(len(cluster) * (len(cluster) - 1) for cluster in pred.values()) * 2 + r_den = sum(len(cluster) * (len(cluster) - 1) for cluster in true.values()) * 2 + return _prf(correct, p_den, correct, r_den) def _vilain(A, B_mapping): @@ -150,9 +304,10 @@ def _vilain(A, B_mapping): corresponding.add(B_mapping[m]) numerator += len(cluster) - n_unaligned - len(corresponding) denominator += len(cluster) - 1 - return numerator / denominator + return numerator, denominator +@_cross_check('muc') def muc(true, pred): """The MUC evaluation metric defined in Vilain et al. (1995) @@ -178,36 +333,17 @@ def muc(true, pred): ... # doctest: +ELLIPSIS (0.5, 0.4, 0.44...) """ - p = _vilain(pred, sets_to_mapping(true)) - r = _vilain(true, sets_to_mapping(pred)) - return p, r, _f1(p, r) - + p_num, p_den = _vilain(pred, sets_to_mapping(true)) + r_num, r_den = _vilain(true, sets_to_mapping(pred)) + return _prf(p_num, p_den, r_num, r_den) -def read_conll_coref(f): - res = defaultdict(set) - # TODO: handle annotations over document boundary - i = 0 - opened = {} - for l in f: - if l.startswith('#'): - continue - l = l.split() - if not l: - assert not opened - continue - i += 1 - tag = l[-1] - for match in re.finditer(r'\(?[0-9]+\)?', tag): - match = match.group() - cid = match.strip('()') - if match.startswith('('): - assert cid not in opened - opened[cid] = i - if match.endswith(')'): - res[cid].add((opened.pop(cid), i)) - return dict(res) +if REFERENCE_COREF_SCORER_PATH is not None: + if _run_reference_coref_scorer({}, {}).get('bcub') != (0., 0., 0.): + warnings.warn('Not using coreference metric debug mode:' + 'The script is producing invalid output') + REFERENCE_COREF_SCORER_PATH = None if __name__ == '__main__':