Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve/2dmse #56

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 74 additions & 40 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import sys
import json
import copy
import random
from functools import partial
from typing import Any, Dict, Optional, List, Union, Tuple, Callable
Expand Down Expand Up @@ -581,12 +582,13 @@ class AngleDataCollator:
:param padding: Union[bool, str, PaddingStrategy], padding strategy
:param max_length: Optional[int], max length
:param return_tensors: str

:param filter_duplicate: bool. Whether filter duplicate data
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = 'longest'
max_length: Optional[int] = None
return_tensors: str = "pt"
filter_duplicate: bool = True

def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]:
if return_tensors is None:
Expand All @@ -595,6 +597,7 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
end_with_eos = features[0]['extra']['end_with_eos']

new_features = []
duplicate_set = set()
for feature in features:
seperate_ids = feature['seperate_ids']
input_ids = feature['input_ids']
Expand All @@ -609,26 +612,41 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str

max_seperate_id = max(seperate_ids)
prev_start_idx = 0
current_features = []
is_duplicate = False
for seperate_id in range(1, max_seperate_id + 1):
start_idx = seperate_ids.index(seperate_id)

new_feature = {}
new_feature['input_ids'] = input_ids[prev_start_idx:start_idx]
new_input_ids = input_ids[prev_start_idx:start_idx]
if tuple(new_input_ids) in duplicate_set:
is_duplicate = True
if self.filter_duplicate:
break
duplicate_set.add(tuple(new_input_ids))
new_feature['input_ids'] = new_input_ids
new_feature['attention_mask'] = attention_mask[prev_start_idx:start_idx]
if has_token_type_ids:
new_feature['token_type_ids'] = token_type_ids[prev_start_idx:start_idx]
new_feature['labels'] = feature['labels']
new_features.append(new_feature)
current_features.append(new_feature)
prev_start_idx = start_idx

# last
new_feature = {}
new_feature['input_ids'] = input_ids[prev_start_idx:]
new_input_ids = input_ids[prev_start_idx:]
if tuple(new_input_ids) in duplicate_set:
is_duplicate = True
duplicate_set.add(tuple(new_input_ids))
new_feature['input_ids'] = new_input_ids
new_feature['attention_mask'] = attention_mask[prev_start_idx:]
if has_token_type_ids:
new_feature['token_type_ids'] = token_type_ids[prev_start_idx:]
new_feature['labels'] = feature['labels']
new_features.append(new_feature)
current_features.append(new_feature)

if self.filter_duplicate and is_duplicate:
continue
new_features += current_features

# remove features
del features
Expand Down Expand Up @@ -685,13 +703,17 @@ def __init__(self,
self.padding_strategy = padding_strategy
self.is_llm = is_llm

def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None) -> torch.Tensor:
def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None,
return_all_layer_outputs: bool = False) -> torch.Tensor:
"""
:param inputs: Dict. Model inputs.
:param layer_index: int. Get embeddings from specific layer.
:param embedding_size: int. Set embedding size for sentence embeddings for 2DMSE models.
"""
outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states[layer_index]
all_layer_outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states
if return_all_layer_outputs:
return all_layer_outputs
outputs = all_layer_outputs[layer_index]
if self.is_llm:
batch_size = inputs['input_ids'].shape[0]
sequence_lengths = -1 if self.padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1
Expand Down Expand Up @@ -802,46 +824,48 @@ def __init__(self,
self.tdmse_student_lambda = tdmse_student_lambda
self.apply_tdmse_kl = apply_tdmse_kl
self.n_layers = self.pooler.model.config.num_hidden_layers
self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.pooler.model.config.hidden_size)
self.hidden_size = self.pooler.model.config.hidden_size
self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.hidden_size)
self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean')
logger.info('Train 2DMSE!')
logger.info('Train with 2DMSE!')

def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels", None)
# layer
sample_layer = random.randint(1, self.n_layers - 1)
if self.fixed_teacher_name_or_path is not None:
all_teacher_outputs = self.pooler(inputs, layer_index=-1)
teacher_outputs = get_pooling(all_teacher_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
all_student_outputs = self.pooler(inputs, layer_index=sample_layer)
student_outputs = get_pooling(all_student_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
else:
teacher_outputs = self.pooler(inputs, layer_index=-1)
student_outputs = self.pooler(inputs, layer_index=sample_layer)

kl_outputs = teacher_outputs
pooling_strategy = (self.alignment_pooling_strategy
if self.pooler.pooling_strategy == 'all'
else self.pooler.pooling_strategy)
all_layer_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True)
all_teacher_outputs = all_layer_outputs[-1]
teacher_outputs = get_pooling(all_teacher_outputs, inputs,
pooling_strategy,
self.pooler.padding_strategy)
all_student_outputs = all_layer_outputs[sample_layer]
student_outputs = get_pooling(all_student_outputs,
inputs,
pooling_strategy,
self.pooler.padding_strategy)

teacher_kl_outputs = teacher_outputs
if self.fixed_teacher_name_or_path is not None:
with torch.no_grad():
self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device)
all_fixed_outputs = self.fixed_teacher_pooler(inputs)
kl_outputs = get_pooling(all_fixed_outputs, inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)
teacher_kl_outputs = get_pooling(all_fixed_outputs,
inputs,
self.alignment_pooling_strategy,
self.pooler.padding_strategy)

teacher_loss = self.loss_fct(labels, teacher_outputs)
loss1 = self.tdmse_teacher_lambda * teacher_loss
if self.tdmse_student_lambda > 0:
student_loss = self.loss_fct(labels, student_outputs)
loss1 += self.tdmse_student_lambda * student_loss
loss1 = teacher_loss
student_loss = self.loss_fct(labels, student_outputs)
loss1 += student_loss / sample_layer
if self.apply_tdmse_kl and self.tdmse_student_lambda > 0:
kl_loss = self.kl_loss_fct(
F.log_softmax(student_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1),
F.softmax(kl_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1)
) * self.tdmse_kl_temperature**2
F.log_softmax(student_outputs / self.tdmse_kl_temperature, dim=-1),
F.softmax(teacher_kl_outputs / self.tdmse_kl_temperature, dim=-1)
) * self.tdmse_kl_temperature
loss1 += kl_loss

# feature
Expand All @@ -850,10 +874,10 @@ def compute_loss(self, model, inputs, return_outputs=False):
slimmed_student_outputs = student_outputs[:, :hidden_size]

slimmed_teacher_loss = self.loss_fct(labels, slimmed_teacher_outputs)
loss2 = self.tdmse_teacher_lambda * slimmed_teacher_loss
if self.tdmse_student_lambda > 0:
slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs)
loss2 += self.tdmse_student_lambda * slimmed_student_loss
loss2 = slimmed_teacher_loss
slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs)
loss2 += slimmed_student_loss / sample_layer

loss = loss1 + loss2

if self.fixed_teacher_name_or_path is not None:
Expand Down Expand Up @@ -1216,6 +1240,8 @@ def __init__(self,
padding_strategy=self.tokenizer.padding_side,
is_llm=self.is_llm)

# full_backbone is used to 2DMSE inference
self.full_backbone = None
self.__cfg = {
'model_name_or_path': model_name_or_path,
'max_length': max_length,
Expand Down Expand Up @@ -1334,7 +1360,8 @@ def fit(self,
argument_kwargs: Optional[Dict] = None,
trainer_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[Dict] = None,
apply_tdmse: bool = False):
apply_tdmse: bool = False,
filter_duplicate: bool = True):
"""
Fit using AnglE.

Expand Down Expand Up @@ -1412,7 +1439,7 @@ def fit(self,
),
callbacks=callbacks,
data_collator=AngleDataCollator(
self.tokenizer, return_tensors="pt", max_length=self.max_length
self.tokenizer, return_tensors="pt", max_length=self.max_length, filter_duplicate=filter_duplicate
),
**trainer_kwargs
)
Expand All @@ -1428,6 +1455,7 @@ def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[floa
self.tokenizer,
return_tensors="pt",
max_length=self.max_length,
filter_duplicate=False,
)
y_trues, y_preds = [], []
# for X, y in data.make_iter(random=False):
Expand Down Expand Up @@ -1474,6 +1502,12 @@ def encode(self,
:param embedding_size: Optional[int]. Specify embedding size (for 2DMSE).
:param device: Optional[Any]. Default None.
"""
if layer_index != -1 and self.full_backbone is None:
self.full_backbone = copy.deepcopy(self.backbone)

if layer_index != -1:
self.backbone.encoder.layer = self.full_backbone.encoder.layer[:layer_index]

if device is None:
device = self.device
self.backbone.eval()
Expand Down
15 changes: 11 additions & 4 deletions angle_emb/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
help='Specify huggingface datasets subset name for train set')
parser.add_argument('--train_split_name', type=str, default='train',
help='Specify huggingface datasets split name for train set, Default `train`')
parser.add_argument('--valid_split_name', type=str, default=None,
help='Specify huggingface datasets split name for valid set, Default None')
parser.add_argument('--valid_name_or_path', type=str, default=None,
help='Specify huggingface datasets name or local file path for valid set.')
parser.add_argument('--valid_subset_name', type=str, default=None,
help='Specify huggingface datasets subset name for valid set')
parser.add_argument('--prompt_template', type=str, default=None,
help='Specify prompt_template like "Instruct: xxx\nInput: {text}", default None')
parser.add_argument('--save_dir', type=str, default=None,
Expand Down Expand Up @@ -150,10 +152,15 @@ def main():
train_ds = ds[args.train_split_name].shuffle(args.dataset_seed).map(
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template), num_proc=args.workers)

valid_ds = None
if args.valid_split_name is not None:
if valid_ds is None and args.valid_name_or_path is not None:
logger.info('Validation detected, processing validation...')
valid_ds = ds[args.valid_split_name].shuffle(args.dataset_seed).map(
if os.path.exists(args.valid_name_or_path):
valid_ds = load_dataset('json', data_files=[args.valid_name_or_path])
else:
valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name)
valid_ds = valid_ds[args.valid_subset_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length,
prompt_template=args.prompt_template), num_proc=args.workers)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_loadding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,14 @@ def test_loadding():
assert isinstance(vecs, np.ndarray)
vecs = angle.encode([{'text': 'hello world', 'text': 'hi there👋'}])
assert isinstance(vecs, np.ndarray)


def test_2dmse_loadding():
import numpy as np
from angle_emb import AnglE

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1')
vecs = angle.encode('hello world', layer_index=20)
assert isinstance(vecs, np.ndarray)
vecs = angle.encode(['hello world', 'hi there👋'], layer_index=20, embedding_size=512)
assert isinstance(vecs, np.ndarray)
Loading