-
Notifications
You must be signed in to change notification settings - Fork 363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect parameter dtype initialisation for flax transformer-engine modules #1451
Comments
Hi, thank you for reporting this issue.
Please let us know if you still observe any issues. |
Thanks @phu0ngng. I can confirm that the dtypes are now initialised correctly for the Dense layers on your branch; however, I seem to run into an issue with the MHA where the binding to the fused primitive fails now - not sure if this is due to the upgrade to the new TE version though, so, if you like, I can open a new issue for this. The code and error that were raised are below. I have also attached a printout of the package versions. from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
import transformer_engine.jax.flax as te_flax
class Model(nn.Module):
embed_dim: int
param_dtype: Any
@nn.compact
def __call__(self, x):
return te_flax.MultiHeadAttention(
head_dim=self.embed_dim,
num_attention_heads=8,
dtype=self.param_dtype,
)(x, x)[0]
x = jnp.ones((8, 16, 128), dtype=jnp.bfloat16)
model = Model(embed_dim=128, param_dtype=jnp.bfloat16)
params = model.init(jax.random.key(0), x)
y = jax.jit(model.apply)(params, x)
Full traceback below:
Output from `pip list`
|
Hi @liamclarkza, this issue is addressed in the PR #1477. Thanks for your reproducible code. @zlsh80826: could you add a test based on the code snip above? |
Thanks @phu0ngng , much appreciated 👍 |
It appears that the
dtype
argument for layers likeLayerNormDenseGeneral
andMultiHeadAttention
is being ignored.The argument is documented as follows, and so I would expect the parameters to be initialised to
bfloat16
in the sample code below, but this isn't the case:I have tested this with transformer-engine 1.14.0, which comes bundled with the
nvcr.io/nvidia/jax:25.01-py3
docker image. I don't see a release for this version on GitHub yet, though. I have also tested with 1.12.0, which yielded the same results.Printout from
jax.print_environment_info()
The text was updated successfully, but these errors were encountered: