Skip to content

Commit

Permalink
Improve FeatureExtractor2.
Browse files Browse the repository at this point in the history
  • Loading branch information
chantera committed Aug 18, 2020
1 parent 3705fec commit c264f47
Showing 1 changed file with 47 additions and 11 deletions.
58 changes: 47 additions & 11 deletions src/models/feature.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import functools

import chainer
Expand Down Expand Up @@ -104,19 +105,54 @@ def _feature_repl(hs_flatten, pairs, ckeys, lengths):
return repl1, repl2

@staticmethod
def _forward_spans(hs_flatten, pairs, ckeys, lengths):
def _forward_spans(hs_flatten, pairs, ckeys, lengths,
use_block=True, block_size=128):
xp = chainer.cuda.get_array_module(hs_flatten)
begins, ends = pairs.T

@functools.lru_cache(maxsize=None)
def _get_span_v(i, j):
return F.average(hs_flatten[i:j + 1], axis=0)

left_spans = F.vstack(
[_get_span_v(begin, ckey_pre) for begin, ckey_pre
in zip(begins, ckeys - 1)])
right_spans = F.vstack(
[_get_span_v(ckey_post, end) for ckey_post, end
in zip(ckeys + 1, ends)])
if use_block:
def _uniq(start, end):
idxs = defaultdict(lambda: len(idxs))
offset = np.array([idxs[(s, e)] for s, e in zip(start, end)])
start, end = np.array([k for k, v in sorted(
idxs.items(), key=lambda x: x[1])]).T
return start, end, offset

def _sum(start, end):
size = len(start)
lb, ub = min(start), max(end)
hs = hs_flatten[lb: ub]
mask = xp.zeros((size, ub - lb, 1), dtype=xp.float32)
for i, (s, e) in enumerate(zip(start, end)):
mask[i, s - lb: e - lb] = 1.0
return F.sum(hs * mask, axis=1)

def _extract(start, end):
spans = []
start, end, offset = _uniq(start, end)
ofs, lb, ub = 0, 0, 0
for k in range(len(start)):
lb, ub = min(lb, start[k]), max(ub, end[k])
if ub - lb > block_size and k > 0:
spans.append(_sum(start[ofs: k], end[ofs: k]))
ofs, lb, ub = k, start[k], end[k]
spans.append(_sum(start[ofs:], end[ofs:]))
spans = F.vstack(spans) / xp.asarray(end - start)[:, None]
return F.embed_id(xp.asarray(offset), spans)

left_spans = _extract(begins, ckeys)
right_spans = _extract(ckeys + 1, ends + 1)
else:
@functools.lru_cache(maxsize=None)
def _get_span_v(i, j):
return F.average(hs_flatten[i:j + 1], axis=0)

left_spans = F.vstack(
[_get_span_v(begin, ckey_pre) for begin, ckey_pre
in zip(begins, ckeys - 1)])
right_spans = F.vstack(
[_get_span_v(ckey_post, end) for ckey_post, end
in zip(ckeys + 1, ends)])

return left_spans, right_spans

Expand Down

0 comments on commit c264f47

Please sign in to comment.