From 672d7e5bb49dcb34e1b2fdeb09f3f4588dc583a6 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Tue, 9 Jul 2024 11:43:41 +1000 Subject: [PATCH] feat: Add tie_weights parameter to Llava model initialization --- lmms_eval/models/llava.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 7528e356..7d6420ba 100755 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -58,6 +58,7 @@ def __init__( device_map="cuda:0", conv_template="vicuna_v1", use_cache=True, + tie_weights: bool = True, truncate_context=False, # whether to truncate the context in generation, set it False for LLaVA-1.6 customized_config=None, # ends in json **kwargs, @@ -97,6 +98,8 @@ def __init__( self._tokenizer, self._model, self._image_processor, self._max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, **llava_model_args) self._config = self._model.config self.model.eval() + if tie_weights: + self.model.tie_weights() self.truncation = truncation self.batch_size_per_gpu = int(batch_size)