Skip to content

Commit

Permalink
Implement swap pre-/post-processing.
Browse files Browse the repository at this point in the history
  • Loading branch information
chantera committed Aug 18, 2020
1 parent 2189c24 commit 6751a8d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 53 deletions.
4 changes: 2 additions & 2 deletions data/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from teras.io.reader import read_tree


CC_SEP = [",", ";", ":", "--", "..."]
EXCLUDING_SPAN_LABEL = ["-NONE-", "``", "''"]
CC_SEP = (",", ";", ":", "--", "...")
EXCLUDING_SPAN_LABEL = ("-NONE-")


def clean_tree(tree, exclusion=None):
Expand Down
1 change: 0 additions & 1 deletion data/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def _traverse(node, buf):
buf = []
_traverse(tree, buf)
return buf
return ''.join(buf)


if __name__ == "__main__":
Expand Down
146 changes: 96 additions & 50 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
_CHAR_PAD = "<PAD>"
CC_KEY = ["and", "or", "but", "nor", "and/or"]
CC_SEP = [",", ";", ":"]
OPEN_QUOTE = ("``", "`")
CLOSE_QUOTE = ("''", "'")


class DataLoader(CachedTextLoader):
Expand Down Expand Up @@ -75,8 +77,8 @@ def map(self, item):
words, postags, cont_embeds, coords \
= _convert_item_to_attrs(item, type(self._primary_reader))
(word_ids, postag_ids, char_ids,
cc_indices, sep_indices, cont_embeds, coords) \
= _map(self, words, postags, cont_embeds, coords)
cc_indices, sep_indices, cont_embeds, indices, coords) \
= _map(self, words, postags, cont_embeds, coords, swap_quote=True)
self._updated = self.train

for cc in cc_indices:
Expand Down Expand Up @@ -167,43 +169,7 @@ def use_pretrained_embed(self):
return self._use_pretrained_embed

def load_from_tagged_file(self, file, contextualized_embed_file=None):
QUOTE = ("``", "''", "`", "'")
file_reader = reader.ZipReader([
reader.CsvReader(file, delimiter=' '),
reader.ContextualizedEmbeddingsReader(contextualized_embed_file),
])
samples = []
for idx, (sentence, cont_embeds) in enumerate(file_reader):
words, postags = zip(*[token.split('_') for token in sentence])
raw_words = words
is_quote = np.array([word in QUOTE for word in words])
# assert not any(is_quote)
words = [word for word, quote in zip(words, is_quote) if not quote]
postags = [postag for postag, quote
in zip(postags, is_quote) if not quote]

cc_indices = np.array([i for i, word in enumerate(words)
if word.lower() in CC_KEY], dtype=np.int32)
if len(cc_indices) == 0:
continue
sep_indices = np.array([i for i, word in enumerate(words)
if word in CC_SEP], dtype=np.int32)
word_ids = self.map_attr('word', words, False)
char_ids = [self.map_attr('char', list(word), False)
for word in words]
postag_ids = self.map_attr('pos', postags, False)
if cont_embeds is not None:
if is_quote.sum() > 0:
warnings.warn("contextualized embeddings are changed "
"to strip quotation marks: sentence=`{}`"
.format(' '.join(words)))
cont_embeds = np.delete(
cont_embeds, np.argwhere(is_quote), axis=1)
assert cont_embeds.shape[1] == len(words)
sample = (word_ids, postag_ids, char_ids, cc_indices, sep_indices,
cont_embeds, raw_words, is_quote, idx)
samples.append(sample)
return Dataset(samples)
raise NotImplementedError


def _convert_item_to_attrs(item, reader_type):
Expand All @@ -226,7 +192,21 @@ def _convert_item_to_attrs(item, reader_type):
return words, postags, cont_embeds, coords


def _map(loader, words, postags, cont_embeds, coords=None):
def _map(loader, words, postags, cont_embeds, coords=None,
swap_quote=False):
indices = np.arange(len(words), dtype=np.int32)
swapped = False
if swap_quote:
for i, word in enumerate(words[:-1]):
if word in CC_SEP and words[i + 1] in CLOSE_QUOTE:
indices[i], indices[i + 1] = i + 1, i
swapped = True

