Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect parameter dtype initialisation for flax transformer-engine modules #1451

Closed
liamclarkza opened this issue Feb 3, 2025 · 4 comments
Assignees

Comments

@liamclarkza
Copy link

liamclarkza commented Feb 3, 2025

It appears that the dtype argument for layers like LayerNormDenseGeneral and MultiHeadAttention is being ignored.

The argument is documented as follows, and so I would expect the parameters to be initialised to bfloat16 in the sample code below, but this isn't the case:

dtype (jax.numpy.dtype, default = jax.numpy.float32) – The data type used to allocate the initial parameters.

from typing import Any

import flax.linen as nn
import jax
import jax.numpy as jnp
import transformer_engine.jax.flax as te_flax


class Model(nn.Module):
    embed_dim: int
    param_dtype: Any

    def setup(self):
        self.layer1 = te_flax.LayerNormDenseGeneral(
            self.embed_dim,
            return_layernorm_output=False,
            transpose_batch_sequence=False,
            dtype=self.param_dtype,
        )
        self.layer2 = te_flax.MultiHeadAttention(
            head_dim=self.embed_dim,
            num_attention_heads=8,
            dtype=self.param_dtype,
        )

    def __call__(self, x):
        x = self.layer1(x)[0]
        x = self.layer2(x, x)[0]
        return x


compute_dtype = jnp.bfloat16
param_dtype = jnp.bfloat16

x = jnp.ones((8, 16, 128), dtype=compute_dtype)
model = Model(
    embed_dim=128,
    param_dtype=param_dtype,
)

print(f"Test: {compute_dtype=}, {param_dtype=}")
print(model.tabulate(jax.random.key(0), x, console_kwargs={"width": 200}))

params = model.init(jax.random.key(0),  x)

# Double check param dtypes
jax.tree.map_with_path(
    lambda k,v: print(f"{'/'.join(p.key for p in k):30s}\t expected dtype: {compute_dtype.dtype}\t got dtype: {v.dtype}"),
    params
)
y = jax.jit(model.apply)(params, x)
print(f"{y.dtype=}\t {y.shape=}")
Test: compute_dtype=<class 'jax.numpy.bfloat16'>, param_dtype=<class 'jax.numpy.bfloat16'>

                                                                                          Model Summary                                                                                          
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ path                                                     ┃ module                    ┃ inputs                   ┃ outputs                 ┃ params_axes         ┃ params                      ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│                                                          │ Model                     │ bfloat16[8,16,128]       │ bfloat16[8,16,128]      │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer1                                                   │ LayerNormDenseGeneral     │ bfloat16[8,16,128]       │ - bfloat16[8,16,128]    │ kernel_axes:        │ kernel: float32[128,128]    │
│                                                          │                           │                          │ - None                  │   names: []         │ ln_bias: float32[128]       │
│                                                          │                           │                          │                         │ ln_bias_axes:       │ scale: float32[128]         │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - embed           │ 16,640 (66.6 KB)            │
│                                                          │                           │                          │                         │ scale_axes:         │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - embed           │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2                                                   │ MultiHeadAttention        │ - bfloat16[8,16,128]     │ - bfloat16[8,16,128]    │                     │                             │
│                                                          │                           │ - bfloat16[8,16,128]     │ - None                  │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/qkv                                               │ LayerNormDenseGeneral     │ bfloat16[8,16,128]       │ - bfloat16[8,16,3,1024] │ kernel_axes:        │ kernel: float32[128,3,1024] │
│                                                          │                           │                          │ - None                  │   names:            │ ln_bias: float32[128]       │
│                                                          │                           │                          │                         │   - nvte_w_fsdp     │ scale: float32[128]         │
│                                                          │                           │                          │                         │   - nvte_w_joined   │                             │
│                                                          │                           │                          │                         │   - nvte_w_tp       │ 393,472 (1.6 MB)            │
│                                                          │                           │                          │                         │ ln_bias_axes:       │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_no_shard │                             │
│                                                          │                           │                          │                         │ scale_axes:         │                             │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_no_shard │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/DotProductAttention_0                             │ DotProductAttention       │ - bfloat16[8,16,3,8,128] │ bfloat16[8,16,8,128]    │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - deterministic: False   │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/DotProductAttention_0/_FusedDotProductAttention_0 │ _FusedDotProductAttention │ - bfloat16[8,16,3,8,128] │ bfloat16[8,16,8,128]    │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - None                   │                         │                     │                             │
│                                                          │                           │ - deterministic: False   │                         │                     │                             │
│                                                          │                           │   dropout_rng: None      │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│ layer2/out                                               │ DenseGeneral              │ bfloat16[8,16,1024]      │ bfloat16[8,16,128]      │ kernel_axes:        │ kernel: float32[1024,128]   │
│                                                          │                           │                          │                         │   names:            │                             │
│                                                          │                           │                          │                         │   - nvte_w_tp       │ 131,072 (524.3 KB)          │
│                                                          │                           │                          │                         │   - nvte_w_fsdp     │                             │
│                                                          │                           │                          │                         │                     │                             │
│                                                          │                           │                          │                         │                     │                             │
├──────────────────────────────────────────────────────────┼───────────────────────────┼──────────────────────────┼─────────────────────────┼─────────────────────┼─────────────────────────────┤
│                                                          │                           │                          │                   Total │                     │ 541,184 (2.2 MB)            │
└──────────────────────────────────────────────────────────┴───────────────────────────┴──────────────────────────┴─────────────────────────┴─────────────────────┴─────────────────────────────┘
                                                                                                                                                                                                 
                                                                               Total Parameters: 541,184 (2.2 MB)                                                 


params/layer1/kernel          	 expected dtype: bfloat16	 got dtype: float32
params/layer1/ln_bias         	 expected dtype: bfloat16	 got dtype: float32
params/layer1/scale           	 expected dtype: bfloat16	 got dtype: float32
params/layer2/out/kernel      	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/kernel      	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/ln_bias     	 expected dtype: bfloat16	 got dtype: float32
params/layer2/qkv/scale       	 expected dtype: bfloat16	 got dtype: float32
y.dtype=dtype(bfloat16)	 y.shape=(8, 16, 128)

I have tested this with transformer-engine 1.14.0, which comes bundled with the nvcr.io/nvidia/jax:25.01-py3 docker image. I don't see a release for this version on GitHub yet, though. I have also tested with 1.12.0, which yielded the same results.

Printout from jax.print_environment_info()

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.2
python: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0]
device info: NVIDIA H100 80GB HBM3-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='experiment-7a880fc8-f456-head', release='6.8.0-49-generic', version='#49~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Wed Nov  6 17:42:15 UTC 2', machine='x86_64')


$ nvidia-smi
Mon Feb  3 12:30:37 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:19:00.0 Off |                    0 |
| N/A   41C    P0            120W /  700W |     550MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:3B:00.0 Off |                    0 |
| N/A   37C    P0            121W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:4C:00.0 Off |                    0 |
| N/A   34C    P0            116W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:5D:00.0 Off |                    0 |
| N/A   38C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  |   00000000:9B:00.0 Off |                    0 |
| N/A   41C    P0            126W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  |   00000000:BB:00.0 Off |                    0 |
| N/A   37C    P0            119W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  |   00000000:CB:00.0 Off |                    0 |
| N/A   38C    P0            113W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  |   00000000:DB:00.0 Off |                    0 |
| N/A   36C    P0            117W /  700W |     538MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    1   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    2   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    3   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    4   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    5   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    6   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
|    7   N/A  N/A     35065      C   /usr/bin/python                                 0MiB |
+-----------------------------------------------------------------------------------------+
@phu0ngng
Copy link
Collaborator

Hi, thank you for reporting this issue.
With the fixes introduced in #1472, I was able to obtain all the params in bfloat16 with your example.

phuonguyen@s4124-0136:~/te$ python ex.py
Test: compute_dtype=<class 'jax.numpy.bfloat16'>, param_dtype=<class 'jax.numpy.bfloat16'>

                                                                                             Model Summary
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ path                                                         ┃ module                      ┃ inputs                   ┃ outputs                 ┃ params                       ┃ params_axes         ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│                                                              │ Model                       │ bfloat16[8,16,128]       │ bfloat16[8,16,128]      │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer1                                                       │ LayerNormDenseGeneral       │ bfloat16[8,16,128]       │ - bfloat16[8,16,128]    │ kernel: bfloat16[128,128]    │ kernel_axes:        │
│                                                              │                             │                          │ - None                  │ ln_bias: bfloat16[128]       │   names: []         │
│                                                              │                             │                          │                         │ scale: bfloat16[128]         │ ln_bias_axes:       │
│                                                              │                             │                          │                         │                              │   names:            │
│                                                              │                             │                          │                         │ 16,640 (33.3 KB)             │   - embed           │
│                                                              │                             │                          │                         │                              │ scale_axes:         │
│                                                              │                             │                          │                         │                              │   names:            │
│                                                              │                             │                          │                         │                              │   - embed           │
│                                                              │                             │                          │                         │                              │                     │
│                                                              │                             │                          │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2                                                       │ MultiHeadAttention          │ - bfloat16[8,16,128]     │ - bfloat16[8,16,128]    │                              │                     │
│                                                              │                             │ - bfloat16[8,16,128]     │ - None                  │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2/qkv                                                   │ LayerNormDenseGeneral       │ bfloat16[8,16,128]       │ - bfloat16[8,16,3,1024] │ kernel: bfloat16[128,3,1024] │ kernel_axes:        │
│                                                              │                             │                          │ - None                  │ ln_bias: bfloat16[128]       │   names:            │
│                                                              │                             │                          │                         │ scale: bfloat16[128]         │   - nvte_w_fsdp     │
│                                                              │                             │                          │                         │                              │   - nvte_w_joined   │
│                                                              │                             │                          │                         │ 393,472 (786.9 KB)           │   - nvte_w_tp       │
│                                                              │                             │                          │                         │                              │ ln_bias_axes:       │
│                                                              │                             │                          │                         │                              │   names:            │
│                                                              │                             │                          │                         │                              │   - nvte_w_no_shard │
│                                                              │                             │                          │                         │                              │ scale_axes:         │
│                                                              │                             │                          │                         │                              │   names:            │
│                                                              │                             │                          │                         │                              │   - nvte_w_no_shard │
│                                                              │                             │                          │                         │                              │                     │
│                                                              │                             │                          │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2/DotProductAttention_0                                 │ DotProductAttention         │ - bfloat16[8,16,3,8,128] │ bfloat16[8,16,8,128]    │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - deterministic: False   │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2/DotProductAttention_0/_UnfusedDotProductAttention_0   │ _UnfusedDotProductAttention │ - bfloat16[8,16,8,128]   │ bfloat16[8,16,8,128]    │                              │                     │
│                                                              │                             │ - bfloat16[8,16,8,128]   │                         │                              │                     │
│                                                              │                             │ - bfloat16[8,16,8,128]   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - deterministic: False   │                         │                              │                     │
│                                                              │                             │   dropout_rng: None      │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2/DotProductAttention_0/_UnfusedDotProductAttention_0/… │ Softmax                     │ - bfloat16[16,8,8,8]     │ bfloat16[16,8,8,8]      │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
│                                                              │                             │ - None                   │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│ layer2/out                                                   │ DenseGeneral                │ bfloat16[8,16,1024]      │ bfloat16[8,16,128]      │ kernel: bfloat16[1024,128]   │ kernel_axes:        │
│                                                              │                             │                          │                         │                              │   names:            │
│                                                              │                             │                          │                         │ 131,072 (262.1 KB)           │   - nvte_w_tp       │
│                                                              │                             │                          │                         │                              │   - nvte_w_fsdp     │
│                                                              │                             │                          │                         │                              │                     │
│                                                              │                             │                          │                         │                              │                     │
├──────────────────────────────────────────────────────────────┼─────────────────────────────┼──────────────────────────┼─────────────────────────┼──────────────────────────────┼─────────────────────┤
│                                                              │                             │                          │                   Total │ 541,184 (1.1 MB)             │                     │
└──────────────────────────────────────────────────────────────┴─────────────────────────────┴──────────────────────────┴─────────────────────────┴──────────────────────────────┴─────────────────────┘

                                                                                   Total Parameters: 541,184 (1.1 MB)


params/layer1/kernel             expected dtype: bfloat16        got dtype: bfloat16
params/layer1/ln_bias            expected dtype: bfloat16        got dtype: bfloat16
params/layer1/scale              expected dtype: bfloat16        got dtype: bfloat16
params/layer2/out/kernel         expected dtype: bfloat16        got dtype: bfloat16
params/layer2/qkv/kernel         expected dtype: bfloat16        got dtype: bfloat16
params/layer2/qkv/ln_bias        expected dtype: bfloat16        got dtype: bfloat16
params/layer2/qkv/scale          expected dtype: bfloat16        got dtype: bfloat16
y.dtype=dtype(bfloat16)  y.shape=(8, 16, 128)

Please let us know if you still observe any issues.

@liamclarkza
Copy link
Author

Thanks @phu0ngng. I can confirm that the dtypes are now initialised correctly for the Dense layers on your branch; however, I seem to run into an issue with the MHA where the binding to the fused primitive fails now - not sure if this is due to the upgrade to the new TE version though, so, if you like, I can open a new issue for this.

The code and error that were raised are below. I have also attached a printout of the package versions.

from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
import transformer_engine.jax.flax as te_flax


class Model(nn.Module):
    embed_dim: int
    param_dtype: Any

    @nn.compact
    def __call__(self, x):
        return te_flax.MultiHeadAttention(
            head_dim=self.embed_dim,
            num_attention_heads=8,
            dtype=self.param_dtype,
        )(x, x)[0]


x = jnp.ones((8, 16, 128), dtype=jnp.bfloat16)
model = Model(embed_dim=128, param_dtype=jnp.bfloat16)
params = model.init(jax.random.key(0),  x)
y = jax.jit(model.apply)(params, x)
File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/cpp_extensions/attention.py:2245, in fused_attn_fwd(qkv, bias, sequence_descriptor, seed, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, max_segments_per_seq, window_size, context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis)
   2242         primitive = FusedRingAttnFwdPrimitive.outer_primitive
   2244 seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
-> 2245 return primitive.bind(
   2246     *qkv_for_primitive,
   2247     bias,
   2248     seed,
   2249     *seq_desc_flatten,
   2250     config=fused_config,
   2251 )

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params)
    461 trace_ctx.set_trace(eval_trace)
    462 try:
--> 463   return self.bind_with_trace(prev_trace, args, params)
    464 finally:
    465   trace_ctx.set_trace(prev_trace)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:468, in Primitive.bind_with_trace(self, trace, args, params)
    467 def bind_with_trace(self, trace, args, params):
--> 468   return trace.process_primitive(self, args, params)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:954, in EvalTrace.process_primitive(self, primitive, args, params)
    952       return primitive.bind_with_trace(arg._trace, args, params)
    953 check_eval_args(args)
--> 954 return primitive.impl(*args, **params)

TypeError: FusedAttnFwdPrimitive.impl() missing 8 required positional arguments: 'q_seqlen', 'kv_seqlen', 'q_seq_offsets', 'k_seq_offsets', '_q_segment_ids', '_kv_segment_ids', '_q_segment_pos', and '_kv_segment_pos'
Full traceback below:
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[11], line 23
     21 x = jnp.ones((8, 16, 128), dtype=jnp.bfloat16)
     22 model = Model(embed_dim=128, param_dtype=jnp.bfloat16)
---> 23 params = model.init(jax.random.key(0),  x)
     24 y = jax.jit(model.apply)(params, x)

    [... skipping hidden 1 frame]

File /opt/flax/flax/linen/module.py:2447, in Module.init(self, rngs, method, mutable, capture_intermediates, *args, **kwargs)
   2316 """Initializes a module method with variables and returns modified variables.
   2317 
   2318 ``init`` takes as first argument either a single ``PRNGKey``, or a
   (...)
   2443   The initialized variable dict.
   2444 """
   2445 Module._module_checks(self)
-> 2447 _, v_out = self.init_with_output(
   2448   rngs,
   2449   *args,
   2450   method=method,
   2451   mutable=mutable,
   2452   capture_intermediates=capture_intermediates,
   2453   **kwargs,
   2454 )
   2455 return v_out

    [... skipping hidden 1 frame]

File /opt/flax/flax/linen/module.py:2299, in Module.init_with_output(self, rngs, method, mutable, capture_intermediates, *args, **kwargs)
   2297   method = self.__call__
   2298 method = _get_unbound_fn(method)
-> 2299 return init_with_output(
   2300   method,
   2301   self,
   2302   mutable=mutable,
   2303   capture_intermediates=capture_intermediates,
   2304 )(rngs, *args, **kwargs)

File /opt/flax/flax/core/scope.py:1115, in init.<locals>.wrapper(rngs, *args, **kwargs)
   1113   rngs = {'params': rngs}
   1114 init_flags = {**(flags if flags is not None else {}), 'initializing': True}
-> 1115 return apply(fn, mutable=mutable, flags=init_flags)(
   1116   {}, *args, rngs=rngs, **kwargs
   1117 )

File /opt/flax/flax/core/scope.py:1079, in apply.<locals>.wrapper(variables, rngs, *args, **kwargs)
   1074   raise errors.ApplyScopeInvalidVariablesStructureError(variables)
   1076 with bind(
   1077   variables, rngs=rngs, mutable=mutable, flags=flags
   1078 ).temporary() as root:
-> 1079   y = fn(root, *args, **kwargs)
   1080 if mutable is not False:
   1081   return y, root.mutable_variables()

File /opt/flax/flax/linen/module.py:3088, in init_with_output.<locals>.scope_fn(scope, *args, **kwargs)
   3086 _context.capture_stack.append(capture_intermediates)
   3087 try:
-> 3088   return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
   3089 finally:
   3090   _context.capture_stack.pop()

File /opt/flax/flax/linen/module.py:699, in wrap_method_once.<locals>.wrapped_module_method(*args, **kwargs)
    697 if args and isinstance(args[0], Module):
    698   self, args = args[0], args[1:]
--> 699   return self._call_wrapped_method(fun, args, kwargs)
    700 else:
    701   return fun(*args, **kwargs)

File /opt/flax/flax/linen/module.py:1216, in Module._call_wrapped_method(self, fun, args, kwargs)
   1214 if _use_named_call:
   1215   with jax.named_scope(_derive_profiling_name(self, fun)):
-> 1216     y = run_fun(self, *args, **kwargs)
   1217 else:
   1218   y = run_fun(self, *args, **kwargs)

Cell In[11], line 14, in Model.__call__(self, x)
     12 @nn.compact
     13 def __call__(self, x):
---> 14     return te_flax.MultiHeadAttention(
     15         head_dim=self.embed_dim,
     16         num_attention_heads=8,
     17         dtype=self.param_dtype,
     18     )(x, x)[0]

File /opt/flax/flax/linen/module.py:699, in wrap_method_once.<locals>.wrapped_module_method(*args, **kwargs)
    697 if args and isinstance(args[0], Module):
    698   self, args = args[0], args[1:]
--> 699   return self._call_wrapped_method(fun, args, kwargs)
    700 else:
    701   return fun(*args, **kwargs)

File /opt/flax/flax/linen/module.py:1216, in Module._call_wrapped_method(self, fun, args, kwargs)
   1214 if _use_named_call:
   1215   with jax.named_scope(_derive_profiling_name(self, fun)):
-> 1216     y = run_fun(self, *args, **kwargs)
   1217 else:
   1218   y = run_fun(self, *args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:1321, in MultiHeadAttention.__call__(self, inputs_q, inputs_kv, mask, bias, decode, deterministic)
   1318     dpa_args = [query, key, value]
   1320 scale_factor = 1.0 [/](https://kyber-1fedefecdedd4f73.exp.aichor.ai/) sqrt(self.head_dim) if self.scale_attn_logits else 1.0
-> 1321 x = DotProductAttention(
   1322     head_dim=self.head_dim,
   1323     num_attention_heads=self.num_attention_heads,
   1324     num_gqa_groups=self.num_gqa_groups,
   1325     attn_mask_type=self.attn_mask_type,
   1326     attn_bias_type=self.attn_bias_type,
   1327     attention_dropout=self.attention_dropout,
   1328     dtype=self.dtype,
   1329     dropout_rng_name=self.dropout_rng_name,
   1330     float32_logits=self.float32_logits,
   1331     qkv_layout=qkv_layout.name,
   1332     scale_factor=scale_factor,
   1333     transpose_batch_sequence=self.transpose_batch_sequence,
   1334     window_size=self.window_size,
   1335 )(*dpa_args, mask, bias, deterministic=deterministic)
   1336 x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
   1338 attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES)

File [/opt/flax/flax/linen/module.py:699](https://kyber-1fedefecdedd4f73.exp.aichor.ai/opt/flax/flax/linen/module.py#line=698), in wrap_method_once.<locals>.wrapped_module_method(*args, **kwargs)
    697 if args and isinstance(args[0], Module):
    698   self, args = args[0], args[1:]
--> 699   return self._call_wrapped_method(fun, args, kwargs)
    700 else:
    701   return fun(*args, **kwargs)

File /opt/flax/flax/linen/module.py:1216, in Module._call_wrapped_method(self, fun, args, kwargs)
   1214 if _use_named_call:
   1215   with jax.named_scope(_derive_profiling_name(self, fun)):
-> 1216     y = run_fun(self, *args, **kwargs)
   1217 else:
   1218   y = run_fun(self, *args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:624, in DotProductAttention.__call__(self, query, key, value, mask, bias, deterministic)
    613     x = _UnfusedDotProductAttention(
    614         attention_dropout=self.attention_dropout,
    615         attn_mask_type=attn_mask_type,
   (...)
    621         window_size=self.window_size,
    622     )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
    623 else:
--> 624     x = _FusedDotProductAttention(
    625         attention_dropout=self.attention_dropout,
    626         attn_mask_type=attn_mask_type,
    627         attn_bias_type=attn_bias_type,
    628         dtype=self.dtype,
    629         scale_factor=scale_factor,
    630         transpose_batch_sequence=self.transpose_batch_sequence,
    631         qkv_layout=qkv_layout,
    632         window_size=self.window_size,
    633         context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
    634         context_parallel_axis=self.context_parallel_axis,
    635     )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic)
    637 return x

File /opt/flax/flax/linen/module.py:699, in wrap_method_once.<locals>.wrapped_module_method(*args, **kwargs)
    697 if args and isinstance(args[0], Module):
    698   self, args = args[0], args[1:]
--> 699   return self._call_wrapped_method(fun, args, kwargs)
    700 else:
    701   return fun(*args, **kwargs)

File /opt/flax/flax/linen/module.py:1216, in Module._call_wrapped_method(self, fun, args, kwargs)
   1214 if _use_named_call:
   1215   with jax.named_scope(_derive_profiling_name(self, fun)):
-> 1216     y = run_fun(self, *args, **kwargs)
   1217 else:
   1218   y = run_fun(self, *args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:304, in _FusedDotProductAttention.__call__(self, query, key, value, mask, bias, dropout_rng, deterministic)
    302     if self.transpose_batch_sequence:
    303         qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4])
--> 304     x = fused_attn(
    305         (qkv_packed,),
    306         bias,
    307         mask,
    308         seed,
    309         attn_mask_type=self.attn_mask_type,
    310         attn_bias_type=self.attn_bias_type,
    311         qkv_layout=self.qkv_layout,
    312         scaling_factor=scale_factor,
    313         dropout_probability=self.attention_dropout,
    314         is_training=not deterministic,
    315         window_size=self.window_size,
    316         context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
    317         context_parallel_axis=self.context_parallel_axis,
    318     )
    319 elif self.qkv_layout == QKVLayout.BSHD_BS2HD:
    320     """kvpacked format, treat
    321     query: query tensor, shape = [..., h, d]
    322     key: kvpacked tensor, shape = [..., 2, h, d]
    323     value: ignore
    324     """

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/attention.py:977, in fused_attn(qkv, bias, sequence_descriptor, seed, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, max_segments_per_seq, window_size, context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis)
    960         raise ValueError("Passing mask is only supported for non-THD case.")
    961     return _legacy_fused_attn(
    962         qkv,
    963         bias,
   (...)
    975         context_parallel_axis=context_parallel_axis,
    976     )
--> 977 output = _fused_attn(
    978     qkv,
    979     bias,
    980     sequence_descriptor,
    981     seed,
    982     attn_bias_type=attn_bias_type,
    983     attn_mask_type=attn_mask_type,
    984     qkv_layout=qkv_layout,
    985     scaling_factor=scaling_factor,
    986     dropout_probability=dropout_probability,
    987     is_training=is_training,
    988     max_segments_per_seq=max_segments_per_seq,
    989     window_size=window_size,
    990     context_parallel_strategy=context_parallel_strategy,
    991     context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
    992     context_parallel_axis=context_parallel_axis,
    993 )
    995 return output

    [... skipping hidden 1 frame]

File /usr/local/lib/python3.12/dist-packages/jax/_src/custom_derivatives.py:620, in custom_vjp.__call__(self, *args, **kwargs)
    616 flat_fwd, out_trees = _flatten_fwd(
    617     fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
    618     fwd_name, in_tree, out_type)
    619 flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
--> 620 out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
    621                                   *args_flat, out_trees=out_trees,
    622                                   symbolic_zeros=self.symbolic_zeros)
    623 _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
    624 return tree_unflatten(out_tree, out_flat)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params)
    461 trace_ctx.set_trace(eval_trace)
    462 try:
--> 463   return self.bind_with_trace(prev_trace, args, params)
    464 finally:
    465   trace_ctx.set_trace(prev_trace)

File /usr/local/lib/python3.12/dist-packages/jax/_src/custom_derivatives.py:840, in CustomVJPCallPrimitive.bind_with_trace(self, trace, args, params)
    838 def bind_with_trace(self, trace, args, params):
    839   fun, fwd, bwd, tracers = args[0], args[1], args[2], args[3:]
--> 840   return trace.process_custom_vjp_call(self, fun, fwd, bwd, tracers, **params)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:975, in EvalTrace.process_custom_vjp_call(***failed resolving arguments***)
    973 def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_):  # pytype: disable=signature-mismatch
    974   del primitive, fwd, bwd, _  # Unused.
--> 975   return fun.call_wrapped(*tracers)

File /usr/local/lib/python3.12/dist-packages/jax/_src/linear_util.py:192, in WrappedFun.call_wrapped(self, *args, **kwargs)
    190 def call_wrapped(self, *args, **kwargs):
    191   """Calls the transformed function"""
--> 192   return self.f_transformed(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/jax/_src/custom_derivatives.py:83, in _flatten_fun_nokwargs(f, store, in_tree, *args_flat)
     80 @lu.transformation_with_aux2
     81 def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
     82   py_args = tree_unflatten(in_tree, args_flat)
---> 83   ans = f(*py_args)
     84   ans_flat, ans_tree = tree_flatten(ans)
     85   ans_avals = [core.get_aval(x) for x in ans_flat]

File /usr/local/lib/python3.12/dist-packages/jax/_src/api_util.py:292, in _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs)
    290 args = [next(fixed_args_).val if x is sentinel else x for x in args]
    291 assert next(fixed_args_, sentinel) is sentinel
--> 292 return _fun(*args, **kwargs)

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/attention.py:750, in _fused_attn(qkv, bias, sequence_descriptor, seed, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, max_segments_per_seq, window_size, context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis)
    732 @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
    733 def _fused_attn(
    734     qkv: Tuple[jnp.ndarray, ...],
   (...)
    748     context_parallel_axis: str,
    749 ):
--> 750     output, _ = _fused_attn_fwd_rule(
    751         qkv,
    752         bias,
    753         sequence_descriptor,
    754         seed,
    755         attn_bias_type,
    756         attn_mask_type,
    757         qkv_layout,
    758         scaling_factor,
    759         dropout_probability,
    760         is_training,
    761         max_segments_per_seq,
    762         window_size,
    763         context_parallel_strategy,
    764         context_parallel_causal_load_balanced,
    765         context_parallel_axis,
    766     )
    767     return output

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/attention.py:787, in _fused_attn_fwd_rule(qkv, bias, sequence_descriptor, seed, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, max_segments_per_seq, window_size, context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis)
    770 def _fused_attn_fwd_rule(
    771     qkv,
    772     bias,
   (...)
    785     context_parallel_axis,
    786 ):
--> 787     output, softmax_aux, rng_state = tex.fused_attn_fwd(
    788         qkv,
    789         bias,
    790         sequence_descriptor,
    791         seed,
    792         attn_bias_type=attn_bias_type.value,
    793         attn_mask_type=attn_mask_type.value,
    794         qkv_layout=qkv_layout.value,
    795         scaling_factor=scaling_factor,
    796         dropout_probability=dropout_probability,
    797         is_training=is_training,
    798         max_segments_per_seq=max_segments_per_seq,
    799         window_size=window_size,
    800         context_parallel_strategy=context_parallel_strategy,
    801         context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
    802         context_parallel_axis=context_parallel_axis,
    803     )
    804     output = checkpoint_name(output, "context")
    805     softmax_aux = checkpoint_name(softmax_aux, "context")

File /usr/local/lib/python3.12/dist-packages/transformer_engine/jax/cpp_extensions/attention.py:2245, in fused_attn_fwd(qkv, bias, sequence_descriptor, seed, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, max_segments_per_seq, window_size, context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis)
   2242         primitive = FusedRingAttnFwdPrimitive.outer_primitive
   2244 seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
-> 2245 return primitive.bind(
   2246     *qkv_for_primitive,
   2247     bias,
   2248     seed,
   2249     *seq_desc_flatten,
   2250     config=fused_config,
   2251 )

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:463, in Primitive.bind(self, *args, **params)
    461 trace_ctx.set_trace(eval_trace)
    462 try:
--> 463   return self.bind_with_trace(prev_trace, args, params)
    464 finally:
    465   trace_ctx.set_trace(prev_trace)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:468, in Primitive.bind_with_trace(self, trace, args, params)
    467 def bind_with_trace(self, trace, args, params):
--> 468   return trace.process_primitive(self, args, params)

File /usr/local/lib/python3.12/dist-packages/jax/_src/core.py:954, in EvalTrace.process_primitive(self, primitive, args, params)
    952       return primitive.bind_with_trace(arg._trace, args, params)
    953 check_eval_args(args)
--> 954 return primitive.impl(*args, **params)

TypeError: FusedAttnFwdPrimitive.impl() missing 8 required positional arguments: 'q_seqlen', 'kv_seqlen', 'q_seq_offsets', 'k_seq_offsets', '_q_segment_ids', '_kv_segment_ids', '_q_segment_pos', and '_kv_segment_pos'
Output from `pip list`
Package                    Version              Editable project location
-------------------------- -------------------- ----------------------------------------
absl-py                    2.1.0
aiohappyeyeballs           2.4.6
aiohttp                    3.11.12
aiohttp-cors               0.7.0
aiosignal                  1.3.2
annotated-types            0.7.0
antlr4-python3-runtime     4.9.3
anyio                      4.8.0
argon2-cffi                23.1.0
argon2-cffi-bindings       21.2.0
array_record               0.6.0
arrow                      1.3.0
asttokens                  3.0.0
astunparse                 1.6.3
async-lru                  2.0.4
attrs                      25.1.0
awscli                     1.33.30
babel                      2.17.0
beautifulsoup4             4.13.3
bio                        1.7.1
biopython                  1.85
biothings_client           0.4.1
bleach                     6.2.0
boto3                      1.34.148
botocore                   1.34.148
bravado                    11.0.3
bravado-core               6.1.1
build                      1.2.2.post1
cachetools                 5.5.1
certifi                    2024.12.14
cffi                       1.17.1
charset-normalizer         3.4.1
chex                       0.1.88
click                      8.1.8
cloudpickle                3.1.1
clu                        0.0.12
colorama                   0.4.6
colorful                   0.5.6
comm                       0.2.2
dataclasses-json           0.6.7
datasets                   2.20.0
debugpy                    1.8.12
decorator                  5.1.1
defusedxml                 0.7.1
dill                       0.3.8
distlib                    0.3.9
dm-tree                    0.1.9
docutils                   0.16
einops                     0.8.0
etils                      1.12.0
executing                  2.1.0
fair-esm                   2.0.0
fastjsonschema             2.21.1
filelock                   3.17.0
flatbuffers                25.2.10
flax                       0.10.3
fqdn                       1.5.1
frozenlist                 1.5.0
fsspec                     2025.2.0
future                     1.0.0
gast                       0.6.0
gcsfs                      2024.5.0
gitdb                      4.0.12
GitPython                  3.1.44
google-api-core            2.24.1
google-auth                2.38.0
google-auth-oauthlib       1.2.1
google-cloud-core          2.4.1
google-cloud-storage       3.0.0
google-crc32c              1.6.0
google-pasta               0.2.0
google-resumable-media     2.7.2
googleapis-common-protos   1.66.0
gprofiler-official         1.0.0
grain                      0.2.1
greenlet                   3.1.1
grpcio                     1.70.0
gviz-api                   1.10.0
h11                        0.14.0
h5py                       3.12.1
httpcore                   1.0.7
httpx                      0.28.1
huggingface-hub            0.28.1
humanize                   4.11.0
hydra-core                 1.3.2
idna                       3.10
importlib_metadata         8.6.1
importlib_resources        6.5.2
ipykernel                  6.29.5
ipython                    8.31.0
isoduration                20.11.0
jax                        0.5.0
jax-cuda12-pjrt            0.5.0
jax-cuda12-plugin          0.5.0
jaxlib                     0.5.0
jaxtyping                  0.2.36
jedi                       0.19.2
Jinja2                     3.1.5
jmespath                   1.0.1
json5                      0.10.0
jsonpointer                3.0.0
jsonref                    1.1.0
jsonschema                 4.23.0
jsonschema-specifications  2024.10.1
jupyter_client             8.6.3
jupyter_core               5.7.2
jupyter-events             0.12.0
jupyter-lsp                2.2.5
jupyter_server             2.15.0
jupyter_server_terminals   0.5.3
jupyterlab                 4.3.5
jupyterlab_pygments        0.3.0
jupyterlab_server          2.27.3
keras                      3.8.0
libclang                   18.1.1
Markdown                   3.7
markdown-it-py             3.0.0
MarkupSafe                 3.0.2
marshmallow                3.25.1
matplotlib-inline          0.1.7
mdurl                      0.1.2
mistune                    3.1.1
ml_collections             1.0.0
ml_dtypes                  0.5.1
monotonic                  1.6
more-itertools             10.6.0
mpmath                     1.3.0
msgpack                    1.1.0
multidict                  6.1.0
multiprocess               0.70.16
mygene                     3.2.2
mypy-extensions            1.0.0
namex                      0.0.8
nbclient                   0.10.2
nbconvert                  7.16.6
nbformat                   5.10.4
neptune                    1.10.4
nest-asyncio               1.6.0
networkx                   3.4.2
notebook                   7.3.2
notebook_shim              0.2.4
nsys-jax                   0.1.dev1085+gb0ec72a /opt/nsys-jax/.github/container/nsys_jax
numpy                      2.2.2
nvidia-cublas-cu12         12.8.3.14
nvidia-cuda-cupti-cu12     12.8.57
nvidia-cuda-nvcc-cu12      12.8.61
nvidia-cuda-runtime-cu12   12.8.57
nvidia-cudnn-cu12          9.7.1.26
nvidia-cufft-cu12          11.3.3.41
nvidia-cusolver-cu12       11.7.2.55
nvidia-cusparse-cu12       12.5.7.53
nvidia-nccl-cu12           2.25.1
nvidia-nvjitlink-cu12      12.8.61
oauthlib                   3.2.2
omegaconf                  2.3.0
opencensus                 0.11.4
opencensus-context         0.1.3
opt_einsum                 3.4.0
optax                      0.2.4
optree                     0.14.0
orbax-checkpoint           0.11.4
orbax-export               0.0.6
overrides                  7.7.0
packaging                  24.2
pandas                     2.2.3
pandocfilters              1.5.1
parso                      0.8.4
pexpect                    4.9.0
pillow                     11.1.0
pip                        23.3.1               /opt/pip
pip-tools                  7.4.1
platformdirs               4.3.6
pooch                      1.8.2
prometheus_client          0.21.1
prompt_toolkit             3.0.50
propcache                  0.2.1
proto-plus                 1.26.0
protobuf                   5.29.3
psutil                     6.1.1
ptyprocess                 0.7.0
pure_eval                  0.2.3
py-spy                     0.4.0
pyarrow                    19.0.0
pyarrow-hotfix             0.6
pyasn1                     0.6.1
pyasn1_modules             0.4.1
pybind11                   2.13.6
pybind11_global            2.13.6
pycparser                  2.22
pydantic                   2.10.6
pydantic_core              2.27.2
Pygments                   2.19.1
PyJWT                      2.10.1
pyproject_hooks            1.2.0
python-dateutil            2.9.0.post0
python-json-logger         3.2.1
pytz                       2024.2
PyYAML                     6.0.2
pyzmq                      26.2.1
ray                        2.42.0
referencing                0.36.2
regex                      2024.5.15
requests                   2.32.3
requests-oauthlib          2.0.0
rfc3339-validator          0.1.4
rfc3986-validator          0.1.1
rich                       13.9.4
rpds-py                    0.22.3
rsa                        4.7.2
s3fs                       0.4.2
s3transfer                 0.10.4
scipy                      1.15.1
Send2Trash                 1.8.3
setuptools                 75.8.0
simplejson                 3.19.3
six                        1.17.0
smart-open                 7.1.0
smmap                      5.0.2
sniffio                    1.3.1
soupsieve                  2.6
SQLAlchemy                 2.0.38
stack-data                 0.6.3
swagger-spec-validator     3.0.4
sympy                      1.13.3
tensorboard                2.18.0
tensorboard-data-server    0.7.2
tensorboard-plugin-profile 2.18.0
tensorflow                 2.18.0
tensorstore                0.1.71
termcolor                  2.5.0
terminado                  0.18.1
tinycss2                   1.4.0
toolz                      1.0.0
torch                      2.4.1+cpu
tornado                    6.4.2
tqdm                       4.67.1
traitlets                  5.14.3
transformer_engine         2.1.0.dev0+a4a78aa
treescope                  0.1.8
types-python-dateutil      2.9.0.20241206
typing_extensions          4.12.2
typing-inspect             0.9.0
tzdata                     2025.1
uncertainties              3.2.2
uri-template               1.3.0
virtualenv                 20.29.2
waffle                     0.1.0
wcwidth                    0.2.13
webcolors                  24.11.1
webencodings               0.5.1
websocket-client           1.8.0
Werkzeug                   3.1.3
wheel                      0.45.1
wrapt                      1.17.2
xxhash                     3.5.0
yarl                       1.18.3
zipp                       3.21.0

@phu0ngng
Copy link
Collaborator

phu0ngng commented Feb 11, 2025

Hi @liamclarkza, this issue is addressed in the PR #1477. Thanks for your reproducible code.

@zlsh80826: could you add a test based on the code snip above?

@liamclarkza
Copy link
Author

Thanks @phu0ngng , much appreciated 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants