Skip to content

Commit

Permalink
remove position arg name dtype= in some places
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Feb 13, 2025
1 parent a0ef28a commit 281d34b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
46 changes: 18 additions & 28 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
):
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, dtype=weight_dtype, axes=scale_axes
)
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, weight_dtype, axes=scale_axes)
scale = scale.astype(dtype)

layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, dtype=weight_dtype, axes=bias_axes
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
)
bias = bias.astype(dtype)
else:
Expand Down Expand Up @@ -463,13 +461,13 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, dtype=self.weight_dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)

if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, dtype=self.weight_dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
Expand Down Expand Up @@ -500,7 +498,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -512,7 +510,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand Down Expand Up @@ -730,7 +728,7 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, dtype=self.weight_dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)

Expand Down Expand Up @@ -775,7 +773,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
Expand All @@ -787,7 +785,7 @@ def __call__(self, inputs: Array) -> Array:
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
Expand All @@ -799,7 +797,7 @@ def __call__(self, inputs: Array) -> Array:
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, dtype=self.weight_dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)

Expand Down Expand Up @@ -1096,7 +1094,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
Expand All @@ -1112,21 +1110,13 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias",
self.bias_init,
bias_1_shape,
dtype=self.weight_dtype,
axes=self.bias_axes_1,
"wi_bias", self.bias_init, bias_1_shape, self.weight_dtype, axes=self.bias_axes_1
)
bias_1 = bias_1.astype(self.dtype)

bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias",
self.bias_init,
bias_2_shape,
dtype=self.weight_dtype,
axes=self.bias_axes_2,
"wo_bias", self.bias_init, bias_2_shape, self.weight_dtype, axes=self.bias_axes_2
)
bias_2 = bias_2.astype(self.dtype)
else:
Expand Down Expand Up @@ -1211,7 +1201,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
Expand All @@ -1231,7 +1221,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wi_bias",
self.bias_init,
intermediate_dim,
dtype=self.weight_dtype,
self.weight_dtype,
axes=self.bias_axes_1,
)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
Expand Down Expand Up @@ -1274,7 +1264,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
Expand All @@ -1285,7 +1275,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
dtype=self.weight_dtype,
self.weight_dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
Expand All @@ -1305,7 +1295,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
"wo_bias",
self.bias_init,
(hidden_size,),
dtype=self.weight_dtype,
self.weight_dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
Expand Down
9 changes: 7 additions & 2 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ def __post_init__(self):

if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype
1.0, "fan_in", "normal", self.weight_dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
Expand Down Expand Up @@ -1341,6 +1341,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq):
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
Expand Down Expand Up @@ -1459,7 +1460,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
self.dtype,
self.weight_dtype,
axes=self.embedding_axes,
)

Expand Down Expand Up @@ -1793,6 +1794,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None):
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
Expand Down Expand Up @@ -1826,6 +1828,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None):
x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
Expand Down Expand Up @@ -1904,6 +1907,7 @@ def hidden_dropout(x, deterministic):
y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
Expand Down Expand Up @@ -2019,6 +2023,7 @@ def hidden_dropout(x, deterministic):
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm",
)(z)

Expand Down

0 comments on commit 281d34b

Please sign in to comment.