Skip to content

Commit

Permalink
Fixing fjformer import bug
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Nov 3, 2023
1 parent 44b84f9 commit 7e52037
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 9 deletions.
4 changes: 2 additions & 2 deletions lib/python/EasyDel/transform/falcon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax
from fjformer.utils import load_and_convert_checkpoint
from fjformer import load_and_convert_checkpoint_to_torch
from jax import numpy as jnp
from tqdm import tqdm
from transformers import FalconForCausalLM
Expand Down Expand Up @@ -148,7 +148,7 @@ def falcon_easydel_to_hf(path, config: FalconConfig):
"""
Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
"""
torch_params = load_and_convert_checkpoint(path)
torch_params = load_and_convert_checkpoint_to_torch(path)
edited_params = {}
for k, v in torch_params.items():
edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
Expand Down
4 changes: 2 additions & 2 deletions lib/python/EasyDel/transform/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from transformers import LlamaForCausalLM
from ..modules.llama import LlamaConfig
from fjformer.utils import load_and_convert_checkpoint
from fjformer import load_and_convert_checkpoint_to_torch


def inverse_permute(w, num_attention_heads, in_dim, out_dim):
Expand Down Expand Up @@ -150,7 +150,7 @@ def llama_easydel_to_hf(path, config: LlamaConfig):
"""
Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
"""
torch_params = load_and_convert_checkpoint(path)
torch_params = load_and_convert_checkpoint_to_torch(path)
edited_params = {}
for k, v in torch_params.items():
edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
Expand Down
4 changes: 2 additions & 2 deletions lib/python/EasyDel/transform/mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from fjformer.utils import load_and_convert_checkpoint
from fjformer import load_and_convert_checkpoint_to_torch
from jax import numpy as jnp
import jax
import torch
Expand Down Expand Up @@ -250,7 +250,7 @@ def mistral_easydel_to_hf(path, config: MistralConfig):
"""
Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
"""
torch_params = load_and_convert_checkpoint(path)
torch_params = load_and_convert_checkpoint_to_torch(path)
edited_params = {}
for k, v in torch_params.items():
edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
Expand Down
2 changes: 0 additions & 2 deletions lib/python/EasyDel/transform/mpt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from fjformer.utils import load_and_convert_checkpoint

from .. import MptConfig
from jax import numpy as jnp
import jax
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/transform/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fjformer import StreamingCheckpointer, float_tensor_to_dtype
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.traverse_util import flatten_dict


def match_keywords(string, positives, negatives):
Expand Down

0 comments on commit 7e52037

Please sign in to comment.