Part of the goal of this project is for me to learn as I go, so I am going to start at the beginning - with with Andrej Karpathy’s PyTorch GPT-2 trainer from llm.c. This is the script that Keller Jordan used for his initial baseline. This trainer is very similar to the NanoGPT trainer with some minor modifications / simplifications (such as no dropout).
I have upstreamed some QOL improvements and basic tweaks to the training script from Keller’s fork, but have not changed any of the core training / modeling logic. Specifically:
Additionally, I added wandb
logging for easy tracking of training progress - optimistically I may need to remove this one day as it slightly increases step time.
Commit with the initial setup is here: b3c32f8
.
The baseline run time on my 2xRTX 4090 setup is 8.13 hours.
Additionally, I added wandb
logging for easy tracking of training progress - optimistically I may need to remove this one day as it slightly increases step time.
Commit with the initial setup is here: b3c32f8
.
The baseline run time on my 2xRTX 4090 setup is 8.13 hours.
+Waiting 8 hours for a result, so I’m going to begin by implementing some of the notable improvements from the 8xH100 leaderboard. I’ll start with the most impactful/easiest changes first:
Architectural changes (31.8% speedup, then 24% speedup)
+ + +There are some basic architectural changes and modernizations that can be made to the model that will speed up training. These changes are general improvements to the transformer decoder architecture that have been generally adopted since the original GPT-2 paper. The changes are:
In addition, learning rate and batch size have been tuned.
Once again, many of these changes are downstreamed from the modded-nanogpt repository / 8xH100 speedrun. Its not efficient to reinvent the wheel, and I want to get training time down as fast as possible in the beginning.
After implementing these changes (commit b7bb93f
), the new run time is 7.51 hours. This run was more data-efficient than the baseline, requiring only 5.07B tokens. However, the tokens/second increased, likely due to the larger batch size (more gradient accumulation steps which tends to translate to lower throughput) and the architectural changes, such as the inclusion of RoPE. Once I have a shorter run time, I will be able to tune more effectively and see if I can remove gradient accumulation.