Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
Paper: https://arxiv.org/abs/2312.00752 Edited to be custom bidirectional by @logprobability
Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.
This version is a simplified bidirectional one and as a result is not suitable for AR sequence generation (GPT type), instead focusing on NAR generation or encoder models.
Build from source with pip install .
from this repository.
If pip
complains about PyTorch versions, try passing --no-build-isolation
to pip
.
Other requirements:
- Linux
- NVIDIA GPU
- PyTorch 1.12+
- CUDA 11.6+
Our models were trained using PyTorch AMP for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, as a first step please try a framework storing parameters in fp32 (such as AMP).
Some parts of the model have initializations inherited from prior work on S4 models.
For example, the nn.Linear
modules to zero).
If this is the case, you may have to add custom logic (e.g. this line turns off re-initializing in our trainer, but would be a no-op in any other framework)
that is specific to the training framework.
If you use this codebase, or otherwise found our work valuable, please cite Mamba:
@article{mamba,
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
author={Gu, Albert and Dao, Tri},
journal={arXiv preprint arXiv:2312.00752},
year={2023}
}
With the caveat that it is highly modified. You may want to cite this repo as well, or at least let me know you're using it so I can include a linkback to your repo here