Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves the doc for NeuronTrainingArguments #794

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/package_reference/trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class CustomNeuronTrainer(NeuronTrainer):

Another way to customize the training loop behavior for the PyTorch [`NeuronTrainer`](https://huggingface.co/docs/optimum/neuron/package_reference/trainer#optimum.neuron.NeuronTrainer) is to use [callbacks](https://huggingface.co/docs/transformers/main_classes/callback) that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early stopping).

## NeuronTrainingArguments

[[autodoc]] training_args.NeuronTrainingArguments

## NeuronTrainer

[[autodoc]] trainers.NeuronTrainer
Expand Down
18 changes: 11 additions & 7 deletions docs/source/quickstart.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,32 @@ limitations under the License.

## Training

There are two main classes one needs to know:
There are three main classes one needs to know:

- NeuronArgumentParser: inherits the original [HfArgumentParser](https://huggingface.co/docs/transformers/main/en/internal/trainer_utils#transformers.HfArgumentParser) in Transformers with additional checks on the argument values to make sure that they will work well with AWS Trainium instances.
- [NeuronTrainingArguments](https://huggingface.co/docs/optimum/neuron/package_reference/trainer#optimum.neuron.NeuronTrainingArguments): extends the original TrainingArguments with Trainium-specific parameters to optimize performance on AWS Trainium hardware.
- [NeuronTrainer](https://huggingface.co/docs/optimum/neuron/package_reference/trainer): the trainer class that takes care of compiling and distributing the model to run on Trainium Chips, and performing training and evaluation.

The [NeuronTrainer](https://huggingface.co/docs/optimum/neuron/package_reference/trainer) is very similar to the [🤗 Transformers Trainer](https://huggingface.co/docs/transformers/main_classes/trainer), and adapting a script using the Trainer to make it work with Trainium will mostly consist in simply swapping the `Trainer` class for the `NeuronTrainer` one.
That's how most of the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) were adapted from their [original counterparts](https://github.com/huggingface/transformers/tree/main/examples/pytorch).
The [NeuronTrainer](https://huggingface.co/docs/optimum/neuron/package_reference/trainer) is very similar to the [🤗 Transformers Trainer](https://huggingface.co/docs/transformers/main_classes/trainer), and adapting a script for Trainium instances will mostly consist in:
- Swapping the `Trainer` class for the `NeuronTrainer` one
- Using `NeuronTrainingArguments` instead of `TrainingArguments`

modifications:
That's how most of the [example scripts](https://github.com/huggingface/optimum-neuron/tree/main/examples) were adapted from their [original counterparts](https://github.com/huggingface/transformers/tree/main/examples/pytorch):

```diff
from transformers import TrainingArguments
-from transformers import TrainingArguments
+from optimum.neuron import NeuronTrainingArguments
-from transformers import Trainer
+from optimum.neuron import NeuronTrainer as Trainer
training_args = TrainingArguments(
-training_args = TrainingArguments(
+training_args = NeuronTrainingArguments(
# training arguments...
)
# A lot of code here
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args, # Original training arguments.
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
Expand Down