diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 3b39c0d..eb010bc 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] files = setup.py angle_emb/__init__.py -current_version = 0.3.4 +current_version = 0.3.5 commit = True tag = True diff --git a/angle_emb/__init__.py b/angle_emb/__init__.py index b69a15c..445315d 100644 --- a/angle_emb/__init__.py +++ b/angle_emb/__init__.py @@ -3,4 +3,4 @@ from .angle import * -__version__ = '0.3.4' +__version__ = '0.3.5' diff --git a/angle_emb/angle.py b/angle_emb/angle.py index d0375d1..18b6250 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -744,6 +744,7 @@ def __init__(self, pooling_strategy='all', padding_strategy=self.pooler.padding_strategy, is_llm=False) + self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean') logger.info(f'Train with alignment, teacher={fixed_teacher_name_or_path}') def compute_loss(self, model, inputs, return_outputs=False): diff --git a/angle_emb/train_cli.py b/angle_emb/train_cli.py index 26e29ae..2edec00 100644 --- a/angle_emb/train_cli.py +++ b/angle_emb/train_cli.py @@ -166,14 +166,17 @@ def main(): argument_kwargs['report_to'] = 'wandb' trainer_kwargs = None - if args.apply_tdmse: + if args.fixed_teacher_name_or_path is not None: trainer_kwargs = { + 'fixed_teacher_name_or_path': args.fixed_teacher_name_or_path + } + if args.apply_tdmse: + trainer_kwargs = dict(trainer_kwargs, **{ 'apply_tdmse_kl': args.apply_tdmse_kl, 'tdmse_kl_temperature': args.tdmse_kl_temperature, 'tdmse_teacher_lambda': args.tdmse_teacher_lambda, 'tdmse_student_lambda': args.tdmse_student_lambda, - 'fixed_teacher_name_or_path': args.fixed_teacher_name_or_path, - } + }) model.fit( train_ds=train_ds, diff --git a/setup.py b/setup.py index 957c6a0..ff8dcc8 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name='angle_emb', - version='0.3.4', + version='0.3.5', description='AnglE-optimize Text Embeddings', long_description=readme, long_description_content_type="text/markdown",