Skip to content

Latest commit

 

History

History
73 lines (60 loc) · 3.41 KB

flash_attention.md

File metadata and controls

73 lines (60 loc) · 3.41 KB

中文说明 | English

使用FlashAttention加速Chinese-CLIP

Chinese-CLIP训练现已支持通过FlashAttention加速训练进程。

环境准备

  • TuringAmpereAdaHopper架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见此文档表格
  • CUDA 11.4及以上版本。
  • Pytorch 1.12及以上版本。
  • FlashAttention:通过执行pip install flash-attn安装FlashAttention。

更多信息可参见FlashAttention项目仓库

在Chinese-CLIP中用起来!

在Chinese-CLIP finetune中应用FlashAttention非常简单,只需要在finetune的sh脚本中加入--use-flash-attention配置项即可。我们提供了样例脚本run_scripts/muge_finetune_vit-b-16_rbt-base_flashattn.sh

训练速度和显存占用对比

启用FlashAttention可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。我们的实验在一台8卡A100 GPU(80GB显存)机器进行,FlashAttention 0.2.8,Pytorch 1.10.1。

我们分别列出finetune过程中,相同batch size下启用FlashAttention前后每个规模模型的FP16精度finetune的batch time和显存占用对比,可以看到启用FlashAttention后,训练速度有所提升,也更加节约显存。对于更大规模模型的训练速度提升和显存占用降低更为显著。

Batch Time
单位: 秒/itBatch sizew/o FlashAttentionw/ FlashAttentionSpeedup
CN-CLIPRN501200*81.7101.6801.02×
CN-CLIPViT-B/16450*81.4770.9601.54×
CN-CLIPViT-L/14128*81.2930.7851.65×
CN-CLIPViT-L/14@336px40*81.3970.5872.38×
CN-CLIPViT-H/1464*81.2650.8451.50×

显存
单位: GBBatch sizew/o FlashAttentionw/ FlashAttention
CN-CLIPRN501200*87975
CN-CLIPViT-B/16450*88056
CN-CLIPViT-L/14128*87750
CN-CLIPViT-L/14@336px40*87837
CN-CLIPViT-H/1464*87657