From 7f300efe141e3043c54263d3a5725e30355fed95 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 14:22:34 -0700 Subject: [PATCH 1/6] add client --- gazelle/__init__.py | 1 + gazelle/modeling_gazelle.py | 98 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/gazelle/__init__.py b/gazelle/__init__.py index 28a6065..ffb17b8 100644 --- a/gazelle/__init__.py +++ b/gazelle/__init__.py @@ -4,6 +4,7 @@ GazelleForConditionalGeneration, GazellePreTrainedModel, GazelleProcessor, + GazelleClient, ) __version__ = "0.1.0" diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index 8e9fff0..caf3138 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -47,6 +47,10 @@ replace_return_docstrings, ) +import torchaudio +import transformers +from transformers import BitsAndBytesConfig + logger = logging.get_logger(__name__) @@ -949,3 +953,97 @@ def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names audio_processor_input_names = self.audio_processor_class.model_input_names return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names)) + + +def load_audio_from_file(fpath): + test_audio, sr = torchaudio.load(fpath) + + if sr != 16000: + test_audio = torchaudio.transforms.Resample(sr, 16000)(test_audio) + + return test_audio + + +class GazelleClient: + def __init__( + self, + model_id="tincans-ai/gazelle-v0.2", + quantization=None, + ): + """ + Args: + model_id (str): The model id to load. Defaults to "tincans-ai/gazelle-v0.2". + quantization (str): "8-bit" or "4-bit". Defaults to None. + """ + assert quantization in [None, "8-bit", "4-bit"], "Invalid quantization. Must be None, '8-bit' or '4-bit'." + self.model_id = model_id + self.config = GazelleConfig.from_pretrained(model_id) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) + if quantization: + self.config.quantization_config = BitsAndBytesConfig(**quantization) + + # If the quantization config is None, then load the model in the default way. + if quantization is None: + device = "cpu" + dtype = torch.float32 + if torch.cuda.is_available(): + device = "cuda" + dtype = torch.bfloat16 + print(f"Using {device} device") + elif torch.backends.mps.is_available(): + device = "mps" + dtype = torch.float16 + print(f"Using {device} device") + + # Load the model. + self.model = GazelleForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=dtype + ).to(device, dtype=dtype) + + # If the quantization config is not None, + else: + if quantization == "8-bit": + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + elif quantization == "4-bit": + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + self.model = GazelleForConditionalGeneration.from_pretrained( + model_id, + device_map="cuda:0", + quantization_config=quantization_config + ) + + self.audio_processor = transformers.Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-base-960h" + ) + + def inference_collator( + self, audio_input, prompt="Transcribe the following \n<|audio|>", audio_dtype=torch.float16 + ): + audio_values = self.audio_processor( + audio=audio_input, return_tensors="pt", sampling_rate=16000 + ).input_values + msgs = [ + {"role": "user", "content": prompt}, + ] + labels = self.tokenizer.apply_chat_template( + msgs, return_tensors="pt", add_generation_prompt=True + ) + return { + "audio_values": audio_values.squeeze(0).to("cuda").to(audio_dtype), + "input_ids": labels.to("cuda"), + } + + def infer(self, audio_input, prompt="Transcribe the following \n<|audio|>"): + """Inference method for the Gazelle model.""" + inputs = self.inference_collator(audio_input, prompt=prompt) + response = self.tokenizer.decode(self.model.generate(**inputs, max_new_tokens=64)[0]) + # Get everything after [/INST] token. + response = response.split("[/INST]")[1].strip() + return response From 9a97200afce1f3f187273d11169e583e62c7b769 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 14:25:41 -0700 Subject: [PATCH 2/6] remove annoying copilot artifact --- gazelle/modeling_gazelle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index caf3138..284b189 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -979,8 +979,6 @@ def __init__( self.model_id = model_id self.config = GazelleConfig.from_pretrained(model_id) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) - if quantization: - self.config.quantization_config = BitsAndBytesConfig(**quantization) # If the quantization config is None, then load the model in the default way. if quantization is None: From 092dd90d50916b0e5e99ab0b67ba0fc0878f059e Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 14:50:17 -0700 Subject: [PATCH 3/6] add helper function --- gazelle/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gazelle/__init__.py b/gazelle/__init__.py index ffb17b8..813e639 100644 --- a/gazelle/__init__.py +++ b/gazelle/__init__.py @@ -5,6 +5,7 @@ GazellePreTrainedModel, GazelleProcessor, GazelleClient, + load_audio_from_file, ) __version__ = "0.1.0" From 1936282f5899357bbf34c9c3d90354874801baf9 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 15:01:50 -0700 Subject: [PATCH 4/6] remove split --- gazelle/modeling_gazelle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index 284b189..c94049b 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -1042,6 +1042,4 @@ def infer(self, audio_input, prompt="Transcribe the following \n<|audio|>"): """Inference method for the Gazelle model.""" inputs = self.inference_collator(audio_input, prompt=prompt) response = self.tokenizer.decode(self.model.generate(**inputs, max_new_tokens=64)[0]) - # Get everything after [/INST] token. - response = response.split("[/INST]")[1].strip() return response From 95fab1a6eba84a01811230f0110e90de103f491b Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 15:02:18 -0700 Subject: [PATCH 5/6] Revert "remove split" This reverts commit 1936282f5899357bbf34c9c3d90354874801baf9. --- gazelle/modeling_gazelle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index c94049b..284b189 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -1042,4 +1042,6 @@ def infer(self, audio_input, prompt="Transcribe the following \n<|audio|>"): """Inference method for the Gazelle model.""" inputs = self.inference_collator(audio_input, prompt=prompt) response = self.tokenizer.decode(self.model.generate(**inputs, max_new_tokens=64)[0]) + # Get everything after [/INST] token. + response = response.split("[/INST]")[1].strip() return response From d85f80f0c01ba8ed872dbb502d86c1b610814c30 Mon Sep 17 00:00:00 2001 From: Andrew Han Date: Fri, 10 May 2024 16:36:21 -0700 Subject: [PATCH 6/6] fix dtype issue --- gazelle/modeling_gazelle.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index 284b189..9593ba4 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -983,21 +983,21 @@ def __init__( # If the quantization config is None, then load the model in the default way. if quantization is None: device = "cpu" - dtype = torch.float32 + self.audio_dtype = torch.float32 if torch.cuda.is_available(): device = "cuda" - dtype = torch.bfloat16 + self.audio_dtype = torch.bfloat16 print(f"Using {device} device") elif torch.backends.mps.is_available(): device = "mps" - dtype = torch.float16 + self.audio_dtype = torch.float16 print(f"Using {device} device") # Load the model. self.model = GazelleForConditionalGeneration.from_pretrained( model_id, - torch_dtype=dtype - ).to(device, dtype=dtype) + torch_dtype=self.audio_dtype + ).to(device, dtype=self.audio_dtype) # If the quantization config is not None, else: @@ -1022,7 +1022,7 @@ def __init__( ) def inference_collator( - self, audio_input, prompt="Transcribe the following \n<|audio|>", audio_dtype=torch.float16 + self, audio_input, prompt="Transcribe the following \n<|audio|>" ): audio_values = self.audio_processor( audio=audio_input, return_tensors="pt", sampling_rate=16000 @@ -1033,6 +1033,7 @@ def inference_collator( labels = self.tokenizer.apply_chat_template( msgs, return_tensors="pt", add_generation_prompt=True ) + audio_dtype = self.audio_dtype return { "audio_values": audio_values.squeeze(0).to("cuda").to(audio_dtype), "input_ids": labels.to("cuda"),