diff --git a/data/clean.py b/data/clean.py index 4bf81fb..f882099 100644 --- a/data/clean.py +++ b/data/clean.py @@ -3,8 +3,8 @@ from teras.io.reader import read_tree -CC_SEP = [",", ";", ":", "--", "..."] -EXCLUDING_SPAN_LABEL = ["-NONE-", "``", "''"] +CC_SEP = (",", ";", ":", "--", "...") +EXCLUDING_SPAN_LABEL = ("-NONE-") def clean_tree(tree, exclusion=None): diff --git a/data/extract.py b/data/extract.py index 28eda0e..b00e65a 100644 --- a/data/extract.py +++ b/data/extract.py @@ -14,7 +14,6 @@ def _traverse(node, buf): buf = [] _traverse(tree, buf) return buf - return ''.join(buf) if __name__ == "__main__": diff --git a/src/dataset.py b/src/dataset.py index 16e90af..59a0ecf 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -12,6 +12,8 @@ _CHAR_PAD = "" CC_KEY = ["and", "or", "but", "nor", "and/or"] CC_SEP = [",", ";", ":"] +OPEN_QUOTE = ("``", "`") +CLOSE_QUOTE = ("''", "'") class DataLoader(CachedTextLoader): @@ -75,8 +77,8 @@ def map(self, item): words, postags, cont_embeds, coords \ = _convert_item_to_attrs(item, type(self._primary_reader)) (word_ids, postag_ids, char_ids, - cc_indices, sep_indices, cont_embeds, coords) \ - = _map(self, words, postags, cont_embeds, coords) + cc_indices, sep_indices, cont_embeds, indices, coords) \ + = _map(self, words, postags, cont_embeds, coords, swap_quote=True) self._updated = self.train for cc in cc_indices: @@ -167,43 +169,7 @@ def use_pretrained_embed(self): return self._use_pretrained_embed def load_from_tagged_file(self, file, contextualized_embed_file=None): - QUOTE = ("``", "''", "`", "'") - file_reader = reader.ZipReader([ - reader.CsvReader(file, delimiter=' '), - reader.ContextualizedEmbeddingsReader(contextualized_embed_file), - ]) - samples = [] - for idx, (sentence, cont_embeds) in enumerate(file_reader): - words, postags = zip(*[token.split('_') for token in sentence]) - raw_words = words - is_quote = np.array([word in QUOTE for word in words]) - # assert not any(is_quote) - words = [word for word, quote in zip(words, is_quote) if not quote] - postags = [postag for postag, quote - in zip(postags, is_quote) if not quote] - - cc_indices = np.array([i for i, word in enumerate(words) - if word.lower() in CC_KEY], dtype=np.int32) - if len(cc_indices) == 0: - continue - sep_indices = np.array([i for i, word in enumerate(words) - if word in CC_SEP], dtype=np.int32) - word_ids = self.map_attr('word', words, False) - char_ids = [self.map_attr('char', list(word), False) - for word in words] - postag_ids = self.map_attr('pos', postags, False) - if cont_embeds is not None: - if is_quote.sum() > 0: - warnings.warn("contextualized embeddings are changed " - "to strip quotation marks: sentence=`{}`" - .format(' '.join(words))) - cont_embeds = np.delete( - cont_embeds, np.argwhere(is_quote), axis=1) - assert cont_embeds.shape[1] == len(words) - sample = (word_ids, postag_ids, char_ids, cc_indices, sep_indices, - cont_embeds, raw_words, is_quote, idx) - samples.append(sample) - return Dataset(samples) + raise NotImplementedError def _convert_item_to_attrs(item, reader_type): @@ -226,7 +192,21 @@ def _convert_item_to_attrs(item, reader_type): return words, postags, cont_embeds, coords -def _map(loader, words, postags, cont_embeds, coords=None): +def _map(loader, words, postags, cont_embeds, coords=None, + swap_quote=False): + indices = np.arange(len(words), dtype=np.int32) + swapped = False + if swap_quote: + for i, word in enumerate(words[:-1]): + if word in CC_SEP and words[i + 1] in CLOSE_QUOTE: + indices[i], indices[i + 1] = i + 1, i + swapped = True + + if swapped: + words = [words[idx] for idx in indices] + postags = [postags[idx] for idx in indices] + else: + indices = None cc_indices = [i for i, word in enumerate(words) if word.lower() in CC_KEY] sep_indices = [i for i, word in enumerate(words) if word in CC_SEP] word_ids = loader.map_attr( @@ -236,11 +216,19 @@ def _map(loader, words, postags, cont_embeds, coords=None): postag_ids = loader.map_attr('pos', postags, loader.train) cc_indices = np.array(cc_indices, np.int32) sep_indices = np.array(sep_indices, np.int32) + if cont_embeds is not None: + if swapped: + # warnings.warn("contextualized embeddings are changed " + # "to swap quotation marks: sentence=`{}`" + # .format(' '.join(words))) + cont_embeds = cont_embeds[:, indices] assert cont_embeds.shape[1] == len(words) + if swapped and coords is not None: + coords = preprocess(coords, indices) return word_ids, postag_ids, char_ids, \ - cc_indices, sep_indices, cont_embeds, coords + cc_indices, sep_indices, cont_embeds, indices, coords def _filter(words, coords, target='any'): @@ -429,7 +417,7 @@ def _traverse(tree, index): cc = None for child in tree[1:]: child_label = child[0] - assert child_label not in ["-NONE-", "``", "''"] + assert child_label != "-NONE-" child_span = _traverse(child, index) if "COORD" in child_label: conjuncts.append(child_span) @@ -504,7 +492,8 @@ def _find_separator(words, search_from, search_to, search_len=2): class Coordination(object): __slots__ = ('cc', 'conjuncts', 'seps', 'label') - def __init__(self, cc, conjuncts, seps=None, label=None): + def __init__(self, cc, conjuncts, seps=None, label=None, + suppress_warning=False): assert isinstance(conjuncts, (list, tuple)) and len(conjuncts) >= 2 assert all(isinstance(conj, tuple) for conj in conjuncts) conjuncts = sorted(conjuncts, key=lambda span: span[0]) @@ -515,7 +504,7 @@ def __init__(self, cc, conjuncts, seps=None, label=None): if len(seps) == len(conjuncts) - 2: for i, sep in enumerate(seps): assert conjuncts[i][1] < sep and conjuncts[i + 1][0] > sep - else: + elif not suppress_warning: warnings.warn( "Coordination does not contain enough separators. " "It may be a wrong coordination: " @@ -553,15 +542,72 @@ def __eq__(self, other): in zip(self.conjuncts, other.conjuncts)) -def post_process(coords, is_quote): +def preprocess(coords, indices): new_coords = {} - offsets = np.delete(is_quote.cumsum(), np.argwhere(is_quote)) for cc, coord in coords.items(): - cc = cc + offsets[cc] + cc = indices[cc] if coord is not None: - conjuncts = [(b + offsets[b], e + offsets[e]) - for (b, e) in coord.conjuncts] - seps = [s + offsets[s] for s in coord.seps] + seps = [indices[s] for s in coord.seps] + conjuncts = [] + n = len(coord.conjuncts) + for i, (b, e) in enumerate(coord.conjuncts): + if indices[b] == b - 1 and indices[b - 1] == b: + # Special case for annotation error + new_b = b + 1 + else: + new_b = indices[b] + if indices[e] == e - 1 and indices[e - 1] == e: + # Case: ``A,'' ``B,'' and ``C,'' ; (0, 3), (4, 7), (9, 12) + # => ``A'', ``B'', and ``C'', ; (0, 2), (4, 6), (9, 11) + new_e = indices[e] + sep = e + if i < n - 2 and sep + 1 != cc and sep not in seps: + seps.append(sep) + elif indices[e] == e + 1 and indices[e + 1] == e: + # Case: ``A,'' ``B,'' and ``C,'' ; (1, 2), (5, 6), (10, 11) + # => ``A'', ``B'', and ``C'', ; (1, 1), (5, 5), (10, 10) + new_e = e - 1 + sep = e + 1 + if i < n - 2 and sep + 1 != cc and sep not in seps: + seps.append(sep) + else: + new_e = indices[e] + conjuncts.append((new_b, new_e)) coord = Coordination(cc, conjuncts, seps, coord.label) new_coords[cc] = coord return new_coords + + +def postprocess(coords, indices): + new_coords = {} + for cc, coord in coords.items(): + cc = indices[cc] + if coord is not None: + seps = [indices[s] for s in coord.seps] + conjuncts = [] + n = len(coord.conjuncts) + for i, (b, e) in enumerate(coord.conjuncts): + new_b = indices[b] + if indices[e] == e + 1 and indices[e + 1] == e: + # Case: ``A'', ``B'', and ``C'', ; (0, 2), (4, 6), (9, 11) + # => ``A,'' ``B,'' and ``C,'' ; (0, 3), (4, 7), (9, 12) + new_e = indices[e] + sep = e + if i < n - 1 and sep in seps: + seps.remove(sep) + elif e < indices.size - 1 \ + and indices[e + 1] == e + 2 \ + and indices[e + 2] == e + 1: + # Case: ``A'', ``B'', and ``C'', ; (1, 1), (5, 5), (10, 10) + # => ``A,'' ``B,'' and ``C,'' ; (1, 1), (5, 5), (10, 10) + new_e = e # always exclude the trailing comma + sep = e + 1 + if i < n - 1 and sep in seps: + seps.remove(sep) + else: + new_e = indices[e] + conjuncts.append((new_b, new_e)) + coord = Coordination(cc, conjuncts, seps, coord.label, + suppress_warning=True) + new_coords[cc] = coord + return new_coords