Skip to content

Commit

Permalink
Merge pull request #6 from tc20042008/develop
Browse files Browse the repository at this point in the history
RpExprParser
  • Loading branch information
jiahy0825 authored Jul 31, 2024
2 parents 3f051aa + e51ecd0 commit 6df9edb
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 23 deletions.
9 changes: 7 additions & 2 deletions athena/generators/paddle_op_call_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,5 +369,10 @@ def pd_op_rnn_(self, op, *inputs):
outs = self.GenerateCOpsCall(op, inputs, op_name="rnn")
return f"{outs} + (None,)"

def pd_op_select_input(self, op, cond, *inputs):
return f"[{', '.join(x.name for x in inputs)}][int({cond.name})]"
def pd_op_slice(self, op, *inputs):
if op.attrs["decrease_axis"] == "[0]" and op.output_types[0].shape == [1]:
op.attrs["decrease_axis"] = "[]"
return self.GenerateCOpsCall(op, inputs)

def pd_op_select_input(self, op, cond, elem0, elem1):
return f"({elem0.name} if {cond.name} == 0 else {elem1.name})"
28 changes: 21 additions & 7 deletions athena/rp_expr/rp_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class ListRpExpr(RpExpr):
pass


@dataclass
class NaiveTokenListRpExpr(ListRpExpr):
tensors: t.List[np.ndarray["N", np.int64]]


@dataclass
class TokenizedRpExpr(RpExpr):
token_id2primitive_id: t.List[PrimitiveId]
Expand All @@ -31,11 +36,6 @@ class TokenRpExpr(RpExpr):
pass


@dataclass
class NaiveTokenListRpExpr(ListRpExpr):
tensors: t.List[np.ndarray["N", np.int64]]


@dataclass
class FlattenedTokenListRpExpr(ListRpExpr):
tensor_list_size: int
Expand All @@ -50,10 +50,17 @@ class NaiveTokenRpExpr(TokenRpExpr):
@dataclass
class LetsTokenRpExpr(TokenRpExpr):
symbol_token_ids: t.List[TokenId]
symbol_token_tensors: t.List[NaiveTokenRpExpr]
symbol_token_tensors: t.List[np.ndarray["N", np.int64]]
body_rp_expr: NaiveTokenRpExpr


@dataclass
class LetsListTokenRpExpr(TokenRpExpr):
symbol_token_ids: t.List[TokenId]
symbol_token_tensors: t.List[np.ndarray["N", np.int64]]
body_rp_expr: t.List[np.ndarray["N", np.int64]]


class TokenIdAllocator:
def __init__(self, next_token_id: int = 0):
self.next_token_id = next_token_id
Expand Down Expand Up @@ -82,4 +89,11 @@ def Tokenize(
)
for primitive_id_list in primitive_id_lists
]
return NaiveTokenListRpExpr(token_tensors), token_id_allocator
token_id2primitive_id = [None] * len(primitive_id2token_id)
for primitive_id, token_id in primitive_id2token_id.items():
token_id2primitive_id[token_id] = primitive_id
return (
NaiveTokenListRpExpr(token_tensors),
token_id_allocator,
token_id2primitive_id,
)
54 changes: 54 additions & 0 deletions athena/rp_expr/rp_expr_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import typing as t
import numpy as np
from athena.rp_expr.rp_expr import Tokenize, PrimitiveId, LetsListTokenRpExpr
from athena.rp_expr.rp_expr_passes import (
FlattenTokenListPass,
FoldTokensPass,
RecursiveFoldTokensPass,
FoldIfTokenIdGreatEqualPass,
)


class RpExprParser:
def __init__(self):
pass

def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]):
token_list, id_allocator, token_id2primitive_id = Tokenize(primitive_id_lists)
flatten_pass = FlattenTokenListPass(id_allocator)
success, flattened_rp_expr = flatten_pass(token_list)
assert success
fold_pass = RecursiveFoldTokensPass(id_allocator)
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
assert success
threshold = len(primitive_id_lists)
threshold_fold_pass = FoldIfTokenIdGreatEqualPass(
id_allocator=id_allocator,
threshold_start_token_id=threshold,
)
success, threshold_fold_rp_expr = threshold_fold_pass(fold_rp_expr.body_rp_expr)
assert success
threshold_fold_rp_expr = self.MergeAndUnflatten(
fold_rp_expr, threshold_fold_rp_expr, threshold
)
return threshold_fold_rp_expr, token_id2primitive_id

def MergeAndUnflatten(self, fold_rp_expr, threshold_fold_rp_expr, threshold):
assert len(threshold_fold_rp_expr.body_rp_expr) == threshold
return LetsListTokenRpExpr(
symbol_token_ids=[
x - threshold
for x in (
fold_rp_expr.symbol_token_ids
+ threshold_fold_rp_expr.symbol_token_ids
)
],
symbol_token_tensors=[
x - threshold
for x in (
fold_rp_expr.symbol_token_tensors
+ threshold_fold_rp_expr.symbol_token_tensors
)
],
body_rp_expr=[x - threshold for x in threshold_fold_rp_expr.body_rp_expr],
)
81 changes: 71 additions & 10 deletions athena/rp_expr/rp_expr_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FlattenedTokenListRpExpr,
NaiveTokenRpExpr,
LetsTokenRpExpr,
LetsListTokenRpExpr,
)
import itertools

