Skip to content

Commit

Permalink
using monkey patch to replace models
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Jan 22, 2024
1 parent 84e4af9 commit bda4829
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 27 deletions.
21 changes: 17 additions & 4 deletions egs/aishell/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

Command for training is:
```bash
pip install -r whisper/requirements.txt

./prepare.sh --stage 30 --stop_stage 30

#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--use-fp16 1 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
Expand All @@ -27,21 +28,33 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--use-fp16 1 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
```

Command for decoding is:
Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
Command for decoding using pretrained models (before fine-tuning):
```bash
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
```
Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_aishell_whisper>

Expand Down
44 changes: 40 additions & 4 deletions egs/aishell/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
# Command for decoding using fine-tuned models:
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
# Command for decoding using pretrained models (before fine-tuning):
python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch -1 --avg 1 \
--remove-whisper-encoder-input-length-restriction False \
--beam-size 10 --max-duration 50
"""

import argparse
import logging
Expand All @@ -29,8 +52,8 @@
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from model import load_model

#from model import load_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
Expand Down Expand Up @@ -104,7 +127,7 @@ def average_checkpoints(

def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?'
punctuation = '!,.;:?、!,。;:?《》 '
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
return text
Expand Down Expand Up @@ -183,6 +206,13 @@ def get_parser():
help="""The model name to use.
""",
)

parser.add_argument(
"--remove-whisper-encoder-input-length-restriction",
type=str2bool,
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)

return parser

Expand Down Expand Up @@ -246,6 +276,10 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)

supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
Expand Down Expand Up @@ -404,7 +438,9 @@ def main():

logging.info(f"device: {device}")

model = load_model(params.model_name)
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
Expand Down
2 changes: 1 addition & 1 deletion egs/aishell/ASR/whisper/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ git+https://github.com/lhotse-speech/lhotse
sentencepiece
tensorboard
librosa
openai-whisper==20231117
openai-whisper==git+https://github.com/yuekaizhang/whisper.git
zhconv
WeTextProcessing
deepspeed
37 changes: 19 additions & 18 deletions egs/aishell/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@
"""
Usage:
./prepare.sh
If you use --datatang-prob=0, then you don't need to run the above script.
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
#fine-tuning with deepspeed zero stage 1
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
# fine-tuning with ddp
torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_medium \
--base-lr 1e-5 \
--model-name medium
"""


Expand Down Expand Up @@ -88,7 +88,7 @@
)

import whisper
from model import load_model
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from label_smoothing import LabelSmoothingLoss

LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
Expand Down Expand Up @@ -227,7 +227,7 @@ def get_parser():
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
default=True,
help="Whether to use half precision training.",
)

Expand Down Expand Up @@ -744,8 +744,9 @@ def run(rank, world_size, args):
logging.info(params)

logging.info("About to create model")

model = load_model(params.model_name)

replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
Expand Down
26 changes: 26 additions & 0 deletions egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import whisper

def forward(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)

x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)

for block in self.blocks:
x = block(x)

x = self.ln_post(x)
return x

def replace_whisper_encoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
To be called before the model is loaded, it changes whisper to process audio with any length < 30s.
"""
whisper.model.AudioEncoder.forward = forward

0 comments on commit bda4829

Please sign in to comment.