if swapped:
words = [words[idx] for idx in indices]
postags = [postags[idx] for idx in indices]
else:
indices = None
cc_indices = [i for i, word in enumerate(words) if word.lower() in CC_KEY]
sep_indices = [i for i, word in enumerate(words) if word in CC_SEP]
word_ids = loader.map_attr(
Expand All @@ -236,11 +216,19 @@ def _map(loader, words, postags, cont_embeds, coords=None):
postag_ids = loader.map_attr('pos', postags, loader.train)
cc_indices = np.array(cc_indices, np.int32)
sep_indices = np.array(sep_indices, np.int32)

if cont_embeds is not None:
if swapped:
# warnings.warn("contextualized embeddings are changed "
# "to swap quotation marks: sentence=`{}`"
# .format(' '.join(words)))
cont_embeds = cont_embeds[:, indices]
assert cont_embeds.shape[1] == len(words)
if swapped and coords is not None:
coords = preprocess(coords, indices)

return word_ids, postag_ids, char_ids, \
cc_indices, sep_indices, cont_embeds, coords
cc_indices, sep_indices, cont_embeds, indices, coords


def _filter(words, coords, target='any'):
Expand Down Expand Up @@ -429,7 +417,7 @@ def _traverse(tree, index):
cc = None
for child in tree[1:]:
child_label = child[0]
assert child_label not in ["-NONE-", "``", "''"]
assert child_label != "-NONE-"
child_span = _traverse(child, index)
if "COORD" in child_label:
conjuncts.append(child_span)
Expand Down Expand Up @@ -504,7 +492,8 @@ def _find_separator(words, search_from, search_to, search_len=2):
class Coordination(object):
__slots__ = ('cc', 'conjuncts', 'seps', 'label')

def __init__(self, cc, conjuncts, seps=None, label=None):
def __init__(self, cc, conjuncts, seps=None, label=None,
suppress_warning=False):
assert isinstance(conjuncts, (list, tuple)) and len(conjuncts) >= 2
assert all(isinstance(conj, tuple) for conj in conjuncts)
conjuncts = sorted(conjuncts, key=lambda span: span[0])
Expand All @@ -515,7 +504,7 @@ def __init__(self, cc, conjuncts, seps=None, label=None):
if len(seps) == len(conjuncts) - 2:
for i, sep in enumerate(seps):
assert conjuncts[i][1] < sep and conjuncts[i + 1][0] > sep
else:
elif not suppress_warning:
warnings.warn(
"Coordination does not contain enough separators. "
"It may be a wrong coordination: "
Expand Down Expand Up @@ -553,15 +542,72 @@ def __eq__(self, other):
in zip(self.conjuncts, other.conjuncts))


def post_process(coords, is_quote):
def preprocess(coords, indices):
new_coords = {}
offsets = np.delete(is_quote.cumsum(), np.argwhere(is_quote))
for cc, coord in coords.items():
cc = cc + offsets[cc]
cc = indices[cc]
if coord is not None:
conjuncts = [(b + offsets[b], e + offsets[e])
for (b, e) in coord.conjuncts]
seps = [s + offsets[s] for s in coord.seps]
seps = [indices[s] for s in coord.seps]
conjuncts = []
n = len(coord.conjuncts)
for i, (b, e) in enumerate(coord.conjuncts):
if indices[b] == b - 1 and indices[b - 1] == b:
# Special case for annotation error
new_b = b + 1
else:
new_b = indices[b]
if indices[e] == e - 1 and indices[e - 1] == e:
# Case: ``A,'' ``B,'' and ``C,'' ; (0, 3), (4, 7), (9, 12)
# => ``A'', ``B'', and ``C'', ; (0, 2), (4, 6), (9, 11)
new_e = indices[e]
sep = e
if i < n - 2 and sep + 1 != cc and sep not in seps:
seps.append(sep)
elif indices[e] == e + 1 and indices[e + 1] == e:
# Case: ``A,'' ``B,'' and ``C,'' ; (1, 2), (5, 6), (10, 11)
# => ``A'', ``B'', and ``C'', ; (1, 1), (5, 5), (10, 10)
new_e = e - 1
sep = e + 1
if i < n - 2 and sep + 1 != cc and sep not in seps:
seps.append(sep)
else:
new_e = indices[e]
conjuncts.append((new_b, new_e))
coord = Coordination(cc, conjuncts, seps, coord.label)
new_coords[cc] = coord
return new_coords


def postprocess(coords, indices):
new_coords = {}
for cc, coord in coords.items():
cc = indices[cc]
if coord is not None:
seps = [indices[s] for s in coord.seps]
conjuncts = []
n = len(coord.conjuncts)
for i, (b, e) in enumerate(coord.conjuncts):
new_b = indices[b]
if indices[e] == e + 1 and indices[e + 1] == e:
# Case: ``A'', ``B'', and ``C'', ; (0, 2), (4, 6), (9, 11)
# => ``A,'' ``B,'' and ``C,'' ; (0, 3), (4, 7), (9, 12)
new_e = indices[e]
sep = e
if i < n - 1 and sep in seps:
seps.remove(sep)
elif e < indices.size - 1 \
and indices[e + 1] == e + 2 \
and indices[e + 2] == e + 1:
# Case: ``A'', ``B'', and ``C'', ; (1, 1), (5, 5), (10, 10)
# => ``A,'' ``B,'' and ``C,'' ; (1, 1), (5, 5), (10, 10)
new_e = e # always exclude the trailing comma
sep = e + 1
if i < n - 1 and sep in seps:
seps.remove(sep)
else:
new_e = indices[e]
conjuncts.append((new_b, new_e))
coord = Coordination(cc, conjuncts, seps, coord.label,
suppress_warning=True)
new_coords[cc] = coord
return new_coords

0 comments on commit 6751a8d

Please sign in to comment.