Expand Down Expand Up @@ -59,14 +60,12 @@ def __call__(self, token_tensor: NaiveTokenRpExpr):
most_frequent_length, indexes = self.GetMostFrequentPatternLengthAndIndexes(
input_tensor
)
new_token_id = self.id_allocator.NewTokenId()
success, replacement = self.Replace(
new_token_id, replacement = self.Replace(
pattern_length=most_frequent_length,
indexes=indexes,
new_token_id=new_token_id,
input_tensor=input_tensor,
)
if not success:
if new_token_id is None:
return False, token_tensor
start = indexes[0]
return True, LetsTokenRpExpr(
Expand All @@ -79,20 +78,18 @@ def Replace(
self,
pattern_length,
indexes,
new_token_id,
input_tensor: np.ndarray["N", np.int64],
) -> t.Tuple[bool, np.ndarray["N", np.int64]]:
new_token_tensor = paddle.to_tensor([new_token_id], paddle.int64)
) -> t.Tuple[bool, int, np.ndarray["N", np.int64]]:
num_tokens = input_tensor.shape[0]
if pattern_length == 1:
return False, input_tensor
return None, input_tensor
assert indexes.shape[0] > 0
disjoint_range_starts = [
start
for start in self.GetDisjoint(pattern_length, indexes.numpy().tolist())
]
if len(disjoint_range_starts) <= 1:
return False, input_tensor
return None, input_tensor
assert disjoint_range_starts[-1] + pattern_length <= num_tokens
first_start = disjoint_range_starts[0]
pattern_tensor = input_tensor[first_start : (first_start + pattern_length)]
Expand All @@ -108,6 +105,9 @@ def Replace(
uniqued_segment_starts = paddle.unique(paddle.to_tensor(segment_starts))
segment_lengths = paddle.diff(uniqued_segment_starts).numpy().tolist()

new_token_id = self.id_allocator.NewTokenId()
new_token_tensor = paddle.to_tensor([new_token_id], paddle.int64)

def ReplaceTensor(tensor):
if tensor.shape != pattern_tensor.shape:
return tensor
Expand All @@ -120,7 +120,7 @@ def ReplaceTensor(tensor):
for tensor in paddle.split(input_tensor, segment_lengths)
]
output_tensor = paddle.concat(replaced_segment_tensors)
return True, output_tensor
return new_token_id, output_tensor

def GetConv(self, num_tokens):
windows_size = min(num_tokens, self.max_windows_size)
Expand Down Expand Up @@ -204,3 +204,64 @@ def __call__(self, token_tensor: NaiveTokenRpExpr):
symbol_token_tensors=symbol_token_tensors,
body_rp_expr=token_tensor,
)


class FoldIfTokenIdGreatEqualPass(Pass):
def __init__(
self,
id_allocator: TokenIdAllocator,
threshold_start_token_id: int,
):
self.id_allocator = id_allocator
self.threshold_start_token_id = threshold_start_token_id

def __call__(self, token_rp_expr: NaiveTokenRpExpr):
indexes_ge_threshold = self.GetIndexesGeThreshold(token_rp_expr.tensor)
token_ids_ge_threshold = paddle.gather(
token_rp_expr.tensor, indexes_ge_threshold
)
consecutive_index_range_lengths = self.GetConsecutiveIndexRangeLengths(
indexes_ge_threshold=indexes_ge_threshold,
)
tensors = paddle.split(token_ids_ge_threshold, consecutive_index_range_lengths)

def GetSymbolsValuesBodyTriple(tensor):
if tensor.shape[0] == 1:
return [], [], tensor
new_token_id = self.id_allocator.NewTokenId()
return (
[new_token_id],
[tensor],
paddle.to_tensor([new_token_id], paddle.int64),
)

symbols_values_body_triples = [
GetSymbolsValuesBodyTriple(tensor) for tensor in tensors
]
return True, LetsListTokenRpExpr(
symbol_token_ids=[
token_id
for new_token_ids, _, _ in symbols_values_body_triples
for token_id in new_token_ids
],
symbol_token_tensors=[
token_tensor
for _, token_tensors, _ in symbols_values_body_triples
for token_tensor in token_tensors
],
body_rp_expr=[
body_tensor for _, _, body_tensor in symbols_values_body_triples
],
)

def GetIndexesGeThreshold(self, token_tensor: np.ndarray["N", np.int64]):
(indexes,) = paddle.where(token_tensor >= self.threshold_start_token_id)
return indexes.reshape([-1])

