Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_wikitext2 has bug #2020

Open
4 tasks
alex-ber opened this issue Sep 10, 2024 · 2 comments
Open
4 tasks

get_wikitext2 has bug #2020

alex-ber opened this issue Sep 10, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@alex-ber
Copy link

alex-ber commented Sep 10, 2024

System Info

optimum version 1.21.4 (latest)
# Use the official Python image from the Docker Hub
FROM public.ecr.aws/docker/library/python:3.10-slim

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

from optimum.gptq.data import get_wikitext2
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
get_wikitext2(tokenizer=tokenizer, nsamples=128, seqlen=32, split="train")

Produce warning:

Token indices sequence length is longer than the specified maximum sequence length for this model (73218 > 2048). Running this sequence through the model will result in indexing errors

Expected behavior

This is proposed fix:

def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
    if split == "train":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    elif split == "validation":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    ## length of 288059 should be enough
    #text = "".join([" \n" if s == "" else s for s in data["text"][:1000]])

    dataset = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(data) - 1)
            text = data[i]["text"]
            if len(tokenizer.tokenize(text)) >= seqlen:
                enc = tokenizer(text, return_tensors="pt")
                break
        i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = enc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)
        dataset.append({"input_ids": inp, "attention_mask": attention_mask})
    return dataset

Inspired by get_c4`` and get_c4_new```.

No warning is produced.

@alex-ber alex-ber added the bug Something isn't working label Sep 10, 2024
@alex-ber alex-ber changed the title get_wikitext2 has bug get_wikitext2, get_c4, get_c4_new has bug Sep 10, 2024
@alex-ber alex-ber changed the title get_wikitext2, get_c4, get_c4_new has bug get_wikitext2 has bug Sep 10, 2024
@IlyasMoutawwakil
Copy link
Member

@SunMarc is there a reason why get_wikitext2 is different than the other methods ?

@SunMarc
Copy link
Member

SunMarc commented Sep 11, 2024

Not sure. This was something TheBloke coded back then.Maybe this is because data[i]["text"] is pretty long so it takes to while to find a text < seqlen ?

Token indices sequence length is longer than the specified maximum sequence length for this model (73218 > 2048). Running this sequence through the model will result in indexing errors

This does not happen as we are slicing the tokenized data after:

        i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = enc.input_ids[:, i:j]
        attention_mask = torch.ones_like(inp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants