Skip to content

Commit

Permalink
Add fp8 to bf16 huggingface ckpt conversion of deepseek-v3 (#481)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Feb 25, 2025
1 parent e1d00f5 commit 4da7eae
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,12 @@ HG_CKPT_PATH=$9 # HF的CKPT的路径
```bash
export MP_PP0_LAYERS=5
cd /workspace/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/deepseek

python fp8_cast_bf16.py --input-fp8-hf-path /mnt/deepseek-ckpts/DeepSeek-V3 --output-bf16-hf-path /mnt/deepseek-ckpts/DeepSeek-V3-bf16

bash hf2mcore_deepseek_v3_moe_convertor.sh \
A37B \
/mnt/deepseek-ckpts/DeepSeek-V3 \
/mnt/deepseek-ckpts/DeepSeek-V3-bf16 \
/mnt/deepseek-ckpts/DeepSeek-V3-to-mcore-tp8-pp8-ep16 \
8 \
8 \
Expand Down
106 changes: 106 additions & 0 deletions toolkits/model_checkpoints_convertor/deepseek/fp8_cast_bf16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

import torch
from safetensors.torch import load_file, save_file
import triton
import triton.language as tl

@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)

def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y

def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
os.system("cp -rf " + fp8_path + "/config.json " + bf16_path)
os.system("cp -rf " + fp8_path + "/*.py " + bf16_path)
os.system("cp -rf " + fp8_path + "/tokenizer* " + bf16_path)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]

# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []

# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]

safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict

new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight

new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)

# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()

# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)

0 comments on commit 4da7eae

Please sign in to comment.