Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to load data in trl 0.7.8/0.7.9. #1216

Closed
xkszltl opened this issue Jan 11, 2024 · 16 comments · Fixed by #1229
Closed

Failed to load data in trl 0.7.8/0.7.9. #1216

xkszltl opened this issue Jan 11, 2024 · 16 comments · Fixed by #1229

Comments

@xkszltl
Copy link

xkszltl commented Jan 11, 2024

This is a new regression introduced in trl 0.7.8 (and 0.7.9), 0.7.7 is fine.

We run into issues of ValueError: too many dimensions 'str' when loading data to the trainer.
Here's a simple LLAMA2+LoRA fine-tuning on IMDB dataset as minimal repro:

#!/usr/bin/env python3

import datasets
import peft
import transformers
import trl


model_dir = "models/Llama-2-7b-hf"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = transformers.AutoModelForCausalLM.from_pretrained(model_dir)

ds_train = datasets.load_dataset("imdb", split="train[:10]")

trainer = trl.SFTTrainer(
    model=model,
    args=transformers.TrainingArguments(
        output_dir="output",
        max_steps=1,
        remove_unused_columns=False,
    ),
    peft_config=peft.LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=8,
        bias="none",
        task_type="Causal_LM",
    ),
    train_dataset=ds_train,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=8,
)
trainer.train()

0.7.7 works:

# CUDA_VISIBLE_DEVICES=0 ./test.py 
/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.1.0) or chardet (4.0.0) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.52s/it]
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.24it/s]Attempted to log scalar metric train_runtime:
0.8097
Attempted to log scalar metric train_samples_per_second:
9.88
Attempted to log scalar metric train_steps_per_second:
1.235
Attempted to log scalar metric total_flos:
2538830561280.0
Attempted to log scalar metric train_loss:
4.124451637268066
Attempted to log scalar metric epoch:
0.5
{'train_runtime': 0.8097, 'train_samples_per_second': 9.88, 'train_steps_per_second': 1.235, 'train_loss': 4.124451637268066, 'epoch': 0.5}                                                                                                                 
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.24it/s]

0.7.8 failed:

# CUDA_VISIBLE_DEVICES=0 ./test.py 
/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.1.0) or chardet (4.0.0) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.54s/it]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 940.17 examples/s]
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 748, in convert_to_tensors
    tensor = as_tensor(value)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 720, in as_tensor
    return torch.tensor(value)
ValueError: too many dimensions 'str'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "./test.py", line 38, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py", line 317, in train
    output = super().train(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1821, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 448, in __iter__
    current_batch = next(dataloader_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.10/dist-packages/transformers/data/data_collator.py", line 45, in __call__
    return self.torch_call(features)
  File "/usr/local/lib/python3.10/dist-packages/transformers/data/data_collator.py", line 732, in torch_call
    batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 3299, in pad
    return BatchEncoding(batch_outputs, tensor_type=return_tensors)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 223, in __init__
    self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
  File "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py", line 764, in convert_to_tensors
    raise ValueError(
ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`text` in this case) have excessive nesting (inputs type `list` where type `int` is expected).
  0%|          | 0/1 [00:00<?, ?it/s]
@younesbelkada
Copy link
Contributor

Thanks for the repro @xkszltl , will try to repro and fix the issue and make another patch release

@younesbelkada
Copy link
Contributor

Hi @xkszltl
it seems to be caused by remove_unused_columns=False,, can you meanwhile either revert to 0.7.7 or set remove_unused_columns=True ? I'll try to provide the right fix meanwhile

@xkszltl
Copy link
Author

xkszltl commented Jan 11, 2024

True won't work at all (regardless of the version), that's why it's False initially.
My understanding is it drops the "text" column because model only has "input_ids" on its interface?

Traceback (most recent call last):
  File "./test.py", line 58, in main
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py", line 315, in train
    output = super().train(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1821, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 451, in __iter__
    current_batch = next(dataloader_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2804, in __getitems__
    batch = self.__getitem__(keys)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2800, in __getitem__
    return self._getitem(key)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2784, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 583, in query_table
    _check_valid_index_key(key, size)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
    _check_valid_index_key(int(max(key)), size=size)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 23887 is out of bounds for size 0

@xkszltl
Copy link
Author

xkszltl commented Jan 11, 2024

I'm currently pinning to trl<0.7.8.

@younesbelkada
Copy link
Contributor

hi @xkszltl
Thnks for your patience! I had a deeper look at the issue and I made #1229 that should resolve it.
Regarding #1216 (comment) - can you try to update datasets? I cannot repro with:

import datasets
import peft
import transformers
import trl


model_dir = "HuggingFaceM4/tiny-random-LlamaForCausalLM"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = transformers.AutoModelForCausalLM.from_pretrained(model_dir)

ds_train = datasets.load_dataset("imdb", split="train[:10]")

trainer = trl.SFTTrainer(
    model=model,
    args=transformers.TrainingArguments(
        output_dir="output",
        max_steps=1,
        remove_unused_columns=True,
    ),
    peft_config=peft.LoraConfig(
        lora_alpha=16,
        lora_dropout=0.1,
        r=8,
        bias="none",
        task_type="Causal_LM",
    ),
    train_dataset=ds_train,
    tokenizer=tokenizer,
    dataset_text_field="text",
    max_seq_length=8,
)
trainer.train()

@xkszltl
Copy link
Author

xkszltl commented Jan 15, 2024

Are you trying with master or a release?

@younesbelkada
Copy link
Contributor

@xkszltl on master currently

@younesbelkada
Copy link
Contributor

@xkszltl can you try and let me know how it goes?

@xkszltl
Copy link
Author

xkszltl commented Jan 15, 2024

Only tried on released wheel so that may be the reason.
I can give master a try after that PR is merged.

@younesbelkada
Copy link
Contributor

I see ok ! if you want you can build from that branch:

pip install -U git+https://github.com/huggingface/trl.git@fix-breaking-change

@xkszltl
Copy link
Author

xkszltl commented Jan 16, 2024

Still repros on the branch, and I'm using a different dataset this time, not just imdb.

@xkszltl
Copy link
Author

xkszltl commented Jan 16, 2024

In case version matters:

Name: accelerate
Version: 0.26.0
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: [email protected]
License: Apache
Location: /usr/local/lib/python3.10/dist-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: peft, trl
---
Name: datasets
Version: 2.16.1
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: [email protected]
License: Apache 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: aiohttp, dill, filelock, fsspec, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyarrow-hotfix, pyyaml, requests, tqdm, xxhash
Required-by: trl
---
Name: torch
Version: 2.1.2
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: [email protected]
License: BSD-3
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, peft, trl
---
Name: transformers
Version: 4.36.2
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: [email protected]
License: Apache 2.0 License
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, trl
---
Name: trl
Version: 0.7.10.dev0
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl
Author: Leandro von Werra
Author-email: [email protected]
License: Apache 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: accelerate, datasets, numpy, torch, transformers, tyro
Required-by: 

@younesbelkada
Copy link
Contributor

@xkszltl I am using the same library versions as you and was not able to repro, did you run this script: #1216 (comment) ?

@xkszltl
Copy link
Author

xkszltl commented Jan 16, 2024

# CUDA_VISIBLE_DEVICES=0 ./try.py         
/usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.1.0) or chardet (4.0.0) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
tokenizer_config.json: 100%|████████████████████████████████████████████████████████████████| 771/771 [00:00<00:00, 5.36MB/s]
tokenizer.model: 100%|█████████████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 789kB/s]
tokenizer.json: 100%|███████████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 4.61MB/s]
special_tokens_map.json: 100%|██████████████████████████████████████████████████████████████| 552/552 [00:00<00:00, 4.44MB/s]
config.json: 100%|██████████████████████████████████████████████████████████████████████████| 466/466 [00:00<00:00, 3.74MB/s]
pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████| 2.07M/2.07M [00:00<00:00, 10.1MB/s]
generation_config.json: 100%|███████████████████████████████████████████████████████████████| 138/138 [00:00<00:00, 1.06MB/s]
Downloading readme: 100%|███████████████████████████████████████████████████████████████| 7.81k/7.81k [00:00<00:00, 38.4MB/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████| 21.0M/21.0M [00:03<00:00, 6.42MB/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████| 20.5M/20.5M [00:03<00:00, 6.45MB/s]
Downloading data: 100%|█████████████████████████████████████████████████████████████████| 42.0M/42.0M [00:05<00:00, 7.14MB/s]
Generating train split: 100%|███████████████████████████████████████████████| 25000/25000 [00:00<00:00, 192120.43 examples/s]
Generating test split: 100%|████████████████████████████████████████████████| 25000/25000 [00:00<00:00, 205666.45 examples/s]
Generating unsupervised split: 100%|████████████████████████████████████████| 50000/50000 [00:00<00:00, 221609.66 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 1023.63 examples/s]
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  0%|                                                                                                  | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "./try.py", line 38, in <module>
    trainer.train()
  File "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py", line 330, in train
    output = super().train(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1821, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 451, in __iter__
    current_batch = next(dataloader_iter)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2805, in __getitems__
    batch = self.__getitem__(keys)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2801, in __getitem__
    return self._getitem(key)
  File "/usr/local/lib/python3.10/dist-packages/datasets/arrow_dataset.py", line 2785, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 583, in query_table
    _check_valid_index_key(key, size)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 536, in _check_valid_index_key
    _check_valid_index_key(int(max(key)), size=size)
  File "/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py", line 526, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 9 is out of bounds for size 0
  0%|          | 0/1 [00:00<?, ?it/s]

@xkszltl
Copy link
Author

xkszltl commented Jan 16, 2024

This is the output from that script.
And everything is very fresh because it's in a docker, and you can see both the model/dataset are pulled in this run, not even from cache.

@xkszltl
Copy link
Author

xkszltl commented Jan 16, 2024

And I've seen others talking about something similar:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants