Skip to content

Commit

Permalink
Feature/upgrade (#106)
Browse files Browse the repository at this point in the history
* support load from disk

* enable model_kwargs to support more configuration

* remove end_with_eos

* fix bugs

* fix bugs

* improve

* sperate the configuration of ibn (in-batch negative) and cln (contrastive learning with hard negative)

* update docs

* correct docs

* layout

* use last_hidden_state as the last layer outputs

---------

Co-authored-by: Sean Lee <[email protected]>
  • Loading branch information
SeanLee97 and Sean Lee authored Jan 11, 2025
1 parent 42659a7 commit 07f236b
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 100 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,9 @@ angle.fit(
gradient_accumulation_steps=1,
loss_kwargs={
'cosine_w': 1.0,
'ibn_w': 20.0,
'angle_w': 1.0,
'ibn_w': 1.0,
'cln_w': 1.0,
'angle_w': 0.02,
'cosine_tau': 20,
'ibn_tau': 20,
'angle_tau': 20
Expand All @@ -368,9 +369,11 @@ print('Spearman\'s corrcoef:', corrcoef)

## 💡 4. Fine-tuning Tips

For more details, please refer to the [documentation](https://angle.readthedocs.io/en/latest/notes/training.html#fine-tuning-tips).

1️⃣ If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`.

2️⃣ If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` is recommended to set to 20.0.
2️⃣ If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and set `angle_w` to a small value like 0.02. Be sure to set `cln_w` and `ibn_w`.

3️⃣ If your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters.

Expand Down
135 changes: 68 additions & 67 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def get_pooling(outputs: torch.Tensor,
outputs = outputs[:, 0]
elif pooling_strategy == 'cls_avg':
avg = torch.sum(
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1)
outputs = (outputs[:, 0] + avg) / 2.0
elif pooling_strategy == 'cls_max':
maximum, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1)
Expand Down Expand Up @@ -392,7 +392,6 @@ class AngleDataTokenizer:
If providing multiple placeholders in prompt_template, specify their name via extra_columns. Default None
:param dataset_format: Optional[str]. Specify dataset_format from DatasetFormats. Default None.
It will automatically detect the dataset format.
:param end_with_eos: bool. Specify whether ends with the eos token. Default False.
:param fix_data: bool. Specify whether fix the data. Only works when prompt_template is not None. Default True.
Example::
Expand All @@ -416,15 +415,13 @@ def __init__(self,
template_placeholders: Optional[List[str]] = None,
extra_columns: Optional[List[str]] = None,
dataset_format: Optional[str] = None,
end_with_eos: bool = False,
fix_data: bool = True):
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt_template = prompt_template
self.prompt_template_tok = None
self.extra_columns = extra_columns
self.dataset_format = dataset_format
self.end_with_eos = end_with_eos
self.fix_data = fix_data
if template_placeholders is None:
template_placeholders = ['condition', 'text']
Expand Down Expand Up @@ -478,8 +475,6 @@ def __call__(self, data: Dict) -> Dict:
continue
extra_placeholder[key] = val
extra_length += len(self.tokenizer(val, add_special_tokens=False)['input_ids'])
if self.end_with_eos:
extra_length += 1

if self.prompt_template_tok is not None:
max_length = self.max_length - len(self.prompt_template_tok['input_ids']) - extra_length
Expand Down Expand Up @@ -523,7 +518,6 @@ def __call__(self, data: Dict) -> Dict:
combined_tok['seperate_ids'] = seperate_ids
combined_tok['extra'] = {
'dataset_format': self.dataset_format,
'end_with_eos': self.end_with_eos,
'prompt_token_ids': self.prompt_template_tok['input_ids'] if self.prompt_template_tok is not None else None,
}
return combined_tok
Expand Down Expand Up @@ -551,7 +545,7 @@ class AngleDataCollator:
filter_duplicate: bool = True
coword_random_mask_rate: float = 0.0
special_token_id_names: List[str] = field(default_factory=lambda: [
'bos_token_id', 'eos_token_id', 'unk_token_id', 'sep_token_id',
'bos_token_id', 'unk_token_id', 'sep_token_id',
'pad_token_id', 'cls_token_id', 'mask_token_id'])

def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]:
Expand All @@ -564,7 +558,6 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
if return_tensors is None:
return_tensors = self.return_tensors
has_token_type_ids = "token_type_ids" in features[0]
end_with_eos = features[0]['extra']['end_with_eos']
prompt_token_ids = set(features[0]['extra']['prompt_token_ids'] or [])
special_token_ids = set()
for name in self.special_token_id_names:
Expand Down Expand Up @@ -661,22 +654,13 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
# remove features
del features

if end_with_eos:
features = {}
features['input_ids'] = [feature['input_ids'] + [self.tokenizer.eos_token_id] for feature in new_features]
features = self.tokenizer.pad(
features,
padding=self.padding,
return_attention_mask=True,
return_tensors=return_tensors)
else:
features = self.tokenizer.pad(
{'input_ids': [feature['input_ids'] for feature in new_features]},
padding=self.padding,
max_length=self.max_length,
return_attention_mask=True,
return_tensors=return_tensors,
)
features = self.tokenizer.pad(
{'input_ids': [feature['input_ids'] for feature in new_features]},
padding=self.padding,
max_length=self.max_length,
return_attention_mask=True,
return_tensors=return_tensors,
)
features['labels'] = torch.Tensor([feature['labels'] for feature in new_features])

if self.coword_random_mask_rate > 0:
Expand Down Expand Up @@ -727,11 +711,16 @@ def __call__(self,
Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None.
:param return_mlm_logits: bool. Return logits or not. Default False.
"""
ret = self.model(output_hidden_states=True, return_dict=True, **inputs)
all_layer_outputs = ret.hidden_states
if return_all_layer_outputs:
return (all_layer_outputs, ret.logits) if return_mlm_logits else all_layer_outputs
outputs = all_layer_outputs[layer_index]
if layer_index == -1 and not return_all_layer_outputs:
ret = self.model(**inputs)
outputs = ret.last_hidden_state
else:
ret = self.model(output_hidden_states=True, return_dict=True, **inputs)
all_layer_outputs = ret.hidden_states
all_layer_outputs[-1] = ret.last_hidden_state
if return_all_layer_outputs:
return (all_layer_outputs, ret.logits) if return_mlm_logits else all_layer_outputs
outputs = all_layer_outputs[layer_index]
outputs = get_pooling(outputs, inputs,
pooling_strategy or self.pooling_strategy,
padding_side=self.padding_side)
Expand Down Expand Up @@ -771,10 +760,12 @@ def __init__(self,
teacher_name_or_path: Optional[str] = None,
teacher_pooling_strategy: str = 'cls',
pad_token_id: int = 0,
model_kwargs: Optional[Dict] = None,
**kwargs):
super().__init__(**kwargs)
self.pooler = pooler
self.pad_token_id = pad_token_id
self.model_kwargs = model_kwargs
if loss_kwargs is None:
loss_kwargs = {}
self.loss_fct = AngleLoss(dataset_format=dataset_format, **loss_kwargs)
Expand All @@ -788,7 +779,8 @@ def __init__(self,
teacher_backbone = AutoModel.from_pretrained(
teacher_name_or_path,
trust_remote_code=True,
torch_dtype=self.pooler.model.dtype).to(self.pooler.model.device)
torch_dtype=self.pooler.model.dtype,
**self.model_kwargs).to(self.pooler.model.device)

self.teacher_pooler = Pooler(
teacher_backbone,
Expand Down Expand Up @@ -1008,29 +1000,29 @@ class AngleLoss:
Configure AngleLoss.
:param cosine_w: float. weight for cosine_loss. Default 1.0
:param ibn_w: float. weight for contrastive loss. Default 1.0
:param ibn_w: float. weight for in batch negative loss. Default 1.0
:param cln_w: float. weight for contrastive learning with hard negative. Default 1.0
:param angle_w: float. weight for angle loss. Default 1.0
:param cosine_tau: float. tau for cosine loss. Default 20.0
:param ibn_tau: float. tau for contrastive loss. Default 20.0
:param ibn_tau: float. tau for in batch negative loss. Default 20.0
:param angle_tau: float. tau for angle loss. Default 20.0
:param angle_pooling_strategy: str. pooling strategy for angle loss. Default'sum'.
:param dataset_format: Optional[str]. Default None.
"""
def __init__(self,
cosine_w: float = 0.0,
ibn_w: float = 20.0,
angle_w: float = 1.0,
ibn_w: float = 1.0,
cln_w: float = 1.0,
angle_w: float = 0.02,
cosine_tau: float = 20.0,
ibn_tau: float = 20.0,
angle_tau: float = 20.0,
angle_pooling_strategy: str = 'sum',
dataset_format: Optional[str] = None,
**kwargs):
if 'w1' in kwargs or 'w2' in kwargs or 'w3' in kwargs:
assert ('w1, w2, and w3 has been renamed to cosine_w, ibn_w, and angle_w, respecitvely.'
'Please use new names instead.')
self.cosine_w = cosine_w
self.ibn_w = ibn_w
self.cln_w = cln_w
self.angle_w = angle_w
self.cosine_tau = cosine_tau
self.ibn_tau = ibn_tau
Expand Down Expand Up @@ -1058,6 +1050,9 @@ def __call__(self,
loss += self.angle_w * angle_loss(labels, outputs, self.angle_tau,
pooling_strategy=self.angle_pooling_strategy)
elif self.dataset_format == DatasetFormats.B:
if int(self.cln_w) == 0:
logger.info('`cln_w` is set to zero. Contrastive learning with hard negative is disabled. '
'Please manually check whether it is correct.')
# text,positive,negative
text = outputs[::3]
positive = outputs[1::3]
Expand All @@ -1073,13 +1068,22 @@ def __call__(self,
combined_labels = torch.cat((positive_labels, negative_labels), dim=0)

loss = 0.
if self.cosine_w > 0:
loss += self.cosine_w * cosine_loss(combined_labels, combined_inputs, self.cosine_tau)
# contrastive learning loss
cll = 0.
if self.ibn_w > 0:
loss += self.ibn_w * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau)
cll += self.ibn_w * contrastive_with_negative_loss(text, positive, tau=self.ibn_tau)
if self.cln_w > 0:
cll += self.cln_w * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau)
if cll > 0:
loss += cll / 2
# angle loss
if self.angle_w > 0:
loss += self.angle_w * angle_loss(combined_labels, combined_inputs, self.angle_tau,
pooling_strategy=self.angle_pooling_strategy)
# cosine loss
if self.cosine_w > 0:
loss += self.cosine_w * cosine_loss(combined_labels, combined_inputs, self.cosine_tau)

elif self.dataset_format == DatasetFormats.C:
text = outputs[::2]
positive = outputs[1::2]
Expand Down Expand Up @@ -1180,6 +1184,8 @@ def __init__(self,
if torch_dtype is None:
torch_dtype = torch.float32 if train_mode else None

self.model_kwargs = model_kwargs if model_kwargs is not None else {}

lora_config = None
if self.apply_lora:
lora_config = {
Expand All @@ -1200,7 +1206,6 @@ def __init__(self,
if self.is_llm and self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = 0

model_kwargs = model_kwargs if model_kwargs is not None else {}
kbit_kwargs = kbit_kwargs if kbit_kwargs is not None else {}
if self.is_llm:
device_map = "auto"
Expand Down Expand Up @@ -1240,13 +1245,15 @@ def __init__(self,
torch_dtype=torch.float32,
device_map=device_map,
trust_remote_code=True,
**self.model_kwargs
)
else:
model = MODEL_CLASS.from_pretrained(model_name_or_path,
device_map=device_map,
output_hidden_states=True,
trust_remote_code=True,
torch_dtype=torch_dtype or torch.float16)
torch_dtype=torch_dtype or torch.float16,
**self.model_kwargs)
if train_mode and is_kbit:
model = prepare_model_for_kbit_training(model, **kbit_kwargs)

Expand Down Expand Up @@ -1277,13 +1284,17 @@ def __init__(self,
device_map=device_map,
output_hidden_states=True,
trust_remote_code=True,
torch_dtype=torch_dtype or torch.float16)
torch_dtype=torch_dtype or torch.float16,
**self.model_kwargs)
self.backbone = model
else:
MODEL_CLASS = AutoModelForMaskedLM if load_mlm_model else AutoModel
# non-LLMs
if self.apply_lora:
model = MODEL_CLASS.from_pretrained(pretrained_model_path or model_name_or_path, trust_remote_code=True)
model = MODEL_CLASS.from_pretrained(
pretrained_model_path or model_name_or_path,
trust_remote_code=True,
**self.model_kwargs)
if pretrained_lora_path is not None:
model = PeftModel.from_pretrained(
model,
Expand All @@ -1303,7 +1314,8 @@ def __init__(self,
logger.info(f'Load pretrained model from {pretrained_model_path}')
self.backbone = MODEL_CLASS.from_pretrained(
pretrained_model_path or model_name_or_path,
trust_remote_code=True)
trust_remote_code=True,
**self.model_kwargs)

if train_mode and self.apply_lora:
self.backbone.print_trainable_parameters()
Expand Down Expand Up @@ -1374,6 +1386,7 @@ def from_pretrained(model_name_or_path: str,
is_llm: Optional[bool] = None,
pooling_strategy: str = 'cls',
train_mode: bool = False,
model_kwargs: Optional[Dict] = None,
**kwargs):
"""
Load AnglE from pretrained model.
Expand All @@ -1398,6 +1411,7 @@ def from_pretrained(model_name_or_path: str,
# inference
angle.encode(*args, **kwargs)
"""
kwargs['model_kwargs'] = model_kwargs
angle = AnglE(model_name_or_path,
is_llm=is_llm,
pretrained_model_path=pretrained_model_path,
Expand Down Expand Up @@ -1567,6 +1581,7 @@ def fit(self,
filter_duplicate=filter_duplicate,
coword_random_mask_rate=coword_random_mask_rate,
),
model_kwargs=self.model_kwargs,
**trainer_kwargs
)
if torch.__version__ >= "2" and sys.platform != "win32":
Expand Down Expand Up @@ -1607,7 +1622,6 @@ def truncate_layer(self, layer_index: int):
def encode(self,
inputs: Union[List[str], Tuple[str], List[Dict], str],
max_length: Optional[int] = None,
end_with_eos: bool = False,
to_numpy: bool = True,
embedding_start: int = 0,
embedding_size: Optional[int] = None,
Expand Down Expand Up @@ -1639,26 +1653,13 @@ def encode(self,
for i, obj in enumerate(inputs):
assert isinstance(obj, dict), 'The prompt has been set, please pass a dict like {"prompt_key": "text"}'
inputs[i] = prompt.format(**obj)
max_length = max_length or self.max_length
if end_with_eos:
max_length -= 1

if end_with_eos:
tok = self.tokenizer(
inputs,
padding=False,
return_attention_mask=False,
max_length=max_length or self.max_length,
truncation=True)
tok['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in tok['input_ids']]
tok = self.tokenizer.pad(tok, padding=padding, return_attention_mask=True, return_tensors='pt')
else:
tok = self.tokenizer(
inputs,
padding=padding,
max_length=max_length or self.max_length,
truncation=True,
return_tensors='pt')

tok = self.tokenizer(
inputs,
padding=padding,
max_length=max_length or self.max_length,
truncation=True,
return_tensors='pt')
tok.to(device)
with torch.no_grad():
output = self.pooler(tok,
Expand Down
Loading

0 comments on commit 07f236b

Please sign in to comment.