From 07f236b36a024fc6367085098a1da8bb1d45b986 Mon Sep 17 00:00:00 2001 From: Sean Date: Sat, 11 Jan 2025 10:07:45 +0800 Subject: [PATCH] Feature/upgrade (#106) * support load from disk * enable model_kwargs to support more configuration * remove end_with_eos * fix bugs * fix bugs * improve * sperate the configuration of ibn (in-batch negative) and cln (contrastive learning with hard negative) * update docs * correct docs * layout * use last_hidden_state as the last layer outputs --------- Co-authored-by: Sean Lee --- README.md | 9 ++- angle_emb/angle.py | 135 +++++++++++++++++++------------------ angle_emb/angle_trainer.py | 36 ++++++---- docs/notes/training.rst | 54 +++++++++++---- docs/notes/tutorial.rst | 10 +-- 5 files changed, 144 insertions(+), 100 deletions(-) diff --git a/README.md b/README.md index 819bef4..2b59c03 100644 --- a/README.md +++ b/README.md @@ -343,8 +343,9 @@ angle.fit( gradient_accumulation_steps=1, loss_kwargs={ 'cosine_w': 1.0, - 'ibn_w': 20.0, - 'angle_w': 1.0, + 'ibn_w': 1.0, + 'cln_w': 1.0, + 'angle_w': 0.02, 'cosine_tau': 20, 'ibn_tau': 20, 'angle_tau': 20 @@ -368,9 +369,11 @@ print('Spearman\'s corrcoef:', corrcoef) ## 💡 4. Fine-tuning Tips +For more details, please refer to the [documentation](https://angle.readthedocs.io/en/latest/notes/training.html#fine-tuning-tips). + 1️⃣ If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. -2️⃣ If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` is recommended to set to 20.0. +2️⃣ If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and set `angle_w` to a small value like 0.02. Be sure to set `cln_w` and `ibn_w`. 3️⃣ If your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6baae43..e044e9d 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -263,7 +263,7 @@ def get_pooling(outputs: torch.Tensor, outputs = outputs[:, 0] elif pooling_strategy == 'cls_avg': avg = torch.sum( - outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"]) + outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1) outputs = (outputs[:, 0] + avg) / 2.0 elif pooling_strategy == 'cls_max': maximum, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1) @@ -392,7 +392,6 @@ class AngleDataTokenizer: If providing multiple placeholders in prompt_template, specify their name via extra_columns. Default None :param dataset_format: Optional[str]. Specify dataset_format from DatasetFormats. Default None. It will automatically detect the dataset format. - :param end_with_eos: bool. Specify whether ends with the eos token. Default False. :param fix_data: bool. Specify whether fix the data. Only works when prompt_template is not None. Default True. Example:: @@ -416,7 +415,6 @@ def __init__(self, template_placeholders: Optional[List[str]] = None, extra_columns: Optional[List[str]] = None, dataset_format: Optional[str] = None, - end_with_eos: bool = False, fix_data: bool = True): self.tokenizer = tokenizer self.max_length = max_length @@ -424,7 +422,6 @@ def __init__(self, self.prompt_template_tok = None self.extra_columns = extra_columns self.dataset_format = dataset_format - self.end_with_eos = end_with_eos self.fix_data = fix_data if template_placeholders is None: template_placeholders = ['condition', 'text'] @@ -478,8 +475,6 @@ def __call__(self, data: Dict) -> Dict: continue extra_placeholder[key] = val extra_length += len(self.tokenizer(val, add_special_tokens=False)['input_ids']) - if self.end_with_eos: - extra_length += 1 if self.prompt_template_tok is not None: max_length = self.max_length - len(self.prompt_template_tok['input_ids']) - extra_length @@ -523,7 +518,6 @@ def __call__(self, data: Dict) -> Dict: combined_tok['seperate_ids'] = seperate_ids combined_tok['extra'] = { 'dataset_format': self.dataset_format, - 'end_with_eos': self.end_with_eos, 'prompt_token_ids': self.prompt_template_tok['input_ids'] if self.prompt_template_tok is not None else None, } return combined_tok @@ -551,7 +545,7 @@ class AngleDataCollator: filter_duplicate: bool = True coword_random_mask_rate: float = 0.0 special_token_id_names: List[str] = field(default_factory=lambda: [ - 'bos_token_id', 'eos_token_id', 'unk_token_id', 'sep_token_id', + 'bos_token_id', 'unk_token_id', 'sep_token_id', 'pad_token_id', 'cls_token_id', 'mask_token_id']) def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]: @@ -564,7 +558,6 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str if return_tensors is None: return_tensors = self.return_tensors has_token_type_ids = "token_type_ids" in features[0] - end_with_eos = features[0]['extra']['end_with_eos'] prompt_token_ids = set(features[0]['extra']['prompt_token_ids'] or []) special_token_ids = set() for name in self.special_token_id_names: @@ -661,22 +654,13 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str # remove features del features - if end_with_eos: - features = {} - features['input_ids'] = [feature['input_ids'] + [self.tokenizer.eos_token_id] for feature in new_features] - features = self.tokenizer.pad( - features, - padding=self.padding, - return_attention_mask=True, - return_tensors=return_tensors) - else: - features = self.tokenizer.pad( - {'input_ids': [feature['input_ids'] for feature in new_features]}, - padding=self.padding, - max_length=self.max_length, - return_attention_mask=True, - return_tensors=return_tensors, - ) + features = self.tokenizer.pad( + {'input_ids': [feature['input_ids'] for feature in new_features]}, + padding=self.padding, + max_length=self.max_length, + return_attention_mask=True, + return_tensors=return_tensors, + ) features['labels'] = torch.Tensor([feature['labels'] for feature in new_features]) if self.coword_random_mask_rate > 0: @@ -727,11 +711,16 @@ def __call__(self, Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None. :param return_mlm_logits: bool. Return logits or not. Default False. """ - ret = self.model(output_hidden_states=True, return_dict=True, **inputs) - all_layer_outputs = ret.hidden_states - if return_all_layer_outputs: - return (all_layer_outputs, ret.logits) if return_mlm_logits else all_layer_outputs - outputs = all_layer_outputs[layer_index] + if layer_index == -1 and not return_all_layer_outputs: + ret = self.model(**inputs) + outputs = ret.last_hidden_state + else: + ret = self.model(output_hidden_states=True, return_dict=True, **inputs) + all_layer_outputs = ret.hidden_states + all_layer_outputs[-1] = ret.last_hidden_state + if return_all_layer_outputs: + return (all_layer_outputs, ret.logits) if return_mlm_logits else all_layer_outputs + outputs = all_layer_outputs[layer_index] outputs = get_pooling(outputs, inputs, pooling_strategy or self.pooling_strategy, padding_side=self.padding_side) @@ -771,10 +760,12 @@ def __init__(self, teacher_name_or_path: Optional[str] = None, teacher_pooling_strategy: str = 'cls', pad_token_id: int = 0, + model_kwargs: Optional[Dict] = None, **kwargs): super().__init__(**kwargs) self.pooler = pooler self.pad_token_id = pad_token_id + self.model_kwargs = model_kwargs if loss_kwargs is None: loss_kwargs = {} self.loss_fct = AngleLoss(dataset_format=dataset_format, **loss_kwargs) @@ -788,7 +779,8 @@ def __init__(self, teacher_backbone = AutoModel.from_pretrained( teacher_name_or_path, trust_remote_code=True, - torch_dtype=self.pooler.model.dtype).to(self.pooler.model.device) + torch_dtype=self.pooler.model.dtype, + **self.model_kwargs).to(self.pooler.model.device) self.teacher_pooler = Pooler( teacher_backbone, @@ -1008,29 +1000,29 @@ class AngleLoss: Configure AngleLoss. :param cosine_w: float. weight for cosine_loss. Default 1.0 - :param ibn_w: float. weight for contrastive loss. Default 1.0 + :param ibn_w: float. weight for in batch negative loss. Default 1.0 + :param cln_w: float. weight for contrastive learning with hard negative. Default 1.0 :param angle_w: float. weight for angle loss. Default 1.0 :param cosine_tau: float. tau for cosine loss. Default 20.0 - :param ibn_tau: float. tau for contrastive loss. Default 20.0 + :param ibn_tau: float. tau for in batch negative loss. Default 20.0 :param angle_tau: float. tau for angle loss. Default 20.0 :param angle_pooling_strategy: str. pooling strategy for angle loss. Default'sum'. :param dataset_format: Optional[str]. Default None. """ def __init__(self, cosine_w: float = 0.0, - ibn_w: float = 20.0, - angle_w: float = 1.0, + ibn_w: float = 1.0, + cln_w: float = 1.0, + angle_w: float = 0.02, cosine_tau: float = 20.0, ibn_tau: float = 20.0, angle_tau: float = 20.0, angle_pooling_strategy: str = 'sum', dataset_format: Optional[str] = None, **kwargs): - if 'w1' in kwargs or 'w2' in kwargs or 'w3' in kwargs: - assert ('w1, w2, and w3 has been renamed to cosine_w, ibn_w, and angle_w, respecitvely.' - 'Please use new names instead.') self.cosine_w = cosine_w self.ibn_w = ibn_w + self.cln_w = cln_w self.angle_w = angle_w self.cosine_tau = cosine_tau self.ibn_tau = ibn_tau @@ -1058,6 +1050,9 @@ def __call__(self, loss += self.angle_w * angle_loss(labels, outputs, self.angle_tau, pooling_strategy=self.angle_pooling_strategy) elif self.dataset_format == DatasetFormats.B: + if int(self.cln_w) == 0: + logger.info('`cln_w` is set to zero. Contrastive learning with hard negative is disabled. ' + 'Please manually check whether it is correct.') # text,positive,negative text = outputs[::3] positive = outputs[1::3] @@ -1073,13 +1068,22 @@ def __call__(self, combined_labels = torch.cat((positive_labels, negative_labels), dim=0) loss = 0. - if self.cosine_w > 0: - loss += self.cosine_w * cosine_loss(combined_labels, combined_inputs, self.cosine_tau) + # contrastive learning loss + cll = 0. if self.ibn_w > 0: - loss += self.ibn_w * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau) + cll += self.ibn_w * contrastive_with_negative_loss(text, positive, tau=self.ibn_tau) + if self.cln_w > 0: + cll += self.cln_w * contrastive_with_negative_loss(text, positive, negative, tau=self.ibn_tau) + if cll > 0: + loss += cll / 2 + # angle loss if self.angle_w > 0: loss += self.angle_w * angle_loss(combined_labels, combined_inputs, self.angle_tau, pooling_strategy=self.angle_pooling_strategy) + # cosine loss + if self.cosine_w > 0: + loss += self.cosine_w * cosine_loss(combined_labels, combined_inputs, self.cosine_tau) + elif self.dataset_format == DatasetFormats.C: text = outputs[::2] positive = outputs[1::2] @@ -1180,6 +1184,8 @@ def __init__(self, if torch_dtype is None: torch_dtype = torch.float32 if train_mode else None + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + lora_config = None if self.apply_lora: lora_config = { @@ -1200,7 +1206,6 @@ def __init__(self, if self.is_llm and self.tokenizer.pad_token_id is None: self.tokenizer.pad_token_id = 0 - model_kwargs = model_kwargs if model_kwargs is not None else {} kbit_kwargs = kbit_kwargs if kbit_kwargs is not None else {} if self.is_llm: device_map = "auto" @@ -1240,13 +1245,15 @@ def __init__(self, torch_dtype=torch.float32, device_map=device_map, trust_remote_code=True, + **self.model_kwargs ) else: model = MODEL_CLASS.from_pretrained(model_name_or_path, device_map=device_map, output_hidden_states=True, trust_remote_code=True, - torch_dtype=torch_dtype or torch.float16) + torch_dtype=torch_dtype or torch.float16, + **self.model_kwargs) if train_mode and is_kbit: model = prepare_model_for_kbit_training(model, **kbit_kwargs) @@ -1277,13 +1284,17 @@ def __init__(self, device_map=device_map, output_hidden_states=True, trust_remote_code=True, - torch_dtype=torch_dtype or torch.float16) + torch_dtype=torch_dtype or torch.float16, + **self.model_kwargs) self.backbone = model else: MODEL_CLASS = AutoModelForMaskedLM if load_mlm_model else AutoModel # non-LLMs if self.apply_lora: - model = MODEL_CLASS.from_pretrained(pretrained_model_path or model_name_or_path, trust_remote_code=True) + model = MODEL_CLASS.from_pretrained( + pretrained_model_path or model_name_or_path, + trust_remote_code=True, + **self.model_kwargs) if pretrained_lora_path is not None: model = PeftModel.from_pretrained( model, @@ -1303,7 +1314,8 @@ def __init__(self, logger.info(f'Load pretrained model from {pretrained_model_path}') self.backbone = MODEL_CLASS.from_pretrained( pretrained_model_path or model_name_or_path, - trust_remote_code=True) + trust_remote_code=True, + **self.model_kwargs) if train_mode and self.apply_lora: self.backbone.print_trainable_parameters() @@ -1374,6 +1386,7 @@ def from_pretrained(model_name_or_path: str, is_llm: Optional[bool] = None, pooling_strategy: str = 'cls', train_mode: bool = False, + model_kwargs: Optional[Dict] = None, **kwargs): """ Load AnglE from pretrained model. @@ -1398,6 +1411,7 @@ def from_pretrained(model_name_or_path: str, # inference angle.encode(*args, **kwargs) """ + kwargs['model_kwargs'] = model_kwargs angle = AnglE(model_name_or_path, is_llm=is_llm, pretrained_model_path=pretrained_model_path, @@ -1567,6 +1581,7 @@ def fit(self, filter_duplicate=filter_duplicate, coword_random_mask_rate=coword_random_mask_rate, ), + model_kwargs=self.model_kwargs, **trainer_kwargs ) if torch.__version__ >= "2" and sys.platform != "win32": @@ -1607,7 +1622,6 @@ def truncate_layer(self, layer_index: int): def encode(self, inputs: Union[List[str], Tuple[str], List[Dict], str], max_length: Optional[int] = None, - end_with_eos: bool = False, to_numpy: bool = True, embedding_start: int = 0, embedding_size: Optional[int] = None, @@ -1639,26 +1653,13 @@ def encode(self, for i, obj in enumerate(inputs): assert isinstance(obj, dict), 'The prompt has been set, please pass a dict like {"prompt_key": "text"}' inputs[i] = prompt.format(**obj) - max_length = max_length or self.max_length - if end_with_eos: - max_length -= 1 - - if end_with_eos: - tok = self.tokenizer( - inputs, - padding=False, - return_attention_mask=False, - max_length=max_length or self.max_length, - truncation=True) - tok['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in tok['input_ids']] - tok = self.tokenizer.pad(tok, padding=padding, return_attention_mask=True, return_tensors='pt') - else: - tok = self.tokenizer( - inputs, - padding=padding, - max_length=max_length or self.max_length, - truncation=True, - return_tensors='pt') + + tok = self.tokenizer( + inputs, + padding=padding, + max_length=max_length or self.max_length, + truncation=True, + return_tensors='pt') tok.to(device) with torch.no_grad(): output = self.pooler(tok, diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index c52f7a6..405c3fb 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -6,7 +6,7 @@ import numpy as np import torch -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from angle_emb import AnglE, AngleDataTokenizer from angle_emb.utils import logger @@ -61,10 +61,12 @@ help='Specify dataset workers, default 2') parser.add_argument('--cosine_w', type=float, default=0.0, help='Specify weight for cosine loss, default 0.0') -parser.add_argument('--ibn_w', type=float, default=30.0, - help='Specify weight for ibn loss, default 30.0') -parser.add_argument('--angle_w', type=float, default=1.0, - help='Specify weight for angle loss, default 1.0') +parser.add_argument('--ibn_w', type=float, default=1.0, + help='Specify weight for in-batch negative loss, default 1.0') +parser.add_argument('--cln_w', type=float, default=1.0, + help='Specify weight for contrastive learning with hard negative loss, default 1.0') +parser.add_argument('--angle_w', type=float, default=0.02, + help='Specify weight for angle loss, default 0.02') parser.add_argument('--angle_tau', type=float, default=20.0, help='Specify angle_tau, default 20.0') parser.add_argument('--cosine_tau', type=float, default=20.0, @@ -208,10 +210,13 @@ def main(): load_mlm_model=args.load_mlm_model) if os.path.exists(args.train_name_or_path): - ds = load_dataset('json', - data_files=[args.train_name_or_path], - num_proc=args.workers, - streaming=args.streaming) + if os.path.isdir(args.train_name_or_path): + ds = load_from_disk(args.train_name_or_path, num_proc=args.workers) + else: + ds = load_dataset('json', + data_files=[args.train_name_or_path], + num_proc=args.workers, + streaming=args.streaming) else: ds = load_dataset(args.train_name_or_path, args.train_subset_name, @@ -238,7 +243,10 @@ def main(): if valid_ds is None and args.valid_name_or_path is not None: logger.info('Validation detected, processing validation...') if os.path.exists(args.valid_name_or_path): - valid_ds = load_dataset('json', data_files=[args.valid_name_or_path], num_proc=args.workers) + if os.path.isdir(args.valid_name_or_path): + valid_ds = load_from_disk(args.valid_name_or_path, num_proc=args.workers) + else: + valid_ds = load_dataset('json', data_files=[args.valid_name_or_path], num_proc=args.workers) else: if args.valid_subset_name is not None: valid_ds = load_dataset(args.valid_name_or_path, args.valid_subset_name, num_proc=args.workers) @@ -254,8 +262,11 @@ def main(): if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None: logger.info('Validation for callback detected, processing validation...') if os.path.exists(args.valid_name_or_path_for_callback): - valid_ds_for_callback = load_dataset( - 'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers) + if os.path.isdir(args.valid_name_or_path_for_callback): + valid_ds_for_callback = load_from_disk(args.valid_name_or_path_for_callback, num_proc=args.workers) + else: + valid_ds_for_callback = load_dataset( + 'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers) else: if args.valid_subset_name_for_callback is not None: valid_ds_for_callback = load_dataset( @@ -314,6 +325,7 @@ def main(): loss_kwargs={ 'cosine_w': args.cosine_w, 'ibn_w': args.ibn_w, + 'cln_w': args.cln_w, 'angle_w': args.angle_w, 'cosine_tau': args.cosine_tau, 'ibn_tau': args.ibn_tau, diff --git a/docs/notes/training.rst b/docs/notes/training.rst index e83a5b9..a7d96f2 100644 --- a/docs/notes/training.rst +++ b/docs/notes/training.rst @@ -55,9 +55,10 @@ You can train a powerful sentence embedding model using the `angle-trainer` cli --model_name_or_path google-bert/bert-base-uncased \ --pooling_strategy cls \ --maxlen 128 \ - --ibn_w 30.0 \ + --ibn_w 1.0 \ + --cln_w 1.0 \ --cosine_w 0.0 \ - --angle_w 1.0 \ + --angle_w 0.02 \ --angle_tau 20.0 \ --learning_rate 5e-5 \ --push_to_hub 1 --hub_model_id SeanLee97/bert-base-nli-test-0728 --hub_private_repo 1 \ @@ -82,13 +83,14 @@ You can train a powerful sentence embedding model using the `angle-trainer` cli --model_name_or_path NousResearch/Llama-2-7b-chat-hf \ --pooling_strategy avg \ --maxlen 60 \ - --ibn_w 20.0 \ + --ibn_w 1.0 \ + --cln_w 1.0 \ --cosine_w 0.0 \ - --angle_w 1.0 \ + --angle_w 0.02 \ --learning_rate 2e-4 \ - --prompt_template "Represent the following sentence for semantic textual similarity: {text} <|endoftext|>" \ + --prompt_template "Represent the following sentence for semantic textual similarity: {text}" \ --apply_lora 1 --lora_r 64 --lora_alpha 128 --lora_dropout 0.1 \ - --load_kbit 4 \ + --load_kbit 16 \ --is_llm 1 \ --push_to_hub 1 --hub_model_id SeanLee97/test-llama7b-nli --hub_private_repo 1 \ --logging_steps 5 \ @@ -111,16 +113,17 @@ You can train a powerful sentence embedding model using the `angle-trainer` cli --model_name_or_path NousResearch/Llama-2-7b-chat-hf \ --pooling_strategy avg \ --maxlen 60 \ - --ibn_w 20.0 \ + --ibn_w 1.0 \ + --cln_w 1.0 \ --cosine_w 0.0 \ - --angle_w 1.0 \ + --angle_w 0.02 \ --learning_rate 2e-4 \ --apply_lora 1 --lora_r 64 --lora_alpha 128 --lora_dropout 0.1 \ - --load_kbit 4 \ + --load_kbit 16 \ --is_llm 1 \ --apply_billm 1 \ --billm_model_class LlamaForCausalLM \ - --prompt_template "Represent the following sentence for semantic textual similarity: {text} <|endoftext|>" \ + --prompt_template "Represent the following sentence for semantic textual similarity: {text}" \ --push_to_hub 1 --hub_model_id SeanLee97/test-billm-llama7b-nli --hub_private_repo 1 \ --logging_steps 5 \ --save_steps 50 \ @@ -169,9 +172,10 @@ You can also train a sentence embedding model using the `angle_emb` library. Her warmup_steps=0, gradient_accumulation_steps=1, loss_kwargs={ - 'cosine_w': 1.0, - 'ibn_w': 20.0, - 'angle_w': 1.0, + 'cosine_w': 0.0, + 'ibn_w': 1.0, + 'cln_w': 1.0, + 'angle_w': 0.02, 'cosine_tau': 20, 'ibn_tau': 20, 'angle_tau': 20 @@ -190,12 +194,34 @@ You can also train a sentence embedding model using the `angle_emb` library. Her :alt: Open In Colab + + +💡 Hyperparameters +------------------------- + +1. `angle_w`: the weight for angle loss. Default `0.02` + +2. `ibn_w`: the weight for in-batch negative loss. Default `1.0` + +3. `cln_w`: the weight for contrastive learning with hard negative loss. Default `1.0` + +4. `cosine_w`: the weight for cosine loss. Default `0.0` + +5. `angle_tau`: the temperature for angle loss. Default `20.0` + +6. `ibn_tau`: the temperature for ibn and cln losses. Default `20.0` + +7. `cosine_tau`: the temperature for cosine loss. Default `20.0` + + + + 💡 Fine-tuning Tips ------------------------- 1. If your dataset format is `DatasetFormats.A`, it is recommended to slightly increase the weight for `cosine_w` or slightly decrease the weight for `ibn_w`. -2. If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and increase the weight for `ibn_w` such as 10 and 20. The `angle_tau` is recommended to set to 20.0. +2. If your dataset format is `DatasetFormats.B`, it is recommended to set `cosine_w` to 0, and set `angle_w` to a small value like 0.02. Be sure to set `cln_w` and `ibn_w`. 3. If your dataset format is `DatasetFormats.C`, only `ibn_w` and `ibn_tau` are effective. You don't need to tune other parameters. diff --git a/docs/notes/tutorial.rst b/docs/notes/tutorial.rst index 148ee91..9beea41 100644 --- a/docs/notes/tutorial.rst +++ b/docs/notes/tutorial.rst @@ -42,9 +42,10 @@ Here's an example of training a BERT-base model: --model_name_or_path microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext \ --pooling_strategy cls \ --maxlen 75 \ - --ibn_w 20.0 \ + --ibn_w 1.0 \ + --cln_w 1.0 \ --cosine_w 0.0 \ - --angle_w 1.0 \ + --angle_w 0.02 \ --learning_rate 1e-6 \ --logging_steps 5 \ --save_steps 500 \ @@ -69,9 +70,10 @@ And here's another example of training a BERT-large model: --load_mlm_model 1 \ --pooling_strategy cls \ --maxlen 75 \ - --ibn_w 20.0 \ + --ibn_w 1.0 \ + --cln_w 1.0 \ --cosine_w 0.0 \ - --angle_w 1.0 \ + --angle_w 0.02 \ --learning_rate 1e-6 \ --logging_steps 5 \ --save_steps 500 \