diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4969ee559..26dd5df4e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -247,6 +247,15 @@ def create_weights( def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: + if current_platform.is_rocm(): + weight, weight_scale, _ = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=layer.input_scale) + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale, + requires_grad=False) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) @@ -495,6 +504,30 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: + if current_platform.is_rocm(): + w13_weight, w13_weight_scale_inv, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale_inv, + layer.w13_input_scale) + w2_weight, w2_weight_scale_inv, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale_inv, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale_inv, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale_inv, requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) return # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index f3c3e130e..b6882cc7c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -5,6 +5,8 @@ import triton import triton.language as tl +from vllm.platforms import current_platform + def apply_w8a8_block_fp8_linear( input: torch.Tensor, @@ -33,11 +35,14 @@ def apply_w8a8_block_fp8_linear( def input_to_float8( - x: torch.Tensor, - dtype: torch.dtype = torch.float8_e4m3fn + x: torch.Tensor, + dtype: Optional[torch.dtype] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" + if dtype is None: + dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else torch.float8_e4m3fn) finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -125,7 +130,7 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, - dtype: torch.dtype = torch.float8_e4m3fn, + dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -140,6 +145,9 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ + if dtype is None: + dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else torch.float8_e4m3fn) assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}")