Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/classifier-router' into v1
Browse files Browse the repository at this point in the history
  • Loading branch information
alpayariyak committed Sep 22, 2023
2 parents 477c416 + 0381bc5 commit bb33341
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 1 deletion.
138 changes: 138 additions & 0 deletions cluster_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
from embed import BGE_Tokenizer
import random
from transformers import AutoModelForSequenceClassification
from datasets import load_dataset
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch
import yaml


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset_path",
type=str
)
parser.add_argument(
"--train_config_path",
type=str,
default="configs/classifier_training_config.yaml",
)
parser.add_argument(
"--embedding_model",
type=str,
default="BAAI/bge-large-en",
help="Model to use for embedding, options: 'BAAI/bge-*-en', 'BAAI/bge-*-en-v1.5'"
)
parser.add_argument(
"--normalize_embeddings",
type=bool,
default=False,
help="Whether to normalize the embeddings"
)

return parser.parse_args()


class CustomDataset(Dataset):

def __init__(self, texts, labels, tokenizer, max_len, use_bge):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
self.use_bge = use_bge

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
text = str(self.texts[idx])
label = int(self.labels[idx]) # convert label to integer
if self.use_bge:
# Use BGE to encode the text
encoding = self.tokenizer.encode([text])
input_ids = encoding
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
else:
"""Legacy code"""
# Check if the text is longer than the maximum length
if len(text.split()) > self.max_len:
# Calculate the number of tokens to be removed
num_tokens_to_remove = len(text.split()) - self.max_len
# Split the text into tokens
tokens = text.split()
# Randomly select start and end indices for truncation
start_index = random.randint(0, num_tokens_to_remove)
end_index = start_index + self.max_len
# Truncate the tokens and join them back into a string
text = " ".join(tokens[start_index:end_index])

encoding = self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].flatten()
attention_mask = encoding['attention_mask'].flatten()

return {
'text': text,
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': torch.tensor([label], dtype=torch.long)
}


def create_dataset(data, tokenizer, max_len, use_bge):
texts = [item['text'] for item in data]
labels = [item['label'] for item in data]
return CustomDataset(texts, labels, tokenizer, max_len, use_bge)


def train(args):
classifier_dataset = load_dataset(args.dataset_path)["train"].to_dict()
n_labels = len(set(classifier_dataset["label"]))
train_data, val_data = train_test_split(classifier_dataset, test_size=0.1, random_state=42)

use_bge = "bge" in args.embedding_model

if not use_bge:
raise ValueError("Embedding model must be a BGE model at this time.")

tokenizer = BGE_Tokenizer(model_name=args.embedding_model, normalize_embeddings=True)
model = AutoModelForSequenceClassification.from_pretrained(args.embedding_model, num_labels=n_labels)

train_dataset = create_dataset(train_data, tokenizer, args.max_length, use_bge)
val_dataset = create_dataset(val_data, tokenizer, args.max_length, use_bge)

train_config = yaml.safe_load(open(args.train_config_path, "r"))
training_args = TrainingArguments(**train_config["training_args"])

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)

trainer.train()

trainer.save_model("classifier")

print(trainer.evaluate())


if __name__ == "__main__":
args = get_args()
train(args)



24 changes: 24 additions & 0 deletions configs/classifier_training_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
training_args:
output_dir: './bge-large-classifier-32'
num_train_epochs: 6
per_device_train_batch_size: 32
per_device_eval_batch_size: 32
warmup_ratio: 0.03
weight_decay: 0.01
logging_dir: './logs-bge-large-32'
logging_steps: 10
learning_rate: 1e-6
evaluation_strategy: 'steps'
save_steps: 1000
eval_steps: 1000
save_total_limit: 3
load_best_model_at_end: True
metric_for_best_model: "eval_loss"
greater_is_better: False
push_to_hub: True
hub_strategy: "all_checkpoints"
report_to: "wandb"
run_name: "bge-large-32"
wandb_project: "bge-large-32-classifier-HydraALPHA"
wandb_entity: "llama-moe"

2 changes: 1 addition & 1 deletion data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,4 @@ def user_assistant_template(conversation, max_words=1400):
"""Format conversation using user and assistant template."""
return alpaca_template(conversation, max_words=max_words,
naming_map={"instruction": "User",
"output": "Assistant"})
"output": "Assistant"})
22 changes: 22 additions & 0 deletions embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import argparse
from FlagEmbedding import FlagModel
from transformers import AutoTokenizer, AutoModel


def get_args():
Expand Down Expand Up @@ -64,6 +65,27 @@ def encode(self, texts, batch_size=256):
return self.model.encode_queries(texts, batch_size=batch_size)


class BGE_Tokenizer:
"""For classification"""
def __init__(self, model_name, normalize_embeddings, max_length=512):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
self.normalize_embeddings = normalize_embeddings

def encode(self, texts):
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512)

with torch.no_grad():
model_output = self.model(**encoded_input)
embeddings = model_output[0][:, 0] # CLS token pooling

if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

return embeddings


def embed_dataset(args):
dataset = load_dataset(args.dataset_path)["train"]
model = BGE(args.model_name, args.normalize, args.max_length)
Expand Down

0 comments on commit bb33341

Please sign in to comment.