Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hope for Integrate swanlab希望集成SwanLab实验跟踪工具 #237

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@
- 收集、蒸馏、整理并清洗去重所有阶段的高质量数据集,且全部开源。
- 从0实现预训练、指令微调、LoRA、DPO强化学习,白盒模型蒸馏。关键算法几乎不依赖第三方封装的框架,且全部开源。
- 同时兼容`transformers``trl``peft`等第三方主流框架。
- 训练支持单机单卡、单机多卡(DDP、DeepSpeed)训练,支持wandb可视化训练流程。支持动态启停训练。

* 训练支持单机单卡、单机多卡(DDP、DeepSpeed)训练,支持wandb、swanlab可视化训练流程。支持动态启停训练。

- 在第三方测评榜(C-Eval、C-MMLU、OpenBookQA等)进行模型测试。
- 实现Openai-Api协议的极简服务端,便于集成到第三方ChatUI使用(FastGPT、Open-WebUI等)。
- 基于streamlit实现最简聊天WebUI前端。
Expand Down Expand Up @@ -339,17 +341,20 @@ torchrun --nproc_per_node N train_xxx.py
deepspeed --master_port 29500 --num_gpus=N train_xxx.py
```

可根据需要开启wandb记录训练过程
可根据需要开启wandb或者swanlab记录训练过程

```bash
# 需要登录: wandb login
torchrun --nproc_per_node N train_xxx.py --use_wandb
# 以wandb为例,需要登录: wandb login
torchrun --nproc_per_node N train_xxx.py --report_to wandb
# and
python train_xxx.py --use_wandb
python train_xxx.py --report_to wandb
```

通过添加`--use_wandb`参数,可以记录训练过程,训练完成后,可以在wandb网站上查看训练过程。通过修改`wandb_project`
`wandb_run_name`参数,可以指定项目名称和运行名称。
通过添加`--report_to <wandb或者swanlab>`参数,可以使用在线跟踪工具记录训练过程,训练完成后,可以在[wandb网站](https://wandb.ai)或者
[swanlab网站](https://swanlab.cn)上查看训练过程。通过添加`--project_name <项目名称>``--run_name <实验名称>`参数,可以指定项目名称和运行名称。

如果训练服务器无法访问外网,你可以通过添加参数`--report_to swanlab`并根据引导选择3开启离线日志看板模式。在终端使用命令
`swanlab watch -h 0.0.0.0 -p 8080` 启动SwanLab离线仪表板。

</details>

Expand Down
15 changes: 10 additions & 5 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,19 @@ Enable wandb to record the training process if needed:

```bash
# Need to log in: wandb login
torchrun --nproc_per_node N train_xxx.py --use_wandb
torchrun --nproc_per_node N train_xxx.py --report_to wandb
# and
python train_xxx.py --use_wandb
python train_xxx.py --report_to wandb
```

By adding the `--use_wandb` parameter, the training process will be recorded, and after training, you can view the
process on the wandb website. Modify the `wandb_project` and `wandb_run_name` parameters to specify project and run
names.
By adding the `--report_to wandb` parameter, the training process will be recorded, and after training, you can view the
process on the [wandb](https://wandb.ai) website. Modify the `--project_name <PROJECT NAME>` and `--run_name <RUN NAME>` parameters
to specify project and run names.

If you prefer using SwanLab or if the server cannot access the internet, you are welcome to use SwanLab and its
[offline mode](https://docs.swanlab.cn/en/guide_cloud/self_host/offline-board.html) by adding the parameter `--report_to swanlab` and
following the instruct in terminal. Then open the terminal and use `swanlab watch -h 0.0.0.0 -p 8080` command to start a SwanLab
offline dashboard.

</details>

Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ ujson==5.1.0
wandb==0.18.3
streamlit==1.30.0
torch==2.2.2
torchvision==0.17.2
torchvision==0.17.2
swanlab==0.4.8
52 changes: 33 additions & 19 deletions train_distill_reason.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,27 @@ def Logger(content):
print(content)


def init_tracker(args):
if args.report_to == "wandb":
import wandb

wandb.init(project=args.project_name, run_name=args.run_name)
tracker = wandb
if args.report_to == "swanlab":
import swanlab

swanlab.init(project=args.project_name, run_name=args.run_name, config=args)
tracker = swanlab
else:
tracker = None
return tracker


def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))


def train_epoch(epoch, wandb):
def train_epoch(epoch, tracker):
# 思考标签占位符
start_of_think_ids = tokenizer('<think>').input_ids
end_of_think_ids = tokenizer('</think>').input_ids
Expand All @@ -40,9 +56,9 @@ def train_epoch(epoch, wandb):
loss_fct = nn.CrossEntropyLoss(reduction='none')
start_time = time.time()
for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
X = X.to(DEVICE)
Y = Y.to(DEVICE)
loss_mask = loss_mask.to(DEVICE)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Expand All @@ -56,7 +72,7 @@ def train_epoch(epoch, wandb):
sp_ids = torch.isin(Y.view(-1),
torch.tensor(start_of_think_ids + end_of_think_ids
+ start_of_answer_ids + end_of_answer_ids
).to(args.device))
).to(DEVICE))
# 在 sp_ids 对应的位置增加额外的惩罚
loss_mask = loss_mask.view(-1)
loss_mask_sum = loss_mask.sum()
Expand Down Expand Up @@ -89,8 +105,8 @@ def train_epoch(epoch, wandb):
optimizer.param_groups[-1]['lr'],
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))

if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss,
if (tracker is not None) and (not ddp or dist.get_rank() == 0):
tracker.log({"loss": loss,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

Expand All @@ -113,10 +129,10 @@ def init_model(lm_config):
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
state_dict = torch.load(ckp, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
Logger(f'LLM总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
model = model.to(DEVICE)
return model, tokenizer


Expand All @@ -140,8 +156,8 @@ def init_distributed_mode():
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--report_to", type=str, default="")
parser.add_argument("--project_name", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
Expand All @@ -166,21 +182,19 @@ def init_distributed_mode():
torch.manual_seed(1337)
device_type = "cuda" if "cuda" in args.device else "cpu"

args.wandb_run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
args.run_name = f"MiniMind-Distill-Reasoning-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"

ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)

if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
args.device = DEVICE

wandb.init(project=args.wandb_project, name=args.wandb_run_name)
if not ddp or ddp_local_rank == 0:
tracker = init_tracker(args)
else:
wandb = None
tracker = None

model, tokenizer = init_model(lm_config)

Expand All @@ -205,4 +219,4 @@ def init_distributed_mode():

iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)
train_epoch(epoch, tracker)
56 changes: 35 additions & 21 deletions train_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ def Logger(content):
print(content)


def init_tracker(args):
if args.report_to == "wandb":
import wandb

wandb.init(project=args.project_name, run_name=args.run_name)
tracker = wandb
if args.report_to == "swanlab":
import swanlab

swanlab.init(project=args.project_name, run_name=args.run_name, config=args)
tracker = swanlab
else:
tracker = None
return tracker


def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

Expand All @@ -44,17 +60,17 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduct
return (temperature ** 2) * kl


def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
def train_epoch(epoch, tracker, alpha=0.0, temperature=1.0):
start_time = time.time()

if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)

for step, (X, Y, loss_mask) in enumerate(train_loader):
X = X.to(args.device)
Y = Y.to(args.device)
loss_mask = loss_mask.to(args.device)
X = X.to(DEVICE)
Y = Y.to(DEVICE)
loss_mask = loss_mask.to(DEVICE)
lr = get_lr(epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate)
Expand Down Expand Up @@ -95,7 +111,7 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
temperature=temperature
)
else:
distill_loss = torch.tensor(0.0, device=args.device)
distill_loss = torch.tensor(0.0, device=DEVICE)

# 3) 总损失 = alpha * CE + (1-alpha) * Distill
loss = alpha * ce_loss + (1 - alpha) * distill_loss
Expand Down Expand Up @@ -123,8 +139,8 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
)
)

if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({
if (tracker is not None) and (not ddp or dist.get_rank() == 0):
tracker.log({
"loss": loss.item(),
"ce_loss": ce_loss.item(),
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
Expand All @@ -149,10 +165,10 @@ def init_student_model(lm_config):
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
state_dict = torch.load(ckp, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
model = model.to(DEVICE)

return model, tokenizer

Expand All @@ -161,10 +177,10 @@ def init_teacher_model(lm_config):
model = MiniMindLM(lm_config)
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
state_dict = torch.load(ckp, map_location=args.device)
state_dict = torch.load(ckp, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
model = model.to(args.device)
model = model.to(DEVICE)
return model


Expand All @@ -188,8 +204,8 @@ def init_distributed_mode():
parser.add_argument("--learning_rate", type=float, default=5e-6)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--use_wandb", action="store_true")
parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--report_to", type=str, default="")
parser.add_argument("--project_name", type=str, default="MiniMind-Full-SFT")
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--ddp", action="store_true")
parser.add_argument("--accumulation_steps", type=int, default=1)
Expand All @@ -212,21 +228,19 @@ def init_distributed_mode():
torch.manual_seed(1337)
device_type = "cuda" if "cuda" in args.device else "cpu"

args.wandb_run_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
args.un_name = f"MiniMind-Dist-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"

ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)

if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
args.device = DEVICE

wandb.init(project=args.wandb_project, name=args.wandb_run_name)
if not ddp or ddp_local_rank == 0:
tracker = init_tracker(args)
else:
wandb = None
tracker = None

# 初始化学生模型和教师模型
model, tokenizer = init_student_model(lm_config_student)
Expand All @@ -253,4 +267,4 @@ def init_distributed_mode():

iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
train_epoch(epoch, wandb)
train_epoch(epoch, tracker)
Loading