Skip to content

Commit

Permalink
Update docs (#2732)
Browse files Browse the repository at this point in the history
* Update docs

* Update windows install version

Update gh pages (#2741)

update gh pages (#2743)

gh pages update (#2746)

Update gh-pages (#2764)

Update
  • Loading branch information
Tabrizian authored and kaiyux committed Feb 11, 2025
1 parent bfb1bbe commit 130807d
Show file tree
Hide file tree
Showing 215 changed files with 32,942 additions and 19,949 deletions.
11,055 changes: 5,778 additions & 5,277 deletions _cpp_gen/executor.html

Large diffs are not rendered by default.

14,159 changes: 7,141 additions & 7,018 deletions _cpp_gen/runtime.html

Large diffs are not rendered by default.

52 changes: 39 additions & 13 deletions _downloads/408e9af6e2b04a79e78215bde246e8bc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from ..._common import default_net
from ..._utils import pad_vocab_size
from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor,
allgather, concat, non_gated_version, recv, send)
allgather, concat, constant, non_gated_version, recv,
send)
from ...layers import (MOE, Attention, AttentionMaskType, ColumnLinear,
Embedding, FusedGatedMLP, GatedMLP,
PositionEmbeddingType, RmsNorm)
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, config: LLaMAConfig, layer_idx: int):
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
tp_rank=config.mapping.tp_rank,
q_scaling=1.0 / config.attention_multiplier,
quant_mode=config.quant_mode,
cp_group=config.mapping.cp_group,
cp_size=config.mapping.cp_size,
Expand Down Expand Up @@ -153,13 +155,40 @@ def forward(self,
and self.layer_idx == 0) or self.layer_idx > 0:
hidden_states = self.input_layernorm(hidden_states)

reduce_fusion_op = AllReduceFusionOp.NONE
if default_net().plugin_config.reduce_fusion:
if default_net().plugin_config.user_buffer:
if self.config.quant_mode.has_fp8_qdq():
reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8
elif self.config.quant_mode.has_nvfp4():
assert default_net(
).plugin_config.gemm_plugin == "nvfp4", "UB with nvfp4 model must use nvfp4 gemm plugin"
reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4
else:
assert false, "UB must enabled with fp8 or nvfp4 model"
else:
reduce_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM

reduce_fusion_scale = None
if default_net().plugin_config.reduce_fusion and default_net(
).plugin_config.user_buffer and self.config.quant_mode.has_fp8_qdq:
).plugin_config.user_buffer:
if isinstance(self.mlp, FusedGatedMLP):
reduce_fusion_scale = self.mlp.fused_fc.activation_scaling_factor.value
if self.config.quant_mode.has_fp8_qdq():
reduce_fusion_scale = constant(
self.mlp.fused_fc.activation_scaling_factor.raw_value.
copy())
elif self.config.quant_mode.has_nvfp4():
reduce_fusion_scale = constant(
self.mlp.fused_fc.activation_global_scaling_factor.
raw_value.copy())
else:
reduce_fusion_scale = self.mlp.fc.activation_scaling_factor.value
if self.config.quant_mode.has_fp8_qdq():
reduce_fusion_scale = constant(
self.mlp.fc.activation_scaling_factor.raw_value.copy())
elif self.config.quant_mode.has_nvfp4():
reduce_fusion_scale = constant(
self.mlp.fc.activation_global_scaling_factor.raw_value.
copy())
attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
Expand All @@ -169,9 +198,7 @@ def forward(self,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM
if default_net().plugin_config.reduce_fusion else
AllReduceFusionOp.NONE,
fusion_op=reduce_fusion_op,
residual=residual,
norm_weight=self.post_layernorm.weight.value,
scale=reduce_fusion_scale,
Expand Down Expand Up @@ -200,17 +227,15 @@ def forward(self,
if default_net().plugin_config.reduce_fusion:
hidden_states, residual = attention_output
else:
hidden_states = residual + attention_output
hidden_states = residual + attention_output * self.config.residual_multiplier
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
if next_layer_input_layernorm_args is not None:
hidden_states = self.mlp(
hidden_states,
lora_layer_params=lora_layer_params,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM
if default_net().plugin_config.reduce_fusion else
AllReduceFusionOp.NONE,
fusion_op=reduce_fusion_op,
residual=residual,
norm_weight=next_layer_input_layernorm_args[0],
scale=next_layer_input_layernorm_args[2],
Expand All @@ -235,8 +260,7 @@ def forward(self,
else:
hidden_states = self.mlp(
hidden_states, lora_layer_params=lora_layer_params)
hidden_states = residual + hidden_states

hidden_states = residual + hidden_states * self.config.residual_multiplier
if use_cache:
return (hidden_states, presents)
return hidden_states
Expand All @@ -253,6 +277,7 @@ def __init__(self, config: LLaMAConfig) -> None:
self.vocab_embedding = Embedding(config.vocab_size,
config.hidden_size,
dtype=config.dtype)
self.embedding_multiplier = config.embedding_multiplier

self.layers = DecoderLayerList(LLaMADecoderLayer, config)

Expand Down Expand Up @@ -293,6 +318,7 @@ def forward(self,

if self.mapping.is_first_pp_rank():
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
hidden_states *= self.embedding_multiplier
else:
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
if default_net().plugin_config.pp_reduce_scatter:
Expand Down
Loading

0 comments on commit 130807d

Please sign in to comment.