diff --git a/gazelle/__init__.py b/gazelle/__init__.py index 28a6065..813e639 100644 --- a/gazelle/__init__.py +++ b/gazelle/__init__.py @@ -4,6 +4,8 @@ GazelleForConditionalGeneration, GazellePreTrainedModel, GazelleProcessor, + GazelleClient, + load_audio_from_file, ) __version__ = "0.1.0" diff --git a/gazelle/modeling_gazelle.py b/gazelle/modeling_gazelle.py index 8e9fff0..9593ba4 100644 --- a/gazelle/modeling_gazelle.py +++ b/gazelle/modeling_gazelle.py @@ -47,6 +47,10 @@ replace_return_docstrings, ) +import torchaudio +import transformers +from transformers import BitsAndBytesConfig + logger = logging.get_logger(__name__) @@ -949,3 +953,96 @@ def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names audio_processor_input_names = self.audio_processor_class.model_input_names return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names)) + + +def load_audio_from_file(fpath): + test_audio, sr = torchaudio.load(fpath) + + if sr != 16000: + test_audio = torchaudio.transforms.Resample(sr, 16000)(test_audio) + + return test_audio + + +class GazelleClient: + def __init__( + self, + model_id="tincans-ai/gazelle-v0.2", + quantization=None, + ): + """ + Args: + model_id (str): The model id to load. Defaults to "tincans-ai/gazelle-v0.2". + quantization (str): "8-bit" or "4-bit". Defaults to None. + """ + assert quantization in [None, "8-bit", "4-bit"], "Invalid quantization. Must be None, '8-bit' or '4-bit'." + self.model_id = model_id + self.config = GazelleConfig.from_pretrained(model_id) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) + + # If the quantization config is None, then load the model in the default way. + if quantization is None: + device = "cpu" + self.audio_dtype = torch.float32 + if torch.cuda.is_available(): + device = "cuda" + self.audio_dtype = torch.bfloat16 + print(f"Using {device} device") + elif torch.backends.mps.is_available(): + device = "mps" + self.audio_dtype = torch.float16 + print(f"Using {device} device") + + # Load the model. + self.model = GazelleForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=self.audio_dtype + ).to(device, dtype=self.audio_dtype) + + # If the quantization config is not None, + else: + if quantization == "8-bit": + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + elif quantization == "4-bit": + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + self.model = GazelleForConditionalGeneration.from_pretrained( + model_id, + device_map="cuda:0", + quantization_config=quantization_config + ) + + self.audio_processor = transformers.Wav2Vec2Processor.from_pretrained( + "facebook/wav2vec2-base-960h" + ) + + def inference_collator( + self, audio_input, prompt="Transcribe the following \n<|audio|>" + ): + audio_values = self.audio_processor( + audio=audio_input, return_tensors="pt", sampling_rate=16000 + ).input_values + msgs = [ + {"role": "user", "content": prompt}, + ] + labels = self.tokenizer.apply_chat_template( + msgs, return_tensors="pt", add_generation_prompt=True + ) + audio_dtype = self.audio_dtype + return { + "audio_values": audio_values.squeeze(0).to("cuda").to(audio_dtype), + "input_ids": labels.to("cuda"), + } + + def infer(self, audio_input, prompt="Transcribe the following \n<|audio|>"): + """Inference method for the Gazelle model.""" + inputs = self.inference_collator(audio_input, prompt=prompt) + response = self.tokenizer.decode(self.model.generate(**inputs, max_new_tokens=64)[0]) + # Get everything after [/INST] token. + response = response.split("[/INST]")[1].strip() + return response