diff --git a/cluster_classifier.py b/cluster_classifier.py index e69de29..8131f2a 100644 --- a/cluster_classifier.py +++ b/cluster_classifier.py @@ -0,0 +1,138 @@ +import argparse +from embed import BGE_Tokenizer +import random +from transformers import AutoModelForSequenceClassification +from datasets import load_dataset +from sklearn.model_selection import train_test_split +from transformers import Trainer, TrainingArguments +from torch.utils.data import Dataset +import torch +import yaml + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset_path", + type=str + ) + parser.add_argument( + "--train_config_path", + type=str, + default="configs/classifier_training_config.yaml", + ) + parser.add_argument( + "--embedding_model", + type=str, + default="BAAI/bge-large-en", + help="Model to use for embedding, options: 'BAAI/bge-*-en', 'BAAI/bge-*-en-v1.5'" + ) + parser.add_argument( + "--normalize_embeddings", + type=bool, + default=False, + help="Whether to normalize the embeddings" + ) + + return parser.parse_args() + + +class CustomDataset(Dataset): + + def __init__(self, texts, labels, tokenizer, max_len, use_bge): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_len = max_len + self.use_bge = use_bge + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = int(self.labels[idx]) # convert label to integer + if self.use_bge: + # Use BGE to encode the text + encoding = self.tokenizer.encode([text]) + input_ids = encoding + attention_mask = torch.ones(input_ids.shape, dtype=torch.long) + else: + """Legacy code""" + # Check if the text is longer than the maximum length + if len(text.split()) > self.max_len: + # Calculate the number of tokens to be removed + num_tokens_to_remove = len(text.split()) - self.max_len + # Split the text into tokens + tokens = text.split() + # Randomly select start and end indices for truncation + start_index = random.randint(0, num_tokens_to_remove) + end_index = start_index + self.max_len + # Truncate the tokens and join them back into a string + text = " ".join(tokens[start_index:end_index]) + + encoding = self.tokenizer.encode_plus( + text, + add_special_tokens=True, + max_length=self.max_len, + return_token_type_ids=False, + pad_to_max_length=True, + return_attention_mask=True, + return_tensors='pt', + ) + input_ids = encoding['input_ids'].flatten() + attention_mask = encoding['attention_mask'].flatten() + + return { + 'text': text, + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'labels': torch.tensor([label], dtype=torch.long) + } + + +def create_dataset(data, tokenizer, max_len, use_bge): + texts = [item['text'] for item in data] + labels = [item['label'] for item in data] + return CustomDataset(texts, labels, tokenizer, max_len, use_bge) + + +def train(args): + classifier_dataset = load_dataset(args.dataset_path)["train"].to_dict() + n_labels = len(set(classifier_dataset["label"])) + train_data, val_data = train_test_split(classifier_dataset, test_size=0.1, random_state=42) + + use_bge = "bge" in args.embedding_model + + if not use_bge: + raise ValueError("Embedding model must be a BGE model at this time.") + + tokenizer = BGE_Tokenizer(model_name=args.embedding_model, normalize_embeddings=True) + model = AutoModelForSequenceClassification.from_pretrained(args.embedding_model, num_labels=n_labels) + + train_dataset = create_dataset(train_data, tokenizer, args.max_length, use_bge) + val_dataset = create_dataset(val_data, tokenizer, args.max_length, use_bge) + + train_config = yaml.safe_load(open(args.train_config_path, "r")) + training_args = TrainingArguments(**train_config["training_args"]) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + ) + + trainer.train() + + trainer.save_model("classifier") + + print(trainer.evaluate()) + + +if __name__ == "__main__": + args = get_args() + train(args) + + + diff --git a/configs/classifier_training_config.yaml b/configs/classifier_training_config.yaml new file mode 100644 index 0000000..b7aa0d6 --- /dev/null +++ b/configs/classifier_training_config.yaml @@ -0,0 +1,24 @@ +training_args: + output_dir: './bge-large-classifier-32' + num_train_epochs: 6 + per_device_train_batch_size: 32 + per_device_eval_batch_size: 32 + warmup_ratio: 0.03 + weight_decay: 0.01 + logging_dir: './logs-bge-large-32' + logging_steps: 10 + learning_rate: 1e-6 + evaluation_strategy: 'steps' + save_steps: 1000 + eval_steps: 1000 + save_total_limit: 3 + load_best_model_at_end: True + metric_for_best_model: "eval_loss" + greater_is_better: False + push_to_hub: True + hub_strategy: "all_checkpoints" + report_to: "wandb" + run_name: "bge-large-32" + wandb_project: "bge-large-32-classifier-HydraALPHA" + wandb_entity: "llama-moe" + diff --git a/data_utils.py b/data_utils.py index ff1a53e..ce37937 100644 --- a/data_utils.py +++ b/data_utils.py @@ -92,4 +92,4 @@ def user_assistant_template(conversation, max_words=1400): """Format conversation using user and assistant template.""" return alpaca_template(conversation, max_words=max_words, naming_map={"instruction": "User", - "output": "Assistant"}) + "output": "Assistant"}) \ No newline at end of file diff --git a/embed.py b/embed.py index 9d3baf4..37f520d 100644 --- a/embed.py +++ b/embed.py @@ -2,6 +2,7 @@ import numpy as np import argparse from FlagEmbedding import FlagModel +from transformers import AutoTokenizer, AutoModel def get_args(): @@ -64,6 +65,27 @@ def encode(self, texts, batch_size=256): return self.model.encode_queries(texts, batch_size=batch_size) +class BGE_Tokenizer: + """For classification""" + def __init__(self, model_name, normalize_embeddings, max_length=512): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModel.from_pretrained(model_name) + self.model.eval() + self.normalize_embeddings = normalize_embeddings + + def encode(self, texts): + encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512) + + with torch.no_grad(): + model_output = self.model(**encoded_input) + embeddings = model_output[0][:, 0] # CLS token pooling + + if self.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + return embeddings + + def embed_dataset(args): dataset = load_dataset(args.dataset_path)["train"] model = BGE(args.model_name, args.normalize, args.max_length)