-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-chunking.py
83 lines (63 loc) · 1.87 KB
/
test-chunking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List
from tiktoken import get_encoding
import sys
# def count_tokens(text: str) -> int:
# # Placeholder implementation
# return len(text.split())
def count_tokens(text: str) -> int:
enc = get_encoding("gpt2")
tokens = enc.encode(text)
token_count = len(tokens)
return token_count
def split_text_into_batches(
text: str,
batch_size_in_tokens: int = 10,
) -> List[str]:
lines = text.split("\n")
batches = []
current_batch = ""
current_batch_tokens = 0
index = 1
for line in lines:
line_tokens = count_tokens(line + "\n")
print("Line", index, line_tokens, line)
index += 1
if line_tokens > batch_size_in_tokens:
print(
f"Error: Line exceeds the batch size of {batch_size_in_tokens} tokens."
)
print("Line:", line)
print("Tokens:", line_tokens)
sys.exit(1)
if current_batch_tokens + line_tokens < batch_size_in_tokens:
current_batch += line + "\n"
current_batch_tokens += line_tokens
else:
batches.append(current_batch.strip())
current_batch = line + "\n"
current_batch_tokens = line_tokens
if current_batch.strip():
batches.append(current_batch.strip())
return batches
# Test the split_text_into_batches function
sample_text = """
This is the first line.
This is the second line, which is longer.
Third line.
Fourth line, even longer than the second.
Fifth and final line.
"""
print("Original text:")
print(sample_text)
batch_size = 20
overlap = 2
batches = split_text_into_batches(
sample_text,
batch_size,
)
print(f"\nBatches (batch_size={batch_size}, overlap={overlap}):")
for i, batch in enumerate(batches, start=1):
print(f"Batch {i}:")
print(batch)
print("Tokens:", count_tokens(batch))
print()