Skip to content

Commit

Permalink
update train_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
cwx-worst-one committed Jan 22, 2025
1 parent 229765a commit 4830fe1
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions src/slam_llm/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,31 +109,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
with autocast():
outputs, *rest = model(**batch)
acc = rest[0] if rest else -1
audio_acc = rest[1] if rest else -1 # seven layers of audio acc
layer_loss = rest[2] if rest else -1 # eight layers of loss (seven audio and one text)
loss = outputs.loss

loss = loss / gradient_accumulation_steps
layer_loss = [l / gradient_accumulation_steps for l in layer_loss]
acc = acc / gradient_accumulation_steps
audio_acc = [a / gradient_accumulation_steps for a in audio_acc]

if log_config.use_wandb and step % log_config.log_interval == 0:
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step))
for layer, acc in enumerate(audio_acc):
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
for layer, l in enumerate(layer_loss[:-1]):
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
else:
wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step))
for layer, acc in enumerate(audio_acc):
wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step))
for layer, l in enumerate(layer_loss[:-1]):
wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step))
wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step))

total_loss += loss.detach().float()
total_acc += acc
Expand Down

0 comments on commit 4830fe1

Please sign in to comment.