diff --git a/lib/python/EasyDel/transform/falcon.py b/lib/python/EasyDel/transform/falcon.py index a3386b27e..96b045d72 100644 --- a/lib/python/EasyDel/transform/falcon.py +++ b/lib/python/EasyDel/transform/falcon.py @@ -1,5 +1,5 @@ import jax -from fjformer.utils import load_and_convert_checkpoint +from fjformer import load_and_convert_checkpoint_to_torch from jax import numpy as jnp from tqdm import tqdm from transformers import FalconForCausalLM @@ -148,7 +148,7 @@ def falcon_easydel_to_hf(path, config: FalconConfig): """ Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface) """ - torch_params = load_and_convert_checkpoint(path) + torch_params = load_and_convert_checkpoint_to_torch(path) edited_params = {} for k, v in torch_params.items(): edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v diff --git a/lib/python/EasyDel/transform/llama.py b/lib/python/EasyDel/transform/llama.py index 35f766aee..7949eaba0 100644 --- a/lib/python/EasyDel/transform/llama.py +++ b/lib/python/EasyDel/transform/llama.py @@ -5,7 +5,7 @@ import torch from transformers import LlamaForCausalLM from ..modules.llama import LlamaConfig -from fjformer.utils import load_and_convert_checkpoint +from fjformer import load_and_convert_checkpoint_to_torch def inverse_permute(w, num_attention_heads, in_dim, out_dim): @@ -150,7 +150,7 @@ def llama_easydel_to_hf(path, config: LlamaConfig): """ Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface) """ - torch_params = load_and_convert_checkpoint(path) + torch_params = load_and_convert_checkpoint_to_torch(path) edited_params = {} for k, v in torch_params.items(): edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v diff --git a/lib/python/EasyDel/transform/mistral.py b/lib/python/EasyDel/transform/mistral.py index 5ae58c533..9d681ad85 100644 --- a/lib/python/EasyDel/transform/mistral.py +++ b/lib/python/EasyDel/transform/mistral.py @@ -1,6 +1,6 @@ from pathlib import Path -from fjformer.utils import load_and_convert_checkpoint +from fjformer import load_and_convert_checkpoint_to_torch from jax import numpy as jnp import jax import torch @@ -250,7 +250,7 @@ def mistral_easydel_to_hf(path, config: MistralConfig): """ Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface) """ - torch_params = load_and_convert_checkpoint(path) + torch_params = load_and_convert_checkpoint_to_torch(path) edited_params = {} for k, v in torch_params.items(): edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v diff --git a/lib/python/EasyDel/transform/mpt.py b/lib/python/EasyDel/transform/mpt.py index fcc0b2594..aec7a29c5 100644 --- a/lib/python/EasyDel/transform/mpt.py +++ b/lib/python/EasyDel/transform/mpt.py @@ -1,5 +1,3 @@ -from fjformer.utils import load_and_convert_checkpoint - from .. import MptConfig from jax import numpy as jnp import jax diff --git a/lib/python/EasyDel/transform/utils.py b/lib/python/EasyDel/transform/utils.py index b27e7fa6d..36ca7d17f 100644 --- a/lib/python/EasyDel/transform/utils.py +++ b/lib/python/EasyDel/transform/utils.py @@ -1,5 +1,5 @@ from fjformer import StreamingCheckpointer, float_tensor_to_dtype -from flax.traverse_util import flatten_dict, unflatten_dict +from flax.traverse_util import flatten_dict def match_keywords(string, positives, negatives):