Skip to content

Commit

Permalink
fixing MPT model generation Issue
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 16, 2024
1 parent 42bd785 commit 7a5950c
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,11 @@ def init_cache(self, batch_size, max_length):

input_ids = jnp.ones((batch_size, max_length), dtype="i4")
attention_mask = jnp.ones_like(input_ids)
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)

init_variables = self.module.init(
jax.random.PRNGKey(0),
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=False,
init_cache=True
)
Expand Down

0 comments on commit 7a5950c

Please sign in to comment.