diff --git a/src/slam_llm/utils/train_utils.py b/src/slam_llm/utils/train_utils.py index 849f7870..8f5c34e6 100644 --- a/src/slam_llm/utils/train_utils.py +++ b/src/slam_llm/utils/train_utils.py @@ -109,31 +109,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche with autocast(): outputs, *rest = model(**batch) acc = rest[0] if rest else -1 - audio_acc = rest[1] if rest else -1 # seven layers of audio acc - layer_loss = rest[2] if rest else -1 # eight layers of loss (seven audio and one text) loss = outputs.loss loss = loss / gradient_accumulation_steps - layer_loss = [l / gradient_accumulation_steps for l in layer_loss] acc = acc / gradient_accumulation_steps - audio_acc = [a / gradient_accumulation_steps for a in audio_acc] if log_config.use_wandb and step % log_config.log_interval == 0: if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: - wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_text_accuracy":acc}, step=(epoch * total_length + step)) - for layer, acc in enumerate(audio_acc): - wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step)) - for layer, l in enumerate(layer_loss[:-1]): - wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step)) - wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step)) + wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) else: wandb.log({"train_inner/train_inner_loss":loss, "train_inner/train_inner_accuracy":acc}, step=(epoch * total_length + step)) - for layer, acc in enumerate(audio_acc): - wandb.log({f"train_inner/train_inner_audio_accuracy_layer{layer}":acc}, step=(epoch * total_length + step)) - for layer, l in enumerate(layer_loss[:-1]): - wandb.log({f"train_inner/train_inner_audio_loss_layer{layer}":l}, step=(epoch * total_length + step)) - wandb.log({f"train_inner/train_inner_text_loss":layer_loss[-1]}, step=(epoch * total_length + step)) total_loss += loss.detach().float() total_acc += acc