Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed implementation of (IBM) Granite 3.1 Patch for Transformers #17

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

CoffeeVampir3
Copy link

@CoffeeVampir3 CoffeeVampir3 commented Jan 4, 2025

This is a proposed patch for IBM granite https://huggingface.co/ibm-granite/granite-3.1-8b-instruct which follows very closely with the llama patch, it's largely the same except for a small alteration where Granite models use logit scaling https://huggingface.co/ibm-granite/granite-3.1-8b-instruct/blob/main/config.json#L15

There's a hiccup at least when using TRL to train, where in cce_backward.py there's a mandatory change on line 257/258 to ensure the dtypes are float16

    e = e.half()
    c = c.half()

This was too hacky for my tastes, so I've omitted it from the patch and am not sure how to currently handle this as trl seems to want to use float32 for classifier and embedding weights and I couldn't find a way to fix it on the trl end. I'd appreciate your takes on a proper way to set this up.

For reference, if someone would like to use this patch now, the required alteration is simply to change cce_backward_kernel in cce_backward.py pending a better solution

def cce_backward_kernel(
    do: torch.Tensor,
    e: torch.Tensor,
    c: torch.Tensor,
    lse: torch.Tensor,
    valids: torch.Tensor | None,
    softcap: float | None,
    filter_eps: float | None,
    targets: torch.Tensor | None = None,
    shift: bool = False,
    vocab_ordering: torch.Tensor | None = None,
    grad_scale: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor]:
    assert do.numel() in (e.size(0), 1)
    assert c.size(1) == e.size(1)
    assert lse.size(0) == e.size(0) or (valids is not None and lse.size(0) == valids.size(0))
    
    e = e.half()
    c = c.half()
    assert e.dtype in (
        torch.float16,
        torch.bfloat16,
    ), "Backwards requires embeddings to be bf16 or fp16"
    assert c.dtype in (
        torch.float16,
        torch.bfloat16,
    ), "Backwards requires classifier to be bf16 or fp16"

I was able to get a good loss curve and proper training, it appears that this patch is working barring a few questions I have.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant