Skip to content

Commit

Permalink
Merge pull request #5 from tc20042008/develop
Browse files Browse the repository at this point in the history
Repeat Pattern Extractor
  • Loading branch information
jiahy0825 authored Jul 26, 2024
2 parents 00a7a4f + 19e58dc commit 3f051aa
Show file tree
Hide file tree
Showing 6 changed files with 320 additions and 198 deletions.
2 changes: 1 addition & 1 deletion athena/generators/paddle_func_body_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_stmts_pd_op_while(
f"ATHENA_WHILE_LOOP_LIMIT = os.getenv('ATHENA_WHILE_LOOP_LIMIT')"
),
self.Indent0(
f"kWhileLoopLimit = (128 if ATHENA_WHILE_LOOP_LIMIT is not None else int(ATHENA_WHILE_LOOP_LIMIT))"
f"kWhileLoopLimit = (128 if ATHENA_WHILE_LOOP_LIMIT is None else int(ATHENA_WHILE_LOOP_LIMIT))"
),
self.Indent0(f"while_loop_counter_{op.op_id} = 0"),
self.Indent0(f"while {cond.name}:"),
Expand Down
85 changes: 58 additions & 27 deletions athena/rp_expr/rp_expr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from dataclasses import dataclass
import typing as t
import numpy as np
import paddle
from collections import defaultdict

PatternId = int
PrimitiveId = t.TypeVar("PrimitiveId")

TokenId = int


# Repeat Pattern Expression
Expand All @@ -12,43 +16,70 @@ class RpExpr:


@dataclass
class PrimitiveRpExpr(RpExpr):
primitive_tensor: np.ndarray["N", np.int32]
class ListRpExpr(RpExpr):
pass


@dataclass
class FoldRpExpr(RpExpr):
fold_tensor: np.ndarray["N", np.int32]
class TokenizedRpExpr(RpExpr):
token_id2primitive_id: t.List[PrimitiveId]
token_tensors: ListRpExpr


@dataclass
class LetsListRpExpr(RpExpr):
symbol_pattern_ids: t.List[PatternId]
symbol_rp_exprs: t.Union[t.List[PrimitiveRpExpr], "LetsListRpExpr"]
body_rp_expr: t.List[FoldRpExpr]
class TokenRpExpr(RpExpr):
pass


@dataclass
class LetsRpExpr(RpExpr):
symbol_pattern_ids: t.List[PatternId]
symbol_rp_exprs: t.Union[t.List[PrimitiveRpExpr], LetsListRpExpr]
body_rp_expr: t.Union[FoldRpExpr, "LetsRpExpr"]
class NaiveTokenListRpExpr(ListRpExpr):
tensors: t.List[np.ndarray["N", np.int64]]


class PatternIdAllocator:
def __init__(self, next_pattern_id: int):
self.next_pattern_id = next_pattern_id
@dataclass
class FlattenedTokenListRpExpr(ListRpExpr):
tensor_list_size: int
flattened_tensor: TokenRpExpr

def NewPatternId(self):
value = self.next_pattern_id
self.next_pattern_id += 1
return value

@dataclass
class NaiveTokenRpExpr(TokenRpExpr):
tensor: np.ndarray["N", np.int64]


@dataclass
class LetsTokenRpExpr(TokenRpExpr):
symbol_token_ids: t.List[TokenId]
symbol_token_tensors: t.List[NaiveTokenRpExpr]
body_rp_expr: NaiveTokenRpExpr


class TokenIdAllocator:
def __init__(self, next_token_id: int = 0):
self.next_token_id = next_token_id

def NewTokenId(self):
value = self.next_token_id
self.next_token_id += 1
return value

def TrivialParse(
primitive_ids: t.List[int],
) -> t.Tuple[PrimitiveRpExpr, PatternIdAllocator]:
primitive_tensor = np.array(primitive_ids, dtype=np.int32)
min_pattern_id = int(np.min(primitive_tensor))
assert min_pattern_id >= 0
return PrimitiveRpExpr(primitive_tensor), PatternIdAllocator(min_pattern_id + 1)
def NextTokenId(self):
return self.next_token_id

def Skip(self, size):
self.next_token_id += size


def Tokenize(
primitive_id_lists: t.List[t.List[PrimitiveId]],
) -> t.Tuple[TokenizedRpExpr, TokenIdAllocator]:
token_id_allocator = TokenIdAllocator()
primitive_id2token_id = defaultdict(token_id_allocator.NewTokenId)
token_tensors = [
paddle.to_tensor(
[primitive_id2token_id[primitive_id] for primitive_id in primitive_id_list],
paddle.int64,
)
for primitive_id_list in primitive_id_lists
]
return NaiveTokenListRpExpr(token_tensors), token_id_allocator
Loading

0 comments on commit 3f051aa

Please sign in to comment.