-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] DeepseekV2 Support #499
base: main
Are you sure you want to change the base?
[Model] DeepseekV2 Support #499
Conversation
Add deepseekv2 convergence test
…war/Liger-Kernel into feature/deepseekv2
@ByronHsu @yundai424 @Tcc0403 @qingquansong deepseek: cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed llama: cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed` I will create a separate PR to implement the DeepSeek rope. |
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward | ||
if cross_entropy: | ||
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): | ||
from transformers.loss.loss_utils import nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since its so common to use import torch.nn as nn
, perhaps we import loss_utils under a different symbol?
Maybe even just from transformers.loss import loss_utils
?
import sys | ||
|
||
# Ensure the model is a DeepSeek model | ||
if "deepseek" not in model.__class__.__module__: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do deepseek and deepseek-v3 share the same architecture? If so, perhaps this function should be called apply_liger_kernel_to_deepseek
, if not, perhaps we should strengthen this check.
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): | ||
from transformers.loss.loss_utils import nn | ||
|
||
nn.functional.cross_entropy = liger_cross_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will globally patch from transformers.loss.loss_utils.function.cross_entropy
, which is a pretty undersireable / unexpected side effect of applying this deepseek-specific monkey patch.
See this issue: #315
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be fixed if deepseekv2 is added to the transformers library (see below comment about trust_remote_code
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
heavy plus^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yundai424 @tyler-romero I agree.
I think we should create a separate PR to refactor the monkey_patch file and fix this for all models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that the fix needs to be made in the transformers library - its because instead of importing an individual component they import an entire module, which makes it tough to monkeypatch without global side effects.
if model_name[:6] == "remote": | ||
revert_kwargs["remote_model_module"] = MINI_MODEL_SETUPS[model_name].remote_model_module | ||
|
||
model = create_model(model_name).to(dtype).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change to create the model before applying the patch?
model_class = MINI_MODEL_SETUPS[model_name].model_class | ||
return model_class(model_config) | ||
if model_name[:6] == "remote": | ||
config = AutoConfig.from_pretrained(MINI_MODEL_SETUPS[model_name].remote_model_path, trust_remote_code=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why this is necessary? Its it because the model cannot be run without trust_remote_code
? As is, this default opts-in anyone who runs these unit tests into running remote code on their machine, which is a red flag.
I think a preferable path would be to add deepseekv2 to the transformers library, then add it to Liger, so that trust_remote_code is not necessary.
This also has the benefit of making it easier to follow changes that are made to the underlying model, which is a common source of bugs in Liger.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like support for deepseekv2 is underway (maybe stalled though): huggingface/transformers#31976
|
||
DeepseekV2_INPUTS_DOCSTRING = r""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it'll be helpful to document where this part of docstring is ported from -- at least need a link to https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py
Summary
Resolves #129 Add monkeypatch to support deepseepV2 model.
Details
Ops patched:
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence