Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a client for Gazelle #3

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gazelle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
GazelleForConditionalGeneration,
GazellePreTrainedModel,
GazelleProcessor,
GazelleClient,
load_audio_from_file,
)

__version__ = "0.1.0"
97 changes: 97 additions & 0 deletions gazelle/modeling_gazelle.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
replace_return_docstrings,
)

import torchaudio
import transformers
from transformers import BitsAndBytesConfig

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -949,3 +953,96 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor_class.model_input_names
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names))


def load_audio_from_file(fpath):
test_audio, sr = torchaudio.load(fpath)

if sr != 16000:
test_audio = torchaudio.transforms.Resample(sr, 16000)(test_audio)

return test_audio


class GazelleClient:
def __init__(
self,
model_id="tincans-ai/gazelle-v0.2",
quantization=None,
):
"""
Args:
model_id (str): The model id to load. Defaults to "tincans-ai/gazelle-v0.2".
quantization (str): "8-bit" or "4-bit". Defaults to None.
"""
assert quantization in [None, "8-bit", "4-bit"], "Invalid quantization. Must be None, '8-bit' or '4-bit'."
self.model_id = model_id
self.config = GazelleConfig.from_pretrained(model_id)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

# If the quantization config is None, then load the model in the default way.
if quantization is None:
device = "cpu"
self.audio_dtype = torch.float32
if torch.cuda.is_available():
device = "cuda"
self.audio_dtype = torch.bfloat16
print(f"Using {device} device")
elif torch.backends.mps.is_available():
device = "mps"
self.audio_dtype = torch.float16
print(f"Using {device} device")

# Load the model.
self.model = GazelleForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=self.audio_dtype
).to(device, dtype=self.audio_dtype)

# If the quantization config is not None,
else:
if quantization == "8-bit":
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
elif quantization == "4-bit":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)

self.model = GazelleForConditionalGeneration.from_pretrained(
model_id,
device_map="cuda:0",
quantization_config=quantization_config
)

self.audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-base-960h"
)

def inference_collator(
self, audio_input, prompt="Transcribe the following \n<|audio|>"
):
audio_values = self.audio_processor(
audio=audio_input, return_tensors="pt", sampling_rate=16000
).input_values
msgs = [
{"role": "user", "content": prompt},
]
labels = self.tokenizer.apply_chat_template(
msgs, return_tensors="pt", add_generation_prompt=True
)
audio_dtype = self.audio_dtype
return {
"audio_values": audio_values.squeeze(0).to("cuda").to(audio_dtype),
"input_ids": labels.to("cuda"),
}

def infer(self, audio_input, prompt="Transcribe the following \n<|audio|>"):
"""Inference method for the Gazelle model."""
inputs = self.inference_collator(audio_input, prompt=prompt)
response = self.tokenizer.decode(self.model.generate(**inputs, max_new_tokens=64)[0])
# Get everything after [/INST] token.
response = response.split("[/INST]")[1].strip()
return response