Skip to content

Commit

Permalink
fix saving custom code
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 16, 2023
1 parent 2c867b9 commit 1e13584
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 24 deletions.
12 changes: 6 additions & 6 deletions src/llmtuner/tuner/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead

from llmtuner.extras.logging import get_logger
Expand All @@ -36,7 +36,7 @@ def load_model_and_tokenizer(
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
r"""
Loads pretrained model and tokenizer.
Expand Down Expand Up @@ -113,12 +113,12 @@ def load_model_and_tokenizer(
)

# Register auto class to save the custom code files.
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
tokenizer.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
Expand Down
101 changes: 83 additions & 18 deletions tests/modeling_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,45 @@ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BaichuanModel):
module.gradient_checkpointing = value

@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)

@staticmethod
def _convert_to_baichuan_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)


class BaichuanModel(BaichuanPreTrainedModel):

Expand All @@ -318,9 +357,9 @@ def __init__(self, config: BaichuanConfig):

def get_input_embeddings(self):
return self.embed_tokens

def set_input_embeddings(self, value):
self.embed_tokens = value
self.embed_tokens = value

def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
return build_alibi_tensor(attention_mask, num_heads, dtype)
Expand Down Expand Up @@ -468,7 +507,7 @@ def custom_forward(*inputs):
hidden_states=all_hidden_states,
attentions=all_self_attns,
)


class BaichuanForCausalLM(BaichuanPreTrainedModel):

Expand Down Expand Up @@ -498,7 +537,7 @@ def set_decoder(self, decoder):

def get_decoder(self):
return self.model

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -528,7 +567,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
Expand Down Expand Up @@ -559,33 +598,59 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
if past_key_values:
input_ids = input_ids[:, -1:]

# the cache may be in the standard format (e.g. in contrastive search)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_baichuan_cache(past_key_values)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
return tuple(
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
for layer_past in past_key_values
def _reorder_cache(
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))

# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)

return self._convert_to_baichuan_cache(reordered_past)

def quantize(self, bits: int):
try:
Expand All @@ -594,7 +659,7 @@ def quantize(self, bits: int):
raise ImportError(
f"Needs QLinear to run quantize."
)

for layer in self.model.layers:
layer.self_attn.W_pack = QLinear(
bits=bits,
Expand All @@ -621,7 +686,7 @@ def quantize(self, bits: int):
weight=layer.mlp.up_proj.weight,
bias = None,
)
return self
return self

def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
Expand Down

0 comments on commit 1e13584

Please sign in to comment.