Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved training raw string chunking logic #3476

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,18 +319,20 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch

def encode(text, add_bos_token):
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)

# Check if the first two tokens are BOS
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
result = result[1:]

if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
result = result[1:]

return result

def tokenize(prompt, append_eos_token=False):
def tokenize(prompt, append_eos_token=False, prepend_bos_token=True):

if train_only_after == '' or train_only_after not in prompt:
input_ids = encode(prompt, True)
input_ids = encode(prompt, prepend_bos_token)

if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
input_ids.append(shared.tokenizer.eos_token_id)
Expand All @@ -340,7 +342,7 @@ def tokenize(prompt, append_eos_token=False):

else:
ind = prompt.index(train_only_after) + len(train_only_after)
before_tokens = encode(prompt[:ind], True)
before_tokens = encode(prompt[:ind], prepend_bos_token)
after_tokens = encode(prompt[ind:], False)

if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
Expand Down Expand Up @@ -385,8 +387,13 @@ def tokenize(prompt, append_eos_token=False):
raw_text = file.read().replace('\r', '')

cut_string = hard_cut_string.replace('\\n', '\n')
newline_token = set(shared.tokenizer.encode('\n')[1:])

eos_added = 0
out_tokens = []
if add_eos_token and shared.tokenizer.eos_token_id == shared.tokenizer.bos_token_id:
logger.warning("EOS and BOS tokens are identical when adding EOS tokens. Check model config.")

for text_part in raw_text.split(cut_string):
if len(text_part.strip()) <= min_chars:
continue
Expand All @@ -401,19 +408,16 @@ def tokenize(prompt, append_eos_token=False):
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
return

out_tokens.extend(split_chunks(tokens, cutoff_len, step))
out_tokens.extend(split_chunks(tokens, cutoff_len, step, newline_favor_len, newline_token))

if eos_added > 0:
print(f"EOS added to {eos_added} text blocks")
logger.info(f"EOS token added to {eos_added} text blocks")

del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]

train_data = Dataset.from_list(out_tokens)
del out_tokens
if newline_favor_len > 0:
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]

train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
del text_chunks
eval_data = None
else:
if dataset in ['None', '']:
Expand Down Expand Up @@ -700,28 +704,44 @@ def threaded_run():
yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training."


def split_chunks(arr, size, step):
for i in range(0, len(arr), step):
yield arr[i:i + size]


def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk

first_newline = chunk.index('\n')
if first_newline < max_length:
chunk = chunk[first_newline + 1:]

if '\n' not in chunk:
return chunk

last_newline = chunk.rindex('\n')
if len(chunk) - last_newline < max_length:
chunk = chunk[:last_newline]
def split_chunks(arr, size: int, step: int, max_newline_length: int, newline_tokens: set):
while arr and arr[1] in newline_tokens: # Skip the first token, which will be <BOS>
del arr[1]
while arr and arr[-1] in newline_tokens:
del arr[-1]
num_tokens = len(arr)
split_end = num_tokens - size + step # Don't split in the last overlap
if split_end < 0:
split_end = num_tokens

split_starts = list(range(0, split_end, step))
for index in range(1, len(split_starts)): # First split always starts at 0
if split_starts[index] + size > num_tokens:
split_starts[index] = num_tokens - size + 1

if max_newline_length > 0 and ( newline_tokens.intersection(arr[split_starts[index]:split_starts[index] + max_newline_length])):
first_newline = end_first_block(arr[split_starts[index]:],newline_tokens)
split_starts[index] += first_newline

labels = [1] * size
for i in split_starts:
input_ids = arr[i:i + size]
input_ids = [shared.tokenizer.pad_token_id] * (size - len(input_ids)) + input_ids
input_ids = torch.tensor(input_ids)
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": input_ids.ne(shared.tokenizer.pad_token_id),
}

return chunk

def end_first_block(arr:list, tokens: set):
offset = 0
while arr[offset] not in tokens:
offset += 1
while arr[offset] in tokens:
offset += 1
return offset

def format_time(seconds: float):
if seconds < 120:
Expand Down