-
Notifications
You must be signed in to change notification settings - Fork 134
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fp8 to bf16 huggingface ckpt conversion of deepseek-v3 (#481)
Co-authored-by: 同润 <[email protected]>
- Loading branch information
1 parent
e1d00f5
commit 4da7eae
Showing
2 changed files
with
110 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
106 changes: 106 additions & 0 deletions
106
toolkits/model_checkpoints_convertor/deepseek/fp8_cast_bf16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
import json | ||
from argparse import ArgumentParser | ||
from glob import glob | ||
from tqdm import tqdm | ||
|
||
import torch | ||
from safetensors.torch import load_file, save_file | ||
import triton | ||
import triton.language as tl | ||
|
||
@triton.jit | ||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): | ||
pid_m = tl.program_id(axis=0) | ||
pid_n = tl.program_id(axis=1) | ||
n = tl.cdiv(N, BLOCK_SIZE) | ||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
offs = offs_m[:, None] * N + offs_n[None, :] | ||
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) | ||
s = tl.load(s_ptr + pid_m * n + pid_n) | ||
y = x * s | ||
tl.store(y_ptr + offs, y, mask=mask) | ||
|
||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: | ||
assert x.is_contiguous() and s.is_contiguous() | ||
assert x.dim() == 2 and s.dim() == 2 | ||
M, N = x.size() | ||
y = torch.empty_like(x, dtype=torch.get_default_dtype()) | ||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) | ||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) | ||
return y | ||
|
||
def main(fp8_path, bf16_path): | ||
torch.set_default_dtype(torch.bfloat16) | ||
os.makedirs(bf16_path, exist_ok=True) | ||
os.system("cp -rf " + fp8_path + "/config.json " + bf16_path) | ||
os.system("cp -rf " + fp8_path + "/*.py " + bf16_path) | ||
os.system("cp -rf " + fp8_path + "/tokenizer* " + bf16_path) | ||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") | ||
with open(model_index_file, "r") as f: | ||
model_index = json.load(f) | ||
weight_map = model_index["weight_map"] | ||
|
||
# Cache for loaded safetensor files | ||
loaded_files = {} | ||
fp8_weight_names = [] | ||
|
||
# Helper function to get tensor from the correct file | ||
def get_tensor(tensor_name): | ||
file_name = weight_map[tensor_name] | ||
if file_name not in loaded_files: | ||
file_path = os.path.join(fp8_path, file_name) | ||
loaded_files[file_name] = load_file(file_path, device="cuda") | ||
return loaded_files[file_name][tensor_name] | ||
|
||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) | ||
safetensor_files.sort() | ||
for safetensor_file in tqdm(safetensor_files): | ||
file_name = os.path.basename(safetensor_file) | ||
current_state_dict = load_file(safetensor_file, device="cuda") | ||
loaded_files[file_name] = current_state_dict | ||
|
||
new_state_dict = {} | ||
for weight_name, weight in current_state_dict.items(): | ||
if weight_name.endswith("_scale_inv"): | ||
continue | ||
elif weight.element_size() == 1: # FP8 weight | ||
scale_inv_name = f"{weight_name}_scale_inv" | ||
try: | ||
# Get scale_inv from the correct file | ||
scale_inv = get_tensor(scale_inv_name) | ||
fp8_weight_names.append(weight_name) | ||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) | ||
except KeyError: | ||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") | ||
new_state_dict[weight_name] = weight | ||
else: | ||
new_state_dict[weight_name] = weight | ||
|
||
new_safetensor_file = os.path.join(bf16_path, file_name) | ||
save_file(new_state_dict, new_safetensor_file) | ||
|
||
# Memory management: keep only the 2 most recently used files | ||
if len(loaded_files) > 2: | ||
oldest_file = next(iter(loaded_files)) | ||
del loaded_files[oldest_file] | ||
torch.cuda.empty_cache() | ||
|
||
# Update model index | ||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") | ||
for weight_name in fp8_weight_names: | ||
scale_inv_name = f"{weight_name}_scale_inv" | ||
if scale_inv_name in weight_map: | ||
weight_map.pop(scale_inv_name) | ||
with open(new_model_index_file, "w") as f: | ||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = ArgumentParser() | ||
parser.add_argument("--input-fp8-hf-path", type=str, required=True) | ||
parser.add_argument("--output-bf16-hf-path", type=str, required=True) | ||
args = parser.parse_args() | ||
main(args.input_fp8_hf_path, args.output_bf16_hf_path) |