From e9e1213e6c3ee3a865a886afbc7b053d628085ea Mon Sep 17 00:00:00 2001 From: root Date: Fri, 5 Jan 2024 03:54:43 +0000 Subject: [PATCH] clean up comments and documentation --- .../benchmarks/bert/src/flash_attn_triton.py | 4 ++++ examples/benchmarks/bert/src/mosaic_bert.py | 20 ++++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/benchmarks/bert/src/flash_attn_triton.py b/examples/benchmarks/bert/src/flash_attn_triton.py index b2b946c06..7f3fdabce 100644 --- a/examples/benchmarks/bert/src/flash_attn_triton.py +++ b/examples/benchmarks/bert/src/flash_attn_triton.py @@ -17,6 +17,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +*Update: 01-04-2024* +This version of Triton Flash Attention is being deprecated in favor of Flash Attention 2, +which now supports ALiBi natively https://github.com/Dao-AILab/flash-attention + *Experimental* implementation of FlashAttention in Triton. We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py diff --git a/examples/benchmarks/bert/src/mosaic_bert.py b/examples/benchmarks/bert/src/mosaic_bert.py index 8436783ca..0ce897be2 100644 --- a/examples/benchmarks/bert/src/mosaic_bert.py +++ b/examples/benchmarks/bert/src/mosaic_bert.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML Examples authors # SPDX-License-Identifier: Apache-2.0 -"""Implements a Mosaic BERT wrapper around a :class:`.ComposerTransformer`.""" +"""Implements a MosaicBERT wrapper around a :class:`.ComposerTransformer`.""" from __future__ import annotations @@ -31,12 +31,12 @@ def create_mosaic_bert_mlm(pretrained_model_name: str = 'bert-base-uncased', tokenizer_name: Optional[str] = None, gradient_checkpointing: Optional[bool] = False, pretrained_checkpoint: Optional[str] = None): - """Mosaic BERT masked language model based on |:hugging_face:| Transformers. + """MosaicBERT masked language model based on |:hugging_face:| Transformers. For more information, see `Transformers. `_. - This function creates a Mosaic BERT, which includes several throughput + This function creates a MosaicBERT, which includes several throughput optimizations not available in |:hugging_face:| BERT as well as architecture changes based on ALiBi and Gated Linear Units. @@ -82,7 +82,7 @@ def create_mosaic_bert_mlm(pretrained_model_name: str = 'bert-base-uncased', "vocab_size": 30522 } - To create a Mosaic BERT model for Masked Language Model pretraining: + To create a MosaicBERT model for Masked Language Model pretraining: .. testcode:: @@ -145,11 +145,11 @@ def create_mosaic_bert_classification( tokenizer_name: Optional[str] = None, gradient_checkpointing: Optional[bool] = False, pretrained_checkpoint: Optional[str] = None): - """Mosaic BERT classification model based on |:hugging_face:| Transformers. + """MosaicBERT classification model based on |:hugging_face:| Transformers. For more information, see `Transformers. `_. - This function creates a Mosaic BERT, which includes several throughput + This function creates a MosaicBERT, which includes several throughput optimizations not available in |:hugging_face:| BERT as well as architecture changes based on ALiBi and Gated Linear Units. @@ -207,7 +207,7 @@ def create_mosaic_bert_classification( "vocab_size": 30522 } - To create a Mosaic BERT model for classification: + To create a MosaicBERT model for classification: .. testcode:: from mosaic_bert import create_mosaic_bert_classification @@ -229,8 +229,10 @@ def create_mosaic_bert_classification( if not model_config: model_config = {} - # By default, turn off attention dropout in Mosaic BERT - # (otherwise, Flash Attention will be off by default) + # By default, turn off attention dropout in MosaicBERT + # Flash Attention 2 supports dropout in the attention module + # while our previous Triton Flash Attention layer only works with + # attention_probs_dropout_prob = 0. if 'attention_probs_dropout_prob' not in model_config: model_config['attention_probs_dropout_prob'] = 0.0