def GetConsecutiveIndexRangeLengths(self, indexes_ge_threshold):
groups = self.GetNumpyConsecutiveGroups(indexes_ge_threshold.numpy())
return [group.shape[0] for group in groups]

# reference: https://stackoverflow.com/questions/7352684/how-to-find-the-groups-of-consecutive-elements-in-a-numpy-array
def GetNumpyConsecutiveGroups(self, data, stepsize=1):
return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1)
60 changes: 56 additions & 4 deletions tests/test_rp_expr_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
FlattenTokenListPass,
FoldTokensPass,
RecursiveFoldTokensPass,
FoldIfTokenIdGreatEqualPass,
)
from athena.rp_expr.rp_expr_parser import RpExprParser


class TestTokenize(unittest.TestCase):

def test_simple(self):
primitive_id_lists = [list(range(10 + i)) for i in range(5)]
token_list, id_allocator = Tokenize(primitive_id_lists)
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
self.assertEqual(len(token_list.tensors), len(primitive_id_lists))


Expand All @@ -21,7 +23,7 @@ def test_simple(self):
base = 10
size = 5
primitive_id_lists = [list(range(base + i)) for i in range(size)]
token_list, id_allocator = Tokenize(primitive_id_lists)
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
rp_expr_pass = FlattenTokenListPass(id_allocator)
success, flattened_rp_expr_pass = rp_expr_pass(token_list)
self.assertTrue(success)
Expand All @@ -34,7 +36,7 @@ def test_simple(self):
base = 3
size = 3
primitive_id_lists = [list(range(base + i)) for i in range(size)]
token_list, id_allocator = Tokenize(primitive_id_lists)
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
flatten_pass = FlattenTokenListPass(id_allocator)
_, flattened_rp_expr = flatten_pass(token_list)
fold_pass = FoldTokensPass(id_allocator)
Expand All @@ -56,11 +58,13 @@ def test_simple(self):
base = 3
size = 3
primitive_id_lists = [list(range(base + i)) for i in range(size)]
token_list, id_allocator = Tokenize(primitive_id_lists)
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
flatten_pass = FlattenTokenListPass(id_allocator)
_, flattened_rp_expr = flatten_pass(token_list)
print("before recursive next_token_id:", id_allocator.NextTokenId())
fold_pass = RecursiveFoldTokensPass(id_allocator)
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
print("after recursive next_token_id:", id_allocator.NextTokenId())
self.assertTrue(success)
input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist()
pattern = [x.numpy().tolist() for x in fold_rp_expr.symbol_token_tensors]
Expand All @@ -72,5 +76,53 @@ def test_simple(self):
self.assertEqual(output, [8, 1, 9, 2, 9, 7])


class TestFoldIfTokenIdGreatEqualPass(unittest.TestCase):

def test_simple(self):
base = 3
size = 3
primitive_id_lists = [list(range(base + i)) for i in range(size)]
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
flatten_pass = FlattenTokenListPass(id_allocator)
_, flattened_rp_expr = flatten_pass(token_list)
fold_pass = RecursiveFoldTokensPass(id_allocator)
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
self.assertTrue(success)
threshold_fold_pass = FoldIfTokenIdGreatEqualPass(
id_allocator=id_allocator,
threshold_start_token_id=len(primitive_id_lists),
)
success, threshold_fold_rp_expr = threshold_fold_pass(fold_rp_expr.body_rp_expr)
self.assertTrue(success)
input = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
pattern = [
x.numpy().tolist() for x in threshold_fold_rp_expr.symbol_token_tensors
]
replacement = threshold_fold_rp_expr.symbol_token_ids
self.assertEqual(len(threshold_fold_rp_expr.body_rp_expr), 3)
output = [x.numpy().tolist() for x in threshold_fold_rp_expr.body_rp_expr]
self.assertEqual(input, [8, 1, 9, 2, 9, 7])
self.assertEqual(pattern, [[9, 7]])
self.assertEqual(replacement, [10])
self.assertEqual(output, [[8], [9], [10]])


class TestRpExprParser(unittest.TestCase):

def test_simple(self):
base = 3
size = 3
primitive_id_lists = [list(range(base + i)) for i in range(size)]
parser = RpExprParser()
lets_list_rp_expr, token_id2primitive_id = parser(primitive_id_lists)
pattern = [x.numpy().tolist() for x in lets_list_rp_expr.symbol_token_tensors]
replacement = lets_list_rp_expr.symbol_token_ids
output = [x.numpy().tolist() for x in lets_list_rp_expr.body_rp_expr]
self.assertEqual(pattern, [[0, 1, 2], [5, 3], [6, 4]])
self.assertEqual(replacement, [5, 6, 7])
self.assertEqual(output, [[5], [6], [7]])
self.assertEqual(token_id2primitive_id, [0, 1, 2, 3, 4])


if __name__ == "__main__":
unittest.main()

0 comments on commit 6df9edb

Please sign in to comment.