From f65976d54f13d5673f418f698108cb66266f08fe Mon Sep 17 00:00:00 2001 From: Z <coffeevampirebusiness@gmail.com> Date: Sat, 4 Jan 2025 07:54:11 -0700 Subject: [PATCH 1/2] Add files via upload --- cut_cross_entropy/transformers/granite.py | 115 ++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 cut_cross_entropy/transformers/granite.py diff --git a/cut_cross_entropy/transformers/granite.py b/cut_cross_entropy/transformers/granite.py new file mode 100644 index 0000000..48eceb0 --- /dev/null +++ b/cut_cross_entropy/transformers/granite.py @@ -0,0 +1,115 @@ +from types import MethodType +from typing import List, Optional, Tuple, Union + +import torch +import transformers +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ( + add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, + replace_return_docstrings, +) + +from cut_cross_entropy import linear_cross_entropy +from .utils import PatchOptions, TransformersModelT + +_PATCH_OPTS: PatchOptions | None = None + +def cce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + original_dtype = hidden_states.dtype + loss = None + logits = None + + if labels is not None and _PATCH_OPTS is not None: + # Granite uses logit scaling + if hasattr(self.config, 'logits_scaling'): + scaling = torch.tensor(self.config.logits_scaling, dtype=original_dtype, device=hidden_states.device) + hidden_states = hidden_states / scaling + loss = linear_cross_entropy( + hidden_states, + self.lm_head.weight, + labels.to(hidden_states.device), + shift=True, + impl=_PATCH_OPTS.impl, + reduction=_PATCH_OPTS.reduction, + ) + else: + # Granite uses logit scaling + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if hasattr(self.config, 'logits_scaling'): + scaling = torch.tensor(self.config.logits_scaling, dtype=original_dtype, device=hidden_states.device) + hidden_states = hidden_states / scaling + logits = logits.float() + + if labels is not None: + loss_fct = CrossEntropyLoss() + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def patch_granite( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS + from transformers.models.granite import modeling_granite + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_granite.GraniteForCausalLM + ), f"Expected a GraniteForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + else: + modeling_granite.GraniteForCausalLM.forward = cce_forward \ No newline at end of file From 4416b37da75f87e32f659384f73b2bdf289a88a2 Mon Sep 17 00:00:00 2001 From: Z <coffeevampirebusiness@gmail.com> Date: Sat, 4 Jan 2025 07:54:48 -0700 Subject: [PATCH 2/2] Update patch.py --- cut_cross_entropy/transformers/patch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cut_cross_entropy/transformers/patch.py b/cut_cross_entropy/transformers/patch.py index 3dd7c69..e4d80d8 100644 --- a/cut_cross_entropy/transformers/patch.py +++ b/cut_cross_entropy/transformers/patch.py @@ -9,6 +9,7 @@ from .llama import patch_llama from .mistral import patch_mistral from .phi3 import patch_phi3 +from .granite import patch_granite from .utils import PatchOptions, TransformersModelT @@ -57,5 +58,7 @@ def cce_patch( return patch_gemma2(model_type_or_model, patch_options) case "mistral": return patch_mistral(model_type_or_model, patch_options) + case "granite": + return patch_granite(model_type_or_model, patch_options) case _: raise RuntimeError(f"Unknown model type {model_type}")