diff --git a/athena/generators/paddle_func_body_generator.py b/athena/generators/paddle_func_body_generator.py index a415cde..26a7e51 100644 --- a/athena/generators/paddle_func_body_generator.py +++ b/athena/generators/paddle_func_body_generator.py @@ -118,7 +118,7 @@ def get_stmts_pd_op_while( f"ATHENA_WHILE_LOOP_LIMIT = os.getenv('ATHENA_WHILE_LOOP_LIMIT')" ), self.Indent0( - f"kWhileLoopLimit = (128 if ATHENA_WHILE_LOOP_LIMIT is not None else int(ATHENA_WHILE_LOOP_LIMIT))" + f"kWhileLoopLimit = (128 if ATHENA_WHILE_LOOP_LIMIT is None else int(ATHENA_WHILE_LOOP_LIMIT))" ), self.Indent0(f"while_loop_counter_{op.op_id} = 0"), self.Indent0(f"while {cond.name}:"), diff --git a/athena/rp_expr/rp_expr.py b/athena/rp_expr/rp_expr.py index 1e24c52..0c0212e 100644 --- a/athena/rp_expr/rp_expr.py +++ b/athena/rp_expr/rp_expr.py @@ -1,8 +1,12 @@ from dataclasses import dataclass import typing as t import numpy as np +import paddle +from collections import defaultdict -PatternId = int +PrimitiveId = t.TypeVar("PrimitiveId") + +TokenId = int # Repeat Pattern Expression @@ -12,43 +16,70 @@ class RpExpr: @dataclass -class PrimitiveRpExpr(RpExpr): - primitive_tensor: np.ndarray["N", np.int32] +class ListRpExpr(RpExpr): + pass @dataclass -class FoldRpExpr(RpExpr): - fold_tensor: np.ndarray["N", np.int32] +class TokenizedRpExpr(RpExpr): + token_id2primitive_id: t.List[PrimitiveId] + token_tensors: ListRpExpr @dataclass -class LetsListRpExpr(RpExpr): - symbol_pattern_ids: t.List[PatternId] - symbol_rp_exprs: t.Union[t.List[PrimitiveRpExpr], "LetsListRpExpr"] - body_rp_expr: t.List[FoldRpExpr] +class TokenRpExpr(RpExpr): + pass @dataclass -class LetsRpExpr(RpExpr): - symbol_pattern_ids: t.List[PatternId] - symbol_rp_exprs: t.Union[t.List[PrimitiveRpExpr], LetsListRpExpr] - body_rp_expr: t.Union[FoldRpExpr, "LetsRpExpr"] +class NaiveTokenListRpExpr(ListRpExpr): + tensors: t.List[np.ndarray["N", np.int64]] -class PatternIdAllocator: - def __init__(self, next_pattern_id: int): - self.next_pattern_id = next_pattern_id +@dataclass +class FlattenedTokenListRpExpr(ListRpExpr): + tensor_list_size: int + flattened_tensor: TokenRpExpr - def NewPatternId(self): - value = self.next_pattern_id - self.next_pattern_id += 1 - return value +@dataclass +class NaiveTokenRpExpr(TokenRpExpr): + tensor: np.ndarray["N", np.int64] + + +@dataclass +class LetsTokenRpExpr(TokenRpExpr): + symbol_token_ids: t.List[TokenId] + symbol_token_tensors: t.List[NaiveTokenRpExpr] + body_rp_expr: NaiveTokenRpExpr + + +class TokenIdAllocator: + def __init__(self, next_token_id: int = 0): + self.next_token_id = next_token_id + + def NewTokenId(self): + value = self.next_token_id + self.next_token_id += 1 + return value -def TrivialParse( - primitive_ids: t.List[int], -) -> t.Tuple[PrimitiveRpExpr, PatternIdAllocator]: - primitive_tensor = np.array(primitive_ids, dtype=np.int32) - min_pattern_id = int(np.min(primitive_tensor)) - assert min_pattern_id >= 0 - return PrimitiveRpExpr(primitive_tensor), PatternIdAllocator(min_pattern_id + 1) + def NextTokenId(self): + return self.next_token_id + + def Skip(self, size): + self.next_token_id += size + + +def Tokenize( + primitive_id_lists: t.List[t.List[PrimitiveId]], +) -> t.Tuple[TokenizedRpExpr, TokenIdAllocator]: + token_id_allocator = TokenIdAllocator() + primitive_id2token_id = defaultdict(token_id_allocator.NewTokenId) + token_tensors = [ + paddle.to_tensor( + [primitive_id2token_id[primitive_id] for primitive_id in primitive_id_list], + paddle.int64, + ) + for primitive_id_list in primitive_id_lists + ] + return NaiveTokenListRpExpr(token_tensors), token_id_allocator diff --git a/athena/rp_expr/rp_expr_passes.py b/athena/rp_expr/rp_expr_passes.py index 1a6e1c3..2baa9f5 100644 --- a/athena/rp_expr/rp_expr_passes.py +++ b/athena/rp_expr/rp_expr_passes.py @@ -3,190 +3,204 @@ import numpy as np import re import itertools - -PrimitiveId = int +import paddle +import paddle.nn.functional as F +import math +from athena.rp_expr.rp_expr import ( + TokenIdAllocator, + NaiveTokenListRpExpr, + FlattenedTokenListRpExpr, + NaiveTokenRpExpr, + LetsTokenRpExpr, +) +import itertools -@dataclass -class PatternTree: +class Pass: pass -@dataclass -class PrimitivePattern(PatternTree): - value: PrimitiveId - - -@dataclass -class TrivialPattern(PatternTree): - children: t.List[PrimitivePattern] +class FlattenTokenListPass(Pass): + def __init__(self, id_allocator: TokenIdAllocator): + self.id_allocator = id_allocator + def __call__(self, token_tensors_rp_expr: NaiveTokenListRpExpr): + tensor_list_size = len(token_tensors_rp_expr.tensors) + self.id_allocator.Skip(tensor_list_size) -@dataclass -class RepeatedPattern: - repeated_count: int - value: PatternTree + def GetSepTensor(i): + if i == 0: + return [] + return [paddle.to_tensor([i], paddle.int64)] - -@dataclass -class TuplePattern(PatternTree): - children: t.List[RepeatedPattern] + token_tensors = [ + tensor + for i, token_tensor in enumerate(token_tensors_rp_expr.tensors) + for tensor in GetSepTensor(i) + [token_tensor + tensor_list_size] + ] + return True, FlattenedTokenListRpExpr( + tensor_list_size=tensor_list_size, + flattened_tensor=NaiveTokenRpExpr( + tensor=paddle.concat(token_tensors, axis=0), + ), + ) -class RepeatPatternsParser: - def __init__( - self, - window_size: int = 4096, - ): - self.window_size = window_size +class FoldTokensPass(Pass): + def __init__(self, id_allocator: TokenIdAllocator): + self.max_windows_size = 64 + self.id_allocator = id_allocator + size = id_allocator.NextTokenId() + self.embedding = paddle.uniform([size], dtype="float64", min=-1, max=1) + self.embedding.stop_gradient = False - def Parse( - self, - primitive_ids: t.List[PrimitiveId], - ) -> t.Optional[PatternTree]: - if len(primitive_ids) == 0: - return None - kThreshold = 4096 - loop_count = itertools.count() - pattern_tree = self.ConvertPrimitiveIdsToPatternTree(primitive_ids) - while True: - if next(loop_count) >= kThreshold: - raise RuntimeError("Dead loop detected.") - (pattern_tree_from_trivial, replace_ctx) = ( - self.ConvertPrimitivePatternsToPatternTree(pattern_tree) - ) - if isinstance(pattern_tree_from_trivial, TrivialPattern): - break - pattern_tree = self.ReplaceTrivialPatterns( - pattern_tree, pattern_tree_from_trivial, replace_ctx - ) + def __call__(self, token_tensor: NaiveTokenRpExpr): + input_tensor = token_tensor.tensor + most_frequent_length, indexes = self.GetMostFrequentPatternLengthAndIndexes( + input_tensor + ) + new_token_id = self.id_allocator.NewTokenId() + success, replacement = self.Replace( + pattern_length=most_frequent_length, + indexes=indexes, + new_token_id=new_token_id, + input_tensor=input_tensor, + ) + if not success: + return False, token_tensor + start = indexes[0] + return True, LetsTokenRpExpr( + symbol_token_ids=[new_token_id], + symbol_token_tensors=[input_tensor[start : (start + most_frequent_length)]], + body_rp_expr=NaiveTokenRpExpr(tensor=replacement), + ) - def ConvertPrimitiveIdsToPatternTree( + def Replace( self, - primitive_ids: t.List[PrimitiveId], - ) -> t.Optional[PatternTree]: - if len(primitive_ids) == 0: - return None - ctx = ParserCtx.InitFromPrimitiveIds(primitive_ids) - while len(ctx.pattern_ids) > 1: - ctx = ctx.ReduceFrequentest( - window_size=self.window_size, - ) - return ctx.id2pattern_tree[int(ctx.pattern_ids[0])] - - -@dataclass -class ParserCtx: - pattern_ids: np.ndarray[("N",), np.int32] - id2pattern_tree: t.Dict[int, PatternTree] - make_new_pattern_id: Callable[[], int] - - @classmethod - def InitFromPrimitiveIds(cls, primitive_ids: t.List[PrimitiveId]): - primitive2pattern_id = {} - id2pattern_tree = {} - max_pattern_id = 0 - - def GetOrCreatePatternId(primitive_id: PrimitiveId): - nonlocal max_pattern_id - if primitive_id not in primitive2pattern_id: - new_pattern_id = max_pattern_id - primitive2pattern_id[primitive_id] = new_pattern_id - id2pattern_tree[new_pattern_id] = PrimitivePattern(primitive_id) - max_pattern_id += 1 - return primitive2pattern_id[primitive_id] - - pattern_ids = [GetOrCreatePatternId(i) for i in primitive_ids] - return ParserCtx( - pattern_ids=np.array(pattern_ids, dtype=np.int32), - id2pattern_tree=id2pattern_tree, - make_new_pattern_id=lambda pattern_ids: np.max(pattern_ids) + 1, + pattern_length, + indexes, + new_token_id, + input_tensor: np.ndarray["N", np.int64], + ) -> t.Tuple[bool, np.ndarray["N", np.int64]]: + new_token_tensor = paddle.to_tensor([new_token_id], paddle.int64) + num_tokens = input_tensor.shape[0] + if pattern_length == 1: + return False, input_tensor + assert indexes.shape[0] > 0 + disjoint_range_starts = [ + start + for start in self.GetDisjoint(pattern_length, indexes.numpy().tolist()) + ] + if len(disjoint_range_starts) <= 1: + return False, input_tensor + assert disjoint_range_starts[-1] + pattern_length <= num_tokens + first_start = disjoint_range_starts[0] + pattern_tensor = input_tensor[first_start : (first_start + pattern_length)] + segment_starts = ( + [0] + + [ + index + for start in disjoint_range_starts + for index in [start, start + pattern_length] + ] + + [num_tokens] ) + uniqued_segment_starts = paddle.unique(paddle.to_tensor(segment_starts)) + segment_lengths = paddle.diff(uniqued_segment_starts).numpy().tolist() + + def ReplaceTensor(tensor): + if tensor.shape != pattern_tensor.shape: + return tensor + if bool(paddle.all(tensor == pattern_tensor)): + return new_token_tensor + return tensor + + replaced_segment_tensors = [ + ReplaceTensor(tensor) + for tensor in paddle.split(input_tensor, segment_lengths) + ] + output_tensor = paddle.concat(replaced_segment_tensors) + return True, output_tensor - def ReduceFrequentest( + def GetConv(self, num_tokens): + windows_size = min(num_tokens, self.max_windows_size) + weight = paddle.uniform( + [windows_size, windows_size], dtype="float64", min=-1, max=1 + ) + weight.stop_gradient = False + weight_shape = (windows_size, 1, windows_size) + conv_weight = paddle.triu(weight).transpose([1, 0]).reshape(weight_shape) + conv = lambda input: F.conv1d(input, conv_weight, padding="VALID") + return conv, windows_size + + def GetDisjoint(self, gap, indexes): + if len(indexes) == 0: + return + last = indexes[0] + yield last + for current in indexes: + if current >= (last + gap): + yield current + last = current + + def GetMostFrequentPatternLengthAndIndexes( self, - window_size, - ) -> ParserCtx: - repeated_pattern_ids = self.FindRepeatedPatternIds(window_size) - if repeated_pattern_ids is None: - # take all pattern_ids as repeated. - repeated_pattern_ids = self.pattern_ids - assert len(repeated_pattern_ids.shape) == 1 - if repeated_pattern_ids.shape[0] == 1: - # take all pattern_ids as repeated. - repeated_pattern_ids = self.pattern_ids - new_pattern_id = self.make_new_pattern_id(self.pattern_ids) - id2pattern_tree = self.UpdateId2PatternTree( - new_pattern_id=new_pattern_id, - repeated_pattern_ids=repeated_pattern_ids, + token_tensor: np.ndarray["N", np.int64], + ): + conv, windows_size = self.GetConv(num_tokens=token_tensor.shape[0]) + input = paddle.gather(self.embedding, token_tensor) + input.stop_gradient = False + zeros = paddle.zeros([windows_size - 1], paddle.float64) + input = paddle.concat([input, zeros]) + input = input.reshape((1, 1, -1)) + y = conv(input) + y = y.reshape((windows_size, -1)) + y_hash = y.view(paddle.int64) + hash_weight = paddle.arange(windows_size).reshape((-1, 1)).expand(y_hash.shape) + weighted_y_hash = paddle.concat( + [hash_weight.reshape((-1, 1)), y_hash.reshape((-1, 1))], axis=1 ) - new_pattern_ids = self.ReplacePatternId( - new_pattern_id, - repeated_pattern_ids, + unique_weighted_hash, counts = paddle.unique( + weighted_y_hash, axis=0, return_counts=True ) - return ParserCtx( - pattern_ids=new_pattern_ids, - id2pattern_tree=id2pattern_tree, - make_new_pattern_id=lambda pattern_ids: np.max(pattern_ids) + 1, + most_frequent_hash_idx = paddle.argmax( + unique_weighted_hash[:, 0] * (counts - 1) ) - - def FindRepeatedPatternIds(self, window_size): - assert len(self.pattern_ids.shape) == 1 - window = self.pattern_ids[0:window_size,] - window_size = window.shape[0] - diff = window.reshape((1, window_size)) - window.reshape((window_size, 1)) - is_equal = diff == 0 - not_equal_myself = np.logical_not(np.eye(window_size, dtype=np.bool)) - xs, ys = np.where(is_equal and not_equal_myself) - if xs.shape[0] == 0: - return None - intervals = xs - ys - interval_values, interval_counts = np.unique( - np.abs(intervals), return_counts=True + most_frequent_hash = int(unique_weighted_hash[most_frequent_hash_idx, 1]) + most_frequent_hash_weight = int(unique_weighted_hash[most_frequent_hash_idx, 0]) + (indexes,) = paddle.where( + most_frequent_hash == y_hash[most_frequent_hash_weight, :] + ) + indexes = indexes.reshape((-1,)) + return most_frequent_hash_weight + 1, indexes + + +class RecursiveFoldTokensPass(Pass): + def __init__(self, id_allocator: TokenIdAllocator): + self.id_allocator = id_allocator + + def __call__(self, token_tensor: NaiveTokenRpExpr): + success, ret = FoldTokensPass(self.id_allocator)(token_tensor) + if not success: + return False, token_tensor + symbol_token_ids = ret.symbol_token_ids + symbol_token_tensors = ret.symbol_token_tensors + token_tensor = ret.body_rp_expr + counter = itertools.count() + kLimit = 9999999 + while True: + success, ret = FoldTokensPass(self.id_allocator)(token_tensor) + if not success: + token_tensor = ret + break + assert ret.body_rp_expr.tensor.shape[0] < token_tensor.tensor.shape[0] + if next(counter) > kLimit: + raise RuntimeError("dead loop detected.") + symbol_token_ids += ret.symbol_token_ids + symbol_token_tensors += ret.symbol_token_tensors + token_tensor = ret.body_rp_expr + return True, LetsTokenRpExpr( + symbol_token_ids=symbol_token_ids, + symbol_token_tensors=symbol_token_tensors, + body_rp_expr=token_tensor, ) - frequentest_interval_idx = np.argmax(interval_counts) - frequentest_interval_value = interval_values[frequentest_interval_idx] - (pointer_indexes,) = np.where(intervals == frequentest_interval_value) - pointer_xs = np.take(xs, pointer_indexes) - consecutive_index_groups = GetConsecutiveGroups(pointer_xs) - group_lens = np.array([x.shape[0] for x in consecutive_index_groups]) - longest_group_idx = np.argmax(group_lens) - longest_group = consecutive_index_groups[longest_group_idx] - repeated_pattern_ids = np.take(window, longest_group) - if frequentest_interval_value > 1: - repeated_pattern_ids = repeated_pattern_ids[0:frequentest_interval_value] - return repeated_pattern_ids - - # reference: https://stackoverflow.com/questions/7352684/how-to-find-the-groups-of-consecutive-elements-in-a-numpy-array - def GetConsecutiveGroups(self, data, stepsize=1): - return np.split(data, np.where(np.diff(data) != stepsize)[0] + 1) - - def UpdateId2PatternTree(self, new_pattern_id, repeated_pattern_ids): - children = [ - self.id2pattern_tree[pattern_id] - for pattern_id in repeated_pattern_ids.tolist() - ] - if all(isinstance(x, PrimitivePattern) for x in children): - new_pattern = TrivialPattern(children=children) - else: - new_pattern = TuplePattern(children=children) - self.id2pattern_tree[new_pattern_id] = new_pattern - return self.id2pattern_tree - - def ReplacePatternId( - self, - new_pattern_id, - repeated_pattern_ids, - ): - ids = self.pattern_ids.tolist() - ids_str = "".join(f"{num:08d}" for num in ids) - repeated_pattern_ids = repeated_pattern_ids.tolist() - repeated_pattern_ids_str = "".join(f"{num:08d}" for num in repeated_pattern_ids) - new_pattern_id_str = f"{new_pattern_id:08d}" - replaced_ids_str = re.sub(repeated_pattern_ids_str, new_pattern_id_str, ids_str) - assert len(replaced_ids_str) % 8 == 0 - replaced_ids = [ - int(replaced_ids_str[i : (i + 8)]) - for i in range(start=0, stop=len(replaced_ids_str), step=8) - ] - return np.array(replaced_ids, dtype=np.int32) diff --git a/athena/version.py b/athena/version.py index 811adb0..5d70c5c 100644 --- a/athena/version.py +++ b/athena/version.py @@ -3,6 +3,7 @@ TYPE_CHECKING = False if TYPE_CHECKING: from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] else: VERSION_TUPLE = object @@ -12,5 +13,5 @@ __version_tuple__: VERSION_TUPLE version_tuple: VERSION_TUPLE -__version__ = version = '0.1.dev54+g64471b3.d20240712' -__version_tuple__ = version_tuple = (0, 1, 'dev54', 'g64471b3.d20240712') +__version__ = version = "0.1.dev54+g64471b3.d20240712" +__version_tuple__ = version_tuple = (0, 1, "dev54", "g64471b3.d20240712") diff --git a/setup.py b/setup.py index a0cb9ea..67dabbc 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ python_version = platform.python_version() version_detail = sys.version_info -version = str(version_detail[0]) + '.' + str(version_detail[1]) +version = str(version_detail[0]) + "." + str(version_detail[1]) env_version = os.getenv("PY_VERSION", None) if version_detail < (3, 8): diff --git a/tests/test_rp_expr_passes.py b/tests/test_rp_expr_passes.py new file mode 100644 index 0000000..f8661d4 --- /dev/null +++ b/tests/test_rp_expr_passes.py @@ -0,0 +1,76 @@ +import unittest +from athena.rp_expr.rp_expr import Tokenize +from athena.rp_expr.rp_expr_passes import ( + FlattenTokenListPass, + FoldTokensPass, + RecursiveFoldTokensPass, +) + + +class TestTokenize(unittest.TestCase): + + def test_simple(self): + primitive_id_lists = [list(range(10 + i)) for i in range(5)] + token_list, id_allocator = Tokenize(primitive_id_lists) + self.assertEqual(len(token_list.tensors), len(primitive_id_lists)) + + +class TestFlattenTokenListPass(unittest.TestCase): + + def test_simple(self): + base = 10 + size = 5 + primitive_id_lists = [list(range(base + i)) for i in range(size)] + token_list, id_allocator = Tokenize(primitive_id_lists) + rp_expr_pass = FlattenTokenListPass(id_allocator) + success, flattened_rp_expr_pass = rp_expr_pass(token_list) + self.assertTrue(success) + self.assertEqual(id_allocator.NextTokenId(), base + 2 * size - 1) + + +class TestFoldTokensPass(unittest.TestCase): + + def test_simple(self): + base = 3 + size = 3 + primitive_id_lists = [list(range(base + i)) for i in range(size)] + token_list, id_allocator = Tokenize(primitive_id_lists) + flatten_pass = FlattenTokenListPass(id_allocator) + _, flattened_rp_expr = flatten_pass(token_list) + fold_pass = FoldTokensPass(id_allocator) + success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor) + self.assertTrue(success) + input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist() + pattern = fold_rp_expr.symbol_token_tensors[0].numpy().tolist() + replacement = fold_rp_expr.symbol_token_ids[0] + output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist() + self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7]) + self.assertEqual(pattern, [3, 4, 5]) + self.assertEqual(replacement, 8) + self.assertEqual(output, [8, 1, 8, 6, 2, 8, 6, 7]) + + +class TestRecursiveFoldTokensPass(unittest.TestCase): + + def test_simple(self): + base = 3 + size = 3 + primitive_id_lists = [list(range(base + i)) for i in range(size)] + token_list, id_allocator = Tokenize(primitive_id_lists) + flatten_pass = FlattenTokenListPass(id_allocator) + _, flattened_rp_expr = flatten_pass(token_list) + fold_pass = RecursiveFoldTokensPass(id_allocator) + success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor) + self.assertTrue(success) + input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist() + pattern = [x.numpy().tolist() for x in fold_rp_expr.symbol_token_tensors] + replacement = fold_rp_expr.symbol_token_ids + output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist() + self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7]) + self.assertEqual(pattern, [[3, 4, 5], [8, 6]]) + self.assertEqual(replacement, [8, 9]) + self.assertEqual(output, [8, 1, 9, 2, 9, 7]) + + +if __name__ == "__main__": + unittest.main()